vastdb 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/schema.py CHANGED
@@ -6,11 +6,16 @@ It is possible to list and access VAST snapshots generated over a bucket.
6
6
 
7
7
  import logging
8
8
  from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, List, Optional
9
10
 
10
11
  import pyarrow as pa
11
12
 
12
13
  from . import bucket, errors, schema, table
13
14
 
15
+ if TYPE_CHECKING:
16
+ from .table import Table
17
+
18
+
14
19
  log = logging.getLogger(__name__)
15
20
 
16
21
 
@@ -26,7 +31,7 @@ class Schema:
26
31
  """VAST transaction used for this schema."""
27
32
  return self.bucket.tx
28
33
 
29
- def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True) -> "table.Table":
34
+ def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True) -> "Table":
30
35
  """Create a new table under this schema."""
31
36
  if current := self.table(table_name, fail_if_missing=False):
32
37
  if fail_if_exists:
@@ -35,9 +40,9 @@ class Schema:
35
40
  return current
36
41
  self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid)
37
42
  log.info("Created table: %s", table_name)
38
- return self.table(table_name)
43
+ return self.table(table_name) # type: ignore[return-value]
39
44
 
40
- def table(self, name: str, fail_if_missing=True) -> "table.Table":
45
+ def table(self, name: str, fail_if_missing=True) -> Optional["table.Table"]:
41
46
  """Get a specific table under this schema."""
42
47
  t = self.tables(table_name=name)
43
48
  if not t:
@@ -49,14 +54,14 @@ class Schema:
49
54
  log.debug("Found table: %s", t[0])
50
55
  return t[0]
51
56
 
52
- def tables(self, table_name=None) -> ["table.Table"]:
57
+ def tables(self, table_name=None) -> List["Table"]:
53
58
  """List all tables under this schema."""
54
59
  tables = []
55
60
  next_key = 0
56
61
  name_prefix = table_name if table_name else ""
57
62
  exact_match = bool(table_name)
58
63
  while True:
59
- bucket_name, schema_name, curr_tables, next_key, is_truncated, _ = \
64
+ _bucket_name, _schema_name, curr_tables, next_key, is_truncated, _ = \
60
65
  self.tx._rpc.api.list_tables(
61
66
  bucket=self.bucket.name, schema=self.name, next_key=next_key, txid=self.tx.txid,
62
67
  exact_match=exact_match, name_prefix=name_prefix, include_list_stats=exact_match)
vastdb/table.py CHANGED
@@ -5,18 +5,12 @@ import queue
5
5
  from dataclasses import dataclass, field
6
6
  from math import ceil
7
7
  from threading import Event
8
- from typing import List, Union
8
+ from typing import Dict, List, Optional, Tuple, Union
9
9
 
10
10
  import ibis
11
11
  import pyarrow as pa
12
12
 
13
- from . import errors, schema
14
- from .internal_commands import (
15
- TABULAR_INVALID_ROW_ID,
16
- VastdbApi,
17
- build_query_data_request,
18
- parse_query_data_response,
19
- )
13
+ from . import errors, internal_commands, schema
20
14
 
21
15
  log = logging.getLogger(__name__)
22
16
 
@@ -27,18 +21,20 @@ MAX_ROWS_PER_BATCH = 512 * 1024
27
21
  # for example insert of 512k uint8 result in 512k*8bytes response since row_ids are uint64
28
22
  MAX_INSERT_ROWS_PER_PATCH = 512 * 1024
29
23
 
24
+
30
25
  @dataclass
31
26
  class TableStats:
32
27
  num_rows: int
33
28
  size_in_bytes: int
34
29
  is_external_rowid_alloc: bool = False
35
- endpoints: List[str] = None
30
+ endpoints: Tuple[str, ...] = ()
31
+
36
32
 
37
33
  @dataclass
38
34
  class QueryConfig:
39
35
  num_sub_splits: int = 4
40
36
  num_splits: int = 1
41
- data_endpoints: [str] = None
37
+ data_endpoints: Optional[List[str]] = None
42
38
  limit_rows_per_sub_split: int = 128 * 1024
43
39
  num_row_groups_per_sub_split: int = 8
44
40
  use_semi_sorted_projections: bool = True
@@ -50,15 +46,16 @@ class QueryConfig:
50
46
  class ImportConfig:
51
47
  import_concurrency: int = 2
52
48
 
49
+
53
50
  class SelectSplitState():
54
- def __init__(self, query_data_request, table : "Table", split_id : int, config: QueryConfig) -> None:
51
+ def __init__(self, query_data_request, table: "Table", split_id: int, config: QueryConfig) -> None:
55
52
  self.split_id = split_id
56
53
  self.subsplits_state = {i: 0 for i in range(config.num_sub_splits)}
57
54
  self.config = config
58
55
  self.query_data_request = query_data_request
59
56
  self.table = table
60
57
 
61
- def batches(self, api : VastdbApi):
58
+ def batches(self, api: internal_commands.VastdbApi):
62
59
  while not self.done:
63
60
  response = api.query_data(
64
61
  bucket=self.table.bucket.name,
@@ -72,20 +69,21 @@ class SelectSplitState():
72
69
  limit_rows=self.config.limit_rows_per_sub_split,
73
70
  sub_split_start_row_ids=self.subsplits_state.items(),
74
71
  enable_sorted_projections=self.config.use_semi_sorted_projections)
75
- pages_iter = parse_query_data_response(
72
+ pages_iter = internal_commands.parse_query_data_response(
76
73
  conn=response.raw,
77
74
  schema=self.query_data_request.response_schema,
78
- start_row_ids=self.subsplits_state)
75
+ start_row_ids=self.subsplits_state,
76
+ parser=self.query_data_request.response_parser)
79
77
 
80
78
  for page in pages_iter:
81
79
  for batch in page.to_batches():
82
80
  if len(batch) > 0:
83
81
  yield batch
84
82
 
85
-
86
83
  @property
87
84
  def done(self):
88
- return all(row_id == TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
85
+ return all(row_id == internal_commands.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
86
+
89
87
 
90
88
  @dataclass
91
89
  class Table:
@@ -93,12 +91,10 @@ class Table:
93
91
  schema: "schema.Schema"
94
92
  handle: int
95
93
  stats: TableStats
96
- properties: dict = None
97
94
  arrow_schema: pa.Schema = field(init=False, compare=False)
98
95
  _ibis_table: ibis.Schema = field(init=False, compare=False)
99
96
 
100
97
  def __post_init__(self):
101
- self.properties = self.properties or {}
102
98
  self.arrow_schema = self.columns()
103
99
 
104
100
  table_path = f'{self.schema.bucket.name}/{self.schema.name}/{self.name}'
@@ -136,13 +132,13 @@ class Table:
136
132
  log.debug("Found projection: %s", projs[0])
137
133
  return projs[0]
138
134
 
139
- def projections(self, projection_name=None) -> ["Projection"]:
135
+ def projections(self, projection_name=None) -> List["Projection"]:
140
136
  projections = []
141
137
  next_key = 0
142
138
  name_prefix = projection_name if projection_name else ""
143
139
  exact_match = bool(projection_name)
144
140
  while True:
145
- bucket_name, schema_name, table_name, curr_projections, next_key, is_truncated, _ = \
141
+ _bucket_name, _schema_name, _table_name, curr_projections, next_key, is_truncated, _ = \
146
142
  self.tx._rpc.api.list_projections(
147
143
  bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid,
148
144
  exact_match=exact_match, name_prefix=name_prefix)
@@ -153,7 +149,7 @@ class Table:
153
149
  break
154
150
  return [_parse_projection_info(projection, self) for projection in projections]
155
151
 
156
- def import_files(self, files_to_import: [str], config: ImportConfig = None) -> None:
152
+ def import_files(self, files_to_import: List[str], config: Optional[ImportConfig] = None) -> None:
157
153
  source_files = {}
158
154
  for f in files_to_import:
159
155
  bucket_name, object_path = _parse_bucket_and_object_names(f)
@@ -161,7 +157,7 @@ class Table:
161
157
 
162
158
  self._execute_import(source_files, config=config)
163
159
 
164
- def import_partitioned_files(self, files_and_partitions: {str: pa.RecordBatch}, config: ImportConfig = None) -> None:
160
+ def import_partitioned_files(self, files_and_partitions: Dict[str, pa.RecordBatch], config: Optional[ImportConfig] = None) -> None:
165
161
  source_files = {}
166
162
  for f, record_batch in files_and_partitions.items():
167
163
  bucket_name, object_path = _parse_bucket_and_object_names(f)
@@ -209,7 +205,7 @@ class Table:
209
205
  max_workers=config.import_concurrency, thread_name_prefix='import_thread') as pool:
210
206
  try:
211
207
  for endpoint in endpoints:
212
- session = VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
208
+ session = internal_commands.VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
213
209
  futures.append(pool.submit(import_worker, files_queue, session))
214
210
 
215
211
  log.debug("Waiting for import workers to finish")
@@ -218,24 +214,30 @@ class Table:
218
214
  finally:
219
215
  stop_event.set()
220
216
  # ThreadPoolExecutor will be joined at the end of the context
221
- def refresh_stats(self):
217
+
218
+ def get_stats(self) -> TableStats:
222
219
  stats_tuple = self.tx._rpc.api.get_table_stats(
223
220
  bucket=self.bucket.name, schema=self.schema.name, name=self.name, txid=self.tx.txid)
224
- self.stats = TableStats(**stats_tuple._asdict())
221
+ return TableStats(**stats_tuple._asdict())
225
222
 
226
- def select(self, columns: [str] = None,
223
+ def select(self, columns: Optional[List[str]] = None,
227
224
  predicate: ibis.expr.types.BooleanColumn = None,
228
- config: QueryConfig = None,
225
+ config: Optional[QueryConfig] = None,
229
226
  *,
230
227
  internal_row_id: bool = False) -> pa.RecordBatchReader:
231
228
  if config is None:
232
229
  config = QueryConfig()
233
230
 
234
- self.refresh_stats()
231
+ # Take a snapshot of enpoints
232
+ stats = self.get_stats()
233
+ endpoints = stats.endpoints if config.data_endpoints is None else config.data_endpoints
235
234
 
236
- if self.stats.num_rows > config.rows_per_split and config.num_splits is None:
237
- config.num_splits = self.stats.num_rows // config.rows_per_split
238
- log.debug(f"num_rows={self.stats.num_rows} rows_per_splits={config.rows_per_split} num_splits={config.num_splits} ")
235
+ if stats.num_rows > config.rows_per_split and config.num_splits is None:
236
+ config.num_splits = stats.num_rows // config.rows_per_split
237
+ log.debug(f"num_rows={stats.num_rows} rows_per_splits={config.rows_per_split} num_splits={config.num_splits} ")
238
+
239
+ if columns is None:
240
+ columns = [f.name for f in self.arrow_schema]
239
241
 
240
242
  query_schema = self.arrow_schema
241
243
  if internal_row_id:
@@ -244,12 +246,12 @@ class Table:
244
246
  query_schema = pa.schema(queried_fields)
245
247
  columns.append(INTERNAL_ROW_ID)
246
248
 
247
- query_data_request = build_query_data_request(
249
+ query_data_request = internal_commands.build_query_data_request(
248
250
  schema=query_schema,
249
251
  predicate=predicate,
250
252
  field_names=columns)
251
253
 
252
- splits_queue = queue.Queue()
254
+ splits_queue: queue.Queue[int] = queue.Queue()
253
255
 
254
256
  for split in range(config.num_splits):
255
257
  splits_queue.put(split)
@@ -257,8 +259,10 @@ class Table:
257
259
  # this queue shouldn't be large it is marely a pipe through which the results
258
260
  # are sent to the main thread. Most of the pages actually held in the
259
261
  # threads that fetch the pages.
260
- record_batches_queue = queue.Queue(maxsize=2)
262
+ record_batches_queue: queue.Queue[pa.RecordBatch] = queue.Queue(maxsize=2)
263
+
261
264
  stop_event = Event()
265
+
262
266
  class StoppedException(Exception):
263
267
  pass
264
268
 
@@ -266,9 +270,9 @@ class Table:
266
270
  if stop_event.is_set():
267
271
  raise StoppedException
268
272
 
269
- def single_endpoint_worker(endpoint : str):
273
+ def single_endpoint_worker(endpoint: str):
270
274
  try:
271
- host_api = VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
275
+ host_api = internal_commands.VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
272
276
  while True:
273
277
  check_stop()
274
278
  try:
@@ -293,12 +297,11 @@ class Table:
293
297
  log.debug("exiting")
294
298
  record_batches_queue.put(None)
295
299
 
296
- # Take a snapshot of enpoints
297
- endpoints = list(self.stats.endpoints) if config.data_endpoints is None else list(config.data_endpoints)
298
-
299
300
  def batches_iterator():
300
- def propagate_first_exception(futures : List[concurrent.futures.Future], block = False):
301
+ def propagate_first_exception(futures: List[concurrent.futures.Future], block=False):
301
302
  done, not_done = concurrent.futures.wait(futures, None if block else 0, concurrent.futures.FIRST_EXCEPTION)
303
+ if self.tx.txid is None:
304
+ raise errors.MissingTransaction()
302
305
  for future in done:
303
306
  future.result()
304
307
  return not_done
@@ -308,7 +311,7 @@ class Table:
308
311
  if config.query_id:
309
312
  threads_prefix = threads_prefix + "-" + config.query_id
310
313
 
311
- with concurrent.futures.ThreadPoolExecutor(max_workers=len(endpoints), thread_name_prefix=threads_prefix) as tp: # TODO: concurrency == enpoints is just a heuristic
314
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(endpoints), thread_name_prefix=threads_prefix) as tp: # TODO: concurrency == enpoints is just a heuristic
312
315
  futures = [tp.submit(single_endpoint_worker, endpoint) for endpoint in endpoints]
313
316
  tasks_running = len(futures)
314
317
  try:
@@ -340,16 +343,11 @@ class Table:
340
343
 
341
344
  def insert(self, rows: pa.RecordBatch) -> pa.RecordBatch:
342
345
  serialized_slices = self.tx._rpc.api._record_batch_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
343
- row_ids = []
344
346
  for slice in serialized_slices:
345
- res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
347
+ self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
346
348
  txid=self.tx.txid)
347
- (batch,) = pa.RecordBatchStreamReader(res.raw)
348
- row_ids.append(batch[INTERNAL_ROW_ID])
349
-
350
- return pa.chunked_array(row_ids)
351
349
 
352
- def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: list = None) -> None:
350
+ def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: Optional[List[str]] = None) -> None:
353
351
  if columns is not None:
354
352
  update_fields = [(INTERNAL_ROW_ID, pa.uint64())]
355
353
  update_values = [self._combine_chunks(rows[INTERNAL_ROW_ID])]
@@ -417,7 +415,6 @@ class Projection:
417
415
  table: Table
418
416
  handle: int
419
417
  stats: TableStats
420
- properties: dict = None
421
418
 
422
419
  @property
423
420
  def bucket(self):
@@ -438,7 +435,7 @@ class Projection:
438
435
  columns = []
439
436
  next_key = 0
440
437
  while True:
441
- curr_columns, next_key, is_truncated, count, _ = \
438
+ curr_columns, next_key, is_truncated, _count, _ = \
442
439
  self.tx._rpc.api.list_projection_columns(
443
440
  self.bucket.name, self.schema.name, self.table.name, self.name, txid=self.table.tx.txid, next_key=next_key)
444
441
  if not curr_columns:
@@ -467,9 +464,9 @@ def _parse_projection_info(projection_info, table: "Table"):
467
464
  return Projection(name=projection_info.name, table=table, stats=stats, handle=int(projection_info.handle))
468
465
 
469
466
 
470
- def _parse_bucket_and_object_names(path: str) -> (str, str):
467
+ def _parse_bucket_and_object_names(path: str) -> Tuple[str, str]:
471
468
  if not path.startswith('/'):
472
- raise errors.InvalidArgumentError(f"Path {path} must start with a '/'")
469
+ raise errors.InvalidArgument(f"Path {path} must start with a '/'")
473
470
  components = path.split(os.path.sep)
474
471
  bucket_name = components[1]
475
472
  object_path = os.path.sep.join(components[2:])
@@ -0,0 +1,61 @@
1
+ import logging
2
+
3
+ import duckdb
4
+ import pyarrow as pa
5
+ import pyarrow.compute as pc
6
+ import pytest
7
+
8
+ from ..table import QueryConfig
9
+ from .util import prepare_data
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def test_duckdb(session, clean_bucket_name):
15
+ columns = pa.schema([
16
+ ('a', pa.int32()),
17
+ ('b', pa.float64()),
18
+ ])
19
+ data = pa.table(schema=columns, data=[
20
+ [111, 222, 333],
21
+ [0.5, 1.5, 2.5],
22
+ ])
23
+ with prepare_data(session, clean_bucket_name, 's', 't', data) as t:
24
+ conn = duckdb.connect()
25
+ batches = t.select(columns=['a'], predicate=(t['b'] < 2)) # noqa: F841
26
+ actual = conn.execute('SELECT max(a) as "a_max" FROM batches').arrow()
27
+ expected = (data
28
+ .filter(pc.field('b') < 2)
29
+ .group_by([])
30
+ .aggregate([('a', 'max')]))
31
+ assert actual == expected
32
+
33
+
34
+ def test_closed_tx(session, clean_bucket_name):
35
+ columns = pa.schema([
36
+ ('a', pa.int64()),
37
+ ])
38
+ data = pa.table(schema=columns, data=[
39
+ list(range(10000)),
40
+ ])
41
+
42
+ with session.transaction() as tx:
43
+ t = tx.bucket(clean_bucket_name).create_schema("s1").create_table("t1", columns)
44
+ t.insert(data)
45
+
46
+ config = QueryConfig(
47
+ num_sub_splits=1,
48
+ num_splits=1,
49
+ num_row_groups_per_sub_split=1,
50
+ limit_rows_per_sub_split=100)
51
+ batches = t.select(config=config) # noqa: F841
52
+ first = next(batches) # make sure that HTTP response processing has started
53
+ assert first['a'].to_pylist() == list(range(100))
54
+
55
+ conn = duckdb.connect()
56
+ res = conn.execute('SELECT a FROM batches')
57
+ log.debug("closing tx=%s after first batch=%s", t.tx, first)
58
+
59
+ # transaction is closed, collecting the result should fail
60
+ with pytest.raises(duckdb.InvalidInputException, match="Detail: Python exception: MissingTransaction"):
61
+ res.arrow()
@@ -4,6 +4,7 @@ import pyarrow as pa
4
4
 
5
5
  log = logging.getLogger(__name__)
6
6
 
7
+
7
8
  def test_basic_projections(session, clean_bucket_name):
8
9
  with session.transaction() as tx:
9
10
  s = tx.bucket(clean_bucket_name).create_schema('s1')
@@ -57,10 +57,10 @@ def test_version_extraction():
57
57
  return f"vast {version}" if version else "vast"
58
58
 
59
59
  def log_message(self, format, *args):
60
- log.debug(format,*args)
60
+ log.debug(format, *args)
61
61
 
62
62
  # start the server on localhost on some available port port
63
- server_address =('localhost', 0)
63
+ server_address = ('localhost', 0)
64
64
  httpd = HTTPServer(server_address, MockOptionsHandler)
65
65
 
66
66
  def start_http_server_in_thread():
@@ -50,12 +50,13 @@ def test_commits_and_rollbacks(session, clean_bucket_name):
50
50
  b = tx.bucket(clean_bucket_name)
51
51
  b.schema("s3").drop()
52
52
  assert b.schemas() == []
53
- 1/0 # rollback schema dropping
53
+ 1 / 0 # rollback schema dropping
54
54
 
55
55
  with session.transaction() as tx:
56
56
  b = tx.bucket(clean_bucket_name)
57
57
  assert b.schemas() != []
58
58
 
59
+
59
60
  def test_list_snapshots(session, clean_bucket_name):
60
61
  with session.transaction() as tx:
61
62
  b = tx.bucket(clean_bucket_name)