vastdb 0.0.5.1__py3-none-any.whl → 0.0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/api.py CHANGED
@@ -71,8 +71,6 @@ import vast_flatbuf.tabular.S3File as tabular_s3_file
71
71
  import vast_flatbuf.tabular.CreateProjectionRequest as tabular_create_projection
72
72
  import vast_flatbuf.tabular.Column as tabular_projecion_column
73
73
  import vast_flatbuf.tabular.ColumnType as tabular_proj_column_type
74
- import vast_protobuf.tabular.rpc_pb2 as rpc_pb
75
- import vast_protobuf.substrait.type_pb2 as type_pb
76
74
 
77
75
  from vast_flatbuf.org.apache.arrow.computeir.flatbuf.Deref import Deref
78
76
  from vast_flatbuf.org.apache.arrow.computeir.flatbuf.ExpressionImpl import ExpressionImpl
@@ -759,6 +757,7 @@ class VastdbApi:
759
757
  if not port:
760
758
  port = 443 if secure else 80
761
759
 
760
+ self.default_max_list_columns_page_size = 1000
762
761
  self.session = requests.Session()
763
762
  self.session.verify = False
764
763
  self.session.headers['user-agent'] = "VastData Tabular API 1.0 - 2022 (c)"
@@ -932,6 +931,7 @@ class VastdbApi:
932
931
  headers['tabular-list-count-only'] = str(count_only)
933
932
 
934
933
  schemas = []
934
+ schema = schema or ""
935
935
  res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, command="schema"), headers=headers, stream=True)
936
936
  self._check_res(res, "list_schemas", expected_retvals)
937
937
  if res.status_code == 200:
@@ -1208,7 +1208,7 @@ class VastdbApi:
1208
1208
  data=serialized_schema, headers=headers)
1209
1209
  return self._check_res(res, "drop_columns", expected_retvals)
1210
1210
 
1211
- def list_columns(self, bucket, schema, table, *, txid=0, client_tags=None, max_keys=1000, next_key=0,
1211
+ def list_columns(self, bucket, schema, table, *, txid=0, client_tags=None, max_keys=None, next_key=0,
1212
1212
  count_only=False, name_prefix="", exact_match=False,
1213
1213
  expected_retvals=None, bc_list_internals=False):
1214
1214
  """
@@ -1219,6 +1219,7 @@ class VastdbApi:
1219
1219
  tabular-max-keys: 1000
1220
1220
  tabular-next-key: NextColumnId
1221
1221
  """
1222
+ max_keys = max_keys or self.default_max_list_columns_page_size
1222
1223
  client_tags = client_tags or []
1223
1224
  expected_retvals = expected_retvals or []
1224
1225
 
@@ -1393,16 +1394,15 @@ class VastdbApi:
1393
1394
  data=params, headers=headers, stream=True)
1394
1395
  return self._check_res(res, "query_data", expected_retvals)
1395
1396
 
1396
- def _list_table_columns(self, bucket, schema, table, filters=None, field_names=None):
1397
+ def _list_table_columns(self, bucket, schema, table, filters=None, field_names=None, txid=0):
1397
1398
  # build a list of the queried column names
1398
1399
  queried_columns = []
1399
1400
  # get all columns from the table
1400
1401
  all_listed_columns = []
1401
1402
  next_key = 0
1402
1403
  while True:
1403
- cur_columns, next_key, is_truncated, count = self.list_columns(bucket=bucket, schema=schema,
1404
- table=table,
1405
- next_key=next_key)
1404
+ cur_columns, next_key, is_truncated, count = self.list_columns(
1405
+ bucket=bucket, schema=schema, table=table, next_key=next_key, txid=txid)
1406
1406
  if not cur_columns:
1407
1407
  break
1408
1408
  all_listed_columns.extend(cur_columns)
@@ -1454,20 +1454,29 @@ class VastdbApi:
1454
1454
 
1455
1455
  return txid, created_txid
1456
1456
 
1457
- def _prepare_query(self, bucket, schema, table, num_sub_splits, filters=None, field_names=None, queried_columns=None):
1458
- if not queried_columns:
1459
- queried_columns = self._list_table_columns(bucket, schema, table, filters, field_names)
1460
- arrow_schema = pa.schema([(column[0], column[1]) for column in queried_columns])
1461
- _logger.debug(f'_prepare_query: arrow_schema = {arrow_schema}')
1462
- query_data_request = build_query_data_request(schema=arrow_schema, filters=filters, field_names=field_names)
1463
- if self.executor_hosts:
1464
- executor_hosts = self.executor_hosts
1465
- else:
1466
- executor_hosts = [self.host]
1467
- executor_sessions = [VastdbApi(executor_hosts[i], self.access_key, self.secret_key, self.username,
1468
- self.password, self.port, self.secure, self.auth_type) for i in range(len(executor_hosts))]
1457
+ def _prepare_query(self, bucket, schema, table, num_sub_splits, filters=None, field_names=None,
1458
+ queried_columns=None, response_row_id=False, txid=0):
1459
+ queried_fields = []
1460
+ if response_row_id:
1461
+ queried_fields.append(pa.field('$row_id', pa.uint64()))
1462
+
1463
+ if not queried_columns:
1464
+ queried_columns = self._list_table_columns(bucket, schema, table, filters, field_names, txid=txid)
1469
1465
 
1470
- return queried_columns, arrow_schema, query_data_request, executor_sessions
1466
+ queried_fields.extend(pa.field(column[0], column[1]) for column in queried_columns)
1467
+ arrow_schema = pa.schema(queried_fields)
1468
+
1469
+ _logger.debug(f'_prepare_query: arrow_schema = {arrow_schema}')
1470
+
1471
+ query_data_request = build_query_data_request(schema=arrow_schema, filters=filters, field_names=field_names)
1472
+ if self.executor_hosts:
1473
+ executor_hosts = self.executor_hosts
1474
+ else:
1475
+ executor_hosts = [self.host]
1476
+ executor_sessions = [VastdbApi(executor_hosts[i], self.access_key, self.secret_key, self.username,
1477
+ self.password, self.port, self.secure, self.auth_type) for i in range(len(executor_hosts))]
1478
+
1479
+ return queried_columns, arrow_schema, query_data_request, executor_sessions
1471
1480
 
1472
1481
  def _more_pages_exist(self, start_row_ids):
1473
1482
  for row_id in start_row_ids.values():
@@ -1561,7 +1570,7 @@ class VastdbApi:
1561
1570
  try:
1562
1571
  # prepare query
1563
1572
  queried_columns, arrow_schema, query_data_request, executor_sessions = \
1564
- self._prepare_query(bucket, schema, table, num_sub_splits, filters, field_names)
1573
+ self._prepare_query(bucket, schema, table, num_sub_splits, filters, field_names, response_row_id=response_row_id, txid=txid)
1565
1574
 
1566
1575
  # define the per split threaded query func
1567
1576
  def query_iterator_split_id(self, split_id):
@@ -1635,7 +1644,12 @@ class VastdbApi:
1635
1644
  if record_batch:
1636
1645
  # signal to the thread to read the next record batch and yield the current
1637
1646
  next_sems[split_id].release()
1638
- yield record_batch
1647
+ try:
1648
+ yield record_batch
1649
+ except GeneratorExit:
1650
+ killall = True
1651
+ _logger.debug("cancelling query_iterator")
1652
+ raise
1639
1653
  else:
1640
1654
  done_count += 1
1641
1655
 
@@ -1730,7 +1744,7 @@ class VastdbApi:
1730
1744
  try:
1731
1745
  # prepare query
1732
1746
  queried_columns, arrow_schema, query_data_request, executor_sessions = \
1733
- self._prepare_query(bucket, schema, table, num_sub_splits, filters, field_names)
1747
+ self._prepare_query(bucket, schema, table, num_sub_splits, filters, field_names, response_row_id=response_row_id, txid=txid)
1734
1748
 
1735
1749
  # define the per split threaded query func
1736
1750
  def query_split_id(self, split_id):
@@ -1995,7 +2009,7 @@ class VastdbApi:
1995
2009
  txid, created_txid = self._begin_tx_if_necessary(txid)
1996
2010
 
1997
2011
  if rows:
1998
- columns = self._list_table_columns(bucket, schema, table, field_names=rows.keys())
2012
+ columns = self._list_table_columns(bucket, schema, table, field_names=rows.keys(), txid=txid)
1999
2013
  columns_dict = dict([(column[0], column[1]) for column in columns])
2000
2014
  arrow_schema = pa.schema([])
2001
2015
  arrays = []
@@ -2324,232 +2338,6 @@ class VastdbApi:
2324
2338
 
2325
2339
  return columns, next_key, is_truncated, count
2326
2340
 
2327
- def parse_proto_buf_message(conn, msg_type):
2328
- msg_size = 0
2329
- while msg_size == 0: # keepalive
2330
- msg_size_bytes = conn.read(4)
2331
- msg_size, = struct.unpack('>L', msg_size_bytes)
2332
-
2333
- msg = msg_type()
2334
- msg_bytes = conn.read(msg_size)
2335
- msg.ParseFromString(msg_bytes)
2336
- return msg
2337
-
2338
- def parse_rpc_message(conn, msg_name):
2339
- rpc_msg = parse_proto_buf_message(conn, rpc_pb.Rpc)
2340
- if not rpc_msg.HasField(msg_name):
2341
- raise IOError(f"expected {msg_name} but got rpc_msg={rpc_msg}")
2342
-
2343
- content_size = rpc_msg.content_size
2344
- content = conn.read(content_size)
2345
- return getattr(rpc_msg, msg_name), content
2346
-
2347
-
2348
- def parse_select_row_ids_response(conn, debug=False):
2349
- rows_arr = array.array('Q', [])
2350
- subsplits_state = {}
2351
- while True:
2352
- select_rows_msg, content = parse_rpc_message(conn, 'select_row_ids_response_packet')
2353
- msg_type = select_rows_msg.WhichOneof('type')
2354
- if msg_type == "body":
2355
- subsplit_id = select_rows_msg.body.subsplit.id
2356
- if select_rows_msg.body.subsplit.HasField("state"):
2357
- subsplits_state[subsplit_id] = select_rows_msg.body.subsplit.state
2358
-
2359
- arr = array.array('Q', content)
2360
- rows_arr += arr
2361
- if debug:
2362
- _logger.info(f"arr={arr} metrics={select_rows_msg.body.metrics}")
2363
- else:
2364
- _logger.info(f"num_rows={len(arr)} metrics={select_rows_msg.body.metrics}")
2365
- elif msg_type == "trailing":
2366
- status_code = select_rows_msg.trailing.status.code
2367
- finished_pagination = select_rows_msg.trailing.finished_pagination
2368
- total_metrics = select_rows_msg.trailing.metrics
2369
- _logger.info(f"completed finished_pagination={finished_pagination} res={status_code} metrics={total_metrics}")
2370
- if status_code != 0:
2371
- raise IOError(f"Query data stream failed res={select_rows_msg.trailing.status}")
2372
-
2373
- return rows_arr, subsplits_state, finished_pagination
2374
- else:
2375
- raise EOFError(f"unknown response type={msg_type}")
2376
-
2377
-
2378
- def parse_count_rows_response(conn):
2379
- count_rows_msg, _ = parse_rpc_message(conn, 'count_rows_response_packet')
2380
- assert count_rows_msg.WhichOneof('type') == "body"
2381
- subsplit_id = count_rows_msg.body.subsplit.id
2382
- num_rows = count_rows_msg.body.amount_of_rows
2383
- _logger.info(f"completed num_rows={num_rows} subsplit_id={subsplit_id} metrics={count_rows_msg.trailing.metrics}")
2384
-
2385
- count_rows_msg, _ = parse_rpc_message(conn, 'count_rows_response_packet')
2386
- assert count_rows_msg.WhichOneof('type') == "trailing"
2387
- assert count_rows_msg.trailing.status.code == 0
2388
- assert count_rows_msg.trailing.finished_pagination
2389
-
2390
- return (subsplit_id, num_rows)
2391
-
2392
-
2393
- def get_proto_field_type(f):
2394
- t = type_pb.Type()
2395
- if f.type.equals(pa.string()):
2396
- t.string.nullability = 0
2397
- elif f.type.equals(pa.int8()):
2398
- t.i8.nullability = 0
2399
- elif f.type.equals(pa.int16()):
2400
- t.i16.nullability = 0
2401
- elif f.type.equals(pa.int32()):
2402
- t.i32.nullability = 0
2403
- elif f.type.equals(pa.int64()):
2404
- t.i64.nullability = 0
2405
- elif f.type.equals(pa.float32()):
2406
- t.fp32.nullability = 0
2407
- elif f.type.equals(pa.float64()):
2408
- t.fp64.nullability = 0
2409
- else:
2410
- raise ValueError(f'unsupported type={f.type}')
2411
-
2412
- return t
2413
-
2414
- def serialize_proto_request(req):
2415
- req_str = req.SerializeToString()
2416
- buf = struct.pack('>L', len(req_str))
2417
- buf += req_str
2418
- return buf
2419
-
2420
- def build_read_column_request(ids, schema, handles = [], num_subsplits = 1):
2421
- rpc_msg = rpc_pb.Rpc()
2422
- req = rpc_msg.read_columns_request
2423
- req.num_subsplits = num_subsplits
2424
- block = req.row_ids_blocks.add()
2425
- block.row_ids.info.offset = 0
2426
- block.row_ids.info.size = len(ids)
2427
- rpc_msg.content_size = len(ids)
2428
- if handles:
2429
- req.projection_table_handles.extend(handles)
2430
-
2431
- for f in schema:
2432
- req.column_schema.names.append(f.name)
2433
- t = get_proto_field_type(f)
2434
- req.column_schema.struct.types.append(t)
2435
-
2436
- return serialize_proto_request(rpc_msg) + ids
2437
-
2438
- def build_count_rows_request(schema: 'pa.Schema' = pa.schema([]), filters: dict = None, field_names: list = None,
2439
- split=(0, 1, 1), num_subsplits=1, build_relation=False):
2440
- rpc_msg = rpc_pb.Rpc()
2441
- req = rpc_msg.count_rows_request
2442
- req.split.id = split[0]
2443
- req.split.config.total = split[1]
2444
- req.split.config.row_groups_per_split = split[2]
2445
- # add empty state
2446
- state = rpc_pb.SubSplit.State()
2447
- for _ in range(num_subsplits):
2448
- req.subsplits.states.append(state)
2449
-
2450
- if build_relation:
2451
- # TODO use ibis or other library to build substrait relation
2452
- # meanwhile can be similar to build_count_rows_request
2453
- for field in schema:
2454
- req.relation.read.base_schema.names.append(field.name)
2455
- field_type = get_proto_field_type(field)
2456
- req.relation.read.base_schema.struct.types.append(field_type)
2457
- return serialize_proto_request(rpc_msg)
2458
- else:
2459
- query_data_flatbuffer = build_query_data_request(schema, filters, field_names)
2460
- serialized_flatbuffer = query_data_flatbuffer.serialized
2461
- req.legacy_relation.size = len(serialized_flatbuffer)
2462
- req.legacy_relation.offset = 0
2463
- rpc_msg.content_size = req.legacy_relation.size
2464
- return serialize_proto_request(rpc_msg) + serialized_flatbuffer
2465
-
2466
- """
2467
- Expected messages in the ReadColumns flow:
2468
-
2469
- ProtoMsg+Schema+RecordBatch,
2470
- ProtoMsg+RecordBatch
2471
- ProtoMsg+RecordBatch
2472
- ...
2473
- ProtoMsg+RecordBatch+EOS
2474
- ProtoMsg+Schema+RecordBatch,
2475
- ...
2476
- ProtoMsg+RecordBatch+EOS
2477
- ProtoMsg+Schema+RecordBatch,
2478
- ...
2479
- ProtoMsg+RecordBatch+EOS
2480
- ProtoMsg Completed
2481
- """
2482
- def _iter_read_column_resp_columns(conn, readers):
2483
- while True:
2484
- read_column_resp, content = parse_rpc_message(conn, 'read_columns_response_packet')
2485
- stream = BytesIO(content)
2486
-
2487
- msg_type = read_column_resp.WhichOneof('type')
2488
- if msg_type == "body":
2489
- stream_id = read_column_resp.body.subsplit_id
2490
- start_row_offset = read_column_resp.body.start_row_offset
2491
- arrow_msg_size = read_column_resp.body.arrow_ipc_info.size
2492
- metrics = read_column_resp.body.metrics
2493
- _logger.info(f"start stream_id={stream_id} arrow_msg_size={arrow_msg_size} start_row_offset={start_row_offset} metrics={metrics}")
2494
- elif msg_type == "trailing":
2495
- status_code = read_column_resp.trailing.status.code
2496
- _logger.info(f"completed stream_id={stream_id} res={status_code} metrics{read_column_resp.trailing.metrics}")
2497
- if status_code != 0:
2498
- raise IOError(f"Query data stream failed res={read_column_resp.trailing.status}")
2499
-
2500
- return
2501
- else:
2502
- raise EOFError(f"unknown response type={msg_type}")
2503
-
2504
- start_pos = stream.tell()
2505
- if stream_id not in readers:
2506
- # we implicitly read 1st message (Arrow schema) when constructing RecordBatchStreamReader
2507
- reader = pa.ipc.RecordBatchStreamReader(stream)
2508
- _logger.info(f"read ipc stream_id={stream_id} schema={reader.schema}")
2509
- readers[stream_id] = (reader, [])
2510
-
2511
- (reader, batches) = readers[stream_id]
2512
- while stream.tell() - start_pos < arrow_msg_size:
2513
- try:
2514
- batch = reader.read_next_batch() # read single-column chunk data
2515
- batches.append(batch)
2516
- except StopIteration: # we got an end-of-stream IPC message for a given stream ID
2517
- reader, batches = readers.pop(stream_id) # end of column
2518
- table = pa.Table.from_batches(batches) # concatenate all column chunks (as a single)
2519
- _logger.info(f"end of stream_id={stream_id} rows={len(table)} column={table}")
2520
- yield (start_row_offset, stream_id, table)
2521
-
2522
- ResponsePart = namedtuple('response_part', ['start_row_offset', 'table'])
2523
-
2524
- def _parse_read_column_stream(conn, schema, debug=False):
2525
- is_empty_projection = (len(schema) == 0)
2526
- parsers = defaultdict(lambda: QueryDataParser(schema, debug=debug)) # {stream_id: QueryDataParser}
2527
- readers = {} # {stream_id: pa.ipc.RecordBatchStreamReader}
2528
- streams_list = []
2529
- for start_row_offset, stream_id, table in _iter_read_column_resp_columns(conn, readers):
2530
- parser = parsers[stream_id]
2531
- for column in table.columns:
2532
- parser.parse(column)
2533
-
2534
- parsed_table = parser.build()
2535
- if parsed_table is not None: # when we got all columns (and before starting a new "select_rows" cycle)
2536
- parsers.pop(stream_id)
2537
- if is_empty_projection: # VAST returns an empty RecordBatch, with the correct rows' count
2538
- parsed_table = table
2539
-
2540
- _logger.info(f"parse_read_column_response stream_id={stream_id} rows={len(parsed_table)} table={parsed_table}")
2541
- streams_list.append(ResponsePart(start_row_offset, parsed_table))
2542
-
2543
- if parsers:
2544
- raise EOFError(f'all streams should be done before EOF. {parsers}')
2545
-
2546
- return streams_list
2547
-
2548
- def parse_read_column_response(conn, schema, debug=False):
2549
- response_parts = _parse_read_column_stream(conn, schema, debug)
2550
- response_parts.sort(key=lambda s: s.start_row_offset)
2551
- tables = [s.table for s in response_parts]
2552
- return pa.concat_tables(tables)
2553
2341
 
2554
2342
  def _iter_query_data_response_columns(fileobj, stream_ids=None):
2555
2343
  readers = {} # {stream_id: pa.ipc.RecordBatchStreamReader}
@@ -2837,44 +2625,6 @@ class QueryDataRequest:
2837
2625
  self.response_schema = response_schema
2838
2626
 
2839
2627
 
2840
- def build_select_rows_request(schema: 'pa.Schema' = pa.schema([]), filters: dict = None, field_names: list = None, split_id=0,
2841
- total_split=1, row_group_per_split=8, num_subsplits=1, build_relation=False, limit_rows=0,
2842
- subsplits_state=None):
2843
- rpc_msg = rpc_pb.Rpc()
2844
- select_rows_req = rpc_msg.select_row_ids_request
2845
- select_rows_req.split.id = split_id
2846
- select_rows_req.split.config.total = total_split
2847
- select_rows_req.split.config.row_groups_per_split = row_group_per_split
2848
- if limit_rows:
2849
- select_rows_req.limit_rows = limit_rows
2850
-
2851
- # add empty state
2852
- empty_state = rpc_pb.SubSplit.State()
2853
- for i in range(num_subsplits):
2854
- if subsplits_state and i in subsplits_state:
2855
- select_rows_req.subsplits.states.append(subsplits_state[i])
2856
- else:
2857
- select_rows_req.subsplits.states.append(empty_state)
2858
-
2859
- if build_relation:
2860
- # TODO use ibis or other library to build substrait relation
2861
- # meanwhile can be similar to build_count_rows_request
2862
- for field in schema:
2863
- select_rows_req.relation.read.base_schema.names.append(field.name)
2864
- field_type = get_proto_field_type(field)
2865
- select_rows_req.relation.read.base_schema.struct.types.append(field_type)
2866
- return serialize_proto_request(rpc_msg)
2867
- else:
2868
- query_data_flatbuffer = build_query_data_request(schema, filters, field_names)
2869
- serialized_flatbuffer = query_data_flatbuffer.serialized
2870
- select_rows_req.legacy_relation.size = len(serialized_flatbuffer)
2871
- select_rows_req.legacy_relation.offset = 0
2872
- rpc_msg.content_size = select_rows_req.legacy_relation.size
2873
- return serialize_proto_request(rpc_msg) + serialized_flatbuffer
2874
-
2875
- # TODO use ibis or other library to build SelectRowIds protobuf
2876
- # meanwhile can be similar to build_count_rows_request
2877
-
2878
2628
  def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), filters: dict = None, field_names: list = None):
2879
2629
  filters = filters or {}
2880
2630
 
vastdb/bench_scan.py ADDED
@@ -0,0 +1,45 @@
1
+ from vastdb import api
2
+
3
+ from logbook import Logger, StreamHandler
4
+ import sys
5
+ import time
6
+ import pprint
7
+
8
+ StreamHandler(sys.stdout).push_application()
9
+ log = Logger('Logbook')
10
+
11
+ # access_key_id=F3YUMQZDQB60ZZJ1PBAZ
12
+ # secret_access_key=9a9Q3if6IC5LjUexly/nXFv1UCANBnhGxi++Sw6p
13
+
14
+ a = api.VastdbApi(
15
+ access_key='F3YUMQZDQB60ZZJ1PBAZ',
16
+ secret_key='9a9Q3if6IC5LjUexly/nXFv1UCANBnhGxi++Sw6p',
17
+ host='172.19.111.1:172.19.111.16')
18
+
19
+ kwargs = dict(
20
+ bucket='tabular-slothful-jocular-jack',
21
+ schema='tpcds_schema_create_as_select',
22
+ table='store_sales',
23
+ field_names=['ss_sold_date_sk', 'ss_sold_time_sk', 'ss_item_sk'],
24
+ filters={'ss_item_sk': ['le 1']},
25
+ num_sub_splits=8)
26
+
27
+ pprint.pprint(kwargs)
28
+
29
+ res = a.query_iterator(**kwargs)
30
+
31
+ total_bytes = 0
32
+ total_rows = 0
33
+ start = time.time()
34
+ last_log = None
35
+
36
+ for b in res:
37
+ total_bytes += b.get_total_buffer_size()
38
+ total_rows += len(b)
39
+ dt = time.time() - start
40
+ if last_log != int(dt):
41
+ log.info("{:.3f} Mrow/s, {:.3f} MB/s", (total_rows/dt) / 1e6, (total_bytes/dt) / 1e6)
42
+ last_log = int(dt)
43
+
44
+ dt = time.time() - start
45
+ log.info("Done after {:.3f} seconds, {:.3f} Mrows, {:.3f} MB", dt, total_rows / 1e6, total_bytes / 1e6)
File without changes
@@ -0,0 +1,45 @@
1
+ import pytest
2
+ import boto3
3
+
4
+ from vastdb import v2
5
+
6
+
7
+ def pytest_addoption(parser):
8
+ parser.addoption("--tabular-bucket-name", help="Name of the S3 bucket with Tabular enabled")
9
+ parser.addoption("--tabular-access-key", help="Access key with Tabular permissions")
10
+ parser.addoption("--tabular-secret-key", help="Secret key with Tabular permissions")
11
+ parser.addoption("--tabular-endpoint-url", help="Tabular server endpoint")
12
+
13
+
14
+ @pytest.fixture(scope="module")
15
+ def rpc(request):
16
+ return v2.connect(
17
+ access=request.config.getoption("--tabular-access-key"),
18
+ secret=request.config.getoption("--tabular-secret-key"),
19
+ endpoint=request.config.getoption("--tabular-endpoint-url"),
20
+ )
21
+
22
+
23
+ @pytest.fixture(scope="module")
24
+ def test_bucket_name(request):
25
+ return request.config.getoption("--tabular-bucket-name")
26
+
27
+
28
+ @pytest.fixture(scope="module")
29
+ def clean_bucket_name(request, test_bucket_name, rpc):
30
+ with rpc.transaction() as tx:
31
+ b = tx.bucket(test_bucket_name)
32
+ for s in b.schemas():
33
+ for t in s.tables():
34
+ t.drop()
35
+ s.drop()
36
+ return test_bucket_name
37
+
38
+
39
+ @pytest.fixture(scope="module")
40
+ def s3(request):
41
+ return boto3.client(
42
+ 's3',
43
+ aws_access_key_id=request.config.getoption("--tabular-access-key"),
44
+ aws_secret_access_key=request.config.getoption("--tabular-secret-key"),
45
+ endpoint_url=request.config.getoption("--tabular-endpoint-url"))
@@ -0,0 +1,50 @@
1
+ import pytest
2
+ import os
3
+
4
+ import pyarrow as pa
5
+ import pyarrow.parquet as pq
6
+
7
+ from vastdb.v2 import InvalidArgumentError
8
+ from vastdb import util
9
+
10
+
11
+ def test_create_table_from_files(rpc, clean_bucket_name, s3):
12
+ datasets = [
13
+ {'num': [0],
14
+ 'varch': ['z']},
15
+ {'num': [1, 2, 3, 4, 5],
16
+ 'varch': ['a', 'b', 'c', 'd', 'e']},
17
+ {'num': [1, 2, 3, 4, 5],
18
+ 'bool': [True, False, None, None, False],
19
+ 'varch': ['a', 'b', 'c', 'd', 'e']},
20
+ {'num': [1, 2],
21
+ 'bool': [True, True]},
22
+ {'varch': ['a', 'b', 'c'],
23
+ 'mismatch': [1, 2, 3]}
24
+ ]
25
+ for i, ds in enumerate(datasets):
26
+ table = pa.Table.from_pydict(ds)
27
+ pq.write_table(table, f'prq{i}')
28
+ with open(f'prq{i}', 'rb') as f:
29
+ s3.put_object(Bucket=clean_bucket_name, Key=f'prq{i}', Body=f)
30
+ os.remove(f'prq{i}')
31
+
32
+ same_schema_files = [f'/{clean_bucket_name}/prq{i}' for i in range(2)]
33
+ contained_schema_files = [f'/{clean_bucket_name}/prq{i}' for i in range(4)]
34
+ different_schema_files = [f'/{clean_bucket_name}/prq{i}' for i in range(5)]
35
+
36
+ with rpc.transaction() as tx:
37
+ b = tx.bucket(clean_bucket_name)
38
+ s = b.create_schema('s1')
39
+ t = util.create_table_from_files(s, 't1', contained_schema_files)
40
+ assert len(t.arrow_schema) == 3
41
+ assert t.arrow_schema == pa.schema([('num', pa.int64()), ('bool', pa.bool_()), ('varch', pa.string())])
42
+
43
+ with pytest.raises(InvalidArgumentError):
44
+ util.create_table_from_files(s, 't2', different_schema_files)
45
+
46
+ with pytest.raises(InvalidArgumentError):
47
+ util.create_table_from_files(s, 't2', contained_schema_files, schema_merge_func=util.strict_schema_merge)
48
+
49
+ util.create_table_from_files(s, 't2', different_schema_files, schema_merge_func=util.union_schema_merge)
50
+ util.create_table_from_files(s, 't3', same_schema_files, schema_merge_func=util.strict_schema_merge)
@@ -0,0 +1,63 @@
1
+ import logging
2
+
3
+ import threading
4
+ from http.server import HTTPServer, BaseHTTPRequestHandler
5
+ from vastdb import api
6
+ from itertools import cycle
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+ def test_hello_world(rpc):
11
+ with rpc.transaction() as tx:
12
+ assert tx.txid is not None
13
+
14
+ def test_version_extraction():
15
+ # A list of version and expected version parsed by API
16
+ TEST_CASES = [
17
+ (None, None), # vast server without version in header
18
+ ("5", None), # major only is not supported
19
+ ("5.2", "5.2"), # major.minor
20
+ ("5.2.0", "5.2.0"), # major.minor.patch
21
+ ("5.2.0.0", "5.2.0.0"), # major.minor.patch.protocol
22
+ ("5.2.0.0 some other things", "5.2.0.0"), # Test forward comptibility 1
23
+ ("5.2.0.0.20 some other things", "5.2.0.0"), # Test forward comptibility 2
24
+ ]
25
+
26
+ # Mock OPTIONS handle that cycles through the test cases response
27
+ class MockOptionsHandler(BaseHTTPRequestHandler):
28
+ versions_iterator = cycle(TEST_CASES)
29
+
30
+ def __init__(self, *args) -> None:
31
+ super().__init__(*args)
32
+
33
+ def do_OPTIONS(self):
34
+ self.send_response(204)
35
+ self.end_headers()
36
+
37
+ def version_string(self):
38
+ version = next(self.versions_iterator)[0]
39
+ return f"vast {version}" if version else "vast"
40
+
41
+ def log_message(self, format, *args):
42
+ log.debug(format,*args)
43
+
44
+ # start the server on localhost on some available port port
45
+ server_address =('localhost', 0)
46
+ httpd = HTTPServer(server_address, MockOptionsHandler)
47
+
48
+ def start_http_server_in_thread():
49
+ log.info(f"Mock HTTP server is running on port {httpd.server_port}")
50
+ httpd.serve_forever()
51
+ log.info("Mock HTTP server killed")
52
+
53
+ # start the server in a thread so we have the main thread to operate the API
54
+ server_thread = threading.Thread(target=start_http_server_in_thread)
55
+ server_thread.start()
56
+
57
+ try:
58
+ for test_case in TEST_CASES:
59
+ tester = api.VastdbApi(endpoint=f"http://localhost:{httpd.server_port}", access_key="abc", secret_key="abc")
60
+ assert tester.vast_version == test_case[1]
61
+ finally:
62
+ # make sure we shut the server down no matter what
63
+ httpd.shutdown()
@@ -0,0 +1,39 @@
1
+ import pytest
2
+
3
+
4
+ def test_schemas(rpc, clean_bucket_name):
5
+ with rpc.transaction() as tx:
6
+ b = tx.bucket(clean_bucket_name)
7
+ assert b.schemas() == []
8
+
9
+ s = b.create_schema('s1')
10
+ assert s.bucket == b
11
+ assert b.schemas() == [s]
12
+
13
+ s.rename('s2')
14
+ assert s.bucket == b
15
+ assert s.name == 's2'
16
+ assert b.schemas()[0].name == 's2'
17
+
18
+ s.drop()
19
+ assert b.schemas() == []
20
+
21
+
22
+ def test_commits_and_rollbacks(rpc, clean_bucket_name):
23
+ with rpc.transaction() as tx:
24
+ b = tx.bucket(clean_bucket_name)
25
+ assert b.schemas() == []
26
+ b.create_schema("s3")
27
+ assert b.schemas() != []
28
+ # implicit commit
29
+
30
+ with pytest.raises(ZeroDivisionError):
31
+ with rpc.transaction() as tx:
32
+ b = tx.bucket(clean_bucket_name)
33
+ b.schema("s3").drop()
34
+ assert b.schemas() == []
35
+ 1/0 # rollback schema dropping
36
+
37
+ with rpc.transaction() as tx:
38
+ b = tx.bucket(clean_bucket_name)
39
+ assert b.schemas() != []
@@ -0,0 +1,40 @@
1
+ import pyarrow as pa
2
+
3
+
4
+ def test_tables(rpc, clean_bucket_name):
5
+ with rpc.transaction() as tx:
6
+ s = tx.bucket(clean_bucket_name).create_schema('s1')
7
+ columns = pa.schema([
8
+ ('a', pa.int16()),
9
+ ('b', pa.float32()),
10
+ ('s', pa.utf8()),
11
+ ])
12
+ assert s.tables() == []
13
+ t = s.create_table('t1', columns)
14
+ assert s.tables() == [t]
15
+
16
+ rb = pa.record_batch(schema=columns, data=[
17
+ [111, 222],
18
+ [0.5, 1.5],
19
+ ['a', 'b'],
20
+ ])
21
+ expected = pa.Table.from_batches([rb])
22
+ t.insert(rb)
23
+
24
+ actual = pa.Table.from_batches(t.select(columns=['a', 'b', 's']))
25
+ assert actual == expected
26
+
27
+ actual = pa.Table.from_batches(t.select(columns=['a', 'b']))
28
+ assert actual == expected.select(['a', 'b'])
29
+
30
+ actual = pa.Table.from_batches(t.select(columns=['b', 's', 'a']))
31
+ assert actual == expected.select(['b', 's', 'a'])
32
+
33
+ actual = pa.Table.from_batches(t.select(columns=['s']))
34
+ assert actual == expected.select(['s'])
35
+
36
+ actual = pa.Table.from_batches(t.select(columns=[]))
37
+ assert actual == expected.select([])
38
+
39
+ t.drop()
40
+ s.drop()
vastdb/util.py ADDED
@@ -0,0 +1,77 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ import pyarrow as pa
5
+ import pyarrow.parquet as pq
6
+
7
+ from vastdb.v2 import InvalidArgumentError, Table, Schema
8
+
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ def create_table_from_files(
14
+ schema: Schema, table_name: str, parquet_files: [str], schema_merge_func: Callable = None) -> Table:
15
+ if not schema_merge_func:
16
+ schema_merge_func = default_schema_merge
17
+ else:
18
+ assert schema_merge_func in [default_schema_merge, strict_schema_merge, union_schema_merge]
19
+ tx = schema.tx
20
+ current_schema = pa.schema([])
21
+ s3fs = pa.fs.S3FileSystem(
22
+ access_key=tx._rpc.api.access_key, secret_key=tx._rpc.api.secret_key, endpoint_override=tx._rpc.api.url)
23
+ for prq_file in parquet_files:
24
+ if not prq_file.startswith('/'):
25
+ raise InvalidArgumentError(f"Path {prq_file} must start with a '/'")
26
+ parquet_ds = pq.ParquetDataset(prq_file.lstrip('/'), filesystem=s3fs)
27
+ current_schema = schema_merge_func(current_schema, parquet_ds.schema)
28
+
29
+
30
+ log.info("Creating table %s from %d Parquet files, with columns: %s",
31
+ table_name, len(parquet_files), list(current_schema))
32
+ table = schema.create_table(table_name, current_schema)
33
+
34
+ log.info("Starting import of %d files to table: %s", len(parquet_files), table)
35
+ table.import_files(parquet_files)
36
+ log.info("Finished import of %d files to table: %s", len(parquet_files), table)
37
+ return table
38
+
39
+
40
+ def default_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema:
41
+ """
42
+ This function validates a schema is contained in another schema
43
+ Raises an InvalidArgumentError if a certain field does not exist in the target schema
44
+ """
45
+ if not current_schema.names:
46
+ return new_schema
47
+ s1 = set(current_schema)
48
+ s2 = set(new_schema)
49
+
50
+ if len(s1) > len(s2):
51
+ s1, s2 = s2, s1
52
+ result = current_schema # We need this variable in order to preserve the original fields order
53
+ else:
54
+ result = new_schema
55
+
56
+ if not s1.issubset(s2):
57
+ log.error("Schema mismatch. schema: %s isn't contained in schema: %s.", s1, s2)
58
+ raise InvalidArgumentError("Found mismatch in parquet files schemas.")
59
+ return result
60
+
61
+
62
+ def strict_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema:
63
+ """
64
+ This function validates two Schemas are identical.
65
+ Raises an InvalidArgumentError if schemas aren't identical.
66
+ """
67
+ if current_schema.names and current_schema != new_schema:
68
+ raise InvalidArgumentError(f"Schemas are not identical. \n {current_schema} \n vs \n {new_schema}")
69
+
70
+ return new_schema
71
+
72
+
73
+ def union_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema:
74
+ """
75
+ This function returns a unified schema from potentially two different schemas.
76
+ """
77
+ return pa.unify_schemas([current_schema, new_schema])
vastdb/v2.py CHANGED
@@ -1,108 +1,360 @@
1
- from vastdb import *
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+ import os
2
4
 
5
+ import boto3
6
+ import botocore
7
+ import ibis
8
+ import pyarrow as pa
9
+ import requests
3
10
 
4
- class Context:
5
- tx: int
6
- _rpc: RPC
11
+ from vastdb.api import VastdbApi, serialize_record_batch, build_query_data_request, parse_query_data_response, TABULAR_INVALID_ROW_ID
7
12
 
8
- def bucket(name: str) -> Bucket
9
13
 
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class VastException(Exception):
18
+ pass
19
+
20
+
21
+ class NotFoundError(VastException):
22
+ pass
23
+
24
+
25
+ class AccessDeniedError(VastException):
26
+ pass
27
+
28
+
29
+ class ImportFilesError(VastException):
30
+ pass
31
+
32
+
33
+ class InvalidArgumentError(VastException):
34
+ pass
35
+
36
+
37
+ class RPC:
38
+ def __init__(self, access=None, secret=None, endpoint=None):
39
+ if access is None:
40
+ access = os.environ['AWS_ACCESS_KEY_ID']
41
+ if secret is None:
42
+ secret = os.environ['AWS_SECRET_ACCESS_KEY']
43
+ if endpoint is None:
44
+ endpoint = os.environ['AWS_S3_ENDPOINT_URL']
45
+
46
+ self.api = VastdbApi(endpoint, access, secret)
47
+ self.s3 = boto3.client('s3',
48
+ aws_access_key_id=access,
49
+ aws_secret_access_key=secret,
50
+ endpoint_url=endpoint)
51
+
52
+ def __repr__(self):
53
+ return f'RPC(endpoint={self.api.url}, access={self.api.access_key})'
54
+
55
+ def transaction(self):
56
+ return Transaction(self)
57
+
58
+
59
+ def connect(*args, **kw):
60
+ return RPC(*args, **kw)
61
+
62
+
63
+ @dataclass
64
+ class Transaction:
65
+ _rpc: RPC
66
+ txid: int = None
67
+
68
+ def __enter__(self):
69
+ response = self._rpc.api.begin_transaction()
70
+ self.txid = int(response.headers['tabular-txid'])
71
+ log.debug("opened txid=%016x", self.txid)
72
+ return self
73
+
74
+ def __exit__(self, *args):
75
+ if args == (None, None, None):
76
+ log.debug("committing txid=%016x", self.txid)
77
+ self._rpc.api.commit_transaction(self.txid)
78
+ else:
79
+ log.exception("rolling back txid=%016x", self.txid)
80
+ self._rpc.api.rollback_transaction(self.txid)
81
+
82
+ def __repr__(self):
83
+ return f'Transaction(id=0x{self.txid:016x})'
84
+
85
+ def bucket(self, name: str) -> "Bucket":
86
+ try:
87
+ self._rpc.s3.head_bucket(Bucket=name)
88
+ return Bucket(name, self)
89
+ except botocore.exceptions.ClientError as e:
90
+ if e.response['Error']['Code'] == 403:
91
+ raise AccessDeniedError(f"Access is denied to bucket: {name}") from e
92
+ else:
93
+ raise NotFoundError(f"Bucket {name} does not exist") from e
94
+
95
+
96
+ @dataclass
10
97
  class Bucket:
11
- ctx: Context
12
- name: str
98
+ name: str
99
+ tx: Transaction
100
+
101
+ def create_schema(self, path: str) -> "Schema":
102
+ self.tx._rpc.api.create_schema(self.name, path, txid=self.tx.txid)
103
+ log.info("Created schema: %s", path)
104
+ return self.schema(path)
13
105
 
14
- def schema(name: str) -> Schema
106
+ def schema(self, path: str) -> "Schema":
107
+ schema = self.schemas(path)
108
+ log.debug("schema: %s", schema)
109
+ if not schema:
110
+ raise NotFoundError(f"Schema '{path}' was not found in bucket: {self.name}")
111
+ assert len(schema) == 1, f"Expected to receive only a single schema, but got: {len(schema)}. ({schema})"
112
+ log.debug("Found schema: %s", schema[0].name)
113
+ return schema[0]
15
114
 
115
+ def schemas(self, schema: str = None) -> ["Schema"]:
116
+ schemas = []
117
+ next_key = 0
118
+ exact_match = bool(schema)
119
+ log.debug("list schemas param: schema=%s, exact_match=%s", schema, exact_match)
120
+ while True:
121
+ bucket_name, curr_schemas, next_key, is_truncated, _ = \
122
+ self.tx._rpc.api.list_schemas(bucket=self.name, next_key=next_key, txid=self.tx.txid,
123
+ name_prefix=schema, exact_match=exact_match)
124
+ if not curr_schemas:
125
+ break
126
+ schemas.extend(curr_schemas)
127
+ if not is_truncated:
128
+ break
129
+
130
+ return [Schema(name=name, bucket=self) for name, *_ in schemas]
131
+
132
+
133
+ @dataclass
16
134
  class Schema:
17
- ctx: Context
18
- path: str
135
+ name: str
136
+ bucket: Bucket
137
+
138
+ @property
139
+ def tx(self):
140
+ return self.bucket.tx
141
+
142
+ def create_table(self, table_name: str, columns: pa.Schema) -> "Table":
143
+ self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid)
144
+ log.info("Created table: %s", table_name)
145
+ return self.table(table_name)
146
+
147
+ def table(self, name: str) -> "Table":
148
+ t = self.tables(table_name=name)
149
+ if not t:
150
+ raise NotFoundError(f"Table '{name}' was not found under schema: {self.name}")
151
+ assert len(t) == 1, f"Expected to receive only a single table, but got: {len(t)}. tables: {t}"
152
+ log.debug("Found table: %s", t[0])
153
+ return t[0]
154
+
155
+ def tables(self, table_name=None) -> ["Table"]:
156
+ tables = []
157
+ next_key = 0
158
+ name_prefix = table_name if table_name else ""
159
+ exact_match = bool(table_name)
160
+ while True:
161
+ bucket_name, schema_name, curr_tables, next_key, is_truncated, _ = \
162
+ self.tx._rpc.api.list_tables(
163
+ bucket=self.bucket.name, schema=self.name, next_key=next_key, txid=self.tx.txid,
164
+ exact_match=exact_match, name_prefix=name_prefix)
165
+ if not curr_tables:
166
+ break
167
+ tables.extend(curr_tables)
168
+ if not is_truncated:
169
+ break
170
+
171
+ return [_parse_table_info(table, self) for table in tables]
172
+
173
+ def drop(self) -> None:
174
+ self.tx._rpc.api.drop_schema(self.bucket.name, self.name, txid=self.tx.txid)
175
+ log.info("Dropped schema: %s", self.name)
176
+
177
+ def rename(self, new_name) -> None:
178
+ self.tx._rpc.api.alter_schema(self.bucket.name, self.name, txid=self.tx.txid, new_name=new_name)
179
+ log.info("Renamed schema: %s to %s", self.name, new_name)
180
+ self.name = new_name
181
+
182
+
183
+ @dataclass
184
+ class TableStats:
185
+ num_rows: int
186
+ size: int
187
+
188
+
189
+ @dataclass
190
+ class QueryConfig:
191
+ num_sub_splits: int = 4
192
+ num_splits: int = 1
193
+ data_endpoints: [str] = None
194
+ limit_per_sub_split: int = 128 * 1024
195
+ num_row_groups_per_sub_split: int = 8
19
196
 
20
- def schema(name: str) -> Schema
21
- def table(name: str) -> Table
22
197
 
198
+ @dataclass
23
199
  class Table:
24
- ctx: Context
25
- path: str
200
+ name: str
201
+ schema: pa.Schema
202
+ handle: int
203
+ stats: TableStats
204
+ properties: dict = None
205
+ arrow_schema: pa.Schema = field(init=False, compare=False)
206
+ _ibis_table: ibis.Schema = field(init=False, compare=False)
26
207
 
27
- def import_files(...)
28
- def import_partitioned_files(...)
29
- def select(...) -> ???
208
+ def __post_init__(self):
209
+ self.properties = self.properties or {}
210
+ self.arrow_schema = self.columns()
211
+ self._ibis_table = ibis.Schema.from_pyarrow(self.arrow_schema)
30
212
 
213
+ @property
214
+ def tx(self):
215
+ return self.schema.tx
31
216
 
32
- class RPC:
33
- """
34
- INTERNAL STUFF: actually uses requests to send/receive stuff
35
- Cannot do pagination
36
- """
217
+ @property
218
+ def bucket(self):
219
+ return self.schema.bucket
37
220
 
38
- ### We can just copy-paste stuff from api.py
221
+ def __repr__(self):
222
+ return f"{type(self).__name__}(name={self.name})"
39
223
 
40
- def single_shot_query_data()
41
- def single_shot_list_columns()
224
+ def columns(self) -> pa.Schema:
225
+ cols = self.tx._rpc.api._list_table_columns(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid)
226
+ self.arrow_schema = pa.schema([(col[0], col[1]) for col in cols])
227
+ return self.arrow_schema
42
228
 
43
- @contextmanager
44
- def context(access, secret, endpoint):
45
- rpc = RPC(access, secret, endpoint) # Low-level commands => the user should not use it
46
- tx = rpc.begin_transaction()
47
- try:
48
- yield Context(rpc, tx)
49
- finally:
50
- rpc.close_transaction(tx)
229
+ def import_files(self, files_to_import: [str]) -> None:
230
+ source_files = {}
231
+ for f in files_to_import:
232
+ bucket_name, object_path = _parse_bucket_and_object_names(f)
233
+ source_files[(bucket_name, object_path)] = b''
51
234
 
235
+ self._execute_import(source_files)
52
236
 
53
- with context(access, secret, endpoint) as ctx: # open/closes tx
54
- # tx keep-alive?
55
- b = ctx.bucket("buck") # may raise NotFoundError if bucket is missing
237
+ def import_partitioned_files(self, files_and_partitions: {str: pa.RecordBatch}) -> None:
238
+ source_files = {}
239
+ for f, record_batch in files_and_partitions.items():
240
+ bucket_name, object_path = _parse_bucket_and_object_names(f)
241
+ serialized_batch = _serialize_record_batch(record_batch)
242
+ source_files = {(bucket_name, object_path): serialized_batch.to_pybytes()}
56
243
 
57
- ctx._rpc.strange_thing???
244
+ self._execute_import(source_files)
58
245
 
59
- b.create_schema("s1")
60
- b.create_schema("s1/s2")
61
- b.create_schema("s1/s2/s3")
246
+ def _execute_import(self, source_files):
247
+ try:
248
+ self.tx._rpc.api.import_data(
249
+ self.bucket.name, self.schema.name, self.name, source_files, txid=self.tx.txid)
250
+ except requests.HTTPError as e:
251
+ raise ImportFilesError(f"import_files failed with status: {e.response.status_code}, reason: {e.response.reason}")
252
+ except Exception as e:
253
+ # TODO: investigate and raise proper error in case of failure mid import.
254
+ raise ImportFilesError("import_files failed") from e
62
255
 
63
- iterable_of_schema_objects = b.schemas() # BFS or only top-level?
256
+ def select(self, columns: [str], predicate: ibis.expr.types.BooleanColumn = None,
257
+ config: "QueryConfig" = None):
258
+ if config is None:
259
+ config = QueryConfig()
64
260
 
65
- s = b.schema("s1/s2/s3") # may raise NotFoundError if schema is missing
66
- s = b.schema("s1").schema("s2/s3") # may raise NotFoundError if schema is missing
67
- s = b / "s1" / "s2" / "s3" # may raise NotFoundError if schema is missing
261
+ api = self.tx._rpc.api
262
+ field_names = columns
263
+ filters = []
264
+ bucket = self.bucket.name
265
+ schema = self.schema.name
266
+ table = self.name
267
+ query_data_request = build_query_data_request(
268
+ schema=self.arrow_schema, filters=filters, field_names=field_names)
68
269
 
69
- assert s.schemas() == []
270
+ start_row_ids = {i: 0 for i in range(config.num_sub_splits)}
271
+ assert config.num_splits == 1 # TODO()
272
+ split = (0, 1, config.num_row_groups_per_sub_split)
273
+ response_row_id = False
70
274
 
71
- iterable_of_tables_objects = s.tables()
72
- t = s.table("t") # /bucket/s1/s2/s3/t under tx
275
+ while not all(row_id == TABULAR_INVALID_ROW_ID for row_id in start_row_ids.values()):
276
+ response = api.query_data(
277
+ bucket=bucket,
278
+ schema=schema,
279
+ table=table,
280
+ params=query_data_request.serialized,
281
+ split=split,
282
+ num_sub_splits=config.num_sub_splits,
283
+ response_row_id=response_row_id,
284
+ txid=self.tx.txid,
285
+ limit_rows=config.limit_per_sub_split,
286
+ sub_split_start_row_ids=start_row_ids.items())
73
287
 
74
- s.rename()
75
- s.drop()
76
- ...
288
+ pages_iter = parse_query_data_response(
289
+ conn=response.raw,
290
+ schema=query_data_request.response_schema,
291
+ start_row_ids=start_row_ids)
77
292
 
293
+ for page in pages_iter:
294
+ for batch in page.to_batches():
295
+ if len(batch) > 0:
296
+ yield batch
78
297
 
79
- # may take a while - finishes when all files are done
80
- # if all OK, return None
81
- # in case of error raise ImportFilesError(failed_files_list=[(path, code, reason)])
82
- t.import_files(["/buck1/file1", ... "/buck3/file3"])
83
- t.import_partitioned_files({"/buck1/file1": pa.RecordBatch, ... "/buck3/file3": pa.RecordBatch})
298
+ def insert(self, rows: pa.RecordBatch) -> None:
299
+ blob = serialize_record_batch(rows)
300
+ self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=blob, txid=self.tx.txid)
84
301
 
85
- arrow_schema = t.columns()
86
- iterable_of_record_batches = t.select(
87
- column_names: List[str],
88
- predicate: ibis.BooleanColumn???,
89
- limit: int = None,
90
- config: QueryConfig = None
91
- )
302
+ def drop(self) -> None:
303
+ self.tx._rpc.api.drop_table(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid)
304
+ log.info("Dropped table: %s", self.name)
92
305
 
306
+ def rename(self, new_name) -> None:
307
+ self.tx._rpc.api.alter_table(
308
+ self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, new_name=new_name)
309
+ log.info("Renamed table from %s to %s ", self.name, new_name)
310
+ self.name = new_name
93
311
 
94
- t.drop()
95
- t.rename()
96
- t.add_column()
97
- t.drop_column()
98
- ...
312
+ def add_column(self, new_column: pa.Schema) -> None:
313
+ self.tx._rpc.api.add_columns(self.bucket.name, self.schema.name, self.name, new_column, txid=self.tx.txid)
314
+ log.info("Added column(s): %s", new_column)
315
+ self.arrow_schema = self.columns()
99
316
 
317
+ def drop_column(self, column_to_drop: pa.Schema) -> None:
318
+ self.tx._rpc.api.drop_columns(self.bucket.name, self.schema.name, self.name, column_to_drop, txid=self.tx.txid)
319
+ log.info("Dropped column(s): %s", column_to_drop)
320
+ self.arrow_schema = self.columns()
100
321
 
101
- class QueryConfig:
102
- num_of_subsplits: int = 2
103
- num_of_splits: int = 16?
104
- # how to load balance between VIPs?
105
- # we need a new RPC to get the "data_enpoints" VIPs from VAST and then we can round-robin between them?
106
- # => @alon
107
- ##### list_of_data_endpoints: List[str] = None
108
- limit_per_sub_split: int = 128k
322
+ def rename_column(self, current_column_name: str, new_column_name: str) -> None:
323
+ self.tx._rpc.api.alter_column(self.bucket.name, self.schema.name, self.name, name=current_column_name,
324
+ new_name=new_column_name, txid=self.tx.txid)
325
+ log.info("Renamed column: %s to %s", current_column_name, new_column_name)
326
+ self.arrow_schema = self.columns()
327
+
328
+ def __getitem__(self, col_name):
329
+ return self._ibis_table[col_name]
330
+
331
+
332
+ def _parse_table_info(table_info, schema: "Schema"):
333
+ stats = TableStats(num_rows=table_info.num_rows, size=table_info.size_in_bytes)
334
+ return Table(name=table_info.name, schema=schema, handle=int(table_info.handle), stats=stats)
335
+
336
+
337
+ def _parse_bucket_and_object_names(path: str) -> (str, str):
338
+ if not path.startswith('/'):
339
+ raise InvalidArgumentError(f"Path {path} must start with a '/'")
340
+ components = path.split(os.path.sep)
341
+ bucket_name = components[1]
342
+ object_path = os.path.sep.join(components[2:])
343
+ return bucket_name, object_path
344
+
345
+
346
+ def _serialize_record_batch(record_batch: pa.RecordBatch) -> pa.lib.Buffer:
347
+ sink = pa.BufferOutputStream()
348
+ with pa.ipc.new_stream(sink, record_batch.schema) as writer:
349
+ writer.write(record_batch)
350
+ return sink.getvalue()
351
+
352
+
353
+ def _parse_endpoint(endpoint):
354
+ if ":" in endpoint:
355
+ endpoint, port = endpoint.split(":")
356
+ port = int(port)
357
+ else:
358
+ port = 80
359
+ log.debug("endpoint: %s, port: %d", endpoint, port)
360
+ return endpoint, port
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vastdb
3
- Version: 0.0.5.1
3
+ Version: 0.0.5.3
4
4
  Summary: VAST Data SDK
5
5
  Home-page: https://github.com/vast-data/vastdb_sdk
6
6
  Author: VAST DATA
@@ -14,7 +14,6 @@ Requires-Dist: pyarrow
14
14
  Requires-Dist: requests
15
15
  Requires-Dist: aws-requests-auth
16
16
  Requires-Dist: xmltodict
17
- Requires-Dist: protobuf (==3.19.6)
18
17
 
19
18
 
20
19
  `VastdbApi` is a Python based API designed for interacting with *VastDB* & *Vast Catalog*, enabling operations such as schema and table management, data querying, and transaction handling.
@@ -163,10 +163,18 @@ vast_protobuf/substrait/extensions/extensions_pb2.py,sha256=I_6c6nMmMaYvVtzF-5yc
163
163
  vast_protobuf/tabular/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
164
164
  vast_protobuf/tabular/rpc_pb2.py,sha256=7kW2WrA2sGk6WVbD83mc_cKkZ2MxoImSO5GOVz6NbbE,23776
165
165
  vastdb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
166
- vastdb/api.py,sha256=1AWblvumGOElc79AT7SJ0W9ofGhmng2ZzAK3OtWyaNU,135723
167
- vastdb/v2.py,sha256=0fLulaIQGlIbVNBBFGd6iwYPuGhaaJIHTiJORyio_YQ,2438
168
- vastdb-0.0.5.1.dist-info/LICENSE,sha256=obffan7LYrq7hLHNrY7vHcn2pKUTBUYXMKu-VOAvDxU,11333
169
- vastdb-0.0.5.1.dist-info/METADATA,sha256=-qCDf3o5nRkc4NHiqoAmEycmeWlw2tJswd_Sxsp-mL8,1404
170
- vastdb-0.0.5.1.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
171
- vastdb-0.0.5.1.dist-info/top_level.txt,sha256=34x_PO17U_yvzCKNMDpipTYsWMat2I0U3D4Df_lWwBM,34
172
- vastdb-0.0.5.1.dist-info/RECORD,,
166
+ vastdb/api.py,sha256=u5Cf01LeHGN7x_pcjnzfLV-lU485FGFCv7eTIKpSaB0,124883
167
+ vastdb/bench_scan.py,sha256=95O34oHS0UehX2ad4T2mok87CKszCFLCDZASMnZp77M,1208
168
+ vastdb/util.py,sha256=EF892Gbs08BxHVgG3FZ6QvhpKI2-eIL5bPzzrYE_Qd8,2905
169
+ vastdb/v2.py,sha256=gWZUnhSLEvtrXPxoTpTAwNuzU9qxrCaWKXmeNBpMrGY,12601
170
+ vastdb/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
171
+ vastdb/tests/conftest.py,sha256=dcYFJO0Riyn687qZTwcwKbwGieg6s4yZrVFrJAX-ylU,1461
172
+ vastdb/tests/test_create_table_from_parquets.py,sha256=dxykmvUR-vui6Z3qUvXPYJ9Nw6V_qcxKl4NDNQK4kiY,1963
173
+ vastdb/tests/test_sanity.py,sha256=7HmCjuOmtoYnuWiPjMP6m7sYQYop1_qRCzq2ZX0rKlc,2404
174
+ vastdb/tests/test_schemas.py,sha256=-nntn3ltBaaqSTsUvi-i9J0yr4TYvOTRyTNY039vEIk,1047
175
+ vastdb/tests/test_tables.py,sha256=KPe0ESVGWixecTSwQ8whzSF-NZrNVZ-Kv-C4Gz-OQnQ,1225
176
+ vastdb-0.0.5.3.dist-info/LICENSE,sha256=obffan7LYrq7hLHNrY7vHcn2pKUTBUYXMKu-VOAvDxU,11333
177
+ vastdb-0.0.5.3.dist-info/METADATA,sha256=Yd93AoZE5ZUhJUr0MhtfhcMaQUtSFZ1wbzc6vvEvclQ,1369
178
+ vastdb-0.0.5.3.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
179
+ vastdb-0.0.5.3.dist-info/top_level.txt,sha256=Vsj2MKtlhPg0J4so64slQtnwjhgoPmJgcG-6YcVAwVc,20
180
+ vastdb-0.0.5.3.dist-info/RECORD,,
@@ -1,3 +1,2 @@
1
1
  vast_flatbuf
2
- vast_protobuf
3
2
  vastdb