vastdb 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/bench/test_perf.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import logging
2
2
  import time
3
3
 
4
- import pyarrow as pa
5
4
  import pytest
6
5
 
7
6
  from vastdb import util
@@ -20,7 +19,7 @@ def test_bench(session, clean_bucket_name, parquets_path, crater_path):
20
19
  t = util.create_table_from_files(s, 't1', files, config=ImportConfig(import_concurrency=8))
21
20
  config = QueryConfig(num_splits=8, num_sub_splits=4)
22
21
  s = time.time()
23
- pa_table = pa.Table.from_batches(t.select(columns=['sid'], predicate=t['sid'] == 10033007, config=config))
22
+ pa_table = t.select(columns=['sid'], predicate=t['sid'] == 10033007, config=config).read_all()
24
23
  e = time.time()
25
24
  log.info("'SELECT sid from TABLE WHERE sid = 10033007' returned in %s seconds.", e - s)
26
25
  if crater_path:
vastdb/bucket.py CHANGED
@@ -5,8 +5,8 @@ It is possible to list and access VAST snapshots generated over a bucket.
5
5
  """
6
6
 
7
7
  import logging
8
- from dataclasses import dataclass
9
- from typing import TYPE_CHECKING, List, Optional
8
+ from dataclasses import dataclass, field
9
+ from typing import TYPE_CHECKING, Iterable, Optional
10
10
 
11
11
  from . import errors, schema, transaction
12
12
 
@@ -22,48 +22,23 @@ class Bucket:
22
22
 
23
23
  name: str
24
24
  tx: "transaction.Transaction"
25
+ _root_schema: "Schema" = field(init=False, compare=False, repr=False)
25
26
 
26
- def create_schema(self, path: str, fail_if_exists=True) -> "Schema":
27
+ def __post_init__(self):
28
+ """Root schema is empty."""
29
+ self._root_schema = schema.Schema(name="", bucket=self)
30
+
31
+ def create_schema(self, name: str, fail_if_exists=True) -> "Schema":
27
32
  """Create a new schema (a container of tables) under this bucket."""
28
- if current := self.schema(path, fail_if_missing=False):
29
- if fail_if_exists:
30
- raise errors.SchemaExists(self.name, path)
31
- else:
32
- return current
33
- self.tx._rpc.api.create_schema(self.name, path, txid=self.tx.txid)
34
- log.info("Created schema: %s", path)
35
- return self.schema(path) # type: ignore[return-value]
33
+ return self._root_schema.create_schema(name=name, fail_if_exists=fail_if_exists)
36
34
 
37
- def schema(self, path: str, fail_if_missing=True) -> Optional["Schema"]:
35
+ def schema(self, name: str, fail_if_missing=True) -> Optional["Schema"]:
38
36
  """Get a specific schema (a container of tables) under this bucket."""
39
- s = self.schemas(path)
40
- log.debug("schema: %s", s)
41
- if not s:
42
- if fail_if_missing:
43
- raise errors.MissingSchema(self.name, path)
44
- else:
45
- return None
46
- assert len(s) == 1, f"Expected to receive only a single schema, but got: {len(s)}. ({s})"
47
- log.debug("Found schema: %s", s[0].name)
48
- return s[0]
37
+ return self._root_schema.schema(name=name, fail_if_missing=fail_if_missing)
49
38
 
50
- def schemas(self, name: Optional[str] = None) -> List["Schema"]:
39
+ def schemas(self, batch_size=None):
51
40
  """List bucket's schemas."""
52
- schemas = []
53
- next_key = 0
54
- exact_match = bool(name)
55
- log.debug("list schemas param: schema=%s, exact_match=%s", name, exact_match)
56
- while True:
57
- _bucket_name, curr_schemas, next_key, is_truncated, _ = \
58
- self.tx._rpc.api.list_schemas(bucket=self.name, next_key=next_key, txid=self.tx.txid,
59
- name_prefix=name, exact_match=exact_match)
60
- if not curr_schemas:
61
- break
62
- schemas.extend(curr_schemas)
63
- if not is_truncated:
64
- break
65
-
66
- return [schema.Schema(name=name, bucket=self) for name, *_ in schemas]
41
+ return self._root_schema.schemas(batch_size=batch_size)
67
42
 
68
43
  def snapshot(self, name, fail_if_missing=True) -> Optional["Bucket"]:
69
44
  """Get snapshot by name (if exists)."""
@@ -80,7 +55,7 @@ class Bucket:
80
55
 
81
56
  return Bucket(name=f'{self.name}/{expected_name}', tx=self.tx)
82
57
 
83
- def snapshots(self) -> List["Bucket"]:
58
+ def snapshots(self) -> Iterable["Bucket"]:
84
59
  """List bucket's snapshots."""
85
60
  snapshots = []
86
61
  next_key = 0
vastdb/conftest.py CHANGED
@@ -30,14 +30,23 @@ def test_bucket_name(request):
30
30
  return request.config.getoption("--tabular-bucket-name")
31
31
 
32
32
 
33
+ def iter_schemas(s):
34
+ """Recusively scan all schemas."""
35
+ children = s.schemas()
36
+ for c in children:
37
+ yield from iter_schemas(c)
38
+ yield s
39
+
40
+
33
41
  @pytest.fixture(scope="function")
34
42
  def clean_bucket_name(request, test_bucket_name, session):
35
43
  with session.transaction() as tx:
36
44
  b = tx.bucket(test_bucket_name)
37
- for s in b.schemas():
38
- for t in s.tables():
39
- t.drop()
40
- s.drop()
45
+ for top_schema in b.schemas():
46
+ for s in iter_schemas(top_schema):
47
+ for t in s.tables():
48
+ t.drop()
49
+ s.drop()
41
50
  return test_bucket_name
42
51
 
43
52
 
vastdb/errors.py CHANGED
@@ -89,7 +89,11 @@ class InvalidArgument(Exception):
89
89
  pass
90
90
 
91
91
 
92
- class TooWideRow(InvalidArgument):
92
+ class TooLargeRequest(InvalidArgument):
93
+ pass
94
+
95
+
96
+ class TooWideRow(TooLargeRequest):
93
97
  pass
94
98
 
95
99
 
vastdb/schema.py CHANGED
@@ -6,7 +6,7 @@ It is possible to list and access VAST snapshots generated over a bucket.
6
6
 
7
7
  import logging
8
8
  from dataclasses import dataclass
9
- from typing import TYPE_CHECKING, List, Optional
9
+ from typing import TYPE_CHECKING, Iterable, List, Optional
10
10
 
11
11
  import pyarrow as pa
12
12
 
@@ -31,14 +31,63 @@ class Schema:
31
31
  """VAST transaction used for this schema."""
32
32
  return self.bucket.tx
33
33
 
34
- def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True) -> "Table":
34
+ def _subschema_full_name(self, name: str) -> str:
35
+ return f"{self.name}/{name}" if self.name else name
36
+
37
+ def create_schema(self, name: str, fail_if_exists=True) -> "Schema":
38
+ """Create a new schema (a container of tables) under this schema."""
39
+ if current := self.schema(name, fail_if_missing=False):
40
+ if fail_if_exists:
41
+ raise errors.SchemaExists(self.bucket.name, name)
42
+ else:
43
+ return current
44
+ full_name = self._subschema_full_name(name)
45
+ self.tx._rpc.api.create_schema(self.bucket.name, full_name, txid=self.tx.txid)
46
+ log.info("Created schema: %s", full_name)
47
+ return self.schema(name) # type: ignore[return-value]
48
+
49
+ def schema(self, name: str, fail_if_missing=True) -> Optional["Schema"]:
50
+ """Get a specific schema (a container of tables) under this schema."""
51
+ _bucket_name, schemas, _next_key, _is_truncated, _ = \
52
+ self.tx._rpc.api.list_schemas(bucket=self.bucket.name, schema=self.name, next_key=0, txid=self.tx.txid,
53
+ name_prefix=name, exact_match=True, max_keys=1)
54
+ names = [name for name, *_ in schemas]
55
+ log.debug("Found schemas: %s", names)
56
+ if not names:
57
+ if fail_if_missing:
58
+ raise errors.MissingSchema(self.bucket.name, self._subschema_full_name(name))
59
+ else:
60
+ return None
61
+
62
+ assert len(names) == 1, f"Expected to receive only a single schema, but got {len(schemas)}: ({schemas})"
63
+ return schema.Schema(name=self._subschema_full_name(names[0]), bucket=self.bucket)
64
+
65
+ def schemas(self, batch_size=None) -> Iterable["Schema"]:
66
+ """List child schemas."""
67
+ next_key = 0
68
+ if not batch_size:
69
+ batch_size = 1000
70
+ result: List["Schema"] = []
71
+ while True:
72
+ _bucket_name, curr_schemas, next_key, is_truncated, _ = \
73
+ self.tx._rpc.api.list_schemas(bucket=self.bucket.name, schema=self.name, next_key=next_key, max_keys=batch_size, txid=self.tx.txid)
74
+ result.extend(schema.Schema(name=self._subschema_full_name(name), bucket=self.bucket) for name, *_ in curr_schemas)
75
+ if not is_truncated:
76
+ break
77
+ return result
78
+
79
+ def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True, use_external_row_ids_allocation=False) -> "Table":
35
80
  """Create a new table under this schema."""
36
81
  if current := self.table(table_name, fail_if_missing=False):
37
82
  if fail_if_exists:
38
83
  raise errors.TableExists(self.bucket.name, self.name, table_name)
39
84
  else:
40
85
  return current
41
- self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid)
86
+ if use_external_row_ids_allocation:
87
+ self.tx._rpc.features.check_external_row_ids_allocation()
88
+
89
+ self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid,
90
+ use_external_row_ids_allocation=use_external_row_ids_allocation)
42
91
  log.info("Created table: %s", table_name)
43
92
  return self.table(table_name) # type: ignore[return-value]
44
93
 
vastdb/session.py CHANGED
@@ -7,11 +7,16 @@ For more details see:
7
7
  - [Tabular identity policy with the proper permissions](https://support.vastdata.com/s/article/UUID-14322b60-d6a2-89ac-3df0-3dfbb6974182)
8
8
  """
9
9
 
10
+ import logging
10
11
  import os
12
+ from typing import Optional
11
13
 
12
14
  import boto3
13
15
 
14
- from . import errors, internal_commands, transaction
16
+ from . import _internal, errors, transaction
17
+ from ._internal import BackoffConfig
18
+
19
+ log = logging.getLogger()
15
20
 
16
21
 
17
22
  class Features:
@@ -21,21 +26,41 @@ class Features:
21
26
  """Save the server version."""
22
27
  self.vast_version = vast_version
23
28
 
24
- def check_imports_table(self):
25
- """Check if the feature that support imports table is supported."""
26
- if self.vast_version < (5, 2):
27
- raise errors.NotSupportedVersion("import_table requires 5.2+", self.vast_version)
29
+ self.check_imports_table = self._check(
30
+ "Imported objects' table feature requires 5.2+ VAST release",
31
+ vast_version >= (5, 2))
32
+
33
+ self.check_return_row_ids = self._check(
34
+ "Returning row IDs requires 5.1+ VAST release",
35
+ vast_version >= (5, 1))
36
+
37
+ self.check_enforce_semisorted_projection = self._check(
38
+ "Semi-sorted projection enforcement requires 5.1+ VAST release",
39
+ vast_version >= (5, 1))
40
+
41
+ self.check_external_row_ids_allocation = self._check(
42
+ "External row IDs allocation requires 5.1+ VAST release",
43
+ vast_version >= (5, 1))
44
+
45
+ def _check(self, msg, supported):
46
+ log.debug("%s (current version is %s): supported=%s", msg, self.vast_version, supported)
47
+ if not supported:
48
+ def fail():
49
+ raise errors.NotSupportedVersion(msg, self.vast_version)
50
+ return fail
28
51
 
29
- def check_return_row_ids(self):
30
- """Check if insert/update/delete can return the row_ids."""
31
- if self.vast_version < (5, 1):
32
- raise errors.NotSupportedVersion("return_row_ids requires 5.1+", self.vast_version)
52
+ def noop():
53
+ pass
54
+ return noop
33
55
 
34
56
 
35
57
  class Session:
36
58
  """VAST database session."""
37
59
 
38
- def __init__(self, access=None, secret=None, endpoint=None, ssl_verify=True):
60
+ def __init__(self, access=None, secret=None, endpoint=None,
61
+ *,
62
+ ssl_verify=True,
63
+ backoff_config: Optional[BackoffConfig] = None):
39
64
  """Connect to a VAST Database endpoint, using specified credentials."""
40
65
  if access is None:
41
66
  access = os.environ['AWS_ACCESS_KEY_ID']
@@ -44,9 +69,13 @@ class Session:
44
69
  if endpoint is None:
45
70
  endpoint = os.environ['AWS_S3_ENDPOINT_URL']
46
71
 
47
- self.api = internal_commands.VastdbApi(endpoint, access, secret, ssl_verify=ssl_verify)
48
- version_tuple = tuple(int(part) for part in self.api.vast_version.split('.'))
49
- self.features = Features(version_tuple)
72
+ self.api = _internal.VastdbApi(
73
+ endpoint=endpoint,
74
+ access_key=access,
75
+ secret_key=secret,
76
+ ssl_verify=ssl_verify,
77
+ backoff_config=backoff_config)
78
+ self.features = Features(self.api.vast_version)
50
79
  self.s3 = boto3.client('s3',
51
80
  aws_access_key_id=access,
52
81
  aws_secret_access_key=secret,
vastdb/table.py CHANGED
@@ -7,14 +7,14 @@ import queue
7
7
  from dataclasses import dataclass, field
8
8
  from math import ceil
9
9
  from threading import Event
10
- from typing import Any, Dict, List, Optional, Tuple, Union
10
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
11
11
 
12
12
  import backoff
13
13
  import ibis
14
14
  import pyarrow as pa
15
15
  import requests
16
16
 
17
- from . import errors, internal_commands, schema, util
17
+ from . import _internal, errors, schema, util
18
18
 
19
19
  log = logging.getLogger(__name__)
20
20
 
@@ -54,7 +54,8 @@ class QueryConfig:
54
54
  num_sub_splits: int = 4
55
55
 
56
56
  # used to split the table into disjoint subsets of rows, to be processed concurrently using multiple RPCs
57
- num_splits: int = 1
57
+ # will be estimated from the table's row count, if not explicitly set
58
+ num_splits: Optional[int] = None
58
59
 
59
60
  # each endpoint will be handled by a separate worker thread
60
61
  # a single endpoint can be specified more than once to benefit from multithreaded execution
@@ -64,19 +65,27 @@ class QueryConfig:
64
65
  limit_rows_per_sub_split: int = 128 * 1024
65
66
 
66
67
  # each fiber will read the following number of rowgroups coninuously before skipping
67
- # in order to use semi-sorted projections this value must be 8
68
+ # in order to use semi-sorted projections this value must be 8 (this is the hard coded size of a row groups per row block).
68
69
  num_row_groups_per_sub_split: int = 8
69
70
 
70
71
  # can be disabled for benchmarking purposes
71
72
  use_semi_sorted_projections: bool = True
72
73
 
74
+ # enforce using a specific semi-sorted projection (if enabled above)
75
+ semi_sorted_projection_name: Optional[str] = None
76
+
73
77
  # used to estimate the number of splits, given the table rows' count
74
78
  rows_per_split: int = 4000000
75
79
 
76
80
  # used for worker threads' naming
77
81
  query_id: str = ""
78
82
 
79
- # allows retrying QueryData when the server is overloaded
83
+ # non-negative integer, used for server-side prioritization of queued requests:
84
+ # - requests with lower values will be served before requests with higher values.
85
+ # - if unset, the request will be added to the queue's end.
86
+ queue_priority: Optional[int] = None
87
+
88
+ # DEPRECATED: will be removed in a future release
80
89
  backoff_func: Any = field(default=backoff.on_exception(backoff.expo, RETRIABLE_ERRORS, max_tries=10))
81
90
 
82
91
 
@@ -98,14 +107,13 @@ class SelectSplitState:
98
107
  self.query_data_request = query_data_request
99
108
  self.table = table
100
109
 
101
- def batches(self, api: internal_commands.VastdbApi):
110
+ def batches(self, api: _internal.VastdbApi):
102
111
  """Execute QueryData request, and yield parsed RecordBatch objects.
103
112
 
104
113
  Can be called repeatedly, to allow pagination.
105
114
  """
106
115
  while not self.done:
107
- query_with_backoff = self.config.backoff_func(api.query_data)
108
- response = query_with_backoff(
116
+ response = api.query_data(
109
117
  bucket=self.table.bucket.name,
110
118
  schema=self.table.schema.name,
111
119
  table=self.table.name,
@@ -116,9 +124,11 @@ class SelectSplitState:
116
124
  txid=self.table.tx.txid,
117
125
  limit_rows=self.config.limit_rows_per_sub_split,
118
126
  sub_split_start_row_ids=self.subsplits_state.items(),
127
+ schedule_id=self.config.queue_priority,
119
128
  enable_sorted_projections=self.config.use_semi_sorted_projections,
120
- query_imports_table=self.table._imports_table)
121
- pages_iter = internal_commands.parse_query_data_response(
129
+ query_imports_table=self.table._imports_table,
130
+ projection=self.config.semi_sorted_projection_name)
131
+ pages_iter = _internal.parse_query_data_response(
122
132
  conn=response.raw,
123
133
  schema=self.query_data_request.response_schema,
124
134
  start_row_ids=self.subsplits_state,
@@ -132,7 +142,7 @@ class SelectSplitState:
132
142
  @property
133
143
  def done(self):
134
144
  """Returns true iff the pagination over."""
135
- return all(row_id == internal_commands.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
145
+ return all(row_id == _internal.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
136
146
 
137
147
 
138
148
  @dataclass
@@ -182,14 +192,14 @@ class Table:
182
192
  """Get a specific semi-sorted projection of this table."""
183
193
  if self._imports_table:
184
194
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
185
- projs = self.projections(projection_name=name)
195
+ projs = tuple(self.projections(projection_name=name))
186
196
  if not projs:
187
197
  raise errors.MissingProjection(self.bucket.name, self.schema.name, self.name, name)
188
198
  assert len(projs) == 1, f"Expected to receive only a single projection, but got: {len(projs)}. projections: {projs}"
189
199
  log.debug("Found projection: %s", projs[0])
190
200
  return projs[0]
191
201
 
192
- def projections(self, projection_name=None) -> List["Projection"]:
202
+ def projections(self, projection_name=None) -> Iterable["Projection"]:
193
203
  """List all semi-sorted projections of this table."""
194
204
  if self._imports_table:
195
205
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
@@ -209,7 +219,7 @@ class Table:
209
219
  break
210
220
  return [_parse_projection_info(projection, self) for projection in projections]
211
221
 
212
- def import_files(self, files_to_import: List[str], config: Optional[ImportConfig] = None) -> None:
222
+ def import_files(self, files_to_import: Iterable[str], config: Optional[ImportConfig] = None) -> None:
213
223
  """Import a list of Parquet files into this table.
214
224
 
215
225
  The files must be on VAST S3 server and be accessible using current credentials.
@@ -278,7 +288,7 @@ class Table:
278
288
  max_workers=config.import_concurrency, thread_name_prefix='import_thread') as pool:
279
289
  try:
280
290
  for endpoint in endpoints:
281
- session = internal_commands.VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
291
+ session = _internal.VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
282
292
  futures.append(pool.submit(import_worker, files_queue, session))
283
293
 
284
294
  log.debug("Waiting for import workers to finish")
@@ -313,11 +323,16 @@ class Table:
313
323
 
314
324
  # Take a snapshot of enpoints
315
325
  stats = self.get_stats()
326
+ log.debug("stats: %s", stats)
316
327
  endpoints = stats.endpoints if config.data_endpoints is None else config.data_endpoints
328
+ log.debug("endpoints: %s", endpoints)
329
+
330
+ if config.num_splits is None:
331
+ config.num_splits = max(1, stats.num_rows // config.rows_per_split)
332
+ log.debug("config: %s", config)
317
333
 
318
- if stats.num_rows > config.rows_per_split and config.num_splits is None:
319
- config.num_splits = stats.num_rows // config.rows_per_split
320
- log.debug(f"num_rows={stats.num_rows} rows_per_splits={config.rows_per_split} num_splits={config.num_splits} ")
334
+ if config.semi_sorted_projection_name:
335
+ self.tx._rpc.features.check_enforce_semisorted_projection()
321
336
 
322
337
  if columns is None:
323
338
  columns = [f.name for f in self.arrow_schema]
@@ -332,16 +347,18 @@ class Table:
332
347
  if predicate is True:
333
348
  predicate = None
334
349
  if predicate is False:
335
- response_schema = internal_commands.get_response_schema(schema=query_schema, field_names=columns)
350
+ response_schema = _internal.get_response_schema(schema=query_schema, field_names=columns)
336
351
  return pa.RecordBatchReader.from_batches(response_schema, [])
337
352
 
338
353
  if isinstance(predicate, ibis.common.deferred.Deferred):
339
354
  predicate = predicate.resolve(self._ibis_table) # may raise if the predicate is invalid (e.g. wrong types / missing column)
340
355
 
341
- query_data_request = internal_commands.build_query_data_request(
356
+ query_data_request = _internal.build_query_data_request(
342
357
  schema=query_schema,
343
358
  predicate=predicate,
344
359
  field_names=columns)
360
+ if len(query_data_request.serialized) > util.MAX_QUERY_DATA_REQUEST_SIZE:
361
+ raise errors.TooLargeRequest(f"{len(query_data_request.serialized)} bytes")
345
362
 
346
363
  splits_queue: queue.Queue[int] = queue.Queue()
347
364
 
@@ -364,7 +381,7 @@ class Table:
364
381
 
365
382
  def single_endpoint_worker(endpoint: str):
366
383
  try:
367
- host_api = internal_commands.VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
384
+ host_api = _internal.VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
368
385
  while True:
369
386
  check_stop()
370
387
  try:
@@ -461,7 +478,7 @@ class Table:
461
478
  for slice in serialized_slices:
462
479
  res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
463
480
  txid=self.tx.txid)
464
- (batch,) = pa.RecordBatchStreamReader(res.raw)
481
+ (batch,) = pa.RecordBatchStreamReader(res.content)
465
482
  row_ids.append(batch[INTERNAL_ROW_ID])
466
483
  try:
467
484
  self.tx._rpc.features.check_return_row_ids()
@@ -497,6 +514,8 @@ class Table:
497
514
  else:
498
515
  update_rows_rb = rows
499
516
 
517
+ update_rows_rb = util.sort_record_batch_if_needed(update_rows_rb, INTERNAL_ROW_ID)
518
+
500
519
  serialized_slices = util.iter_serialized_slices(update_rows_rb, MAX_ROWS_PER_BATCH)
501
520
  for slice in serialized_slices:
502
521
  self.tx._rpc.api.update_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
@@ -516,6 +535,8 @@ class Table:
516
535
  delete_rows_rb = pa.record_batch(schema=pa.schema([(INTERNAL_ROW_ID, pa.uint64())]),
517
536
  data=[_combine_chunks(rows_chunk)])
518
537
 
538
+ delete_rows_rb = util.sort_record_batch_if_needed(delete_rows_rb, INTERNAL_ROW_ID)
539
+
519
540
  serialized_slices = util.iter_serialized_slices(delete_rows_rb, MAX_ROWS_PER_BATCH)
520
541
  for slice in serialized_slices:
521
542
  self.tx._rpc.api.delete_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
@@ -581,7 +602,7 @@ class Table:
581
602
  return self.imports_table() # type: ignore[return-value]
582
603
 
583
604
  def imports_table(self) -> Optional["Table"]:
584
- """Get the imports table under of this table."""
605
+ """Get the imports table of this table."""
585
606
  self.tx._rpc.features.check_imports_table()
586
607
  return Table(name=self.name, schema=self.schema, handle=int(self.handle), stats=self.stats, _imports_table=True)
587
608
 
@@ -56,6 +56,6 @@ def test_closed_tx(session, clean_bucket_name):
56
56
  res = conn.execute('SELECT a FROM batches')
57
57
  log.debug("closing tx=%s after first batch=%s", t.tx, first)
58
58
 
59
- # transaction is closed, collecting the result should fail
60
- with pytest.raises(duckdb.InvalidInputException, match="Detail: Python exception: MissingTransaction"):
59
+ # transaction is closed, collecting the result should fail internally in DuckDB
60
+ with pytest.raises(duckdb.InvalidInputException):
61
61
  res.arrow()
@@ -38,13 +38,13 @@ def test_parallel_imports(session, clean_bucket_name, s3):
38
38
  t.create_imports_table()
39
39
  log.info("Starting import of %d files", num_files)
40
40
  t.import_files(files)
41
- arrow_table = pa.Table.from_batches(t.select(columns=['num']))
41
+ arrow_table = t.select(columns=['num']).read_all()
42
42
  assert arrow_table.num_rows == num_rows * num_files
43
- arrow_table = pa.Table.from_batches(t.select(columns=['num'], predicate=t['num'] == 100))
43
+ arrow_table = t.select(columns=['num'], predicate=t['num'] == 100).read_all()
44
44
  assert arrow_table.num_rows == num_files
45
45
  import_table = t.imports_table()
46
46
  # checking all imports are on the imports table:
47
- objects_name = pa.Table.from_batches(import_table.select(columns=["ObjectName"]))
47
+ objects_name = import_table.select(columns=["ObjectName"]).read_all()
48
48
  objects_name = objects_name.to_pydict()
49
49
  object_names = set(objects_name['ObjectName'])
50
50
  prefix = 'prq'
@@ -22,13 +22,13 @@ def test_nested_select(session, clean_bucket_name):
22
22
  ])
23
23
 
24
24
  with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
25
- actual = pa.Table.from_batches(t.select())
25
+ actual = t.select().read_all()
26
26
  assert actual == expected
27
27
 
28
28
  names = [f.name for f in columns]
29
29
  for n in range(len(names) + 1):
30
30
  for cols in itertools.permutations(names, n):
31
- actual = pa.Table.from_batches(t.select(columns=cols))
31
+ actual = t.select(columns=cols).read_all()
32
32
  assert actual == expected.select(cols)
33
33
 
34
34
 
@@ -53,7 +53,7 @@ def test_nested_filter(session, clean_bucket_name):
53
53
  ])
54
54
 
55
55
  with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
56
- actual = pa.Table.from_batches(t.select())
56
+ actual = t.select().read_all()
57
57
  assert actual == expected
58
58
 
59
59
  names = list('xyzw')
@@ -62,7 +62,7 @@ def test_nested_filter(session, clean_bucket_name):
62
62
  ibis_predicate = functools.reduce(
63
63
  operator.and_,
64
64
  (t[col] > 2 for col in cols))
65
- actual = pa.Table.from_batches(t.select(predicate=ibis_predicate), t.arrow_schema)
65
+ actual = t.select(predicate=ibis_predicate).read_all()
66
66
 
67
67
  arrow_predicate = functools.reduce(
68
68
  operator.and_,
@@ -1,7 +1,10 @@
1
1
  import logging
2
+ import time
2
3
 
3
4
  import pyarrow as pa
4
5
 
6
+ from vastdb.table import QueryConfig
7
+
5
8
  log = logging.getLogger(__name__)
6
9
 
7
10
 
@@ -41,3 +44,78 @@ def test_basic_projections(session, clean_bucket_name):
41
44
  projs = t.projections()
42
45
  assert len(projs) == 1
43
46
  assert projs[0].name == 'p_new'
47
+
48
+
49
+ def test_query_data_with_projection(session, clean_bucket_name):
50
+ columns = pa.schema([
51
+ ('a', pa.int64()),
52
+ ('b', pa.int64()),
53
+ ('s', pa.utf8()),
54
+ ])
55
+ # need to be large enough in order to consider as projection
56
+
57
+ GROUP_SIZE = 128 * 1024
58
+ expected = pa.table(schema=columns, data=[
59
+ [i for i in range(GROUP_SIZE)],
60
+ [i for i in reversed(range(GROUP_SIZE))],
61
+ [f's{i}' for i in range(GROUP_SIZE)],
62
+ ])
63
+
64
+ expected_projection_p1 = pa.table(schema=columns, data=[
65
+ [i for i in reversed(range(GROUP_SIZE - 5, GROUP_SIZE))],
66
+ [i for i in range(5)],
67
+ [f's{i}' for i in reversed(range(GROUP_SIZE - 5, GROUP_SIZE))],
68
+ ])
69
+
70
+ expected_projection_p2 = pa.table(schema=columns, data=[
71
+ [i for i in range(GROUP_SIZE - 5, GROUP_SIZE)],
72
+ [i for i in reversed(range(5))],
73
+ [f's{i}' for i in range(GROUP_SIZE - 5, GROUP_SIZE)],
74
+ ])
75
+
76
+ schema_name = "schema"
77
+ table_name = "table"
78
+ with session.transaction() as tx:
79
+ s = tx.bucket(clean_bucket_name).create_schema(schema_name)
80
+ t = s.create_table(table_name, expected.schema)
81
+
82
+ sorted_columns = ['b']
83
+ unsorted_columns = ['a', 's']
84
+ t.create_projection('p1', sorted_columns, unsorted_columns)
85
+
86
+ sorted_columns = ['a']
87
+ unsorted_columns = ['b', 's']
88
+ t.create_projection('p2', sorted_columns, unsorted_columns)
89
+
90
+ with session.transaction() as tx:
91
+ s = tx.bucket(clean_bucket_name).schema(schema_name)
92
+ t = s.table(table_name)
93
+ t.insert(expected)
94
+ actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's']))
95
+ assert actual == expected
96
+
97
+ time.sleep(3)
98
+
99
+ with session.transaction() as tx:
100
+ config = QueryConfig()
101
+ # in nfs mock server num row groups per row block is 1 so need to change this in the config
102
+ config.num_row_groups_per_sub_split = 1
103
+
104
+ s = tx.bucket(clean_bucket_name).schema(schema_name)
105
+ t = s.table(table_name)
106
+ projection_actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's'], predicate=(t['b'] < 5), config=config))
107
+ # no projection supply - need to be with p1 projeciton
108
+ assert expected_projection_p1 == projection_actual
109
+
110
+ config.semi_sorted_projection_name = 'p1'
111
+ projection_actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's'], predicate=(t['b'] < 5), config=config))
112
+ # expecting results of projection p1 since we asked it specificaly
113
+ assert expected_projection_p1 == projection_actual
114
+
115
+ config.semi_sorted_projection_name = 'p2'
116
+ projection_actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's'], predicate=(t['b'] < 5), config=config))
117
+ # expecting results of projection p2 since we asked it specificaly
118
+ assert expected_projection_p2 == projection_actual
119
+
120
+ t.drop()
121
+ s.drop()