vastdb 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/errors.py CHANGED
@@ -71,6 +71,10 @@ class ServiceUnavailable(HttpError):
71
71
  pass
72
72
 
73
73
 
74
+ class Slowdown(ServiceUnavailable):
75
+ pass
76
+
77
+
74
78
  class UnexpectedError(HttpError):
75
79
  pass
76
80
 
@@ -97,6 +101,10 @@ class MissingTransaction(Missing):
97
101
  pass
98
102
 
99
103
 
104
+ class MissingRowIdColumn(Missing):
105
+ pass
106
+
107
+
100
108
  class NotSupported(Exception):
101
109
  pass
102
110
 
@@ -163,6 +171,12 @@ class NotSupportedVersion(NotSupported):
163
171
  version: str
164
172
 
165
173
 
174
+ def handle_unavailable(**kwargs):
175
+ if kwargs['code'] == 'SlowDown':
176
+ raise Slowdown(**kwargs)
177
+ raise ServiceUnavailable(**kwargs)
178
+
179
+
166
180
  ERROR_TYPES_MAP = {
167
181
  HttpStatus.BAD_REQUEST: BadRequest,
168
182
  HttpStatus.FOBIDDEN: Forbidden,
@@ -172,7 +186,7 @@ ERROR_TYPES_MAP = {
172
186
  HttpStatus.CONFLICT: Conflict,
173
187
  HttpStatus.INTERNAL_SERVER_ERROR: InternalServerError,
174
188
  HttpStatus.NOT_IMPLEMENTED: NotImplemented,
175
- HttpStatus.SERVICE_UNAVAILABLE: ServiceUnavailable,
189
+ HttpStatus.SERVICE_UNAVAILABLE: handle_unavailable,
176
190
  }
177
191
 
178
192
 
@@ -205,4 +219,4 @@ def from_response(res: requests.Response):
205
219
  log.warning("RPC failed: %s", kwargs)
206
220
  status = HttpStatus(res.status_code)
207
221
  error_type = ERROR_TYPES_MAP.get(status, UnexpectedError)
208
- return error_type(**kwargs)
222
+ return error_type(**kwargs) # type: ignore
@@ -130,46 +130,13 @@ def get_unit_to_flatbuff_time_unit(type):
130
130
  class Predicate:
131
131
  def __init__(self, schema: 'pa.Schema', expr: ibis.expr.types.BooleanColumn):
132
132
  self.schema = schema
133
+ index = itertools.count() # used to generate leaf column positions for VAST QueryData RPC
134
+ # Arrow schema contains the top-level columns, where each column may include multiple subfields
135
+ # We use DFS is used to enumerate all the sub-columns, using `index` as an ID allocator
136
+ nodes = [FieldNode(field, index) for field in schema]
137
+ self.nodes_map = {node.field.name: node for node in nodes}
133
138
  self.expr = expr
134
139
 
135
- def get_field_indexes(self, field: 'pa.Field', field_name_per_index: list) -> None:
136
- field_name_per_index.append(field.name)
137
-
138
- if isinstance(field.type, pa.StructType):
139
- flat_fields = field.flatten()
140
- elif isinstance(field.type, pa.MapType):
141
- flat_fields = [pa.field(f'{field.name}.entries', pa.struct([field.type.key_field, field.type.item_field]))]
142
- elif isinstance(field.type, pa.ListType):
143
- flat_fields = [pa.field(f'{field.name}.{field.type.value_field.name}', field.type.value_field.type)]
144
- else:
145
- return
146
-
147
- for flat_field in flat_fields:
148
- self.get_field_indexes(flat_field, field_name_per_index)
149
-
150
- @property
151
- def field_name_per_index(self):
152
- if self._field_name_per_index is None:
153
- _field_name_per_index = []
154
- for field in self.schema:
155
- self.get_field_indexes(field, _field_name_per_index)
156
- self._field_name_per_index = {field: index for index, field in enumerate(_field_name_per_index)}
157
- return self._field_name_per_index
158
-
159
- def get_projections(self, builder: 'flatbuffers.builder.Builder', field_names: Optional[List[str]] = None):
160
- if field_names is None:
161
- field_names = self.field_name_per_index.keys()
162
- projection_fields = []
163
- for field_name in field_names:
164
- fb_field_index.Start(builder)
165
- fb_field_index.AddPosition(builder, self.field_name_per_index[field_name])
166
- offset = fb_field_index.End(builder)
167
- projection_fields.append(offset)
168
- fb_source.StartProjectionVector(builder, len(projection_fields))
169
- for offset in reversed(projection_fields):
170
- builder.PrependUOffsetTRelative(offset)
171
- return builder.EndVector()
172
-
173
140
  def serialize(self, builder: 'flatbuffers.builder.Builder'):
174
141
  from ibis.expr.operations.generic import (
175
142
  IsNull,
@@ -178,6 +145,7 @@ class Predicate:
178
145
  )
179
146
  from ibis.expr.operations.logical import (
180
147
  And,
148
+ Between,
181
149
  Equals,
182
150
  Greater,
183
151
  GreaterEqual,
@@ -200,10 +168,9 @@ class Predicate:
200
168
  IsNull: self.build_is_null,
201
169
  Not: self.build_is_not_null,
202
170
  StringContains: self.build_match_substring,
171
+ Between: self.build_between,
203
172
  }
204
173
 
205
- positions_map = dict((f.name, index) for index, f in enumerate(self.schema)) # TODO: BFS
206
-
207
174
  self.builder = builder
208
175
 
209
176
  offsets = []
@@ -237,6 +204,9 @@ class Predicate:
237
204
  raise NotImplementedError(self.expr)
238
205
  column, = not_arg.args
239
206
  literals = (None,)
207
+ elif builder_func == self.build_between:
208
+ column, lower, upper = inner_op.args
209
+ literals = (None,)
240
210
  else:
241
211
  column, arg = inner_op.args
242
212
  if isinstance(arg, tuple):
@@ -256,12 +226,19 @@ class Predicate:
256
226
  elif prev_field_name != field_name:
257
227
  raise NotImplementedError(self.expr)
258
228
 
259
- column_offset = self.build_column(position=positions_map[field_name])
229
+ node = self.nodes_map[field_name]
230
+ # TODO: support predicate pushdown for leaf nodes (ORION-160338)
231
+ if node.children:
232
+ raise NotImplementedError(node.field) # no predicate pushdown for nested columns
233
+ column_offset = self.build_column(position=node.index)
260
234
  field = self.schema.field(field_name)
261
235
  for literal in literals:
262
236
  args_offsets = [column_offset]
263
237
  if literal is not None:
264
238
  args_offsets.append(self.build_literal(field=field, value=literal.value))
239
+ if builder_func == self.build_between:
240
+ args_offsets.append(self.build_literal(field=field, value=lower.value))
241
+ args_offsets.append(self.build_literal(field=field, value=upper.value))
265
242
 
266
243
  inner_offsets.append(builder_func(*args_offsets))
267
244
 
@@ -556,6 +533,13 @@ class Predicate:
556
533
  def build_match_substring(self, column: int, literal: int):
557
534
  return self.build_function('match_substring', column, literal)
558
535
 
536
+ def build_between(self, column: int, lower: int, upper: int):
537
+ offsets = [
538
+ self.build_greater_equal(column, lower),
539
+ self.build_less_equal(column, upper),
540
+ ]
541
+ return self.build_and(offsets)
542
+
559
543
 
560
544
  class FieldNodesState:
561
545
  def __init__(self) -> None:
@@ -824,12 +808,13 @@ class VastdbApi:
824
808
  return prefix
825
809
 
826
810
  def _fill_common_headers(self, txid=0, client_tags=[], version_id=1):
827
- common_headers = {'tabular-txid': str(txid), 'tabular-api-version-id': str(version_id),
828
- 'tabular-client-name': 'tabular-api'}
829
- for tag in client_tags:
830
- common_headers['tabular-client-tags-%d' % client_tags.index(tag)] = tag
811
+ common_headers = {
812
+ 'tabular-txid': str(txid),
813
+ 'tabular-api-version-id': str(version_id),
814
+ 'tabular-client-name': 'tabular-api'
815
+ }
831
816
 
832
- return common_headers
817
+ return common_headers | {f'tabular-client-tags-{index}': tag for index, tag in enumerate(client_tags)}
833
818
 
834
819
  def _check_res(self, res, cmd="", expected_retvals=[]):
835
820
  if exc := errors.from_response(res):
@@ -937,8 +922,7 @@ class VastdbApi:
937
922
  res_headers = res.headers
938
923
  next_key = int(res_headers['tabular-next-key'])
939
924
  is_truncated = res_headers['tabular-is-truncated'] == 'true'
940
- flatbuf = b''.join(res.iter_content(chunk_size=128))
941
- lists = list_schemas.GetRootAs(flatbuf)
925
+ lists = list_schemas.GetRootAs(res.content)
942
926
  bucket_name = lists.BucketName().decode()
943
927
  if not bucket.startswith(bucket_name):
944
928
  raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
@@ -961,8 +945,7 @@ class VastdbApi:
961
945
  res = self.session.get(self._api_prefix(bucket=bucket, command="list", url_params=url_params), headers={}, stream=True)
962
946
  self._check_res(res, "list_snapshots")
963
947
 
964
- out = b''.join(res.iter_content(chunk_size=128))
965
- xml_str = out.decode()
948
+ xml_str = res.content.decode()
966
949
  xml_dict = xmltodict.parse(xml_str)
967
950
  list_res = xml_dict['ListBucketResult']
968
951
  is_truncated = list_res['IsTruncated'] == 'true'
@@ -1044,8 +1027,7 @@ class VastdbApi:
1044
1027
  res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=name, command="stats", url_params=url_params), headers=headers)
1045
1028
  self._check_res(res, "get_table_stats", expected_retvals)
1046
1029
 
1047
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1048
- stats = get_table_stats.GetRootAs(flatbuf)
1030
+ stats = get_table_stats.GetRootAs(res.content)
1049
1031
  num_rows = stats.NumRows()
1050
1032
  size_in_bytes = stats.SizeInBytes()
1051
1033
  is_external_rowid_alloc = stats.IsExternalRowidAlloc()
@@ -1144,8 +1126,7 @@ class VastdbApi:
1144
1126
  res_headers = res.headers
1145
1127
  next_key = int(res_headers['tabular-next-key'])
1146
1128
  is_truncated = res_headers['tabular-is-truncated'] == 'true'
1147
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1148
- lists = list_tables.GetRootAs(flatbuf)
1129
+ lists = list_tables.GetRootAs(res.content)
1149
1130
  bucket_name = lists.BucketName().decode()
1150
1131
  schema_name = lists.SchemaName().decode()
1151
1132
  if not bucket.startswith(bucket_name): # ignore snapshot name
@@ -1273,11 +1254,7 @@ class VastdbApi:
1273
1254
  next_key = int(res_headers['tabular-next-key'])
1274
1255
  is_truncated = res_headers['tabular-is-truncated'] == 'true'
1275
1256
  count = int(res_headers['tabular-list-count'])
1276
- columns = []
1277
- if not count_only:
1278
- schema_buf = b''.join(res.iter_content(chunk_size=128))
1279
- schema_out = pa.ipc.open_stream(schema_buf).schema
1280
- columns = schema_out
1257
+ columns = [] if count_only else pa.ipc.open_stream(res.content).schema
1281
1258
 
1282
1259
  return columns, next_key, is_truncated, count
1283
1260
 
@@ -1578,8 +1555,7 @@ class VastdbApi:
1578
1555
  headers['Content-Length'] = str(len(record_batch))
1579
1556
  res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1580
1557
  data=record_batch, headers=headers, stream=True)
1581
- self._check_res(res, "insert_rows", expected_retvals)
1582
- res.raw.read() # flush the response
1558
+ return self._check_res(res, "insert_rows", expected_retvals)
1583
1559
 
1584
1560
  def update_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1585
1561
  """
@@ -1678,8 +1654,7 @@ class VastdbApi:
1678
1654
  res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection-stats", url_params=url_params),
1679
1655
  headers=headers)
1680
1656
  if res.status_code == 200:
1681
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1682
- stats = get_projection_table_stats.GetRootAs(flatbuf)
1657
+ stats = get_projection_table_stats.GetRootAs(res.content)
1683
1658
  num_rows = stats.NumRows()
1684
1659
  size_in_bytes = stats.SizeInBytes()
1685
1660
  dirty_blocks_percentage = stats.DirtyBlocksPercentage()
@@ -1765,8 +1740,7 @@ class VastdbApi:
1765
1740
  next_key = int(res_headers['tabular-next-key'])
1766
1741
  is_truncated = res_headers['tabular-is-truncated'] == 'true'
1767
1742
  count = int(res_headers['tabular-list-count'])
1768
- flatbuf = b''.join(res.iter_content(chunk_size=128))
1769
- lists = list_projections.GetRootAs(flatbuf)
1743
+ lists = list_projections.GetRootAs(res.content)
1770
1744
  bucket_name = lists.BucketName().decode()
1771
1745
  schema_name = lists.SchemaName().decode()
1772
1746
  table_name = lists.TableName().decode()
@@ -1813,13 +1787,8 @@ class VastdbApi:
1813
1787
  next_key = int(res_headers['tabular-next-key'])
1814
1788
  is_truncated = res_headers['tabular-is-truncated'] == 'true'
1815
1789
  count = int(res_headers['tabular-list-count'])
1816
- columns = []
1817
- if not count_only:
1818
- schema_buf = b''.join(res.iter_content(chunk_size=128))
1819
- schema_out = pa.ipc.open_stream(schema_buf).schema
1820
- for f in schema_out:
1821
- columns.append([f.name, f.type, f.metadata])
1822
- # sort_type = f.metadata[b'VAST:sort_type'].decode()
1790
+ columns = [] if count_only else [[f.name, f.type, f.metadata] for f in
1791
+ pa.ipc.open_stream(res.content).schema]
1823
1792
 
1824
1793
  return columns, next_key, is_truncated, count
1825
1794
 
@@ -2107,6 +2076,13 @@ class QueryDataRequest:
2107
2076
  self.response_parser = response_parser
2108
2077
 
2109
2078
 
2079
+ def get_response_schema(schema: 'pa.Schema' = pa.schema([]), field_names: Optional[List[str]] = None):
2080
+ if field_names is None:
2081
+ field_names = [field.name for field in schema]
2082
+
2083
+ return pa.schema([schema.field(name) for name in field_names])
2084
+
2085
+
2110
2086
  def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibis.expr.types.BooleanColumn = None, field_names: Optional[List[str]] = None):
2111
2087
  builder = flatbuffers.Builder(1024)
2112
2088
 
@@ -2127,13 +2103,11 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2127
2103
  filter_obj = predicate.serialize(builder)
2128
2104
 
2129
2105
  parser = QueryDataParser(schema)
2130
- fields_map = {node.field.name: node.field for node in parser.nodes}
2131
2106
  leaves_map = {node.field.name: [leaf.index for leaf in node._iter_leaves()] for node in parser.nodes}
2132
2107
 
2133
- if field_names is None:
2134
- field_names = [field.name for field in schema]
2108
+ response_schema = get_response_schema(schema, field_names)
2109
+ field_names = [field.name for field in response_schema]
2135
2110
 
2136
- response_schema = pa.schema([fields_map[name] for name in field_names])
2137
2111
  projection_fields = []
2138
2112
  for field_name in field_names:
2139
2113
  # TODO: only root-level projection pushdown is supported (i.e. no support for SELECT s.x FROM t)
vastdb/session.py CHANGED
@@ -26,11 +26,16 @@ class Features:
26
26
  if self.vast_version < (5, 2):
27
27
  raise errors.NotSupportedVersion("import_table requires 5.2+", self.vast_version)
28
28
 
29
+ def check_return_row_ids(self):
30
+ """Check if insert/update/delete can return the row_ids."""
31
+ if self.vast_version < (5, 1):
32
+ raise errors.NotSupportedVersion("return_row_ids requires 5.1+", self.vast_version)
33
+
29
34
 
30
35
  class Session:
31
36
  """VAST database session."""
32
37
 
33
- def __init__(self, access=None, secret=None, endpoint=None):
38
+ def __init__(self, access=None, secret=None, endpoint=None, ssl_verify=True):
34
39
  """Connect to a VAST Database endpoint, using specified credentials."""
35
40
  if access is None:
36
41
  access = os.environ['AWS_ACCESS_KEY_ID']
@@ -39,7 +44,7 @@ class Session:
39
44
  if endpoint is None:
40
45
  endpoint = os.environ['AWS_S3_ENDPOINT_URL']
41
46
 
42
- self.api = internal_commands.VastdbApi(endpoint, access, secret)
47
+ self.api = internal_commands.VastdbApi(endpoint, access, secret, ssl_verify=ssl_verify)
43
48
  version_tuple = tuple(int(part) for part in self.api.vast_version.split('.'))
44
49
  self.features = Features(version_tuple)
45
50
  self.s3 = boto3.client('s3',
vastdb/table.py CHANGED
@@ -7,10 +7,12 @@ import queue
7
7
  from dataclasses import dataclass, field
8
8
  from math import ceil
9
9
  from threading import Event
10
- from typing import Dict, List, Optional, Tuple, Union
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
11
 
12
+ import backoff
12
13
  import ibis
13
14
  import pyarrow as pa
15
+ import requests
14
16
 
15
17
  from . import errors, internal_commands, schema, util
16
18
 
@@ -18,10 +20,14 @@ log = logging.getLogger(__name__)
18
20
 
19
21
 
20
22
  INTERNAL_ROW_ID = "$row_id"
23
+ INTERNAL_ROW_ID_FIELD = pa.field(INTERNAL_ROW_ID, pa.uint64())
24
+
21
25
  MAX_ROWS_PER_BATCH = 512 * 1024
22
26
  # for insert we need a smaller limit due to response amplification
23
27
  # for example insert of 512k uint8 result in 512k*8bytes response since row_ids are uint64
24
28
  MAX_INSERT_ROWS_PER_PATCH = 512 * 1024
29
+ # in case insert has TooWideRow - need to insert in smaller batches - each cell could contain up to 128K, and our wire is limited to 5MB
30
+ MAX_COLUMN_IN_BATCH = int(5 * 1024 / 128)
25
31
 
26
32
 
27
33
  @dataclass
@@ -34,19 +40,45 @@ class TableStats:
34
40
  endpoints: Tuple[str, ...] = ()
35
41
 
36
42
 
43
+ RETRIABLE_ERRORS = (
44
+ errors.Slowdown,
45
+ requests.exceptions.ConnectionError,
46
+ )
47
+
48
+
37
49
  @dataclass
38
50
  class QueryConfig:
39
51
  """Query execution configiration."""
40
52
 
53
+ # allows server-side parallel processing by issuing multiple reads concurrently for a single RPC
41
54
  num_sub_splits: int = 4
55
+
56
+ # used to split the table into disjoint subsets of rows, to be processed concurrently using multiple RPCs
42
57
  num_splits: int = 1
58
+
59
+ # each endpoint will be handled by a separate worker thread
60
+ # a single endpoint can be specified more than once to benefit from multithreaded execution
43
61
  data_endpoints: Optional[List[str]] = None
62
+
63
+ # a subsplit fiber will finish after sending this number of rows back to the client
44
64
  limit_rows_per_sub_split: int = 128 * 1024
65
+
66
+ # each fiber will read the following number of rowgroups coninuously before skipping
67
+ # in order to use semi-sorted projections this value must be 8
45
68
  num_row_groups_per_sub_split: int = 8
69
+
70
+ # can be disabled for benchmarking purposes
46
71
  use_semi_sorted_projections: bool = True
72
+
73
+ # used to estimate the number of splits, given the table rows' count
47
74
  rows_per_split: int = 4000000
75
+
76
+ # used for worker threads' naming
48
77
  query_id: str = ""
49
78
 
79
+ # allows retrying QueryData when the server is overloaded
80
+ backoff_func: Any = field(default=backoff.on_exception(backoff.expo, RETRIABLE_ERRORS, max_tries=10))
81
+
50
82
 
51
83
  @dataclass
52
84
  class ImportConfig:
@@ -72,7 +104,8 @@ class SelectSplitState:
72
104
  Can be called repeatedly, to allow pagination.
73
105
  """
74
106
  while not self.done:
75
- response = api.query_data(
107
+ query_with_backoff = self.config.backoff_func(api.query_data)
108
+ response = query_with_backoff(
76
109
  bucket=self.table.bucket.name,
77
110
  schema=self.table.schema.name,
78
111
  table=self.table.name,
@@ -263,7 +296,7 @@ class Table:
263
296
  return TableStats(**stats_tuple._asdict())
264
297
 
265
298
  def select(self, columns: Optional[List[str]] = None,
266
- predicate: ibis.expr.types.BooleanColumn = None,
299
+ predicate: Union[ibis.expr.types.BooleanColumn, ibis.common.deferred.Deferred] = None,
267
300
  config: Optional[QueryConfig] = None,
268
301
  *,
269
302
  internal_row_id: bool = False) -> pa.RecordBatchReader:
@@ -291,11 +324,20 @@ class Table:
291
324
 
292
325
  query_schema = self.arrow_schema
293
326
  if internal_row_id:
294
- queried_fields = [pa.field(INTERNAL_ROW_ID, pa.uint64())]
327
+ queried_fields = [INTERNAL_ROW_ID_FIELD]
295
328
  queried_fields.extend(column for column in self.arrow_schema)
296
329
  query_schema = pa.schema(queried_fields)
297
330
  columns.append(INTERNAL_ROW_ID)
298
331
 
332
+ if predicate is True:
333
+ predicate = None
334
+ if predicate is False:
335
+ response_schema = internal_commands.get_response_schema(schema=query_schema, field_names=columns)
336
+ return pa.RecordBatchReader.from_batches(response_schema, [])
337
+
338
+ if isinstance(predicate, ibis.common.deferred.Deferred):
339
+ predicate = predicate.resolve(self._ibis_table) # may raise if the predicate is invalid (e.g. wrong types / missing column)
340
+
299
341
  query_data_request = internal_commands.build_query_data_request(
300
342
  schema=query_schema,
301
343
  predicate=predicate,
@@ -385,27 +427,68 @@ class Table:
385
427
 
386
428
  return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
387
429
 
388
- def insert(self, rows: pa.RecordBatch) -> pa.RecordBatch:
430
+ def insert_in_column_batches(self, rows: pa.RecordBatch):
431
+ """Split the RecordBatch into max_columns that can be inserted in single RPC.
432
+
433
+ Insert first MAX_COLUMN_IN_BATCH columns and get the row_ids. Then loop on the rest of the columns and
434
+ update in groups of MAX_COLUMN_IN_BATCH.
435
+ """
436
+ column_record_batch = pa.RecordBatch.from_arrays([_combine_chunks(rows.column(i)) for i in range(0, MAX_COLUMN_IN_BATCH)],
437
+ schema=pa.schema([rows.schema.field(i) for i in range(0, MAX_COLUMN_IN_BATCH)]))
438
+ row_ids = self.insert(rows=column_record_batch) # type: ignore
439
+
440
+ columns_names = [field.name for field in rows.schema]
441
+ columns = list(rows.schema)
442
+ arrays = [_combine_chunks(rows.column(i)) for i in range(len(rows.schema))]
443
+ for start in range(MAX_COLUMN_IN_BATCH, len(rows.schema), MAX_COLUMN_IN_BATCH):
444
+ end = start + MAX_COLUMN_IN_BATCH if start + MAX_COLUMN_IN_BATCH < len(rows.schema) else len(rows.schema)
445
+ columns_name_chunk = columns_names[start:end]
446
+ columns_chunks = columns[start:end]
447
+ arrays_chunks = arrays[start:end]
448
+ columns_chunks.append(INTERNAL_ROW_ID_FIELD)
449
+ arrays_chunks.append(row_ids.to_pylist())
450
+ column_record_batch = pa.RecordBatch.from_arrays(arrays_chunks, schema=pa.schema(columns_chunks))
451
+ self.update(rows=column_record_batch, columns=columns_name_chunk)
452
+ return row_ids
453
+
454
+ def insert(self, rows: pa.RecordBatch):
389
455
  """Insert a RecordBatch into this table."""
390
456
  if self._imports_table:
391
457
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
392
- serialized_slices = util.iter_serialized_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
393
- for slice in serialized_slices:
394
- self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
395
- txid=self.tx.txid)
458
+ try:
459
+ row_ids = []
460
+ serialized_slices = util.iter_serialized_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
461
+ for slice in serialized_slices:
462
+ res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
463
+ txid=self.tx.txid)
464
+ (batch,) = pa.RecordBatchStreamReader(res.raw)
465
+ row_ids.append(batch[INTERNAL_ROW_ID])
466
+ try:
467
+ self.tx._rpc.features.check_return_row_ids()
468
+ except errors.NotSupportedVersion:
469
+ return # type: ignore
470
+ return pa.chunked_array(row_ids)
471
+ except errors.TooWideRow:
472
+ self.tx._rpc.features.check_return_row_ids()
473
+ return self.insert_in_column_batches(rows)
396
474
 
397
475
  def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: Optional[List[str]] = None) -> None:
398
476
  """Update a subset of cells in this table.
399
477
 
400
- Row IDs are specified using a special field (named "$row_id" of uint64 type).
478
+ Row IDs are specified using a special field (named "$row_id" of uint64 type) - this function assume that this
479
+ special field is part of arguments.
401
480
 
402
481
  A subset of columns to be updated can be specified via the `columns` argument.
403
482
  """
404
483
  if self._imports_table:
405
484
  raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
485
+ try:
486
+ rows_chunk = rows[INTERNAL_ROW_ID]
487
+ except KeyError:
488
+ raise errors.MissingRowIdColumn
406
489
  if columns is not None:
407
490
  update_fields = [(INTERNAL_ROW_ID, pa.uint64())]
408
- update_values = [_combine_chunks(rows[INTERNAL_ROW_ID])]
491
+ update_values = [_combine_chunks(rows_chunk)]
409
492
  for col in columns:
410
493
  update_fields.append(rows.field(col))
411
494
  update_values.append(_combine_chunks(rows[col]))
@@ -424,8 +507,14 @@ class Table:
424
507
 
425
508
  Row IDs are specified using a special field (named "$row_id" of uint64 type).
426
509
  """
510
+ if self._imports_table:
511
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
512
+ try:
513
+ rows_chunk = rows[INTERNAL_ROW_ID]
514
+ except KeyError:
515
+ raise errors.MissingRowIdColumn
427
516
  delete_rows_rb = pa.record_batch(schema=pa.schema([(INTERNAL_ROW_ID, pa.uint64())]),
428
- data=[_combine_chunks(rows[INTERNAL_ROW_ID])])
517
+ data=[_combine_chunks(rows_chunk)])
429
518
 
430
519
  serialized_slices = util.iter_serialized_slices(delete_rows_rb, MAX_ROWS_PER_BATCH)
431
520
  for slice in serialized_slices:
@@ -1,11 +1,15 @@
1
+ import functools
1
2
  import itertools
3
+ import operator
2
4
 
3
5
  import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+ import pytest
4
8
 
5
9
  from .util import prepare_data
6
10
 
7
11
 
8
- def test_nested(session, clean_bucket_name):
12
+ def test_nested_select(session, clean_bucket_name):
9
13
  columns = pa.schema([
10
14
  ('l', pa.list_(pa.int8())),
11
15
  ('m', pa.map_(pa.utf8(), pa.float64())),
@@ -26,3 +30,73 @@ def test_nested(session, clean_bucket_name):
26
30
  for cols in itertools.permutations(names, n):
27
31
  actual = pa.Table.from_batches(t.select(columns=cols))
28
32
  assert actual == expected.select(cols)
33
+
34
+
35
+ def test_nested_filter(session, clean_bucket_name):
36
+ columns = pa.schema([
37
+ ('x', pa.int64()),
38
+ ('l', pa.list_(pa.int8())),
39
+ ('y', pa.int64()),
40
+ ('m', pa.map_(pa.utf8(), pa.float64())),
41
+ ('z', pa.int64()),
42
+ ('s', pa.struct([('x', pa.int16()), ('y', pa.int32())])),
43
+ ('w', pa.int64()),
44
+ ])
45
+ expected = pa.table(schema=columns, data=[
46
+ [1, 2, 3, None],
47
+ [[1], [], [2, 3], None],
48
+ [1, 2, None, 3],
49
+ [None, {'a': 2.5}, {'b': 0.25, 'c': 0.025}, {}],
50
+ [1, None, 2, 3],
51
+ [{'x': 1, 'y': None}, None, {'x': 2, 'y': 3}, {'x': None, 'y': 4}],
52
+ [None, 1, 2, 3],
53
+ ])
54
+
55
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
56
+ actual = pa.Table.from_batches(t.select())
57
+ assert actual == expected
58
+
59
+ names = list('xyzw')
60
+ for n in range(1, len(names) + 1):
61
+ for cols in itertools.permutations(names, n):
62
+ ibis_predicate = functools.reduce(
63
+ operator.and_,
64
+ (t[col] > 2 for col in cols))
65
+ actual = pa.Table.from_batches(t.select(predicate=ibis_predicate), t.arrow_schema)
66
+
67
+ arrow_predicate = functools.reduce(
68
+ operator.and_,
69
+ (pc.field(col) > 2 for col in cols))
70
+ assert actual == expected.filter(arrow_predicate)
71
+
72
+
73
+ def test_nested_unsupported_filter(session, clean_bucket_name):
74
+ columns = pa.schema([
75
+ ('x', pa.int64()),
76
+ ('l', pa.list_(pa.int8())),
77
+ ('y', pa.int64()),
78
+ ('m', pa.map_(pa.utf8(), pa.float64())),
79
+ ('z', pa.int64()),
80
+ ('s', pa.struct([('x', pa.int16()), ('y', pa.int32())])),
81
+ ('w', pa.int64()),
82
+ ])
83
+ expected = pa.table(schema=columns, data=[
84
+ [1, 2, 3, None],
85
+ [[1], [], [2, 3], None],
86
+ [1, 2, None, 3],
87
+ [None, {'a': 2.5}, {'b': 0.25, 'c': 0.025}, {}],
88
+ [1, None, 2, 3],
89
+ [{'x': 1, 'y': None}, None, {'x': 2, 'y': 3}, {'x': None, 'y': 4}],
90
+ [None, 1, 2, 3],
91
+ ])
92
+
93
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
94
+
95
+ with pytest.raises(NotImplementedError):
96
+ list(t.select(predicate=(t['l'].isnull())))
97
+
98
+ with pytest.raises(NotImplementedError):
99
+ list(t.select(predicate=(t['m'].isnull())))
100
+
101
+ with pytest.raises(NotImplementedError):
102
+ list(t.select(predicate=(t['s'].isnull())))
@@ -7,6 +7,7 @@ import time
7
7
  from contextlib import closing
8
8
  from tempfile import NamedTemporaryFile
9
9
 
10
+ import ibis
10
11
  import pyarrow as pa
11
12
  import pyarrow.compute as pc
12
13
  import pyarrow.parquet as pq
@@ -71,6 +72,16 @@ def test_tables(session, clean_bucket_name):
71
72
  }
72
73
 
73
74
 
75
+ def test_insert_wide_row(session, clean_bucket_name):
76
+ columns = pa.schema([pa.field(f's{i}', pa.utf8()) for i in range(500)])
77
+ data = [['a' * 10**4] for i in range(500)]
78
+ expected = pa.table(schema=columns, data=data)
79
+
80
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
81
+ actual = pa.Table.from_batches(t.select())
82
+ assert actual == expected
83
+
84
+
74
85
  def test_exists(session, clean_bucket_name):
75
86
  with session.transaction() as tx:
76
87
  s = tx.bucket(clean_bucket_name).create_schema('s1')
@@ -205,46 +216,47 @@ def test_types(session, clean_bucket_name):
205
216
  [dt.datetime(2024, 4, 10, 12, 34, 56, 789789), dt.datetime(2025, 4, 10, 12, 34, 56, 789789), dt.datetime(2026, 4, 10, 12, 34, 56, 789789)],
206
217
  ])
207
218
 
208
- with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
219
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as table:
209
220
  def select(predicate):
210
- return pa.Table.from_batches(t.select(predicate=predicate))
221
+ return pa.Table.from_batches(table.select(predicate=predicate))
211
222
 
212
223
  assert select(None) == expected
213
- assert select(t['tb'] == False) == expected.filter(pc.field('tb') == False) # noqa: E712
214
- assert select(t['a1'] == 2) == expected.filter(pc.field('a1') == 2)
215
- assert select(t['a2'] == 2000) == expected.filter(pc.field('a2') == 2000)
216
- assert select(t['a4'] == 222111122) == expected.filter(pc.field('a4') == 222111122)
217
- assert select(t['b'] == 1.5) == expected.filter(pc.field('b') == 1.5)
218
- assert select(t['s'] == "v") == expected.filter(pc.field('s') == "v")
219
- assert select(t['d'] == 231.15) == expected.filter(pc.field('d') == 231.15)
220
- assert select(t['bin'] == b"\x01\x02") == expected.filter(pc.field('bin') == b"\x01\x02")
224
+ for t in [table, ibis._]:
225
+ assert select(t['tb'] == False) == expected.filter(pc.field('tb') == False) # noqa: E712
226
+ assert select(t['a1'] == 2) == expected.filter(pc.field('a1') == 2)
227
+ assert select(t['a2'] == 2000) == expected.filter(pc.field('a2') == 2000)
228
+ assert select(t['a4'] == 222111122) == expected.filter(pc.field('a4') == 222111122)
229
+ assert select(t['b'] == 1.5) == expected.filter(pc.field('b') == 1.5)
230
+ assert select(t['s'] == "v") == expected.filter(pc.field('s') == "v")
231
+ assert select(t['d'] == 231.15) == expected.filter(pc.field('d') == 231.15)
232
+ assert select(t['bin'] == b"\x01\x02") == expected.filter(pc.field('bin') == b"\x01\x02")
221
233
 
222
- date_literal = dt.date(2024, 4, 10)
223
- assert select(t['date'] == date_literal) == expected.filter(pc.field('date') == date_literal)
234
+ date_literal = dt.date(2024, 4, 10)
235
+ assert select(t['date'] == date_literal) == expected.filter(pc.field('date') == date_literal)
224
236
 
225
- time_literal = dt.time(12, 34, 56)
226
- assert select(t['t0'] == time_literal) == expected.filter(pc.field('t0') == time_literal)
237
+ time_literal = dt.time(12, 34, 56)
238
+ assert select(t['t0'] == time_literal) == expected.filter(pc.field('t0') == time_literal)
227
239
 
228
- time_literal = dt.time(12, 34, 56, 789000)
229
- assert select(t['t3'] == time_literal) == expected.filter(pc.field('t3') == time_literal)
240
+ time_literal = dt.time(12, 34, 56, 789000)
241
+ assert select(t['t3'] == time_literal) == expected.filter(pc.field('t3') == time_literal)
230
242
 
231
- time_literal = dt.time(12, 34, 56, 789789)
232
- assert select(t['t6'] == time_literal) == expected.filter(pc.field('t6') == time_literal)
243
+ time_literal = dt.time(12, 34, 56, 789789)
244
+ assert select(t['t6'] == time_literal) == expected.filter(pc.field('t6') == time_literal)
233
245
 
234
- time_literal = dt.time(12, 34, 56, 789789)
235
- assert select(t['t9'] == time_literal) == expected.filter(pc.field('t9') == time_literal)
246
+ time_literal = dt.time(12, 34, 56, 789789)
247
+ assert select(t['t9'] == time_literal) == expected.filter(pc.field('t9') == time_literal)
236
248
 
237
- ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56)
238
- assert select(t['ts0'] == ts_literal) == expected.filter(pc.field('ts0') == ts_literal)
249
+ ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56)
250
+ assert select(t['ts0'] == ts_literal) == expected.filter(pc.field('ts0') == ts_literal)
239
251
 
240
- ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56, 789000)
241
- assert select(t['ts3'] == ts_literal) == expected.filter(pc.field('ts3') == ts_literal)
252
+ ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56, 789000)
253
+ assert select(t['ts3'] == ts_literal) == expected.filter(pc.field('ts3') == ts_literal)
242
254
 
243
- ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56, 789789)
244
- assert select(t['ts6'] == ts_literal) == expected.filter(pc.field('ts6') == ts_literal)
255
+ ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56, 789789)
256
+ assert select(t['ts6'] == ts_literal) == expected.filter(pc.field('ts6') == ts_literal)
245
257
 
246
- ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56, 789789)
247
- assert select(t['ts9'] == ts_literal) == expected.filter(pc.field('ts9') == ts_literal)
258
+ ts_literal = dt.datetime(2024, 4, 10, 12, 34, 56, 789789)
259
+ assert select(t['ts9'] == ts_literal) == expected.filter(pc.field('ts9') == ts_literal)
248
260
 
249
261
 
250
262
  def test_filters(session, clean_bucket_name):
@@ -260,57 +272,63 @@ def test_filters(session, clean_bucket_name):
260
272
  ['a', 'bb', 'ccc', None, 'xyz'],
261
273
  ])
262
274
 
263
- with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
275
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as table:
264
276
  def select(predicate):
265
- return pa.Table.from_batches(t.select(predicate=predicate), t.arrow_schema)
277
+ return pa.Table.from_batches(table.select(predicate=predicate), table.arrow_schema)
266
278
 
267
279
  assert select(None) == expected
268
-
269
- assert select(t['a'] > 222) == expected.filter(pc.field('a') > 222)
270
- assert select(t['a'] < 222) == expected.filter(pc.field('a') < 222)
271
- assert select(t['a'] == 222) == expected.filter(pc.field('a') == 222)
272
- assert select(t['a'] != 222) == expected.filter(pc.field('a') != 222)
273
- assert select(t['a'] <= 222) == expected.filter(pc.field('a') <= 222)
274
- assert select(t['a'] >= 222) == expected.filter(pc.field('a') >= 222)
275
-
276
- assert select(t['b'] > 1.5) == expected.filter(pc.field('b') > 1.5)
277
- assert select(t['b'] < 1.5) == expected.filter(pc.field('b') < 1.5)
278
- assert select(t['b'] == 1.5) == expected.filter(pc.field('b') == 1.5)
279
- assert select(t['b'] != 1.5) == expected.filter(pc.field('b') != 1.5)
280
- assert select(t['b'] <= 1.5) == expected.filter(pc.field('b') <= 1.5)
281
- assert select(t['b'] >= 1.5) == expected.filter(pc.field('b') >= 1.5)
282
-
283
- assert select(t['s'] > 'bb') == expected.filter(pc.field('s') > 'bb')
284
- assert select(t['s'] < 'bb') == expected.filter(pc.field('s') < 'bb')
285
- assert select(t['s'] == 'bb') == expected.filter(pc.field('s') == 'bb')
286
- assert select(t['s'] != 'bb') == expected.filter(pc.field('s') != 'bb')
287
- assert select(t['s'] <= 'bb') == expected.filter(pc.field('s') <= 'bb')
288
- assert select(t['s'] >= 'bb') == expected.filter(pc.field('s') >= 'bb')
289
-
290
- assert select((t['a'] > 111) & (t['b'] > 0) & (t['s'] < 'ccc')) == expected.filter((pc.field('a') > 111) & (pc.field('b') > 0) & (pc.field('s') < 'ccc'))
291
- assert select((t['a'] > 111) & (t['b'] < 2.5)) == expected.filter((pc.field('a') > 111) & (pc.field('b') < 2.5))
292
- assert select((t['a'] > 111) & (t['a'] < 333)) == expected.filter((pc.field('a') > 111) & (pc.field('a') < 333))
293
-
294
- assert select((t['a'] > 111) | (t['a'] < 333)) == expected.filter((pc.field('a') > 111) | (pc.field('a') < 333))
295
- assert select(((t['a'] > 111) | (t['a'] < 333)) & (t['b'] < 2.5)) == expected.filter(((pc.field('a') > 111) | (pc.field('a') < 333)) & (pc.field('b') < 2.5))
296
- with pytest.raises(NotImplementedError):
297
- assert select((t['a'] > 111) | (t['b'] > 0) | (t['s'] < 'ccc')) == expected.filter((pc.field('a') > 111) | (pc.field('b') > 0) | (pc.field('s') < 'ccc'))
298
- assert select((t['a'] > 111) | (t['a'] < 333) | (t['a'] == 777)) == expected.filter((pc.field('a') > 111) | (pc.field('a') < 333) | (pc.field('a') == 777))
299
-
300
- assert select(t['s'].isnull()) == expected.filter(pc.field('s').is_null())
301
- assert select((t['s'].isnull()) | (t['s'] == 'bb')) == expected.filter((pc.field('s').is_null()) | (pc.field('s') == 'bb'))
302
- assert select((t['s'].isnull()) & (t['b'] == 3.5)) == expected.filter((pc.field('s').is_null()) & (pc.field('b') == 3.5))
303
-
304
- assert select(~t['s'].isnull()) == expected.filter(~pc.field('s').is_null())
305
- assert select(t['s'].contains('b')) == expected.filter(pc.field('s') == 'bb')
306
- assert select(t['s'].contains('y')) == expected.filter(pc.field('s') == 'xyz')
307
-
308
- assert select(t['a'].isin([555])) == expected.filter(pc.field('a').isin([555]))
309
- assert select(t['a'].isin([111, 222, 999])) == expected.filter(pc.field('a').isin([111, 222, 999]))
310
- assert select((t['a'] == 111) | t['a'].isin([333, 444]) | (t['a'] > 600)) == expected.filter((pc.field('a') == 111) | pc.field('a').isin([333, 444]) | (pc.field('a') > 600))
311
-
312
- with pytest.raises(NotImplementedError):
313
- select(t['a'].isin([]))
280
+ assert select(True) == expected
281
+ assert select(False) == pa.Table.from_batches([], schema=columns)
282
+
283
+ for t in [table, ibis._]:
284
+ assert select(t['a'].between(222, 444)) == expected.filter((pc.field('a') >= 222) & (pc.field('a') <= 444))
285
+ assert select((t['a'].between(222, 444)) & (t['b'] > 2.5)) == expected.filter((pc.field('a') >= 222) & (pc.field('a') <= 444) & (pc.field('b') > 2.5))
286
+
287
+ assert select(t['a'] > 222) == expected.filter(pc.field('a') > 222)
288
+ assert select(t['a'] < 222) == expected.filter(pc.field('a') < 222)
289
+ assert select(t['a'] == 222) == expected.filter(pc.field('a') == 222)
290
+ assert select(t['a'] != 222) == expected.filter(pc.field('a') != 222)
291
+ assert select(t['a'] <= 222) == expected.filter(pc.field('a') <= 222)
292
+ assert select(t['a'] >= 222) == expected.filter(pc.field('a') >= 222)
293
+
294
+ assert select(t['b'] > 1.5) == expected.filter(pc.field('b') > 1.5)
295
+ assert select(t['b'] < 1.5) == expected.filter(pc.field('b') < 1.5)
296
+ assert select(t['b'] == 1.5) == expected.filter(pc.field('b') == 1.5)
297
+ assert select(t['b'] != 1.5) == expected.filter(pc.field('b') != 1.5)
298
+ assert select(t['b'] <= 1.5) == expected.filter(pc.field('b') <= 1.5)
299
+ assert select(t['b'] >= 1.5) == expected.filter(pc.field('b') >= 1.5)
300
+
301
+ assert select(t['s'] > 'bb') == expected.filter(pc.field('s') > 'bb')
302
+ assert select(t['s'] < 'bb') == expected.filter(pc.field('s') < 'bb')
303
+ assert select(t['s'] == 'bb') == expected.filter(pc.field('s') == 'bb')
304
+ assert select(t['s'] != 'bb') == expected.filter(pc.field('s') != 'bb')
305
+ assert select(t['s'] <= 'bb') == expected.filter(pc.field('s') <= 'bb')
306
+ assert select(t['s'] >= 'bb') == expected.filter(pc.field('s') >= 'bb')
307
+
308
+ assert select((t['a'] > 111) & (t['b'] > 0) & (t['s'] < 'ccc')) == expected.filter((pc.field('a') > 111) & (pc.field('b') > 0) & (pc.field('s') < 'ccc'))
309
+ assert select((t['a'] > 111) & (t['b'] < 2.5)) == expected.filter((pc.field('a') > 111) & (pc.field('b') < 2.5))
310
+ assert select((t['a'] > 111) & (t['a'] < 333)) == expected.filter((pc.field('a') > 111) & (pc.field('a') < 333))
311
+
312
+ assert select((t['a'] > 111) | (t['a'] < 333)) == expected.filter((pc.field('a') > 111) | (pc.field('a') < 333))
313
+ assert select(((t['a'] > 111) | (t['a'] < 333)) & (t['b'] < 2.5)) == expected.filter(((pc.field('a') > 111) | (pc.field('a') < 333)) & (pc.field('b') < 2.5))
314
+ with pytest.raises(NotImplementedError):
315
+ assert select((t['a'] > 111) | (t['b'] > 0) | (t['s'] < 'ccc')) == expected.filter((pc.field('a') > 111) | (pc.field('b') > 0) | (pc.field('s') < 'ccc'))
316
+ assert select((t['a'] > 111) | (t['a'] < 333) | (t['a'] == 777)) == expected.filter((pc.field('a') > 111) | (pc.field('a') < 333) | (pc.field('a') == 777))
317
+
318
+ assert select(t['s'].isnull()) == expected.filter(pc.field('s').is_null())
319
+ assert select((t['s'].isnull()) | (t['s'] == 'bb')) == expected.filter((pc.field('s').is_null()) | (pc.field('s') == 'bb'))
320
+ assert select((t['s'].isnull()) & (t['b'] == 3.5)) == expected.filter((pc.field('s').is_null()) & (pc.field('b') == 3.5))
321
+
322
+ assert select(~t['s'].isnull()) == expected.filter(~pc.field('s').is_null())
323
+ assert select(t['s'].contains('b')) == expected.filter(pc.field('s') == 'bb')
324
+ assert select(t['s'].contains('y')) == expected.filter(pc.field('s') == 'xyz')
325
+
326
+ assert select(t['a'].isin([555])) == expected.filter(pc.field('a').isin([555]))
327
+ assert select(t['a'].isin([111, 222, 999])) == expected.filter(pc.field('a').isin([111, 222, 999]))
328
+ assert select((t['a'] == 111) | t['a'].isin([333, 444]) | (t['a'] > 600)) == expected.filter((pc.field('a') == 111) | pc.field('a').isin([333, 444]) | (pc.field('a') > 600))
329
+
330
+ with pytest.raises(NotImplementedError):
331
+ select(t['a'].isin([]))
314
332
 
315
333
 
316
334
  def test_parquet_export(session, clean_bucket_name):
@@ -331,7 +349,8 @@ def test_parquet_export(session, clean_bucket_name):
331
349
  ['a', 'b'],
332
350
  ])
333
351
  expected = pa.Table.from_batches([rb])
334
- t.insert(rb)
352
+ rb = t.insert(rb)
353
+ assert rb.to_pylist() == [0, 1]
335
354
  actual = pa.Table.from_batches(t.select())
336
355
  assert actual == expected
337
356
 
vastdb/tests/util.py CHANGED
@@ -9,7 +9,9 @@ def prepare_data(session, clean_bucket_name, schema_name, table_name, arrow_tabl
9
9
  with session.transaction() as tx:
10
10
  s = tx.bucket(clean_bucket_name).create_schema(schema_name)
11
11
  t = s.create_table(table_name, arrow_table.schema)
12
- t.insert(arrow_table)
12
+ row_ids_array = t.insert(arrow_table)
13
+ row_ids = row_ids_array.to_pylist()
14
+ assert row_ids == list(range(arrow_table.num_rows))
13
15
  yield t
14
16
  t.drop()
15
17
  s.drop()
vastdb/transaction.py CHANGED
@@ -63,7 +63,7 @@ class Transaction:
63
63
  except botocore.exceptions.ClientError as e:
64
64
  log.warning("res: %s", e.response)
65
65
  if e.response['Error']['Code'] == '404':
66
- raise errors.MissingBucket(name)
66
+ raise errors.MissingBucket(name) from e
67
67
  raise
68
68
  return bucket.Bucket(name, self)
69
69
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vastdb
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: VAST Data SDK
5
5
  Home-page: https://github.com/vast-data/vastdb_sdk
6
6
  Author: VAST DATA
@@ -25,6 +25,7 @@ Requires-Dist: ibis-framework ==8.0.0
25
25
  Requires-Dist: pyarrow
26
26
  Requires-Dist: requests
27
27
  Requires-Dist: xmltodict
28
+ Requires-Dist: backoff ==2.2.1
28
29
 
29
30
 
30
31
  `vastdb` is a Python-based SDK designed for interacting
@@ -151,27 +151,27 @@ vast_flatbuf/tabular/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
151
151
  vastdb/__init__.py,sha256=cMJtZuJ0IL9aKyM3DUWqTCzuP1H1MXXVivKKE1-q0DY,292
152
152
  vastdb/bucket.py,sha256=4rPEm9qlPTg7ccWO6VGmd4LKb8w-BDhJYwzXGjn03sc,3566
153
153
  vastdb/conftest.py,sha256=pKpo_46Vq4QHzTDQAFxasrVhnZ2V2L-y6IMLxojxaFM,2132
154
- vastdb/errors.py,sha256=vKWoq1yXrHyafMWwJgW_sQkSxQYxlI1JbTVCLz5Xi9Y,3793
155
- vastdb/internal_commands.py,sha256=ZD2YXYvZ3lJWYzZU0oHtv8G3lNtDQUF0e8yg8813Xt4,99575
154
+ vastdb/errors.py,sha256=fj8IlPnGi1lbJWIl1-8MSjLavL9bYQ-YUoboWbXCo54,4047
155
+ vastdb/internal_commands.py,sha256=kIdkLHabW8r4-GSygGl1Gdrr4puxD79WPO8Jkx8aszg,98490
156
156
  vastdb/schema.py,sha256=ql4TPB1W_FQ_BHov3CKHI8JX3krXMlcKWz7dTrjpQ1w,3346
157
- vastdb/session.py,sha256=ciYS8Je2cRpuaAEE6Wjk79VsW0KAPdnRB2cqfxFCjis,2323
158
- vastdb/table.py,sha256=xnSTWUUa0QHzXC5MUQWsGT1fsG8yAgMLy3nrgSH4j5Q,25661
159
- vastdb/transaction.py,sha256=g8YTcYnsNPIhB2udbHyT5RIFB5kHnBLJcvV2CWRICwI,2845
157
+ vastdb/session.py,sha256=UTaz1Fh3u71Bnay2r6IyCHNMDrAszbzjnwylPURzhsk,2603
158
+ vastdb/table.py,sha256=1ikj6toITImFowI2WHiimmqSiObmTfAohCdWC89q71Y,30031
159
+ vastdb/transaction.py,sha256=u4pJBLooZQ_YGjsRgEWVL6RPAlt3lgm5oOpPHzPcayM,2852
160
160
  vastdb/util.py,sha256=rs7nLL2Qz-OVEZDSVIqAvS-uETMq-zxQs5jBksB5-JA,4276
161
161
  vastdb/bench/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
162
162
  vastdb/bench/test_perf.py,sha256=iHE3E60fvyU5SBDHPi4h03Dj6QcY6VI9l9mMhgNMtPc,1117
163
163
  vastdb/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
164
164
  vastdb/tests/test_duckdb.py,sha256=KDuv4PrjGEwChCGHG36xNT2JiFlBOt6K3DQ3L06Kq-A,1913
165
165
  vastdb/tests/test_imports.py,sha256=48kbJKsa_MrEXcBYQUbUDr1e9wzjG4FHQ7C3wUEQfXA,5705
166
- vastdb/tests/test_nested.py,sha256=3kejEvtSqV0LrUgb1QglRjrlxnKI4_AXTFw2nE7Q520,951
166
+ vastdb/tests/test_nested.py,sha256=FHYMmaKYvqVh0NvsocUFLr2LDVlSfXZYgqUSopWOSM0,3512
167
167
  vastdb/tests/test_projections.py,sha256=_cDNfD5zTwbCXLk6uGpPUWGN0P-4HElu5OjubWu-Jg0,1255
168
168
  vastdb/tests/test_sanity.py,sha256=ixx0QPo73hLHjAa7bByFXjS1XST0WvmSwLEpgnHh_JY,2960
169
169
  vastdb/tests/test_schemas.py,sha256=qoHTLX51D-0S4bMxdCpRh9gaYQd-BkZdT_agGOwFwTM,1739
170
- vastdb/tests/test_tables.py,sha256=joeEQ30TwKBQc-2N_qGIdviZVnQr4rs6thlNsy5s_og,26672
170
+ vastdb/tests/test_tables.py,sha256=Q3N5P-7mOPVcfAFEfpAzomqkyCJ5gKZmfE4SUW5jehk,27859
171
171
  vastdb/tests/test_util.py,sha256=owRAU3TCKMq-kz54NRdA5wX2O_bZIHqG5ucUR77jm5k,1046
172
- vastdb/tests/util.py,sha256=NaCzKymEGy1xuiyMxyt2_0frKVfVk9iGrFwLf3GHjTI,435
173
- vastdb-0.1.3.dist-info/LICENSE,sha256=obffan7LYrq7hLHNrY7vHcn2pKUTBUYXMKu-VOAvDxU,11333
174
- vastdb-0.1.3.dist-info/METADATA,sha256=3h3JttUxw9oMMsxV_CVG_LMYwhgegsS9-b4gZkihrM0,1319
175
- vastdb-0.1.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
176
- vastdb-0.1.3.dist-info/top_level.txt,sha256=Vsj2MKtlhPg0J4so64slQtnwjhgoPmJgcG-6YcVAwVc,20
177
- vastdb-0.1.3.dist-info/RECORD,,
172
+ vastdb/tests/util.py,sha256=dpRJYbboDnlqL4qIdvScpp8--5fxRUBIcIYitrfcj9o,555
173
+ vastdb-0.1.5.dist-info/LICENSE,sha256=obffan7LYrq7hLHNrY7vHcn2pKUTBUYXMKu-VOAvDxU,11333
174
+ vastdb-0.1.5.dist-info/METADATA,sha256=NJzrnkyfPs4lliFamaEdJy2elLYLzYJtlCxEMRSiLtg,1350
175
+ vastdb-0.1.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
176
+ vastdb-0.1.5.dist-info/top_level.txt,sha256=Vsj2MKtlhPg0J4so64slQtnwjhgoPmJgcG-6YcVAwVc,20
177
+ vastdb-0.1.5.dist-info/RECORD,,
File without changes