vastdb 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ import urllib.parse
8
8
  from collections import defaultdict, namedtuple
9
9
  from enum import Enum
10
10
  from ipaddress import IPv4Address, IPv6Address
11
- from typing import Iterator, Optional, Union
11
+ from typing import Any, Dict, Iterator, List, Optional, Union
12
12
 
13
13
  import flatbuffers
14
14
  import ibis
@@ -92,7 +92,7 @@ UINT64_MAX = 18446744073709551615
92
92
  TABULAR_KEEP_ALIVE_STREAM_ID = 0xFFFFFFFF
93
93
  TABULAR_QUERY_DATA_COMPLETED_STREAM_ID = 0xFFFFFFFF - 1
94
94
  TABULAR_QUERY_DATA_FAILED_STREAM_ID = 0xFFFFFFFF - 2
95
- TABULAR_INVALID_ROW_ID = 0xFFFFFFFFFFFF # (1<<48)-1
95
+ TABULAR_INVALID_ROW_ID = 0xFFFFFFFFFFFF # (1<<48)-1
96
96
  ESTORE_INVALID_EHANDLE = UINT64_MAX
97
97
  IMPORTED_OBJECTS_TABLE_NAME = "vastdb-imported-objects"
98
98
 
@@ -127,11 +127,11 @@ def get_unit_to_flatbuff_time_unit(type):
127
127
  }
128
128
  return unit_to_flatbuff_time_unit[type]
129
129
 
130
+
130
131
  class Predicate:
131
132
  def __init__(self, schema: 'pa.Schema', expr: ibis.expr.types.BooleanColumn):
132
133
  self.schema = schema
133
134
  self.expr = expr
134
- self.builder = None
135
135
 
136
136
  def get_field_indexes(self, field: 'pa.Field', field_name_per_index: list) -> None:
137
137
  field_name_per_index.append(field.name)
@@ -157,8 +157,8 @@ class Predicate:
157
157
  self._field_name_per_index = {field: index for index, field in enumerate(_field_name_per_index)}
158
158
  return self._field_name_per_index
159
159
 
160
- def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: list = None):
161
- if not field_names:
160
+ def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: Optional[List[str]] = None):
161
+ if field_names is None:
162
162
  field_names = self.field_name_per_index.keys()
163
163
  projection_fields = []
164
164
  for field_name in field_names:
@@ -172,7 +172,11 @@ class Predicate:
172
172
  return builder.EndVector()
173
173
 
174
174
  def serialize(self, builder: 'flatbuffers.builder.Builder'):
175
- from ibis.expr.operations.generic import IsNull, Literal, TableColumn
175
+ from ibis.expr.operations.generic import (
176
+ IsNull,
177
+ Literal,
178
+ TableColumn,
179
+ )
176
180
  from ibis.expr.operations.logical import (
177
181
  And,
178
182
  Equals,
@@ -198,7 +202,7 @@ class Predicate:
198
202
  StringContains: self.build_match_substring,
199
203
  }
200
204
 
201
- positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
205
+ positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
202
206
 
203
207
  self.builder = builder
204
208
 
@@ -215,7 +219,7 @@ class Predicate:
215
219
  prev_field_name = None
216
220
  for inner_op in or_args:
217
221
  _logger.debug('inner_op %s', inner_op)
218
- builder_func = builder_map.get(type(inner_op))
222
+ builder_func: Any = builder_map.get(type(inner_op))
219
223
  if not builder_func:
220
224
  raise NotImplementedError(inner_op.name)
221
225
 
@@ -270,20 +274,6 @@ class Predicate:
270
274
  fb_expression.AddImpl(self.builder, ref)
271
275
  return fb_expression.End(self.builder)
272
276
 
273
- def build_domain(self, column: int, field_name: str):
274
- offsets = []
275
- filters = self.filters[field_name]
276
- if not filters:
277
- return self.build_or([self.build_is_not_null(column)])
278
-
279
- field_name, *field_attrs = field_name.split('.')
280
- field = self.schema.field(field_name)
281
- for attr in field_attrs:
282
- field = field.type[attr]
283
- for filter_by_name in filters:
284
- offsets.append(self.build_range(column=column, field=field, filter_by_name=filter_by_name))
285
- return self.build_or(offsets)
286
-
287
277
  def rule_to_operator(self, raw_rule: str):
288
278
  operator_matcher = {
289
279
  'eq': self.build_equal,
@@ -339,6 +329,8 @@ class Predicate:
339
329
  return fb_expression.End(self.builder)
340
330
 
341
331
  def build_literal(self, field: pa.Field, value):
332
+ literal_type: Any
333
+
342
334
  if field.type.equals(pa.int64()):
343
335
  literal_type = fb_int64_lit
344
336
  literal_impl = LiteralImpl.Int64Literal
@@ -551,13 +543,20 @@ class Predicate:
551
543
  return self.build_function('match_substring', column, literal)
552
544
 
553
545
 
546
+ class FieldNodesState:
547
+ def __init__(self) -> None:
548
+ # will be set during by the parser (see below)
549
+ self.buffers: Dict[int, Any] = defaultdict(lambda: None) # a list of Arrow buffers (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout)
550
+ self.length: Dict[int, Any] = defaultdict(lambda: None) # each array must have it's length specified (https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.from_buffers)
551
+
552
+
554
553
  class FieldNode:
555
554
  """Helper class for representing nested Arrow fields and handling QueryData requests"""
556
555
  def __init__(self, field: pa.Field, index_iter, parent: Optional['FieldNode'] = None, debug: bool = False):
557
- self.index = next(index_iter) # we use DFS-first enumeration for communicating the column positions to VAST
556
+ self.index = next(index_iter) # we use DFS-first enumeration for communicating the column positions to VAST
558
557
  self.field = field
559
558
  self.type = field.type
560
- self.parent = parent # will be None if this is the top-level field
559
+ self.parent = parent # will be None if this is the top-level field
561
560
  self.debug = debug
562
561
  if isinstance(self.type, pa.StructType):
563
562
  self.children = [FieldNode(field, index_iter, parent=self) for field in self.type]
@@ -574,11 +573,7 @@ class FieldNode:
574
573
  field = pa.field('entries', pa.struct([self.type.key_field, self.type.item_field]))
575
574
  self.children = [FieldNode(field, index_iter, parent=self)]
576
575
  else:
577
- self.children = [] # for non-nested types
578
-
579
- # will be set during by the parser (see below)
580
- self.buffers = None # a list of Arrow buffers (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout)
581
- self.length = None # each array must have it's length specified (https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.from_buffers)
576
+ self.children = [] # for non-nested types
582
577
 
583
578
  def _iter_to_root(self) -> Iterator['FieldNode']:
584
579
  yield self
@@ -599,22 +594,14 @@ class FieldNode:
599
594
  for child in self.children:
600
595
  yield from child._iter_leaves()
601
596
 
602
- def _iter_leaves(self) -> Iterator['FieldNode']:
603
- """Generate only leaf nodes (i.e. columns having scalar types)."""
604
- if not self.children:
605
- yield self
606
- else:
607
- for child in self.children:
608
- yield from child._iter_leaves()
609
-
610
597
  def debug_log(self, level=0):
611
598
  """Recursively dump this node state to log."""
612
599
  bufs = self.buffers and [b and b.hex() for b in self.buffers]
613
- _logger.debug('%s%d: %s, bufs=%s, len=%s', ' '*level, self.index, self.field, bufs, self.length)
600
+ _logger.debug('%s%d: %s, bufs=%s, len=%s', ' ' * level, self.index, self.field, bufs, self.length)
614
601
  for child in self.children:
615
- child.debug_log(level=level+1)
602
+ child.debug_log(level=level + 1)
616
603
 
617
- def set(self, arr: pa.Array):
604
+ def set(self, arr: pa.Array, state: FieldNodesState):
618
605
  """
619
606
  Assign the relevant Arrow buffers from the received array into this node.
620
607
 
@@ -626,34 +613,39 @@ class FieldNode:
626
613
  For example, `Struct<A, B>` is sent as two separate columns: `Struct<A>` and `Struct<B>`.
627
614
  Also, `Map<K, V>` is sent (as its underlying representation): `List<Struct<K>>` and `List<Struct<V>>`
628
615
  """
629
- buffers = arr.buffers()[:arr.type.num_buffers] # slicing is needed because Array.buffers() returns also nested array buffers
616
+ buffers = arr.buffers()[:arr.type.num_buffers] # slicing is needed because Array.buffers() returns also nested array buffers
630
617
  if self.debug:
631
618
  _logger.debug("set: index=%d %s %s", self.index, self.field, [b and b.hex() for b in buffers])
632
- if self.buffers is None:
633
- self.buffers = buffers
634
- self.length = len(arr)
619
+ if state.buffers[self.index] is None:
620
+ state.buffers[self.index] = buffers
621
+ state.length[self.index] = len(arr)
635
622
  else:
636
623
  # Make sure subsequent assignments are consistent with each other
637
624
  if self.debug:
638
- if not self.buffers == buffers:
639
- raise ValueError(f'self.buffers: {self.buffers} are not equal with buffers: {buffers}')
640
- if not self.length == len(arr):
641
- raise ValueError(f'self.length: {self.length} are not equal with len(arr): {len(arr)}')
625
+ if not state.buffers[self.index] == buffers:
626
+ raise ValueError(f'self.buffers: {state.buffers[self.index]} are not equal with buffers: {buffers}')
627
+ if not state.length[self.index] == len(arr):
628
+ raise ValueError(f'self.length: {state.length[self.index]} are not equal with len(arr): {len(arr)}')
642
629
 
643
- def build(self) -> pa.Array:
630
+ def build(self, state: FieldNodesState) -> pa.Array:
644
631
  """Construct an Arrow array from the collected buffers (recursively)."""
645
- children = self.children and [node.build() for node in self.children]
646
- result = pa.Array.from_buffers(self.type, self.length, buffers=self.buffers, children=children)
632
+ children = self.children and [node.build(state) for node in self.children]
633
+ result = pa.Array.from_buffers(self.type, state.length[self.index], buffers=state.buffers[self.index], children=children)
647
634
  if self.debug:
648
635
  _logger.debug('%s result=%s', self.field, result)
649
636
  return result
650
637
 
651
638
 
652
639
  class QueryDataParser:
640
+ class QueryDataParserState(FieldNodesState):
641
+ def __init__(self) -> None:
642
+ super().__init__()
643
+ self.leaf_offset = 0
644
+
653
645
  """Used to parse VAST QueryData RPC response."""
654
646
  def __init__(self, arrow_schema: pa.Schema, *, debug=False):
655
647
  self.arrow_schema = arrow_schema
656
- index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
648
+ index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
657
649
  self.nodes = [FieldNode(field, index, debug=debug) for field in arrow_schema]
658
650
  self.debug = debug
659
651
  if self.debug:
@@ -661,14 +653,12 @@ class QueryDataParser:
661
653
  node.debug_log()
662
654
  self.leaves = [leaf for node in self.nodes for leaf in node._iter_leaves()]
663
655
 
664
- self.leaf_offset = 0
665
-
666
- def parse(self, column: pa.Array):
656
+ def parse(self, column: pa.Array, state: QueryDataParserState):
667
657
  """Parse a single column response from VAST (see FieldNode.set for details)"""
668
- if not self.leaf_offset < len(self.leaves):
669
- raise ValueError(f'self.leaf_offset: {self.leaf_offset} are not < '
658
+ if not state.leaf_offset < len(self.leaves):
659
+ raise ValueError(f'state.leaf_offset: {state.leaf_offset} are not < '
670
660
  f'than len(self.leaves): {len(self.leaves)}')
671
- leaf = self.leaves[self.leaf_offset]
661
+ leaf = self.leaves[state.leaf_offset]
672
662
 
673
663
  # A column response may be sent in multiple chunks, therefore we need to combine
674
664
  # it into a single chunk to allow reconstruction using `Array.from_buffers()`.
@@ -685,13 +675,13 @@ class QueryDataParser:
685
675
  raise ValueError(f'len(array_list): {len(array_list)} are not eq '
686
676
  f'with len(node_list): {len(node_list)}')
687
677
  for node, arr in zip(node_list, array_list):
688
- node.set(arr)
678
+ node.set(arr, state)
689
679
 
690
- self.leaf_offset += 1
680
+ state.leaf_offset += 1
691
681
 
692
- def build(self) -> Optional[pa.Table]:
682
+ def build(self, state: QueryDataParserState) -> Optional[pa.Table]:
693
683
  """Try to build the resulting Table object (if all columns were parsed)"""
694
- if self.leaf_offset < len(self.leaves):
684
+ if state.leaf_offset < len(self.leaves):
695
685
  return None
696
686
 
697
687
  if self.debug:
@@ -699,11 +689,12 @@ class QueryDataParser:
699
689
  node.debug_log()
700
690
 
701
691
  result = pa.Table.from_arrays(
702
- arrays=[node.build() for node in self.nodes],
692
+ arrays=[node.build(state) for node in self.nodes],
703
693
  schema=self.arrow_schema)
704
- result.validate(full=self.debug) # does expensive validation checks only if debug is enabled
694
+ result.validate(full=self.debug) # does expensive validation checks only if debug is enabled
705
695
  return result
706
696
 
697
+
707
698
  def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
708
699
  """Iterate over a single column response, and recursively generate all of its children."""
709
700
  yield column
@@ -715,7 +706,9 @@ def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
715
706
  yield from _iter_nested_arrays(column.values) # Note: Map is serialized in VAST as a List<Struct<K, V>>
716
707
 
717
708
 
718
- TableInfo = namedtuple('table_info', 'name properties handle num_rows size_in_bytes')
709
+ TableInfo = namedtuple('TableInfo', 'name properties handle num_rows size_in_bytes')
710
+
711
+
719
712
  def _parse_table_info(obj):
720
713
 
721
714
  name = obj.Name().decode()
@@ -725,6 +718,7 @@ def _parse_table_info(obj):
725
718
  used_bytes = obj.SizeInBytes()
726
719
  return TableInfo(name, properties, handle, num_rows, used_bytes)
727
720
 
721
+
728
722
  def build_record_batch(column_info, column_values):
729
723
  fields = [pa.field(column_name, column_type) for column_type, column_name in column_info]
730
724
  schema = pa.schema(fields)
@@ -732,6 +726,7 @@ def build_record_batch(column_info, column_values):
732
726
  batch = pa.record_batch(arrays, schema)
733
727
  return serialize_record_batch(batch)
734
728
 
729
+
735
730
  def serialize_record_batch(batch):
736
731
  sink = pa.BufferOutputStream()
737
732
  with pa.ipc.new_stream(sink, batch.schema) as writer:
@@ -739,61 +734,45 @@ def serialize_record_batch(batch):
739
734
  return sink.getvalue()
740
735
 
741
736
  # Results that returns from tablestats
742
- TableStatsResult = namedtuple("TableStatsResult",["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
737
+
738
+
739
+ TableStatsResult = namedtuple("TableStatsResult", ["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
740
+
743
741
 
744
742
  class VastdbApi:
745
743
  # we expect the vast version to be <major>.<minor>.<patch>.<protocol>
746
744
  VAST_VERSION_REGEX = re.compile(r'^vast (\d+\.\d+\.\d+\.\d+)$')
747
745
 
748
- def __init__(self, endpoint, access_key, secret_key, username=None, password=None,
749
- secure=False, auth_type=AuthType.SIGV4):
750
- url_dict = urllib3.util.parse_url(endpoint)._asdict()
746
+ def __init__(self, endpoint, access_key, secret_key, auth_type=AuthType.SIGV4, ssl_verify=True):
747
+ url = urllib3.util.parse_url(endpoint)
751
748
  self.access_key = access_key
752
749
  self.secret_key = secret_key
753
- self.username = username
754
- self.password = password
755
- self.secure = secure
756
- self.auth_type = auth_type
757
- self.executor_hosts = [endpoint] # TODO: remove
758
-
759
- username = username or ''
760
- password = password or ''
761
- if not url_dict['port']:
762
- url_dict['port'] = 443 if secure else 80
763
-
764
- self.port = url_dict['port']
765
750
 
766
751
  self.default_max_list_columns_page_size = 1000
767
752
  self.session = requests.Session()
768
- self.session.verify = False
753
+ self.session.verify = ssl_verify
769
754
  self.session.headers['user-agent'] = "VastData Tabular API 1.0 - 2022 (c)"
770
- if auth_type == AuthType.BASIC:
771
- self.session.auth = requests.auth.HTTPBasicAuth(username, password)
772
- else:
773
- if url_dict['port'] != 80 and url_dict['port'] != 443:
774
- self.aws_host = '{host}:{port}'.format(**url_dict)
775
- else:
776
- self.aws_host = '{host}'.format(**url_dict)
777
-
778
- self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
779
- aws_secret_access_key=secret_key,
780
- aws_host=self.aws_host,
781
- aws_region='us-east-1',
782
- aws_service='s3')
783
755
 
784
- if not url_dict['scheme']:
785
- url_dict['scheme'] = "https" if secure else "http"
756
+ if url.port in {80, 443, None}:
757
+ self.aws_host = f'{url.host}'
758
+ else:
759
+ self.aws_host = f'{url.host}:{url.port}'
786
760
 
787
- url = urllib3.util.Url(**url_dict)
788
761
  self.url = str(url)
789
762
  _logger.debug('url=%s aws_host=%s', self.url, self.aws_host)
790
763
 
764
+ self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
765
+ aws_secret_access_key=secret_key,
766
+ aws_host=self.aws_host,
767
+ aws_region='us-east-1',
768
+ aws_service='s3')
769
+
791
770
  # probe the cluster for its version
792
771
  self.vast_version = None
793
- res = self.session.options(self.url)
772
+ res = self.session.get(self.url)
794
773
  server_header = res.headers.get("Server")
795
774
  if server_header is None:
796
- _logger.error("OPTIONS response doesn't contain 'Server' header")
775
+ _logger.error("Response doesn't contain 'Server' header")
797
776
  else:
798
777
  _logger.debug("Server header is '%s'", server_header)
799
778
  if m := self.VAST_VERSION_REGEX.match(server_header):
@@ -994,9 +973,8 @@ class VastdbApi:
994
973
 
995
974
  return snapshots, is_truncated, marker
996
975
 
997
-
998
976
  def create_table(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[],
999
- topic_partitions=0, create_imports_table=False):
977
+ topic_partitions=0, create_imports_table=False, use_external_row_ids_allocation=False):
1000
978
  """
1001
979
  Create a table, use the following request
1002
980
  POST /bucket/schema/table?table HTTP/1.1
@@ -1017,6 +995,9 @@ class VastdbApi:
1017
995
 
1018
996
  serialized_schema = arrow_schema.serialize()
1019
997
  headers['Content-Length'] = str(len(serialized_schema))
998
+ if use_external_row_ids_allocation:
999
+ headers['use-external-row-ids-alloc'] = str(use_external_row_ids_allocation)
1000
+
1020
1001
  url_params = {'topic_partitions': str(topic_partitions)} if topic_partitions else {}
1021
1002
  if create_imports_table:
1022
1003
  url_params['sub-table'] = IMPORTED_OBJECTS_TABLE_NAME
@@ -1033,8 +1014,8 @@ class VastdbApi:
1033
1014
  if parquet_path:
1034
1015
  parquet_ds = pq.ParquetDataset(parquet_path)
1035
1016
  elif parquet_bucket_name and parquet_object_name:
1036
- s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1037
- parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name,parquet_object_name]), filesystem=s3fs)
1017
+ s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1018
+ parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name, parquet_object_name]), filesystem=s3fs)
1038
1019
  else:
1039
1020
  raise RuntimeError(f'invalid params parquet_path={parquet_path} parquet_bucket_name={parquet_bucket_name} parquet_object_name={parquet_object_name}')
1040
1021
 
@@ -1049,7 +1030,6 @@ class VastdbApi:
1049
1030
  # create the table
1050
1031
  return self.create_table(bucket, schema, name, arrow_schema, txid, client_tags, expected_retvals)
1051
1032
 
1052
-
1053
1033
  def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[]):
1054
1034
  """
1055
1035
  GET /mybucket/myschema/mytable?stats HTTP/1.1
@@ -1060,29 +1040,33 @@ class VastdbApi:
1060
1040
  """
1061
1041
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1062
1042
  res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=name, command="stats"), headers=headers)
1063
- if res.status_code == 200:
1064
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1065
- stats = get_table_stats.GetRootAs(flatbuf)
1066
- num_rows = stats.NumRows()
1067
- size_in_bytes = stats.SizeInBytes()
1068
- is_external_rowid_alloc = stats.IsExternalRowidAlloc()
1069
- endpoints = []
1070
- if stats.VipsLength() == 0:
1071
- endpoints.append(self.url)
1072
- else:
1073
- ip_cls = IPv6Address if (stats.AddressType() == "ipv6") else IPv4Address
1074
- vips = [stats.Vips(i) for i in range(stats.VipsLength())]
1075
- ips = []
1076
- # extract the vips into list of IPs
1077
- for vip in vips:
1078
- start_ip = int(ip_cls(vip.StartAddress().decode()))
1079
- ips.extend(ip_cls(start_ip + i) for i in range(vip.AddressCount()))
1080
- for ip in ips:
1081
- prefix = "http" if not self.secure else "https"
1082
- endpoints.append(f"{prefix}://{str(ip)}:{self.port}")
1083
- return TableStatsResult(num_rows, size_in_bytes, is_external_rowid_alloc, endpoints)
1084
-
1085
- return self._check_res(res, "get_table_stats", expected_retvals)
1043
+ self._check_res(res, "get_table_stats", expected_retvals)
1044
+
1045
+ flatbuf = b''.join(res.iter_content(chunk_size=128))
1046
+ stats = get_table_stats.GetRootAs(flatbuf)
1047
+ num_rows = stats.NumRows()
1048
+ size_in_bytes = stats.SizeInBytes()
1049
+ is_external_rowid_alloc = stats.IsExternalRowidAlloc()
1050
+ endpoints = []
1051
+ if stats.VipsLength() == 0:
1052
+ endpoints.append(self.url)
1053
+ else:
1054
+ url = urllib3.util.parse_url(self.url)
1055
+
1056
+ ip_cls = IPv6Address if (stats.AddressType() == "ipv6") else IPv4Address
1057
+ vips = [stats.Vips(i) for i in range(stats.VipsLength())]
1058
+ ips = []
1059
+ # extract the vips into list of IPs
1060
+ for vip in vips:
1061
+ start_ip = int(ip_cls(vip.StartAddress().decode()))
1062
+ ips.extend(ip_cls(start_ip + i) for i in range(vip.AddressCount()))
1063
+ # build a list of endpoint URLs, reusing schema and port (if specified when constructing the session).
1064
+ # it is assumed that the client can access the returned IPs (e.g. if they are part of the VIP pool).
1065
+ for ip in ips:
1066
+ d = url._asdict()
1067
+ d['host'] = str(ip)
1068
+ endpoints.append(str(urllib3.util.Url(**d)))
1069
+ return TableStatsResult(num_rows, size_in_bytes, is_external_rowid_alloc, tuple(endpoints))
1086
1070
 
1087
1071
  def alter_table(self, bucket, schema, name, txid=0, client_tags=[], table_properties="",
1088
1072
  new_name="", expected_retvals=[]):
@@ -1171,7 +1155,6 @@ class VastdbApi:
1171
1155
 
1172
1156
  return bucket_name, schema_name, tables, next_key, is_truncated, count
1173
1157
 
1174
-
1175
1158
  def add_columns(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[]):
1176
1159
  """
1177
1160
  Add a column to table, use the following request
@@ -1197,7 +1180,7 @@ class VastdbApi:
1197
1180
  return self._check_res(res, "add_columns", expected_retvals)
1198
1181
 
1199
1182
  def alter_column(self, bucket, schema, table, name, txid=0, client_tags=[], column_properties="",
1200
- new_name="", column_sep = ".", column_stats="", expected_retvals=[]):
1183
+ new_name="", column_sep=".", column_stats="", expected_retvals=[]):
1201
1184
  """
1202
1185
  PUT /bucket/schema/table?column&tabular-column-name=ColumnName&tabular-new-column-name=NewColumnName HTTP/1.1
1203
1186
  Content-Length: ContentLength
@@ -1226,7 +1209,7 @@ class VastdbApi:
1226
1209
  headers['tabular-column-sep'] = column_sep
1227
1210
  headers['Content-Length'] = str(len(alter_column_req))
1228
1211
 
1229
- url_params = {'tabular-column-name': name }
1212
+ url_params = {'tabular-column-name': name}
1230
1213
  if len(new_name):
1231
1214
  url_params['tabular-new-column-name'] = new_name
1232
1215
 
@@ -1573,7 +1556,7 @@ class VastdbApi:
1573
1556
  return self._check_res(res, "import_data", expected_retvals)
1574
1557
 
1575
1558
  def _record_batch_slices(self, batch, rows_per_slice=None):
1576
- max_slice_size_in_bytes = int(0.9*5*1024*1024) # 0.9 * 5MB
1559
+ max_slice_size_in_bytes = int(0.9 * 5 * 1024 * 1024) # 0.9 * 5MB
1577
1560
  batch_len = len(batch)
1578
1561
  serialized_batch = serialize_record_batch(batch)
1579
1562
  batch_size_in_bytes = len(serialized_batch)
@@ -1591,10 +1574,10 @@ class VastdbApi:
1591
1574
  # Attempt slicing according to the current rows_per_slice
1592
1575
  offset = 0
1593
1576
  serialized_slices = []
1594
- for i in range(math.ceil(batch_len/rows_per_slice)):
1577
+ for i in range(math.ceil(batch_len / rows_per_slice)):
1595
1578
  offset = rows_per_slice * i
1596
1579
  if offset >= batch_len:
1597
- done_slicing=True
1580
+ done_slicing = True
1598
1581
  break
1599
1582
  slice_batch = batch.slice(offset, rows_per_slice)
1600
1583
  serialized_slice_batch = serialize_record_batch(slice_batch)
@@ -1605,7 +1588,7 @@ class VastdbApi:
1605
1588
  else:
1606
1589
  _logger.info(f'Using rows_per_slice {rows_per_slice} slice {i} size {sizeof_serialized_slice_batch} exceeds {max_slice_size_in_bytes} bytes, trying smaller rows_per_slice')
1607
1590
  # We have a slice that is too large
1608
- rows_per_slice = int(rows_per_slice/2)
1591
+ rows_per_slice = int(rows_per_slice / 2)
1609
1592
  if rows_per_slice < 1:
1610
1593
  raise ValueError('cannot decrease batch size below 1 row')
1611
1594
  break
@@ -1628,7 +1611,8 @@ class VastdbApi:
1628
1611
  headers['Content-Length'] = str(len(record_batch))
1629
1612
  res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1630
1613
  data=record_batch, headers=headers, stream=True)
1631
- return self._check_res(res, "insert_rows", expected_retvals)
1614
+ self._check_res(res, "insert_rows", expected_retvals)
1615
+ res.raw.read() # flush the response
1632
1616
 
1633
1617
  def update_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1634
1618
  """
@@ -1644,7 +1628,7 @@ class VastdbApi:
1644
1628
  headers['Content-Length'] = str(len(record_batch))
1645
1629
  res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1646
1630
  data=record_batch, headers=headers)
1647
- return self._check_res(res, "update_rows", expected_retvals)
1631
+ self._check_res(res, "update_rows", expected_retvals)
1648
1632
 
1649
1633
  def delete_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[],
1650
1634
  delete_from_imports_table=False):
@@ -1663,7 +1647,7 @@ class VastdbApi:
1663
1647
 
1664
1648
  res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows", url_params=url_params),
1665
1649
  data=record_batch, headers=headers)
1666
- return self._check_res(res, "delete_rows", expected_retvals)
1650
+ self._check_res(res, "delete_rows", expected_retvals)
1667
1651
 
1668
1652
  def create_projection(self, bucket, schema, table, name, columns, txid=0, client_tags=[], expected_retvals=[]):
1669
1653
  """
@@ -1873,6 +1857,10 @@ class VastdbApi:
1873
1857
  return columns, next_key, is_truncated, count
1874
1858
 
1875
1859
 
1860
+ class QueryDataInternalError(Exception):
1861
+ pass
1862
+
1863
+
1876
1864
  def _iter_query_data_response_columns(fileobj, stream_ids=None):
1877
1865
  readers = {} # {stream_id: pa.ipc.RecordBatchStreamReader}
1878
1866
  while True:
@@ -1897,8 +1885,8 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1897
1885
  if stream_id == TABULAR_QUERY_DATA_FAILED_STREAM_ID:
1898
1886
  # read the terminating end chunk from socket
1899
1887
  res = fileobj.read()
1900
- _logger.warning("stream_id=%d res=%s (failed)", stream_id, res)
1901
- raise IOError(f"Query data stream failed res={res}")
1888
+ _logger.debug("stream_id=%d res=%s (failed)", stream_id, res)
1889
+ raise QueryDataInternalError() # connection closed by server due to an internal error
1902
1890
 
1903
1891
  next_row_id_bytes = fileobj.read(8)
1904
1892
  next_row_id, = struct.unpack('<Q', next_row_id_bytes)
@@ -1913,7 +1901,7 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1913
1901
 
1914
1902
  (reader, batches) = readers[stream_id]
1915
1903
  try:
1916
- batch = reader.read_next_batch() # read single-column chunk data
1904
+ batch = reader.read_next_batch() # read single-column chunk data
1917
1905
  _logger.debug("stream_id=%d rows=%d chunk=%s", stream_id, len(batch), batch)
1918
1906
  batches.append(batch)
1919
1907
  except StopIteration: # we got an end-of-stream IPC message for a given stream ID
@@ -1923,7 +1911,7 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1923
1911
  yield (stream_id, next_row_id, table)
1924
1912
 
1925
1913
 
1926
- def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None, debug=False):
1914
+ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None, debug=False, parser: Optional[QueryDataParser] = None):
1927
1915
  """
1928
1916
  Generates pyarrow.Table objects from QueryData API response stream.
1929
1917
 
@@ -1933,16 +1921,18 @@ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None,
1933
1921
  start_row_ids = {}
1934
1922
 
1935
1923
  is_empty_projection = (len(schema) == 0)
1936
- parsers = defaultdict(lambda: QueryDataParser(schema, debug=debug)) # {stream_id: QueryDataParser}
1924
+ if parser is None:
1925
+ parser = QueryDataParser(schema, debug=debug)
1926
+ states: Dict[int, QueryDataParser.QueryDataParserState] = defaultdict(lambda: QueryDataParser.QueryDataParserState()) # {stream_id: QueryDataParser}
1937
1927
 
1938
1928
  for stream_id, next_row_id, table in _iter_query_data_response_columns(conn, stream_ids):
1939
- parser = parsers[stream_id]
1929
+ state = states[stream_id]
1940
1930
  for column in table.columns:
1941
- parser.parse(column)
1931
+ parser.parse(column, state)
1942
1932
 
1943
- parsed_table = parser.build()
1933
+ parsed_table = parser.build(state)
1944
1934
  if parsed_table is not None: # when we got all columns (and before starting a new "select_rows" cycle)
1945
- parsers.pop(stream_id)
1935
+ states.pop(stream_id)
1946
1936
  if is_empty_projection: # VAST returns an empty RecordBatch, with the correct rows' count
1947
1937
  parsed_table = table
1948
1938
 
@@ -1951,8 +1941,9 @@ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None,
1951
1941
  start_row_ids[stream_id] = next_row_id
1952
1942
  yield parsed_table # the result of a single "select_rows()" cycle
1953
1943
 
1954
- if parsers:
1955
- raise EOFError(f'all streams should be done before EOF. {parsers}')
1944
+ if states:
1945
+ raise EOFError(f'all streams should be done before EOF. {states}')
1946
+
1956
1947
 
1957
1948
  def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
1958
1949
  if field.type.equals(pa.int64()):
@@ -2095,6 +2086,7 @@ def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
2095
2086
 
2096
2087
  return field_type, field_type_type
2097
2088
 
2089
+
2098
2090
  def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2099
2091
  children = None
2100
2092
  if isinstance(f.type, pa.StructType):
@@ -2142,12 +2134,13 @@ def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2142
2134
 
2143
2135
 
2144
2136
  class QueryDataRequest:
2145
- def __init__(self, serialized, response_schema):
2137
+ def __init__(self, serialized, response_schema, response_parser):
2146
2138
  self.serialized = serialized
2147
2139
  self.response_schema = response_schema
2140
+ self.response_parser = response_parser
2148
2141
 
2149
2142
 
2150
- def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: list = None):
2143
+ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: Optional[List[str]] = None):
2151
2144
  builder = flatbuffers.Builder(1024)
2152
2145
 
2153
2146
  source_name = builder.CreateString('') # required
@@ -2201,7 +2194,8 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2201
2194
  relation = fb_relation.End(builder)
2202
2195
 
2203
2196
  builder.Finish(relation)
2204
- return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema)
2197
+
2198
+ return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema, response_parser=QueryDataParser(response_schema))
2205
2199
 
2206
2200
 
2207
2201
  def convert_column_types(table: 'pa.Table') -> 'pa.Table':