vastdb 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,13 @@
1
1
  import itertools
2
2
  import json
3
3
  import logging
4
- import math
5
4
  import re
6
5
  import struct
7
6
  import urllib.parse
8
7
  from collections import defaultdict, namedtuple
9
8
  from enum import Enum
10
9
  from ipaddress import IPv4Address, IPv6Address
11
- from typing import Iterator, Optional, Union
10
+ from typing import Any, Dict, Iterator, List, Optional, Union
12
11
 
13
12
  import flatbuffers
14
13
  import ibis
@@ -92,7 +91,7 @@ UINT64_MAX = 18446744073709551615
92
91
  TABULAR_KEEP_ALIVE_STREAM_ID = 0xFFFFFFFF
93
92
  TABULAR_QUERY_DATA_COMPLETED_STREAM_ID = 0xFFFFFFFF - 1
94
93
  TABULAR_QUERY_DATA_FAILED_STREAM_ID = 0xFFFFFFFF - 2
95
- TABULAR_INVALID_ROW_ID = 0xFFFFFFFFFFFF # (1<<48)-1
94
+ TABULAR_INVALID_ROW_ID = 0xFFFFFFFFFFFF # (1<<48)-1
96
95
  ESTORE_INVALID_EHANDLE = UINT64_MAX
97
96
  IMPORTED_OBJECTS_TABLE_NAME = "vastdb-imported-objects"
98
97
 
@@ -127,11 +126,11 @@ def get_unit_to_flatbuff_time_unit(type):
127
126
  }
128
127
  return unit_to_flatbuff_time_unit[type]
129
128
 
129
+
130
130
  class Predicate:
131
131
  def __init__(self, schema: 'pa.Schema', expr: ibis.expr.types.BooleanColumn):
132
132
  self.schema = schema
133
133
  self.expr = expr
134
- self.builder = None
135
134
 
136
135
  def get_field_indexes(self, field: 'pa.Field', field_name_per_index: list) -> None:
137
136
  field_name_per_index.append(field.name)
@@ -157,8 +156,8 @@ class Predicate:
157
156
  self._field_name_per_index = {field: index for index, field in enumerate(_field_name_per_index)}
158
157
  return self._field_name_per_index
159
158
 
160
- def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: list = None):
161
- if not field_names:
159
+ def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: Optional[List[str]] = None):
160
+ if field_names is None:
162
161
  field_names = self.field_name_per_index.keys()
163
162
  projection_fields = []
164
163
  for field_name in field_names:
@@ -172,12 +171,17 @@ class Predicate:
172
171
  return builder.EndVector()
173
172
 
174
173
  def serialize(self, builder: 'flatbuffers.builder.Builder'):
175
- from ibis.expr.operations.generic import IsNull, Literal, TableColumn
174
+ from ibis.expr.operations.generic import (
175
+ IsNull,
176
+ Literal,
177
+ TableColumn,
178
+ )
176
179
  from ibis.expr.operations.logical import (
177
180
  And,
178
181
  Equals,
179
182
  Greater,
180
183
  GreaterEqual,
184
+ InValues,
181
185
  Less,
182
186
  LessEqual,
183
187
  Not,
@@ -198,7 +202,7 @@ class Predicate:
198
202
  StringContains: self.build_match_substring,
199
203
  }
200
204
 
201
- positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
205
+ positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
202
206
 
203
207
  self.builder = builder
204
208
 
@@ -215,40 +219,54 @@ class Predicate:
215
219
  prev_field_name = None
216
220
  for inner_op in or_args:
217
221
  _logger.debug('inner_op %s', inner_op)
218
- builder_func = builder_map.get(type(inner_op))
222
+ op_type = type(inner_op)
223
+ builder_func: Any = builder_map.get(op_type)
219
224
  if not builder_func:
220
- raise NotImplementedError(inner_op.name)
225
+ if op_type == InValues:
226
+ builder_func = self.build_equal
227
+ else:
228
+ raise NotImplementedError(self.expr)
221
229
 
222
230
  if builder_func == self.build_is_null:
223
231
  column, = inner_op.args
224
- literal = None
232
+ literals = (None,)
225
233
  elif builder_func == self.build_is_not_null:
226
234
  not_arg, = inner_op.args
227
235
  # currently we only support not is_null, checking we really got is_null under the not:
228
236
  if not builder_map.get(type(not_arg)) == self.build_is_null:
229
- raise NotImplementedError(not_arg.args[0].name)
237
+ raise NotImplementedError(self.expr)
230
238
  column, = not_arg.args
231
- literal = None
239
+ literals = (None,)
232
240
  else:
233
- column, literal = inner_op.args
234
- if not isinstance(literal, Literal):
235
- raise NotImplementedError(inner_op.name)
241
+ column, arg = inner_op.args
242
+ if isinstance(arg, tuple):
243
+ literals = arg
244
+ else:
245
+ literals = (arg,)
246
+ for literal in literals:
247
+ if not isinstance(literal, Literal):
248
+ raise NotImplementedError(self.expr)
236
249
 
237
250
  if not isinstance(column, TableColumn):
238
- raise NotImplementedError(inner_op.name)
251
+ raise NotImplementedError(self.expr)
239
252
 
240
253
  field_name = column.name
241
254
  if prev_field_name is None:
242
255
  prev_field_name = field_name
243
256
  elif prev_field_name != field_name:
244
- raise NotImplementedError(op.name)
257
+ raise NotImplementedError(self.expr)
245
258
 
246
- args_offsets = [self.build_column(position=positions_map[field_name])]
247
- if literal:
248
- field = self.schema.field(field_name)
249
- args_offsets.append(self.build_literal(field=field, value=literal.value))
259
+ column_offset = self.build_column(position=positions_map[field_name])
260
+ field = self.schema.field(field_name)
261
+ for literal in literals:
262
+ args_offsets = [column_offset]
263
+ if literal is not None:
264
+ args_offsets.append(self.build_literal(field=field, value=literal.value))
250
265
 
251
- inner_offsets.append(builder_func(*args_offsets))
266
+ inner_offsets.append(builder_func(*args_offsets))
267
+
268
+ if not inner_offsets:
269
+ raise NotImplementedError(self.expr) # an empty OR is equivalent to a 'FALSE' literal
252
270
 
253
271
  domain_offset = self.build_or(inner_offsets)
254
272
  offsets.append(domain_offset)
@@ -270,20 +288,6 @@ class Predicate:
270
288
  fb_expression.AddImpl(self.builder, ref)
271
289
  return fb_expression.End(self.builder)
272
290
 
273
- def build_domain(self, column: int, field_name: str):
274
- offsets = []
275
- filters = self.filters[field_name]
276
- if not filters:
277
- return self.build_or([self.build_is_not_null(column)])
278
-
279
- field_name, *field_attrs = field_name.split('.')
280
- field = self.schema.field(field_name)
281
- for attr in field_attrs:
282
- field = field.type[attr]
283
- for filter_by_name in filters:
284
- offsets.append(self.build_range(column=column, field=field, filter_by_name=filter_by_name))
285
- return self.build_or(offsets)
286
-
287
291
  def rule_to_operator(self, raw_rule: str):
288
292
  operator_matcher = {
289
293
  'eq': self.build_equal,
@@ -339,6 +343,8 @@ class Predicate:
339
343
  return fb_expression.End(self.builder)
340
344
 
341
345
  def build_literal(self, field: pa.Field, value):
346
+ literal_type: Any
347
+
342
348
  if field.type.equals(pa.int64()):
343
349
  literal_type = fb_int64_lit
344
350
  literal_impl = LiteralImpl.Int64Literal
@@ -551,13 +557,20 @@ class Predicate:
551
557
  return self.build_function('match_substring', column, literal)
552
558
 
553
559
 
560
+ class FieldNodesState:
561
+ def __init__(self) -> None:
562
+ # will be set during by the parser (see below)
563
+ self.buffers: Dict[int, Any] = defaultdict(lambda: None) # a list of Arrow buffers (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout)
564
+ self.length: Dict[int, Any] = defaultdict(lambda: None) # each array must have it's length specified (https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.from_buffers)
565
+
566
+
554
567
  class FieldNode:
555
568
  """Helper class for representing nested Arrow fields and handling QueryData requests"""
556
569
  def __init__(self, field: pa.Field, index_iter, parent: Optional['FieldNode'] = None, debug: bool = False):
557
- self.index = next(index_iter) # we use DFS-first enumeration for communicating the column positions to VAST
570
+ self.index = next(index_iter) # we use DFS-first enumeration for communicating the column positions to VAST
558
571
  self.field = field
559
572
  self.type = field.type
560
- self.parent = parent # will be None if this is the top-level field
573
+ self.parent = parent # will be None if this is the top-level field
561
574
  self.debug = debug
562
575
  if isinstance(self.type, pa.StructType):
563
576
  self.children = [FieldNode(field, index_iter, parent=self) for field in self.type]
@@ -574,11 +587,7 @@ class FieldNode:
574
587
  field = pa.field('entries', pa.struct([self.type.key_field, self.type.item_field]))
575
588
  self.children = [FieldNode(field, index_iter, parent=self)]
576
589
  else:
577
- self.children = [] # for non-nested types
578
-
579
- # will be set during by the parser (see below)
580
- self.buffers = None # a list of Arrow buffers (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout)
581
- self.length = None # each array must have it's length specified (https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.from_buffers)
590
+ self.children = [] # for non-nested types
582
591
 
583
592
  def _iter_to_root(self) -> Iterator['FieldNode']:
584
593
  yield self
@@ -599,22 +608,14 @@ class FieldNode:
599
608
  for child in self.children:
600
609
  yield from child._iter_leaves()
601
610
 
602
- def _iter_leaves(self) -> Iterator['FieldNode']:
603
- """Generate only leaf nodes (i.e. columns having scalar types)."""
604
- if not self.children:
605
- yield self
606
- else:
607
- for child in self.children:
608
- yield from child._iter_leaves()
609
-
610
611
  def debug_log(self, level=0):
611
612
  """Recursively dump this node state to log."""
612
613
  bufs = self.buffers and [b and b.hex() for b in self.buffers]
613
- _logger.debug('%s%d: %s, bufs=%s, len=%s', ' '*level, self.index, self.field, bufs, self.length)
614
+ _logger.debug('%s%d: %s, bufs=%s, len=%s', ' ' * level, self.index, self.field, bufs, self.length)
614
615
  for child in self.children:
615
- child.debug_log(level=level+1)
616
+ child.debug_log(level=level + 1)
616
617
 
617
- def set(self, arr: pa.Array):
618
+ def set(self, arr: pa.Array, state: FieldNodesState):
618
619
  """
619
620
  Assign the relevant Arrow buffers from the received array into this node.
620
621
 
@@ -626,34 +627,39 @@ class FieldNode:
626
627
  For example, `Struct<A, B>` is sent as two separate columns: `Struct<A>` and `Struct<B>`.
627
628
  Also, `Map<K, V>` is sent (as its underlying representation): `List<Struct<K>>` and `List<Struct<V>>`
628
629
  """
629
- buffers = arr.buffers()[:arr.type.num_buffers] # slicing is needed because Array.buffers() returns also nested array buffers
630
+ buffers = arr.buffers()[:arr.type.num_buffers] # slicing is needed because Array.buffers() returns also nested array buffers
630
631
  if self.debug:
631
632
  _logger.debug("set: index=%d %s %s", self.index, self.field, [b and b.hex() for b in buffers])
632
- if self.buffers is None:
633
- self.buffers = buffers
634
- self.length = len(arr)
633
+ if state.buffers[self.index] is None:
634
+ state.buffers[self.index] = buffers
635
+ state.length[self.index] = len(arr)
635
636
  else:
636
637
  # Make sure subsequent assignments are consistent with each other
637
638
  if self.debug:
638
- if not self.buffers == buffers:
639
- raise ValueError(f'self.buffers: {self.buffers} are not equal with buffers: {buffers}')
640
- if not self.length == len(arr):
641
- raise ValueError(f'self.length: {self.length} are not equal with len(arr): {len(arr)}')
639
+ if not state.buffers[self.index] == buffers:
640
+ raise ValueError(f'self.buffers: {state.buffers[self.index]} are not equal with buffers: {buffers}')
641
+ if not state.length[self.index] == len(arr):
642
+ raise ValueError(f'self.length: {state.length[self.index]} are not equal with len(arr): {len(arr)}')
642
643
 
643
- def build(self) -> pa.Array:
644
+ def build(self, state: FieldNodesState) -> pa.Array:
644
645
  """Construct an Arrow array from the collected buffers (recursively)."""
645
- children = self.children and [node.build() for node in self.children]
646
- result = pa.Array.from_buffers(self.type, self.length, buffers=self.buffers, children=children)
646
+ children = self.children and [node.build(state) for node in self.children]
647
+ result = pa.Array.from_buffers(self.type, state.length[self.index], buffers=state.buffers[self.index], children=children)
647
648
  if self.debug:
648
649
  _logger.debug('%s result=%s', self.field, result)
649
650
  return result
650
651
 
651
652
 
652
653
  class QueryDataParser:
654
+ class QueryDataParserState(FieldNodesState):
655
+ def __init__(self) -> None:
656
+ super().__init__()
657
+ self.leaf_offset = 0
658
+
653
659
  """Used to parse VAST QueryData RPC response."""
654
660
  def __init__(self, arrow_schema: pa.Schema, *, debug=False):
655
661
  self.arrow_schema = arrow_schema
656
- index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
662
+ index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
657
663
  self.nodes = [FieldNode(field, index, debug=debug) for field in arrow_schema]
658
664
  self.debug = debug
659
665
  if self.debug:
@@ -661,14 +667,12 @@ class QueryDataParser:
661
667
  node.debug_log()
662
668
  self.leaves = [leaf for node in self.nodes for leaf in node._iter_leaves()]
663
669
 
664
- self.leaf_offset = 0
665
-
666
- def parse(self, column: pa.Array):
670
+ def parse(self, column: pa.Array, state: QueryDataParserState):
667
671
  """Parse a single column response from VAST (see FieldNode.set for details)"""
668
- if not self.leaf_offset < len(self.leaves):
669
- raise ValueError(f'self.leaf_offset: {self.leaf_offset} are not < '
672
+ if not state.leaf_offset < len(self.leaves):
673
+ raise ValueError(f'state.leaf_offset: {state.leaf_offset} are not < '
670
674
  f'than len(self.leaves): {len(self.leaves)}')
671
- leaf = self.leaves[self.leaf_offset]
675
+ leaf = self.leaves[state.leaf_offset]
672
676
 
673
677
  # A column response may be sent in multiple chunks, therefore we need to combine
674
678
  # it into a single chunk to allow reconstruction using `Array.from_buffers()`.
@@ -685,13 +689,13 @@ class QueryDataParser:
685
689
  raise ValueError(f'len(array_list): {len(array_list)} are not eq '
686
690
  f'with len(node_list): {len(node_list)}')
687
691
  for node, arr in zip(node_list, array_list):
688
- node.set(arr)
692
+ node.set(arr, state)
689
693
 
690
- self.leaf_offset += 1
694
+ state.leaf_offset += 1
691
695
 
692
- def build(self) -> Optional[pa.Table]:
696
+ def build(self, state: QueryDataParserState) -> Optional[pa.Table]:
693
697
  """Try to build the resulting Table object (if all columns were parsed)"""
694
- if self.leaf_offset < len(self.leaves):
698
+ if state.leaf_offset < len(self.leaves):
695
699
  return None
696
700
 
697
701
  if self.debug:
@@ -699,11 +703,12 @@ class QueryDataParser:
699
703
  node.debug_log()
700
704
 
701
705
  result = pa.Table.from_arrays(
702
- arrays=[node.build() for node in self.nodes],
706
+ arrays=[node.build(state) for node in self.nodes],
703
707
  schema=self.arrow_schema)
704
- result.validate(full=self.debug) # does expensive validation checks only if debug is enabled
708
+ result.validate(full=self.debug) # does expensive validation checks only if debug is enabled
705
709
  return result
706
710
 
711
+
707
712
  def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
708
713
  """Iterate over a single column response, and recursively generate all of its children."""
709
714
  yield column
@@ -715,7 +720,9 @@ def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
715
720
  yield from _iter_nested_arrays(column.values) # Note: Map is serialized in VAST as a List<Struct<K, V>>
716
721
 
717
722
 
718
- TableInfo = namedtuple('table_info', 'name properties handle num_rows size_in_bytes')
723
+ TableInfo = namedtuple('TableInfo', 'name properties handle num_rows size_in_bytes')
724
+
725
+
719
726
  def _parse_table_info(obj):
720
727
 
721
728
  name = obj.Name().decode()
@@ -725,75 +732,47 @@ def _parse_table_info(obj):
725
732
  used_bytes = obj.SizeInBytes()
726
733
  return TableInfo(name, properties, handle, num_rows, used_bytes)
727
734
 
728
- def build_record_batch(column_info, column_values):
729
- fields = [pa.field(column_name, column_type) for column_type, column_name in column_info]
730
- schema = pa.schema(fields)
731
- arrays = [pa.array(column_values[column_type], type=column_type) for column_type, _ in column_info]
732
- batch = pa.record_batch(arrays, schema)
733
- return serialize_record_batch(batch)
734
-
735
- def serialize_record_batch(batch):
736
- sink = pa.BufferOutputStream()
737
- with pa.ipc.new_stream(sink, batch.schema) as writer:
738
- writer.write(batch)
739
- return sink.getvalue()
740
735
 
741
736
  # Results that returns from tablestats
742
- TableStatsResult = namedtuple("TableStatsResult",["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
737
+
738
+
739
+ TableStatsResult = namedtuple("TableStatsResult", ["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
740
+
743
741
 
744
742
  class VastdbApi:
745
743
  # we expect the vast version to be <major>.<minor>.<patch>.<protocol>
746
744
  VAST_VERSION_REGEX = re.compile(r'^vast (\d+\.\d+\.\d+\.\d+)$')
747
745
 
748
- def __init__(self, endpoint, access_key, secret_key, username=None, password=None,
749
- secure=False, auth_type=AuthType.SIGV4):
750
- url_dict = urllib3.util.parse_url(endpoint)._asdict()
746
+ def __init__(self, endpoint, access_key, secret_key, auth_type=AuthType.SIGV4, ssl_verify=True):
747
+ url = urllib3.util.parse_url(endpoint)
751
748
  self.access_key = access_key
752
749
  self.secret_key = secret_key
753
- self.username = username
754
- self.password = password
755
- self.secure = secure
756
- self.auth_type = auth_type
757
- self.executor_hosts = [endpoint] # TODO: remove
758
-
759
- username = username or ''
760
- password = password or ''
761
- if not url_dict['port']:
762
- url_dict['port'] = 443 if secure else 80
763
-
764
- self.port = url_dict['port']
765
750
 
766
751
  self.default_max_list_columns_page_size = 1000
767
752
  self.session = requests.Session()
768
- self.session.verify = False
753
+ self.session.verify = ssl_verify
769
754
  self.session.headers['user-agent'] = "VastData Tabular API 1.0 - 2022 (c)"
770
- if auth_type == AuthType.BASIC:
771
- self.session.auth = requests.auth.HTTPBasicAuth(username, password)
772
- else:
773
- if url_dict['port'] != 80 and url_dict['port'] != 443:
774
- self.aws_host = '{host}:{port}'.format(**url_dict)
775
- else:
776
- self.aws_host = '{host}'.format(**url_dict)
777
-
778
- self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
779
- aws_secret_access_key=secret_key,
780
- aws_host=self.aws_host,
781
- aws_region='us-east-1',
782
- aws_service='s3')
783
755
 
784
- if not url_dict['scheme']:
785
- url_dict['scheme'] = "https" if secure else "http"
756
+ if url.port in {80, 443, None}:
757
+ self.aws_host = f'{url.host}'
758
+ else:
759
+ self.aws_host = f'{url.host}:{url.port}'
786
760
 
787
- url = urllib3.util.Url(**url_dict)
788
761
  self.url = str(url)
789
762
  _logger.debug('url=%s aws_host=%s', self.url, self.aws_host)
790
763
 
764
+ self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
765
+ aws_secret_access_key=secret_key,
766
+ aws_host=self.aws_host,
767
+ aws_region='us-east-1',
768
+ aws_service='s3')
769
+
791
770
  # probe the cluster for its version
792
771
  self.vast_version = None
793
- res = self.session.options(self.url)
772
+ res = self.session.get(self.url)
794
773
  server_header = res.headers.get("Server")
795
774
  if server_header is None:
796
- _logger.error("OPTIONS response doesn't contain 'Server' header")
775
+ _logger.error("Response doesn't contain 'Server' header")
797
776
  else:
798
777
  _logger.debug("Server header is '%s'", server_header)
799
778
  if m := self.VAST_VERSION_REGEX.match(server_header):
@@ -973,30 +952,30 @@ class VastdbApi:
973
952
 
974
953
  return bucket_name, schemas, next_key, is_truncated, count
975
954
 
976
- def list_snapshots(self, bucket, max_keys=1000, next_token=None, expected_retvals=None):
955
+ def list_snapshots(self, bucket, max_keys=1000, next_token=None, name_prefix=''):
977
956
  next_token = next_token or ''
978
- expected_retvals = expected_retvals or []
979
- url_params = {'list_type': '2', 'prefix': '.snapshot/', 'delimiter': '/', 'max_keys': str(max_keys)}
957
+ url_params = {'list_type': '2', 'prefix': '.snapshot/' + name_prefix, 'delimiter': '/', 'max_keys': str(max_keys)}
980
958
  if next_token:
981
959
  url_params['continuation-token'] = next_token
982
960
 
983
961
  res = self.session.get(self._api_prefix(bucket=bucket, command="list", url_params=url_params), headers={}, stream=True)
984
- self._check_res(res, "list_snapshots", expected_retvals)
985
- if res.status_code == 200:
986
- out = b''.join(res.iter_content(chunk_size=128))
987
- xml_str = out.decode()
988
- xml_dict = xmltodict.parse(xml_str)
989
- list_res = xml_dict['ListBucketResult']
990
- is_truncated = list_res['IsTruncated'] == 'true'
991
- marker = list_res['Marker']
992
- common_prefixes = list_res['CommonPrefixes'] if 'CommonPrefixes' in list_res else []
993
- snapshots = [v['Prefix'] for v in common_prefixes]
962
+ self._check_res(res, "list_snapshots")
994
963
 
995
- return snapshots, is_truncated, marker
964
+ out = b''.join(res.iter_content(chunk_size=128))
965
+ xml_str = out.decode()
966
+ xml_dict = xmltodict.parse(xml_str)
967
+ list_res = xml_dict['ListBucketResult']
968
+ is_truncated = list_res['IsTruncated'] == 'true'
969
+ marker = list_res['Marker']
970
+ common_prefixes = list_res.get('CommonPrefixes', [])
971
+ if isinstance(common_prefixes, dict): # in case there is a single snapshot
972
+ common_prefixes = [common_prefixes]
973
+ snapshots = [v['Prefix'] for v in common_prefixes]
996
974
 
975
+ return snapshots, is_truncated, marker
997
976
 
998
977
  def create_table(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[],
999
- topic_partitions=0, create_imports_table=False):
978
+ topic_partitions=0, create_imports_table=False, use_external_row_ids_allocation=False):
1000
979
  """
1001
980
  Create a table, use the following request
1002
981
  POST /bucket/schema/table?table HTTP/1.1
@@ -1017,6 +996,9 @@ class VastdbApi:
1017
996
 
1018
997
  serialized_schema = arrow_schema.serialize()
1019
998
  headers['Content-Length'] = str(len(serialized_schema))
999
+ if use_external_row_ids_allocation:
1000
+ headers['use-external-row-ids-alloc'] = str(use_external_row_ids_allocation)
1001
+
1020
1002
  url_params = {'topic_partitions': str(topic_partitions)} if topic_partitions else {}
1021
1003
  if create_imports_table:
1022
1004
  url_params['sub-table'] = IMPORTED_OBJECTS_TABLE_NAME
@@ -1033,8 +1015,8 @@ class VastdbApi:
1033
1015
  if parquet_path:
1034
1016
  parquet_ds = pq.ParquetDataset(parquet_path)
1035
1017
  elif parquet_bucket_name and parquet_object_name:
1036
- s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1037
- parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name,parquet_object_name]), filesystem=s3fs)
1018
+ s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1019
+ parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name, parquet_object_name]), filesystem=s3fs)
1038
1020
  else:
1039
1021
  raise RuntimeError(f'invalid params parquet_path={parquet_path} parquet_bucket_name={parquet_bucket_name} parquet_object_name={parquet_object_name}')
1040
1022
 
@@ -1049,8 +1031,7 @@ class VastdbApi:
1049
1031
  # create the table
1050
1032
  return self.create_table(bucket, schema, name, arrow_schema, txid, client_tags, expected_retvals)
1051
1033
 
1052
-
1053
- def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[]):
1034
+ def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[], imports_table_stats=False):
1054
1035
  """
1055
1036
  GET /mybucket/myschema/mytable?stats HTTP/1.1
1056
1037
  tabular-txid: TransactionId
@@ -1059,30 +1040,35 @@ class VastdbApi:
1059
1040
  The Command will return the statistics in flatbuf format
1060
1041
  """
1061
1042
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1062
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=name, command="stats"), headers=headers)
1063
- if res.status_code == 200:
1064
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1065
- stats = get_table_stats.GetRootAs(flatbuf)
1066
- num_rows = stats.NumRows()
1067
- size_in_bytes = stats.SizeInBytes()
1068
- is_external_rowid_alloc = stats.IsExternalRowidAlloc()
1069
- endpoints = []
1070
- if stats.VipsLength() == 0:
1071
- endpoints.append(self.url)
1072
- else:
1073
- ip_cls = IPv6Address if (stats.AddressType() == "ipv6") else IPv4Address
1074
- vips = [stats.Vips(i) for i in range(stats.VipsLength())]
1075
- ips = []
1076
- # extract the vips into list of IPs
1077
- for vip in vips:
1078
- start_ip = int(ip_cls(vip.StartAddress().decode()))
1079
- ips.extend(ip_cls(start_ip + i) for i in range(vip.AddressCount()))
1080
- for ip in ips:
1081
- prefix = "http" if not self.secure else "https"
1082
- endpoints.append(f"{prefix}://{str(ip)}:{self.port}")
1083
- return TableStatsResult(num_rows, size_in_bytes, is_external_rowid_alloc, endpoints)
1084
-
1085
- return self._check_res(res, "get_table_stats", expected_retvals)
1043
+ url_params = {'sub-table': IMPORTED_OBJECTS_TABLE_NAME} if imports_table_stats else {}
1044
+ res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=name, command="stats", url_params=url_params), headers=headers)
1045
+ self._check_res(res, "get_table_stats", expected_retvals)
1046
+
1047
+ flatbuf = b''.join(res.iter_content(chunk_size=128))
1048
+ stats = get_table_stats.GetRootAs(flatbuf)
1049
+ num_rows = stats.NumRows()
1050
+ size_in_bytes = stats.SizeInBytes()
1051
+ is_external_rowid_alloc = stats.IsExternalRowidAlloc()
1052
+ endpoints = []
1053
+ if stats.VipsLength() == 0:
1054
+ endpoints.append(self.url)
1055
+ else:
1056
+ url = urllib3.util.parse_url(self.url)
1057
+
1058
+ ip_cls = IPv6Address if (stats.AddressType() == "ipv6") else IPv4Address
1059
+ vips = [stats.Vips(i) for i in range(stats.VipsLength())]
1060
+ ips = []
1061
+ # extract the vips into list of IPs
1062
+ for vip in vips:
1063
+ start_ip = int(ip_cls(vip.StartAddress().decode()))
1064
+ ips.extend(ip_cls(start_ip + i) for i in range(vip.AddressCount()))
1065
+ # build a list of endpoint URLs, reusing schema and port (if specified when constructing the session).
1066
+ # it is assumed that the client can access the returned IPs (e.g. if they are part of the VIP pool).
1067
+ for ip in ips:
1068
+ d = url._asdict()
1069
+ d['host'] = str(ip)
1070
+ endpoints.append(str(urllib3.util.Url(**d)))
1071
+ return TableStatsResult(num_rows, size_in_bytes, is_external_rowid_alloc, tuple(endpoints))
1086
1072
 
1087
1073
  def alter_table(self, bucket, schema, name, txid=0, client_tags=[], table_properties="",
1088
1074
  new_name="", expected_retvals=[]):
@@ -1171,7 +1157,6 @@ class VastdbApi:
1171
1157
 
1172
1158
  return bucket_name, schema_name, tables, next_key, is_truncated, count
1173
1159
 
1174
-
1175
1160
  def add_columns(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[]):
1176
1161
  """
1177
1162
  Add a column to table, use the following request
@@ -1197,7 +1182,7 @@ class VastdbApi:
1197
1182
  return self._check_res(res, "add_columns", expected_retvals)
1198
1183
 
1199
1184
  def alter_column(self, bucket, schema, table, name, txid=0, client_tags=[], column_properties="",
1200
- new_name="", column_sep = ".", column_stats="", expected_retvals=[]):
1185
+ new_name="", column_sep=".", column_stats="", expected_retvals=[]):
1201
1186
  """
1202
1187
  PUT /bucket/schema/table?column&tabular-column-name=ColumnName&tabular-new-column-name=NewColumnName HTTP/1.1
1203
1188
  Content-Length: ContentLength
@@ -1226,7 +1211,7 @@ class VastdbApi:
1226
1211
  headers['tabular-column-sep'] = column_sep
1227
1212
  headers['Content-Length'] = str(len(alter_column_req))
1228
1213
 
1229
- url_params = {'tabular-column-name': name }
1214
+ url_params = {'tabular-column-name': name}
1230
1215
  if len(new_name):
1231
1216
  url_params['tabular-new-column-name'] = new_name
1232
1217
 
@@ -1544,11 +1529,18 @@ class VastdbApi:
1544
1529
  if response.status_code != 200:
1545
1530
  return response
1546
1531
 
1532
+ ALLOWED_IMPORT_STATES = {
1533
+ 'Success',
1534
+ 'TabularInProgress',
1535
+ 'TabularAlreadyImported',
1536
+ 'TabularImportNotStarted',
1537
+ }
1538
+
1547
1539
  chunk_size = 1024
1548
1540
  for chunk in response.iter_content(chunk_size=chunk_size):
1549
1541
  chunk_dict = json.loads(chunk)
1550
1542
  _logger.debug("import data chunk=%s, result: %s", chunk_dict, chunk_dict['res'])
1551
- if chunk_dict['res'] != 'Success' and chunk_dict['res'] != 'TabularInProgress' and chunk_dict['res'] != 'TabularAlreadyImported':
1543
+ if chunk_dict['res'] not in ALLOWED_IMPORT_STATES:
1552
1544
  raise errors.ImportFilesError(
1553
1545
  f"Encountered an error during import_data. status: {chunk_dict['res']}, "
1554
1546
  f"error message: {chunk_dict['err_msg'] or 'Unexpected error'} during import of "
@@ -1572,48 +1564,6 @@ class VastdbApi:
1572
1564
 
1573
1565
  return self._check_res(res, "import_data", expected_retvals)
1574
1566
 
1575
- def _record_batch_slices(self, batch, rows_per_slice=None):
1576
- max_slice_size_in_bytes = int(0.9*5*1024*1024) # 0.9 * 5MB
1577
- batch_len = len(batch)
1578
- serialized_batch = serialize_record_batch(batch)
1579
- batch_size_in_bytes = len(serialized_batch)
1580
- _logger.debug('max_slice_size_in_bytes=%d batch_len=%d batch_size_in_bytes=%d',
1581
- max_slice_size_in_bytes, batch_len, batch_size_in_bytes)
1582
-
1583
- if not rows_per_slice:
1584
- if batch_size_in_bytes < max_slice_size_in_bytes:
1585
- rows_per_slice = batch_len
1586
- else:
1587
- rows_per_slice = int(0.9 * batch_len * max_slice_size_in_bytes / batch_size_in_bytes)
1588
-
1589
- done_slicing = False
1590
- while not done_slicing:
1591
- # Attempt slicing according to the current rows_per_slice
1592
- offset = 0
1593
- serialized_slices = []
1594
- for i in range(math.ceil(batch_len/rows_per_slice)):
1595
- offset = rows_per_slice * i
1596
- if offset >= batch_len:
1597
- done_slicing=True
1598
- break
1599
- slice_batch = batch.slice(offset, rows_per_slice)
1600
- serialized_slice_batch = serialize_record_batch(slice_batch)
1601
- sizeof_serialized_slice_batch = len(serialized_slice_batch)
1602
-
1603
- if sizeof_serialized_slice_batch <= max_slice_size_in_bytes:
1604
- serialized_slices.append(serialized_slice_batch)
1605
- else:
1606
- _logger.info(f'Using rows_per_slice {rows_per_slice} slice {i} size {sizeof_serialized_slice_batch} exceeds {max_slice_size_in_bytes} bytes, trying smaller rows_per_slice')
1607
- # We have a slice that is too large
1608
- rows_per_slice = int(rows_per_slice/2)
1609
- if rows_per_slice < 1:
1610
- raise ValueError('cannot decrease batch size below 1 row')
1611
- break
1612
- else:
1613
- done_slicing = True
1614
-
1615
- return serialized_slices
1616
-
1617
1567
  def insert_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1618
1568
  """
1619
1569
  POST /mybucket/myschema/mytable?rows HTTP/1.1
@@ -1628,7 +1578,8 @@ class VastdbApi:
1628
1578
  headers['Content-Length'] = str(len(record_batch))
1629
1579
  res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1630
1580
  data=record_batch, headers=headers, stream=True)
1631
- return self._check_res(res, "insert_rows", expected_retvals)
1581
+ self._check_res(res, "insert_rows", expected_retvals)
1582
+ res.raw.read() # flush the response
1632
1583
 
1633
1584
  def update_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1634
1585
  """
@@ -1644,7 +1595,7 @@ class VastdbApi:
1644
1595
  headers['Content-Length'] = str(len(record_batch))
1645
1596
  res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1646
1597
  data=record_batch, headers=headers)
1647
- return self._check_res(res, "update_rows", expected_retvals)
1598
+ self._check_res(res, "update_rows", expected_retvals)
1648
1599
 
1649
1600
  def delete_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[],
1650
1601
  delete_from_imports_table=False):
@@ -1663,7 +1614,7 @@ class VastdbApi:
1663
1614
 
1664
1615
  res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows", url_params=url_params),
1665
1616
  data=record_batch, headers=headers)
1666
- return self._check_res(res, "delete_rows", expected_retvals)
1617
+ self._check_res(res, "delete_rows", expected_retvals)
1667
1618
 
1668
1619
  def create_projection(self, bucket, schema, table, name, columns, txid=0, client_tags=[], expected_retvals=[]):
1669
1620
  """
@@ -1873,6 +1824,10 @@ class VastdbApi:
1873
1824
  return columns, next_key, is_truncated, count
1874
1825
 
1875
1826
 
1827
+ class QueryDataInternalError(Exception):
1828
+ pass
1829
+
1830
+
1876
1831
  def _iter_query_data_response_columns(fileobj, stream_ids=None):
1877
1832
  readers = {} # {stream_id: pa.ipc.RecordBatchStreamReader}
1878
1833
  while True:
@@ -1897,8 +1852,8 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1897
1852
  if stream_id == TABULAR_QUERY_DATA_FAILED_STREAM_ID:
1898
1853
  # read the terminating end chunk from socket
1899
1854
  res = fileobj.read()
1900
- _logger.warning("stream_id=%d res=%s (failed)", stream_id, res)
1901
- raise IOError(f"Query data stream failed res={res}")
1855
+ _logger.debug("stream_id=%d res=%s (failed)", stream_id, res)
1856
+ raise QueryDataInternalError() # connection closed by server due to an internal error
1902
1857
 
1903
1858
  next_row_id_bytes = fileobj.read(8)
1904
1859
  next_row_id, = struct.unpack('<Q', next_row_id_bytes)
@@ -1913,7 +1868,7 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1913
1868
 
1914
1869
  (reader, batches) = readers[stream_id]
1915
1870
  try:
1916
- batch = reader.read_next_batch() # read single-column chunk data
1871
+ batch = reader.read_next_batch() # read single-column chunk data
1917
1872
  _logger.debug("stream_id=%d rows=%d chunk=%s", stream_id, len(batch), batch)
1918
1873
  batches.append(batch)
1919
1874
  except StopIteration: # we got an end-of-stream IPC message for a given stream ID
@@ -1923,7 +1878,7 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1923
1878
  yield (stream_id, next_row_id, table)
1924
1879
 
1925
1880
 
1926
- def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None, debug=False):
1881
+ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None, debug=False, parser: Optional[QueryDataParser] = None):
1927
1882
  """
1928
1883
  Generates pyarrow.Table objects from QueryData API response stream.
1929
1884
 
@@ -1933,16 +1888,18 @@ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None,
1933
1888
  start_row_ids = {}
1934
1889
 
1935
1890
  is_empty_projection = (len(schema) == 0)
1936
- parsers = defaultdict(lambda: QueryDataParser(schema, debug=debug)) # {stream_id: QueryDataParser}
1891
+ if parser is None:
1892
+ parser = QueryDataParser(schema, debug=debug)
1893
+ states: Dict[int, QueryDataParser.QueryDataParserState] = defaultdict(lambda: QueryDataParser.QueryDataParserState()) # {stream_id: QueryDataParser}
1937
1894
 
1938
1895
  for stream_id, next_row_id, table in _iter_query_data_response_columns(conn, stream_ids):
1939
- parser = parsers[stream_id]
1896
+ state = states[stream_id]
1940
1897
  for column in table.columns:
1941
- parser.parse(column)
1898
+ parser.parse(column, state)
1942
1899
 
1943
- parsed_table = parser.build()
1900
+ parsed_table = parser.build(state)
1944
1901
  if parsed_table is not None: # when we got all columns (and before starting a new "select_rows" cycle)
1945
- parsers.pop(stream_id)
1902
+ states.pop(stream_id)
1946
1903
  if is_empty_projection: # VAST returns an empty RecordBatch, with the correct rows' count
1947
1904
  parsed_table = table
1948
1905
 
@@ -1951,8 +1908,9 @@ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None,
1951
1908
  start_row_ids[stream_id] = next_row_id
1952
1909
  yield parsed_table # the result of a single "select_rows()" cycle
1953
1910
 
1954
- if parsers:
1955
- raise EOFError(f'all streams should be done before EOF. {parsers}')
1911
+ if states:
1912
+ raise EOFError(f'all streams should be done before EOF. {states}')
1913
+
1956
1914
 
1957
1915
  def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
1958
1916
  if field.type.equals(pa.int64()):
@@ -2095,6 +2053,7 @@ def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
2095
2053
 
2096
2054
  return field_type, field_type_type
2097
2055
 
2056
+
2098
2057
  def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2099
2058
  children = None
2100
2059
  if isinstance(f.type, pa.StructType):
@@ -2142,12 +2101,13 @@ def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2142
2101
 
2143
2102
 
2144
2103
  class QueryDataRequest:
2145
- def __init__(self, serialized, response_schema):
2104
+ def __init__(self, serialized, response_schema, response_parser):
2146
2105
  self.serialized = serialized
2147
2106
  self.response_schema = response_schema
2107
+ self.response_parser = response_parser
2148
2108
 
2149
2109
 
2150
- def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: list = None):
2110
+ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: Optional[List[str]] = None):
2151
2111
  builder = flatbuffers.Builder(1024)
2152
2112
 
2153
2113
  source_name = builder.CreateString('') # required
@@ -2201,7 +2161,8 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2201
2161
  relation = fb_relation.End(builder)
2202
2162
 
2203
2163
  builder.Finish(relation)
2204
- return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema)
2164
+
2165
+ return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema, response_parser=QueryDataParser(response_schema))
2205
2166
 
2206
2167
 
2207
2168
  def convert_column_types(table: 'pa.Table') -> 'pa.Table':