vastdb 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,26 +1,23 @@
1
+ import itertools
2
+ import json
1
3
  import logging
4
+ import math
5
+ import re
2
6
  import struct
3
7
  import urllib.parse
4
8
  from collections import defaultdict, namedtuple
5
- from datetime import datetime
6
9
  from enum import Enum
7
- from typing import Union, Optional, Iterator
8
- import ibis
9
- import xmltodict
10
- import math
11
- from functools import cmp_to_key
12
- import pyarrow.parquet as pq
10
+ from ipaddress import IPv4Address, IPv6Address
11
+ from typing import Any, Dict, Iterator, List, Optional, Union
12
+
13
13
  import flatbuffers
14
+ import ibis
14
15
  import pyarrow as pa
16
+ import pyarrow.parquet as pq
15
17
  import requests
16
- import json
17
- import itertools
18
- from aws_requests_auth.aws_auth import AWSRequestsAuth
19
18
  import urllib3
20
- import re
21
-
22
- from . import errors
23
- from ipaddress import IPv4Address, IPv6Address
19
+ import xmltodict
20
+ from aws_requests_auth.aws_auth import AWSRequestsAuth
24
21
 
25
22
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.BinaryLiteral as fb_binary_lit
26
23
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.BooleanLiteral as fb_bool_lit
@@ -32,10 +29,10 @@ import vast_flatbuf.org.apache.arrow.computeir.flatbuf.FieldIndex as fb_field_in
32
29
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.FieldRef as fb_field_ref
33
30
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Float32Literal as fb_float32_lit
34
31
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Float64Literal as fb_float64_lit
32
+ import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Int8Literal as fb_int8_lit
35
33
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Int16Literal as fb_int16_lit
36
34
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Int32Literal as fb_int32_lit
37
35
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Int64Literal as fb_int64_lit
38
- import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Int8Literal as fb_int8_lit
39
36
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Literal as fb_literal
40
37
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.Relation as fb_relation
41
38
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.RelationImpl as rel_impl
@@ -48,45 +45,54 @@ import vast_flatbuf.org.apache.arrow.flatbuf.Bool as fb_bool
48
45
  import vast_flatbuf.org.apache.arrow.flatbuf.Date as fb_date
49
46
  import vast_flatbuf.org.apache.arrow.flatbuf.Decimal as fb_decimal
50
47
  import vast_flatbuf.org.apache.arrow.flatbuf.Field as fb_field
48
+ import vast_flatbuf.org.apache.arrow.flatbuf.FixedSizeBinary as fb_fixed_size_binary
51
49
  import vast_flatbuf.org.apache.arrow.flatbuf.FloatingPoint as fb_floating_point
52
50
  import vast_flatbuf.org.apache.arrow.flatbuf.Int as fb_int
53
- import vast_flatbuf.org.apache.arrow.flatbuf.Schema as fb_schema
54
- import vast_flatbuf.org.apache.arrow.flatbuf.Time as fb_time
55
- import vast_flatbuf.org.apache.arrow.flatbuf.Struct_ as fb_struct
56
51
  import vast_flatbuf.org.apache.arrow.flatbuf.List as fb_list
57
52
  import vast_flatbuf.org.apache.arrow.flatbuf.Map as fb_map
58
- import vast_flatbuf.org.apache.arrow.flatbuf.FixedSizeBinary as fb_fixed_size_binary
53
+ import vast_flatbuf.org.apache.arrow.flatbuf.Schema as fb_schema
54
+ import vast_flatbuf.org.apache.arrow.flatbuf.Struct_ as fb_struct
55
+ import vast_flatbuf.org.apache.arrow.flatbuf.Time as fb_time
59
56
  import vast_flatbuf.org.apache.arrow.flatbuf.Timestamp as fb_timestamp
60
57
  import vast_flatbuf.org.apache.arrow.flatbuf.Utf8 as fb_utf8
61
58
  import vast_flatbuf.tabular.AlterColumnRequest as tabular_alter_column
59
+ import vast_flatbuf.tabular.AlterProjectionTableRequest as tabular_alter_projection
62
60
  import vast_flatbuf.tabular.AlterSchemaRequest as tabular_alter_schema
63
61
  import vast_flatbuf.tabular.AlterTableRequest as tabular_alter_table
64
- import vast_flatbuf.tabular.AlterProjectionTableRequest as tabular_alter_projection
62
+ import vast_flatbuf.tabular.Column as tabular_projecion_column
63
+ import vast_flatbuf.tabular.ColumnType as tabular_proj_column_type
64
+ import vast_flatbuf.tabular.CreateProjectionRequest as tabular_create_projection
65
65
  import vast_flatbuf.tabular.CreateSchemaRequest as tabular_create_schema
66
66
  import vast_flatbuf.tabular.ImportDataRequest as tabular_import_data
67
67
  import vast_flatbuf.tabular.S3File as tabular_s3_file
68
- import vast_flatbuf.tabular.CreateProjectionRequest as tabular_create_projection
69
- import vast_flatbuf.tabular.Column as tabular_projecion_column
70
- import vast_flatbuf.tabular.ColumnType as tabular_proj_column_type
71
-
72
68
  from vast_flatbuf.org.apache.arrow.computeir.flatbuf.Deref import Deref
73
- from vast_flatbuf.org.apache.arrow.computeir.flatbuf.ExpressionImpl import ExpressionImpl
69
+ from vast_flatbuf.org.apache.arrow.computeir.flatbuf.ExpressionImpl import (
70
+ ExpressionImpl,
71
+ )
74
72
  from vast_flatbuf.org.apache.arrow.computeir.flatbuf.LiteralImpl import LiteralImpl
75
73
  from vast_flatbuf.org.apache.arrow.flatbuf.DateUnit import DateUnit
76
74
  from vast_flatbuf.org.apache.arrow.flatbuf.TimeUnit import TimeUnit
77
75
  from vast_flatbuf.org.apache.arrow.flatbuf.Type import Type
76
+ from vast_flatbuf.tabular.GetProjectionTableStatsResponse import (
77
+ GetProjectionTableStatsResponse as get_projection_table_stats,
78
+ )
79
+ from vast_flatbuf.tabular.GetTableStatsResponse import (
80
+ GetTableStatsResponse as get_table_stats,
81
+ )
82
+ from vast_flatbuf.tabular.ListProjectionsResponse import (
83
+ ListProjectionsResponse as list_projections,
84
+ )
78
85
  from vast_flatbuf.tabular.ListSchemasResponse import ListSchemasResponse as list_schemas
79
86
  from vast_flatbuf.tabular.ListTablesResponse import ListTablesResponse as list_tables
80
- from vast_flatbuf.tabular.GetTableStatsResponse import GetTableStatsResponse as get_table_stats
81
- from vast_flatbuf.tabular.GetProjectionTableStatsResponse import GetProjectionTableStatsResponse as get_projection_table_stats
82
- from vast_flatbuf.tabular.ListProjectionsResponse import ListProjectionsResponse as list_projections
87
+
88
+ from . import errors
83
89
 
84
90
  UINT64_MAX = 18446744073709551615
85
91
 
86
92
  TABULAR_KEEP_ALIVE_STREAM_ID = 0xFFFFFFFF
87
93
  TABULAR_QUERY_DATA_COMPLETED_STREAM_ID = 0xFFFFFFFF - 1
88
94
  TABULAR_QUERY_DATA_FAILED_STREAM_ID = 0xFFFFFFFF - 2
89
- TABULAR_INVALID_ROW_ID = 0xFFFFFFFFFFFF # (1<<48)-1
95
+ TABULAR_INVALID_ROW_ID = 0xFFFFFFFFFFFF # (1<<48)-1
90
96
  ESTORE_INVALID_EHANDLE = UINT64_MAX
91
97
  IMPORTED_OBJECTS_TABLE_NAME = "vastdb-imported-objects"
92
98
 
@@ -121,18 +127,11 @@ def get_unit_to_flatbuff_time_unit(type):
121
127
  }
122
128
  return unit_to_flatbuff_time_unit[type]
123
129
 
124
- class Predicate:
125
- unit_to_epoch = {
126
- 'ns': 1_000_000,
127
- 'us': 1_000,
128
- 'ms': 1,
129
- 's': 0.001
130
- }
131
130
 
131
+ class Predicate:
132
132
  def __init__(self, schema: 'pa.Schema', expr: ibis.expr.types.BooleanColumn):
133
133
  self.schema = schema
134
134
  self.expr = expr
135
- self.builder = None
136
135
 
137
136
  def get_field_indexes(self, field: 'pa.Field', field_name_per_index: list) -> None:
138
137
  field_name_per_index.append(field.name)
@@ -158,8 +157,8 @@ class Predicate:
158
157
  self._field_name_per_index = {field: index for index, field in enumerate(_field_name_per_index)}
159
158
  return self._field_name_per_index
160
159
 
161
- def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: list = None):
162
- if not field_names:
160
+ def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: Optional[List[str]] = None):
161
+ if field_names is None:
163
162
  field_names = self.field_name_per_index.keys()
164
163
  projection_fields = []
165
164
  for field_name in field_names:
@@ -173,8 +172,22 @@ class Predicate:
173
172
  return builder.EndVector()
174
173
 
175
174
  def serialize(self, builder: 'flatbuffers.builder.Builder'):
176
- from ibis.expr.operations.generic import TableColumn, Literal, IsNull
177
- from ibis.expr.operations.logical import Greater, GreaterEqual, Less, LessEqual, Equals, NotEquals, And, Or, Not
175
+ from ibis.expr.operations.generic import (
176
+ IsNull,
177
+ Literal,
178
+ TableColumn,
179
+ )
180
+ from ibis.expr.operations.logical import (
181
+ And,
182
+ Equals,
183
+ Greater,
184
+ GreaterEqual,
185
+ Less,
186
+ LessEqual,
187
+ Not,
188
+ NotEquals,
189
+ Or,
190
+ )
178
191
  from ibis.expr.operations.strings import StringContains
179
192
 
180
193
  builder_map = {
@@ -189,7 +202,7 @@ class Predicate:
189
202
  StringContains: self.build_match_substring,
190
203
  }
191
204
 
192
- positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
205
+ positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
193
206
 
194
207
  self.builder = builder
195
208
 
@@ -206,7 +219,7 @@ class Predicate:
206
219
  prev_field_name = None
207
220
  for inner_op in or_args:
208
221
  _logger.debug('inner_op %s', inner_op)
209
- builder_func = builder_map.get(type(inner_op))
222
+ builder_func: Any = builder_map.get(type(inner_op))
210
223
  if not builder_func:
211
224
  raise NotImplementedError(inner_op.name)
212
225
 
@@ -261,20 +274,6 @@ class Predicate:
261
274
  fb_expression.AddImpl(self.builder, ref)
262
275
  return fb_expression.End(self.builder)
263
276
 
264
- def build_domain(self, column: int, field_name: str):
265
- offsets = []
266
- filters = self.filters[field_name]
267
- if not filters:
268
- return self.build_or([self.build_is_not_null(column)])
269
-
270
- field_name, *field_attrs = field_name.split('.')
271
- field = self.schema.field(field_name)
272
- for attr in field_attrs:
273
- field = field.type[attr]
274
- for filter_by_name in filters:
275
- offsets.append(self.build_range(column=column, field=field, filter_by_name=filter_by_name))
276
- return self.build_or(offsets)
277
-
278
277
  def rule_to_operator(self, raw_rule: str):
279
278
  operator_matcher = {
280
279
  'eq': self.build_equal,
@@ -330,6 +329,8 @@ class Predicate:
330
329
  return fb_expression.End(self.builder)
331
330
 
332
331
  def build_literal(self, field: pa.Field, value):
332
+ literal_type: Any
333
+
333
334
  if field.type.equals(pa.int64()):
334
335
  literal_type = fb_int64_lit
335
336
  literal_impl = LiteralImpl.Int64Literal
@@ -403,7 +404,7 @@ class Predicate:
403
404
  field_type = fb_utf8.End(self.builder)
404
405
 
405
406
  value = self.builder.CreateString(value)
406
- elif field.type.equals(pa.date32()): # pa.date64()
407
+ elif field.type.equals(pa.date32()): # pa.date64() is not supported
407
408
  literal_type = fb_date32_lit
408
409
  literal_impl = LiteralImpl.DateLiteral
409
410
 
@@ -411,37 +412,49 @@ class Predicate:
411
412
  fb_date.Start(self.builder)
412
413
  fb_date.AddUnit(self.builder, DateUnit.DAY)
413
414
  field_type = fb_date.End(self.builder)
414
-
415
- start_date = datetime.fromtimestamp(0).date()
416
- date_delta = value - start_date
417
- value = date_delta.days
415
+ value, = pa.array([value], field.type).cast(pa.int32()).to_pylist()
418
416
  elif isinstance(field.type, pa.TimestampType):
419
417
  literal_type = fb_timestamp_lit
420
418
  literal_impl = LiteralImpl.TimestampLiteral
421
419
 
420
+ if field.type.equals(pa.timestamp('s')):
421
+ unit = TimeUnit.SECOND
422
+ if field.type.equals(pa.timestamp('ms')):
423
+ unit = TimeUnit.MILLISECOND
424
+ if field.type.equals(pa.timestamp('us')):
425
+ unit = TimeUnit.MICROSECOND
426
+ if field.type.equals(pa.timestamp('ns')):
427
+ unit = TimeUnit.NANOSECOND
428
+
422
429
  field_type_type = Type.Timestamp
423
430
  fb_timestamp.Start(self.builder)
424
- fb_timestamp.AddUnit(self.builder, get_unit_to_flatbuff_time_unit(field.type.unit))
431
+ fb_timestamp.AddUnit(self.builder, unit)
425
432
  field_type = fb_timestamp.End(self.builder)
426
-
427
- value = int(int(value) * self.unit_to_epoch[field.type.unit])
428
- elif field.type.equals(pa.time32('s')) or field.type.equals(pa.time32('ms')) or field.type.equals(pa.time64('us')) or field.type.equals(pa.time64('ns')):
429
-
433
+ value, = pa.array([value], field.type).cast(pa.int64()).to_pylist()
434
+ elif isinstance(field.type, (pa.Time32Type, pa.Time64Type)):
430
435
  literal_type = fb_time_lit
431
436
  literal_impl = LiteralImpl.TimeLiteral
432
437
 
433
- field_type_str = str(field.type)
434
- start = field_type_str.index('[')
435
- end = field_type_str.index(']')
436
- unit = field_type_str[start + 1:end]
438
+ if field.type.equals(pa.time32('s')):
439
+ target_type = pa.int32()
440
+ unit = TimeUnit.SECOND
441
+ if field.type.equals(pa.time32('ms')):
442
+ target_type = pa.int32()
443
+ unit = TimeUnit.MILLISECOND
444
+ if field.type.equals(pa.time64('us')):
445
+ target_type = pa.int64()
446
+ unit = TimeUnit.MICROSECOND
447
+ if field.type.equals(pa.time64('ns')):
448
+ target_type = pa.int64()
449
+ unit = TimeUnit.NANOSECOND
437
450
 
438
451
  field_type_type = Type.Time
439
452
  fb_time.Start(self.builder)
440
453
  fb_time.AddBitWidth(self.builder, field.type.bit_width)
441
- fb_time.AddUnit(self.builder, get_unit_to_flatbuff_time_unit(unit))
454
+ fb_time.AddUnit(self.builder, unit)
442
455
  field_type = fb_time.End(self.builder)
443
456
 
444
- value = int(value) * self.unit_to_epoch[unit]
457
+ value, = pa.array([value], field.type).cast(target_type).to_pylist()
445
458
  elif field.type.equals(pa.bool_()):
446
459
  literal_type = fb_bool_lit
447
460
  literal_impl = LiteralImpl.BooleanLiteral
@@ -530,13 +543,20 @@ class Predicate:
530
543
  return self.build_function('match_substring', column, literal)
531
544
 
532
545
 
546
+ class FieldNodesState:
547
+ def __init__(self) -> None:
548
+ # will be set during by the parser (see below)
549
+ self.buffers: Dict[int, Any] = defaultdict(lambda: None) # a list of Arrow buffers (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout)
550
+ self.length: Dict[int, Any] = defaultdict(lambda: None) # each array must have it's length specified (https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.from_buffers)
551
+
552
+
533
553
  class FieldNode:
534
554
  """Helper class for representing nested Arrow fields and handling QueryData requests"""
535
555
  def __init__(self, field: pa.Field, index_iter, parent: Optional['FieldNode'] = None, debug: bool = False):
536
- self.index = next(index_iter) # we use DFS-first enumeration for communicating the column positions to VAST
556
+ self.index = next(index_iter) # we use DFS-first enumeration for communicating the column positions to VAST
537
557
  self.field = field
538
558
  self.type = field.type
539
- self.parent = parent # will be None if this is the top-level field
559
+ self.parent = parent # will be None if this is the top-level field
540
560
  self.debug = debug
541
561
  if isinstance(self.type, pa.StructType):
542
562
  self.children = [FieldNode(field, index_iter, parent=self) for field in self.type]
@@ -553,13 +573,7 @@ class FieldNode:
553
573
  field = pa.field('entries', pa.struct([self.type.key_field, self.type.item_field]))
554
574
  self.children = [FieldNode(field, index_iter, parent=self)]
555
575
  else:
556
- self.children = [] # for non-nested types
557
-
558
- # will be set during by the parser (see below)
559
- self.buffers = None # a list of Arrow buffers (https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout)
560
- self.length = None # each array must have it's length specified (https://arrow.apache.org/docs/python/generated/pyarrow.Array.html#pyarrow.Array.from_buffers)
561
- self.is_projected = False
562
- self.projected_field = self.field
576
+ self.children = [] # for non-nested types
563
577
 
564
578
  def _iter_to_root(self) -> Iterator['FieldNode']:
565
579
  yield self
@@ -580,24 +594,14 @@ class FieldNode:
580
594
  for child in self.children:
581
595
  yield from child._iter_leaves()
582
596
 
583
- def _iter_projected_leaves(self) -> Iterator['FieldNode']:
584
- """Generate only leaf nodes (i.e. columns having scalar types)."""
585
- if not self.children:
586
- if self.is_projected:
587
- yield self
588
- else:
589
- for child in self.children:
590
- if child.is_projected:
591
- yield from child._iter_projected_leaves()
592
-
593
597
  def debug_log(self, level=0):
594
598
  """Recursively dump this node state to log."""
595
599
  bufs = self.buffers and [b and b.hex() for b in self.buffers]
596
- _logger.debug('%s%d: %s, bufs=%s, len=%s', ' '*level, self.index, self.field, bufs, self.length)
600
+ _logger.debug('%s%d: %s, bufs=%s, len=%s', ' ' * level, self.index, self.field, bufs, self.length)
597
601
  for child in self.children:
598
- child.debug_log(level=level+1)
602
+ child.debug_log(level=level + 1)
599
603
 
600
- def set(self, arr: pa.Array):
604
+ def set(self, arr: pa.Array, state: FieldNodesState):
601
605
  """
602
606
  Assign the relevant Arrow buffers from the received array into this node.
603
607
 
@@ -609,68 +613,52 @@ class FieldNode:
609
613
  For example, `Struct<A, B>` is sent as two separate columns: `Struct<A>` and `Struct<B>`.
610
614
  Also, `Map<K, V>` is sent (as its underlying representation): `List<Struct<K>>` and `List<Struct<V>>`
611
615
  """
612
- buffers = arr.buffers()[:arr.type.num_buffers] # slicing is needed because Array.buffers() returns also nested array buffers
616
+ buffers = arr.buffers()[:arr.type.num_buffers] # slicing is needed because Array.buffers() returns also nested array buffers
613
617
  if self.debug:
614
618
  _logger.debug("set: index=%d %s %s", self.index, self.field, [b and b.hex() for b in buffers])
615
- if self.buffers is None:
616
- self.buffers = buffers
617
- self.length = len(arr)
619
+ if state.buffers[self.index] is None:
620
+ state.buffers[self.index] = buffers
621
+ state.length[self.index] = len(arr)
618
622
  else:
619
623
  # Make sure subsequent assignments are consistent with each other
620
624
  if self.debug:
621
- if not self.buffers == buffers:
622
- raise ValueError(f'self.buffers: {self.buffers} are not equal with buffers: {buffers}')
623
- if not self.length == len(arr):
624
- raise ValueError(f'self.length: {self.length} are not equal with len(arr): {len(arr)}')
625
+ if not state.buffers[self.index] == buffers:
626
+ raise ValueError(f'self.buffers: {state.buffers[self.index]} are not equal with buffers: {buffers}')
627
+ if not state.length[self.index] == len(arr):
628
+ raise ValueError(f'self.length: {state.length[self.index]} are not equal with len(arr): {len(arr)}')
625
629
 
626
- def build(self) -> pa.Array:
630
+ def build(self, state: FieldNodesState) -> pa.Array:
627
631
  """Construct an Arrow array from the collected buffers (recursively)."""
628
- children = self.children and [node.build() for node in self.children if node.is_projected]
629
- _logger.debug('build: self.field.name=%s, self.projected_field.type=%s, self.length=%s, self.buffers=%s children=%s',
630
- self.field.name, self.projected_field.type, self.length, self.buffers, children)
631
- result = pa.Array.from_buffers(self.projected_field.type, self.length, buffers=self.buffers, children=children)
632
+ children = self.children and [node.build(state) for node in self.children]
633
+ result = pa.Array.from_buffers(self.type, state.length[self.index], buffers=state.buffers[self.index], children=children)
632
634
  if self.debug:
633
635
  _logger.debug('%s result=%s', self.field, result)
634
636
  return result
635
637
 
636
- def build_projected_field(self):
637
- if isinstance(self.type, pa.StructType):
638
- [child.build_projected_field() for child in self.children if child.is_projected]
639
- self.projected_field = pa.field(self.field.name,
640
- pa.struct([child.projected_field for child in self.children if child.is_projected]),
641
- self.field.nullable,
642
- self.field.metadata)
643
638
 
644
639
  class QueryDataParser:
640
+ class QueryDataParserState(FieldNodesState):
641
+ def __init__(self) -> None:
642
+ super().__init__()
643
+ self.leaf_offset = 0
644
+
645
645
  """Used to parse VAST QueryData RPC response."""
646
- def __init__(self, arrow_schema: pa.Schema, *, debug=False, projection_positions=None):
646
+ def __init__(self, arrow_schema: pa.Schema, *, debug=False):
647
647
  self.arrow_schema = arrow_schema
648
- self.projection_positions = projection_positions
649
- index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
648
+ index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
650
649
  self.nodes = [FieldNode(field, index, debug=debug) for field in arrow_schema]
651
650
  self.debug = debug
652
651
  if self.debug:
653
652
  for node in self.nodes:
654
653
  node.debug_log()
655
654
  self.leaves = [leaf for node in self.nodes for leaf in node._iter_leaves()]
656
- self.mark_projected_nodes()
657
- [node.build_projected_field() for node in self.nodes]
658
- self.projected_leaves = [leaf for node in self.nodes for leaf in node._iter_projected_leaves()]
659
-
660
- self.leaf_offset = 0
661
655
 
662
- def mark_projected_nodes(self):
663
- for leaf in self.leaves:
664
- if self.projection_positions is None or leaf.index in self.projection_positions:
665
- for node in leaf._iter_to_root():
666
- node.is_projected = True
667
-
668
- def parse(self, column: pa.Array):
656
+ def parse(self, column: pa.Array, state: QueryDataParserState):
669
657
  """Parse a single column response from VAST (see FieldNode.set for details)"""
670
- if not self.leaf_offset < len(self.projected_leaves):
671
- raise ValueError(f'self.leaf_offset: {self.leaf_offset} are not < '
658
+ if not state.leaf_offset < len(self.leaves):
659
+ raise ValueError(f'state.leaf_offset: {state.leaf_offset} are not < '
672
660
  f'than len(self.leaves): {len(self.leaves)}')
673
- leaf = self.projected_leaves[self.leaf_offset]
661
+ leaf = self.leaves[state.leaf_offset]
674
662
 
675
663
  # A column response may be sent in multiple chunks, therefore we need to combine
676
664
  # it into a single chunk to allow reconstruction using `Array.from_buffers()`.
@@ -687,38 +675,26 @@ class QueryDataParser:
687
675
  raise ValueError(f'len(array_list): {len(array_list)} are not eq '
688
676
  f'with len(node_list): {len(node_list)}')
689
677
  for node, arr in zip(node_list, array_list):
690
- node.set(arr)
678
+ node.set(arr, state)
691
679
 
692
- self.leaf_offset += 1
680
+ state.leaf_offset += 1
693
681
 
694
- def build(self, output_field_names=None) -> Optional[pa.Table]:
682
+ def build(self, state: QueryDataParserState) -> Optional[pa.Table]:
695
683
  """Try to build the resulting Table object (if all columns were parsed)"""
696
- if self.projection_positions is not None:
697
- if self.leaf_offset < len(self.projection_positions):
698
- return None
699
- else:
700
- if self.leaf_offset < len(self.leaves):
701
- return None
684
+ if state.leaf_offset < len(self.leaves):
685
+ return None
702
686
 
703
687
  if self.debug:
704
688
  for node in self.nodes:
705
689
  node.debug_log()
706
690
 
707
- # sort resulting table according to the output field names
708
- projected_nodes = [node for node in self.nodes if node.is_projected]
709
- if output_field_names is not None:
710
- def key_func(projected_node):
711
- return output_field_names.index(projected_node.field.name)
712
- sorted_projected_nodes = sorted(projected_nodes, key=key_func)
713
- else:
714
- sorted_projected_nodes = projected_nodes
715
-
716
691
  result = pa.Table.from_arrays(
717
- arrays=[node.build() for node in sorted_projected_nodes],
718
- schema = pa.schema([node.projected_field for node in sorted_projected_nodes]))
719
- result.validate(full=True) # does expensive validation checks only if debug is enabled
692
+ arrays=[node.build(state) for node in self.nodes],
693
+ schema=self.arrow_schema)
694
+ result.validate(full=self.debug) # does expensive validation checks only if debug is enabled
720
695
  return result
721
696
 
697
+
722
698
  def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
723
699
  """Iterate over a single column response, and recursively generate all of its children."""
724
700
  yield column
@@ -730,7 +706,9 @@ def _iter_nested_arrays(column: pa.Array) -> Iterator[pa.Array]:
730
706
  yield from _iter_nested_arrays(column.values) # Note: Map is serialized in VAST as a List<Struct<K, V>>
731
707
 
732
708
 
733
- TableInfo = namedtuple('table_info', 'name properties handle num_rows size_in_bytes')
709
+ TableInfo = namedtuple('TableInfo', 'name properties handle num_rows size_in_bytes')
710
+
711
+
734
712
  def _parse_table_info(obj):
735
713
 
736
714
  name = obj.Name().decode()
@@ -740,6 +718,7 @@ def _parse_table_info(obj):
740
718
  used_bytes = obj.SizeInBytes()
741
719
  return TableInfo(name, properties, handle, num_rows, used_bytes)
742
720
 
721
+
743
722
  def build_record_batch(column_info, column_values):
744
723
  fields = [pa.field(column_name, column_type) for column_type, column_name in column_info]
745
724
  schema = pa.schema(fields)
@@ -747,6 +726,7 @@ def build_record_batch(column_info, column_values):
747
726
  batch = pa.record_batch(arrays, schema)
748
727
  return serialize_record_batch(batch)
749
728
 
729
+
750
730
  def serialize_record_batch(batch):
751
731
  sink = pa.BufferOutputStream()
752
732
  with pa.ipc.new_stream(sink, batch.schema) as writer:
@@ -754,61 +734,45 @@ def serialize_record_batch(batch):
754
734
  return sink.getvalue()
755
735
 
756
736
  # Results that returns from tablestats
757
- TableStatsResult = namedtuple("TableStatsResult",["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
737
+
738
+
739
+ TableStatsResult = namedtuple("TableStatsResult", ["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
740
+
758
741
 
759
742
  class VastdbApi:
760
743
  # we expect the vast version to be <major>.<minor>.<patch>.<protocol>
761
744
  VAST_VERSION_REGEX = re.compile(r'^vast (\d+\.\d+\.\d+\.\d+)$')
762
745
 
763
- def __init__(self, endpoint, access_key, secret_key, username=None, password=None,
764
- secure=False, auth_type=AuthType.SIGV4):
765
- url_dict = urllib3.util.parse_url(endpoint)._asdict()
746
+ def __init__(self, endpoint, access_key, secret_key, auth_type=AuthType.SIGV4, ssl_verify=True):
747
+ url = urllib3.util.parse_url(endpoint)
766
748
  self.access_key = access_key
767
749
  self.secret_key = secret_key
768
- self.username = username
769
- self.password = password
770
- self.secure = secure
771
- self.auth_type = auth_type
772
- self.executor_hosts = [endpoint] # TODO: remove
773
-
774
- username = username or ''
775
- password = password or ''
776
- if not url_dict['port']:
777
- url_dict['port'] = 443 if secure else 80
778
-
779
- self.port = url_dict['port']
780
750
 
781
751
  self.default_max_list_columns_page_size = 1000
782
752
  self.session = requests.Session()
783
- self.session.verify = False
753
+ self.session.verify = ssl_verify
784
754
  self.session.headers['user-agent'] = "VastData Tabular API 1.0 - 2022 (c)"
785
- if auth_type == AuthType.BASIC:
786
- self.session.auth = requests.auth.HTTPBasicAuth(username, password)
787
- else:
788
- if url_dict['port'] != 80 and url_dict['port'] != 443:
789
- self.aws_host = '{host}:{port}'.format(**url_dict)
790
- else:
791
- self.aws_host = '{host}'.format(**url_dict)
792
755
 
793
- self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
794
- aws_secret_access_key=secret_key,
795
- aws_host=self.aws_host,
796
- aws_region='us-east-1',
797
- aws_service='s3')
798
-
799
- if not url_dict['scheme']:
800
- url_dict['scheme'] = "https" if secure else "http"
756
+ if url.port in {80, 443, None}:
757
+ self.aws_host = f'{url.host}'
758
+ else:
759
+ self.aws_host = f'{url.host}:{url.port}'
801
760
 
802
- url = urllib3.util.Url(**url_dict)
803
761
  self.url = str(url)
804
762
  _logger.debug('url=%s aws_host=%s', self.url, self.aws_host)
805
763
 
764
+ self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
765
+ aws_secret_access_key=secret_key,
766
+ aws_host=self.aws_host,
767
+ aws_region='us-east-1',
768
+ aws_service='s3')
769
+
806
770
  # probe the cluster for its version
807
771
  self.vast_version = None
808
- res = self.session.options(self.url)
772
+ res = self.session.get(self.url)
809
773
  server_header = res.headers.get("Server")
810
774
  if server_header is None:
811
- _logger.error("OPTIONS response doesn't contain 'Server' header")
775
+ _logger.error("Response doesn't contain 'Server' header")
812
776
  else:
813
777
  _logger.debug("Server header is '%s'", server_header)
814
778
  if m := self.VAST_VERSION_REGEX.match(server_header):
@@ -1009,9 +973,8 @@ class VastdbApi:
1009
973
 
1010
974
  return snapshots, is_truncated, marker
1011
975
 
1012
-
1013
976
  def create_table(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[],
1014
- topic_partitions=0, create_imports_table=False):
977
+ topic_partitions=0, create_imports_table=False, use_external_row_ids_allocation=False):
1015
978
  """
1016
979
  Create a table, use the following request
1017
980
  POST /bucket/schema/table?table HTTP/1.1
@@ -1032,6 +995,9 @@ class VastdbApi:
1032
995
 
1033
996
  serialized_schema = arrow_schema.serialize()
1034
997
  headers['Content-Length'] = str(len(serialized_schema))
998
+ if use_external_row_ids_allocation:
999
+ headers['use-external-row-ids-alloc'] = str(use_external_row_ids_allocation)
1000
+
1035
1001
  url_params = {'topic_partitions': str(topic_partitions)} if topic_partitions else {}
1036
1002
  if create_imports_table:
1037
1003
  url_params['sub-table'] = IMPORTED_OBJECTS_TABLE_NAME
@@ -1048,8 +1014,8 @@ class VastdbApi:
1048
1014
  if parquet_path:
1049
1015
  parquet_ds = pq.ParquetDataset(parquet_path)
1050
1016
  elif parquet_bucket_name and parquet_object_name:
1051
- s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1052
- parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name,parquet_object_name]), filesystem=s3fs)
1017
+ s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1018
+ parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name, parquet_object_name]), filesystem=s3fs)
1053
1019
  else:
1054
1020
  raise RuntimeError(f'invalid params parquet_path={parquet_path} parquet_bucket_name={parquet_bucket_name} parquet_object_name={parquet_object_name}')
1055
1021
 
@@ -1064,7 +1030,6 @@ class VastdbApi:
1064
1030
  # create the table
1065
1031
  return self.create_table(bucket, schema, name, arrow_schema, txid, client_tags, expected_retvals)
1066
1032
 
1067
-
1068
1033
  def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[]):
1069
1034
  """
1070
1035
  GET /mybucket/myschema/mytable?stats HTTP/1.1
@@ -1075,29 +1040,33 @@ class VastdbApi:
1075
1040
  """
1076
1041
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1077
1042
  res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=name, command="stats"), headers=headers)
1078
- if res.status_code == 200:
1079
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1080
- stats = get_table_stats.GetRootAs(flatbuf)
1081
- num_rows = stats.NumRows()
1082
- size_in_bytes = stats.SizeInBytes()
1083
- is_external_rowid_alloc = stats.IsExternalRowidAlloc()
1084
- endpoints = []
1085
- if stats.VipsLength() == 0:
1086
- endpoints.append(self.url)
1087
- else:
1088
- ip_cls = IPv6Address if (stats.AddressType() == "ipv6") else IPv4Address
1089
- vips = [stats.Vips(i) for i in range(stats.VipsLength())]
1090
- ips = []
1091
- # extract the vips into list of IPs
1092
- for vip in vips:
1093
- start_ip = int(ip_cls(vip.StartAddress().decode()))
1094
- ips.extend(ip_cls(start_ip + i) for i in range(vip.AddressCount()))
1095
- for ip in ips:
1096
- prefix = "http" if not self.secure else "https"
1097
- endpoints.append(f"{prefix}://{str(ip)}:{self.port}")
1098
- return TableStatsResult(num_rows, size_in_bytes, is_external_rowid_alloc, endpoints)
1099
-
1100
- return self._check_res(res, "get_table_stats", expected_retvals)
1043
+ self._check_res(res, "get_table_stats", expected_retvals)
1044
+
1045
+ flatbuf = b''.join(res.iter_content(chunk_size=128))
1046
+ stats = get_table_stats.GetRootAs(flatbuf)
1047
+ num_rows = stats.NumRows()
1048
+ size_in_bytes = stats.SizeInBytes()
1049
+ is_external_rowid_alloc = stats.IsExternalRowidAlloc()
1050
+ endpoints = []
1051
+ if stats.VipsLength() == 0:
1052
+ endpoints.append(self.url)
1053
+ else:
1054
+ url = urllib3.util.parse_url(self.url)
1055
+
1056
+ ip_cls = IPv6Address if (stats.AddressType() == "ipv6") else IPv4Address
1057
+ vips = [stats.Vips(i) for i in range(stats.VipsLength())]
1058
+ ips = []
1059
+ # extract the vips into list of IPs
1060
+ for vip in vips:
1061
+ start_ip = int(ip_cls(vip.StartAddress().decode()))
1062
+ ips.extend(ip_cls(start_ip + i) for i in range(vip.AddressCount()))
1063
+ # build a list of endpoint URLs, reusing schema and port (if specified when constructing the session).
1064
+ # it is assumed that the client can access the returned IPs (e.g. if they are part of the VIP pool).
1065
+ for ip in ips:
1066
+ d = url._asdict()
1067
+ d['host'] = str(ip)
1068
+ endpoints.append(str(urllib3.util.Url(**d)))
1069
+ return TableStatsResult(num_rows, size_in_bytes, is_external_rowid_alloc, tuple(endpoints))
1101
1070
 
1102
1071
  def alter_table(self, bucket, schema, name, txid=0, client_tags=[], table_properties="",
1103
1072
  new_name="", expected_retvals=[]):
@@ -1186,7 +1155,6 @@ class VastdbApi:
1186
1155
 
1187
1156
  return bucket_name, schema_name, tables, next_key, is_truncated, count
1188
1157
 
1189
-
1190
1158
  def add_columns(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[]):
1191
1159
  """
1192
1160
  Add a column to table, use the following request
@@ -1212,7 +1180,7 @@ class VastdbApi:
1212
1180
  return self._check_res(res, "add_columns", expected_retvals)
1213
1181
 
1214
1182
  def alter_column(self, bucket, schema, table, name, txid=0, client_tags=[], column_properties="",
1215
- new_name="", column_sep = ".", column_stats="", expected_retvals=[]):
1183
+ new_name="", column_sep=".", column_stats="", expected_retvals=[]):
1216
1184
  """
1217
1185
  PUT /bucket/schema/table?column&tabular-column-name=ColumnName&tabular-new-column-name=NewColumnName HTTP/1.1
1218
1186
  Content-Length: ContentLength
@@ -1241,7 +1209,7 @@ class VastdbApi:
1241
1209
  headers['tabular-column-sep'] = column_sep
1242
1210
  headers['Content-Length'] = str(len(alter_column_req))
1243
1211
 
1244
- url_params = {'tabular-column-name': name }
1212
+ url_params = {'tabular-column-name': name}
1245
1213
  if len(new_name):
1246
1214
  url_params['tabular-new-column-name'] = new_name
1247
1215
 
@@ -1588,7 +1556,7 @@ class VastdbApi:
1588
1556
  return self._check_res(res, "import_data", expected_retvals)
1589
1557
 
1590
1558
  def _record_batch_slices(self, batch, rows_per_slice=None):
1591
- max_slice_size_in_bytes = int(0.9*5*1024*1024) # 0.9 * 5MB
1559
+ max_slice_size_in_bytes = int(0.9 * 5 * 1024 * 1024) # 0.9 * 5MB
1592
1560
  batch_len = len(batch)
1593
1561
  serialized_batch = serialize_record_batch(batch)
1594
1562
  batch_size_in_bytes = len(serialized_batch)
@@ -1606,10 +1574,10 @@ class VastdbApi:
1606
1574
  # Attempt slicing according to the current rows_per_slice
1607
1575
  offset = 0
1608
1576
  serialized_slices = []
1609
- for i in range(math.ceil(batch_len/rows_per_slice)):
1577
+ for i in range(math.ceil(batch_len / rows_per_slice)):
1610
1578
  offset = rows_per_slice * i
1611
1579
  if offset >= batch_len:
1612
- done_slicing=True
1580
+ done_slicing = True
1613
1581
  break
1614
1582
  slice_batch = batch.slice(offset, rows_per_slice)
1615
1583
  serialized_slice_batch = serialize_record_batch(slice_batch)
@@ -1620,7 +1588,7 @@ class VastdbApi:
1620
1588
  else:
1621
1589
  _logger.info(f'Using rows_per_slice {rows_per_slice} slice {i} size {sizeof_serialized_slice_batch} exceeds {max_slice_size_in_bytes} bytes, trying smaller rows_per_slice')
1622
1590
  # We have a slice that is too large
1623
- rows_per_slice = int(rows_per_slice/2)
1591
+ rows_per_slice = int(rows_per_slice / 2)
1624
1592
  if rows_per_slice < 1:
1625
1593
  raise ValueError('cannot decrease batch size below 1 row')
1626
1594
  break
@@ -1643,7 +1611,8 @@ class VastdbApi:
1643
1611
  headers['Content-Length'] = str(len(record_batch))
1644
1612
  res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1645
1613
  data=record_batch, headers=headers, stream=True)
1646
- return self._check_res(res, "insert_rows", expected_retvals)
1614
+ self._check_res(res, "insert_rows", expected_retvals)
1615
+ res.raw.read() # flush the response
1647
1616
 
1648
1617
  def update_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1649
1618
  """
@@ -1659,9 +1628,10 @@ class VastdbApi:
1659
1628
  headers['Content-Length'] = str(len(record_batch))
1660
1629
  res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1661
1630
  data=record_batch, headers=headers)
1662
- return self._check_res(res, "update_rows", expected_retvals)
1631
+ self._check_res(res, "update_rows", expected_retvals)
1663
1632
 
1664
- def delete_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1633
+ def delete_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[],
1634
+ delete_from_imports_table=False):
1665
1635
  """
1666
1636
  DELETE /mybucket/myschema/mytable?rows HTTP/1.1
1667
1637
  Content-Length: ContentLength
@@ -1673,9 +1643,11 @@ class VastdbApi:
1673
1643
  """
1674
1644
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1675
1645
  headers['Content-Length'] = str(len(record_batch))
1676
- res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1677
- data=record_batch, headers=headers)
1678
- return self._check_res(res, "delete_rows", expected_retvals)
1646
+ url_params = {'sub-table': IMPORTED_OBJECTS_TABLE_NAME} if delete_from_imports_table else {}
1647
+
1648
+ res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows", url_params=url_params),
1649
+ data=record_batch, headers=headers)
1650
+ self._check_res(res, "delete_rows", expected_retvals)
1679
1651
 
1680
1652
  def create_projection(self, bucket, schema, table, name, columns, txid=0, client_tags=[], expected_retvals=[]):
1681
1653
  """
@@ -1885,6 +1857,10 @@ class VastdbApi:
1885
1857
  return columns, next_key, is_truncated, count
1886
1858
 
1887
1859
 
1860
+ class QueryDataInternalError(Exception):
1861
+ pass
1862
+
1863
+
1888
1864
  def _iter_query_data_response_columns(fileobj, stream_ids=None):
1889
1865
  readers = {} # {stream_id: pa.ipc.RecordBatchStreamReader}
1890
1866
  while True:
@@ -1909,8 +1885,8 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1909
1885
  if stream_id == TABULAR_QUERY_DATA_FAILED_STREAM_ID:
1910
1886
  # read the terminating end chunk from socket
1911
1887
  res = fileobj.read()
1912
- _logger.warning("stream_id=%d res=%s (failed)", stream_id, res)
1913
- raise IOError(f"Query data stream failed res={res}")
1888
+ _logger.debug("stream_id=%d res=%s (failed)", stream_id, res)
1889
+ raise QueryDataInternalError() # connection closed by server due to an internal error
1914
1890
 
1915
1891
  next_row_id_bytes = fileobj.read(8)
1916
1892
  next_row_id, = struct.unpack('<Q', next_row_id_bytes)
@@ -1925,7 +1901,7 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1925
1901
 
1926
1902
  (reader, batches) = readers[stream_id]
1927
1903
  try:
1928
- batch = reader.read_next_batch() # read single-column chunk data
1904
+ batch = reader.read_next_batch() # read single-column chunk data
1929
1905
  _logger.debug("stream_id=%d rows=%d chunk=%s", stream_id, len(batch), batch)
1930
1906
  batches.append(batch)
1931
1907
  except StopIteration: # we got an end-of-stream IPC message for a given stream ID
@@ -1935,7 +1911,7 @@ def _iter_query_data_response_columns(fileobj, stream_ids=None):
1935
1911
  yield (stream_id, next_row_id, table)
1936
1912
 
1937
1913
 
1938
- def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None, debug=False):
1914
+ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None, debug=False, parser: Optional[QueryDataParser] = None):
1939
1915
  """
1940
1916
  Generates pyarrow.Table objects from QueryData API response stream.
1941
1917
 
@@ -1943,20 +1919,20 @@ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None,
1943
1919
  """
1944
1920
  if start_row_ids is None:
1945
1921
  start_row_ids = {}
1946
- projection_positions = schema.projection_positions
1947
- arrow_schema = schema.arrow_schema
1948
- output_field_names = schema.output_field_names
1949
- _logger.debug(f'projection_positions={projection_positions} len(arrow_schema)={len(arrow_schema)} arrow_schema={arrow_schema}')
1950
- is_empty_projection = (len(projection_positions) == 0)
1951
- parsers = defaultdict(lambda: QueryDataParser(arrow_schema, debug=debug, projection_positions=projection_positions)) # {stream_id: QueryDataParser}
1922
+
1923
+ is_empty_projection = (len(schema) == 0)
1924
+ if parser is None:
1925
+ parser = QueryDataParser(schema, debug=debug)
1926
+ states: Dict[int, QueryDataParser.QueryDataParserState] = defaultdict(lambda: QueryDataParser.QueryDataParserState()) # {stream_id: QueryDataParser}
1927
+
1952
1928
  for stream_id, next_row_id, table in _iter_query_data_response_columns(conn, stream_ids):
1953
- parser = parsers[stream_id]
1929
+ state = states[stream_id]
1954
1930
  for column in table.columns:
1955
- parser.parse(column)
1931
+ parser.parse(column, state)
1956
1932
 
1957
- parsed_table = parser.build(output_field_names)
1933
+ parsed_table = parser.build(state)
1958
1934
  if parsed_table is not None: # when we got all columns (and before starting a new "select_rows" cycle)
1959
- parsers.pop(stream_id)
1935
+ states.pop(stream_id)
1960
1936
  if is_empty_projection: # VAST returns an empty RecordBatch, with the correct rows' count
1961
1937
  parsed_table = table
1962
1938
 
@@ -1965,8 +1941,9 @@ def parse_query_data_response(conn, schema, stream_ids=None, start_row_ids=None,
1965
1941
  start_row_ids[stream_id] = next_row_id
1966
1942
  yield parsed_table # the result of a single "select_rows()" cycle
1967
1943
 
1968
- if parsers:
1969
- raise EOFError(f'all streams should be done before EOF. {parsers}')
1944
+ if states:
1945
+ raise EOFError(f'all streams should be done before EOF. {states}')
1946
+
1970
1947
 
1971
1948
  def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
1972
1949
  if field.type.equals(pa.int64()):
@@ -2042,7 +2019,7 @@ def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
2042
2019
  fb_utf8.Start(builder)
2043
2020
  field_type = fb_utf8.End(builder)
2044
2021
 
2045
- elif field.type.equals(pa.date32()): # pa.date64()
2022
+ elif field.type.equals(pa.date32()): # pa.date64() is not supported
2046
2023
  field_type_type = Type.Date
2047
2024
  fb_date.Start(builder)
2048
2025
  fb_date.AddUnit(builder, DateUnit.DAY)
@@ -2109,6 +2086,7 @@ def get_field_type(builder: flatbuffers.Builder, field: pa.Field):
2109
2086
 
2110
2087
  return field_type, field_type_type
2111
2088
 
2089
+
2112
2090
  def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2113
2091
  children = None
2114
2092
  if isinstance(f.type, pa.StructType):
@@ -2155,19 +2133,14 @@ def build_field(builder: flatbuffers.Builder, f: pa.Field, name: str):
2155
2133
  return fb_field.End(builder)
2156
2134
 
2157
2135
 
2158
- class VastDBResponseSchema:
2159
- def __init__(self, arrow_schema, projection_positions, output_field_names):
2160
- self.arrow_schema = arrow_schema
2161
- self.projection_positions = projection_positions
2162
- self.output_field_names = output_field_names
2163
-
2164
2136
  class QueryDataRequest:
2165
- def __init__(self, serialized, response_schema):
2137
+ def __init__(self, serialized, response_schema, response_parser):
2166
2138
  self.serialized = serialized
2167
2139
  self.response_schema = response_schema
2140
+ self.response_parser = response_parser
2168
2141
 
2169
2142
 
2170
- def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: list = None):
2143
+ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: Optional[List[str]] = None):
2171
2144
  builder = flatbuffers.Builder(1024)
2172
2145
 
2173
2146
  source_name = builder.CreateString('') # required
@@ -2187,31 +2160,17 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2187
2160
  filter_obj = predicate.serialize(builder)
2188
2161
 
2189
2162
  parser = QueryDataParser(schema)
2190
- leaves_map = {}
2191
- for node in parser.nodes:
2192
- for descendent in node._iter_nodes():
2193
- if descendent.parent and isinstance(descendent.parent.type, (pa.ListType, pa.MapType)):
2194
- continue
2195
- iter_from_root = reversed(list(descendent._iter_to_root()))
2196
- descendent_full_name = '.'.join([n.field.name for n in iter_from_root])
2197
- descendent_leaves = [leaf.index for leaf in descendent._iter_leaves()]
2198
- leaves_map[descendent_full_name] = descendent_leaves
2199
-
2200
- output_field_names = None
2163
+ fields_map = {node.field.name: node.field for node in parser.nodes}
2164
+ leaves_map = {node.field.name: [leaf.index for leaf in node._iter_leaves()] for node in parser.nodes}
2165
+
2201
2166
  if field_names is None:
2202
2167
  field_names = [field.name for field in schema]
2203
- else:
2204
- output_field_names = [f.split('.')[0] for f in field_names]
2205
- # sort projected field_names according to positions to maintain ordering according to the schema
2206
- def compare_field_names_by_pos(field_name1, field_name2):
2207
- return leaves_map[field_name1][0]-leaves_map[field_name2][0]
2208
- field_names = sorted(field_names, key=cmp_to_key(compare_field_names_by_pos))
2209
2168
 
2169
+ response_schema = pa.schema([fields_map[name] for name in field_names])
2210
2170
  projection_fields = []
2211
- projection_positions = []
2212
2171
  for field_name in field_names:
2172
+ # TODO: only root-level projection pushdown is supported (i.e. no support for SELECT s.x FROM t)
2213
2173
  positions = leaves_map[field_name]
2214
- projection_positions.extend(positions)
2215
2174
  for leaf_position in positions:
2216
2175
  fb_field_index.Start(builder)
2217
2176
  fb_field_index.AddPosition(builder, leaf_position)
@@ -2222,8 +2181,6 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2222
2181
  builder.PrependUOffsetTRelative(offset)
2223
2182
  projection = builder.EndVector()
2224
2183
 
2225
- response_schema = VastDBResponseSchema(schema, projection_positions, output_field_names=output_field_names)
2226
-
2227
2184
  fb_source.Start(builder)
2228
2185
  fb_source.AddName(builder, source_name)
2229
2186
  fb_source.AddSchema(builder, schema_obj)
@@ -2237,7 +2194,8 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2237
2194
  relation = fb_relation.End(builder)
2238
2195
 
2239
2196
  builder.Finish(relation)
2240
- return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema)
2197
+
2198
+ return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema, response_parser=QueryDataParser(response_schema))
2241
2199
 
2242
2200
 
2243
2201
  def convert_column_types(table: 'pa.Table') -> 'pa.Table':