vastdb 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/table.py CHANGED
@@ -1,3 +1,5 @@
1
+ """VAST Database table."""
2
+
1
3
  import concurrent.futures
2
4
  import logging
3
5
  import os
@@ -5,25 +7,32 @@ import queue
5
7
  from dataclasses import dataclass, field
6
8
  from math import ceil
7
9
  from threading import Event
8
- from typing import Dict, List, Optional, Tuple, Union
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
11
 
12
+ import backoff
10
13
  import ibis
11
14
  import pyarrow as pa
12
15
 
13
- from . import errors, internal_commands, schema
16
+ from . import errors, internal_commands, schema, util
14
17
 
15
18
  log = logging.getLogger(__name__)
16
19
 
17
20
 
18
21
  INTERNAL_ROW_ID = "$row_id"
22
+ INTERNAL_ROW_ID_FIELD = pa.field(INTERNAL_ROW_ID, pa.uint64())
23
+
19
24
  MAX_ROWS_PER_BATCH = 512 * 1024
20
25
  # for insert we need a smaller limit due to response amplification
21
26
  # for example insert of 512k uint8 result in 512k*8bytes response since row_ids are uint64
22
27
  MAX_INSERT_ROWS_PER_PATCH = 512 * 1024
28
+ # in case insert has TooWideRow - need to insert in smaller batches - each cell could contain up to 128K, and our wire is limited to 5MB
29
+ MAX_COLUMN_IN_BATCH = int(5 * 1024 / 128)
23
30
 
24
31
 
25
32
  @dataclass
26
33
  class TableStats:
34
+ """Table-related information."""
35
+
27
36
  num_rows: int
28
37
  size_in_bytes: int
29
38
  is_external_rowid_alloc: bool = False
@@ -32,6 +41,8 @@ class TableStats:
32
41
 
33
42
  @dataclass
34
43
  class QueryConfig:
44
+ """Query execution configiration."""
45
+
35
46
  num_sub_splits: int = 4
36
47
  num_splits: int = 1
37
48
  data_endpoints: Optional[List[str]] = None
@@ -40,15 +51,22 @@ class QueryConfig:
40
51
  use_semi_sorted_projections: bool = True
41
52
  rows_per_split: int = 4000000
42
53
  query_id: str = ""
54
+ max_slowdown_retry: int = 10
55
+ backoff_func: Any = field(default=backoff.on_exception(backoff.expo, errors.Slowdown, max_tries=max_slowdown_retry))
43
56
 
44
57
 
45
58
  @dataclass
46
59
  class ImportConfig:
60
+ """Import execution configiration."""
61
+
47
62
  import_concurrency: int = 2
48
63
 
49
64
 
50
- class SelectSplitState():
65
+ class SelectSplitState:
66
+ """State of a specific query split execution."""
67
+
51
68
  def __init__(self, query_data_request, table: "Table", split_id: int, config: QueryConfig) -> None:
69
+ """Initialize query split state."""
52
70
  self.split_id = split_id
53
71
  self.subsplits_state = {i: 0 for i in range(config.num_sub_splits)}
54
72
  self.config = config
@@ -56,8 +74,13 @@ class SelectSplitState():
56
74
  self.table = table
57
75
 
58
76
  def batches(self, api: internal_commands.VastdbApi):
77
+ """Execute QueryData request, and yield parsed RecordBatch objects.
78
+
79
+ Can be called repeatedly, to allow pagination.
80
+ """
59
81
  while not self.done:
60
- response = api.query_data(
82
+ query_with_backoff = self.config.backoff_func(api.query_data)
83
+ response = query_with_backoff(
61
84
  bucket=self.table.bucket.name,
62
85
  schema=self.table.schema.name,
63
86
  table=self.table.name,
@@ -68,7 +91,8 @@ class SelectSplitState():
68
91
  txid=self.table.tx.txid,
69
92
  limit_rows=self.config.limit_rows_per_sub_split,
70
93
  sub_split_start_row_ids=self.subsplits_state.items(),
71
- enable_sorted_projections=self.config.use_semi_sorted_projections)
94
+ enable_sorted_projections=self.config.use_semi_sorted_projections,
95
+ query_imports_table=self.table._imports_table)
72
96
  pages_iter = internal_commands.parse_query_data_response(
73
97
  conn=response.raw,
74
98
  schema=self.query_data_request.response_schema,
@@ -82,19 +106,24 @@ class SelectSplitState():
82
106
 
83
107
  @property
84
108
  def done(self):
109
+ """Returns true iff the pagination over."""
85
110
  return all(row_id == internal_commands.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
86
111
 
87
112
 
88
113
  @dataclass
89
114
  class Table:
115
+ """VAST Table."""
116
+
90
117
  name: str
91
118
  schema: "schema.Schema"
92
119
  handle: int
93
120
  stats: TableStats
94
- arrow_schema: pa.Schema = field(init=False, compare=False)
95
- _ibis_table: ibis.Schema = field(init=False, compare=False)
121
+ arrow_schema: pa.Schema = field(init=False, compare=False, repr=False)
122
+ _ibis_table: ibis.Schema = field(init=False, compare=False, repr=False)
123
+ _imports_table: bool
96
124
 
97
125
  def __post_init__(self):
126
+ """Also, load columns' metadata."""
98
127
  self.arrow_schema = self.columns()
99
128
 
100
129
  table_path = f'{self.schema.bucket.name}/{self.schema.name}/{self.name}'
@@ -102,21 +131,21 @@ class Table:
102
131
 
103
132
  @property
104
133
  def tx(self):
134
+ """Return transaction."""
105
135
  return self.schema.tx
106
136
 
107
137
  @property
108
138
  def bucket(self):
139
+ """Return bucket."""
109
140
  return self.schema.bucket
110
141
 
111
- def __repr__(self):
112
- return f"{type(self).__name__}(name={self.name})"
113
-
114
142
  def columns(self) -> pa.Schema:
143
+ """Return columns' metadata."""
115
144
  fields = []
116
145
  next_key = 0
117
146
  while True:
118
147
  cur_columns, next_key, is_truncated, _count = self.tx._rpc.api.list_columns(
119
- bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid)
148
+ bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid, list_imports_table=self._imports_table)
120
149
  fields.extend(cur_columns)
121
150
  if not is_truncated:
122
151
  break
@@ -125,6 +154,9 @@ class Table:
125
154
  return self.arrow_schema
126
155
 
127
156
  def projection(self, name: str) -> "Projection":
157
+ """Get a specific semi-sorted projection of this table."""
158
+ if self._imports_table:
159
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
128
160
  projs = self.projections(projection_name=name)
129
161
  if not projs:
130
162
  raise errors.MissingProjection(self.bucket.name, self.schema.name, self.name, name)
@@ -133,6 +165,9 @@ class Table:
133
165
  return projs[0]
134
166
 
135
167
  def projections(self, projection_name=None) -> List["Projection"]:
168
+ """List all semi-sorted projections of this table."""
169
+ if self._imports_table:
170
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
136
171
  projections = []
137
172
  next_key = 0
138
173
  name_prefix = projection_name if projection_name else ""
@@ -150,6 +185,12 @@ class Table:
150
185
  return [_parse_projection_info(projection, self) for projection in projections]
151
186
 
152
187
  def import_files(self, files_to_import: List[str], config: Optional[ImportConfig] = None) -> None:
188
+ """Import a list of Parquet files into this table.
189
+
190
+ The files must be on VAST S3 server and be accessible using current credentials.
191
+ """
192
+ if self._imports_table:
193
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
153
194
  source_files = {}
154
195
  for f in files_to_import:
155
196
  bucket_name, object_path = _parse_bucket_and_object_names(f)
@@ -158,6 +199,13 @@ class Table:
158
199
  self._execute_import(source_files, config=config)
159
200
 
160
201
  def import_partitioned_files(self, files_and_partitions: Dict[str, pa.RecordBatch], config: Optional[ImportConfig] = None) -> None:
202
+ """Import a list of Parquet files into this table.
203
+
204
+ The files must be on VAST S3 server and be accessible using current credentials.
205
+ Each file must have its own partition values defined as an Arrow RecordBatch.
206
+ """
207
+ if self._imports_table:
208
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
161
209
  source_files = {}
162
210
  for f, record_batch in files_and_partitions.items():
163
211
  bucket_name, object_path = _parse_bucket_and_object_names(f)
@@ -216,8 +264,10 @@ class Table:
216
264
  # ThreadPoolExecutor will be joined at the end of the context
217
265
 
218
266
  def get_stats(self) -> TableStats:
267
+ """Get the statistics of this table."""
219
268
  stats_tuple = self.tx._rpc.api.get_table_stats(
220
- bucket=self.bucket.name, schema=self.schema.name, name=self.name, txid=self.tx.txid)
269
+ bucket=self.bucket.name, schema=self.schema.name, name=self.name, txid=self.tx.txid,
270
+ imports_table_stats=self._imports_table)
221
271
  return TableStats(**stats_tuple._asdict())
222
272
 
223
273
  def select(self, columns: Optional[List[str]] = None,
@@ -225,6 +275,14 @@ class Table:
225
275
  config: Optional[QueryConfig] = None,
226
276
  *,
227
277
  internal_row_id: bool = False) -> pa.RecordBatchReader:
278
+ """Execute a query over this table.
279
+
280
+ To read a subset of the columns, specify their names via `columns` argument. Otherwise, all columns will be read.
281
+
282
+ In order to apply a filter, a predicate can be specified. See https://github.com/vast-data/vastdb_sdk/blob/main/README.md#filters-and-projections for more details.
283
+
284
+ Query-execution configuration options can be specified via the optional `config` argument.
285
+ """
228
286
  if config is None:
229
287
  config = QueryConfig()
230
288
 
@@ -241,11 +299,17 @@ class Table:
241
299
 
242
300
  query_schema = self.arrow_schema
243
301
  if internal_row_id:
244
- queried_fields = [pa.field(INTERNAL_ROW_ID, pa.uint64())]
302
+ queried_fields = [INTERNAL_ROW_ID_FIELD]
245
303
  queried_fields.extend(column for column in self.arrow_schema)
246
304
  query_schema = pa.schema(queried_fields)
247
305
  columns.append(INTERNAL_ROW_ID)
248
306
 
307
+ if predicate is True:
308
+ predicate = None
309
+ if predicate is False:
310
+ response_schema = internal_commands.get_response_schema(schema=query_schema, field_names=columns)
311
+ return pa.RecordBatchReader.from_batches(response_schema, [])
312
+
249
313
  query_data_request = internal_commands.build_query_data_request(
250
314
  schema=query_schema,
251
315
  predicate=predicate,
@@ -335,82 +399,176 @@ class Table:
335
399
 
336
400
  return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
337
401
 
338
- def _combine_chunks(self, col):
339
- if hasattr(col, "combine_chunks"):
340
- return col.combine_chunks()
341
- else:
342
- return col
343
-
344
- def insert(self, rows: pa.RecordBatch) -> pa.RecordBatch:
345
- serialized_slices = self.tx._rpc.api._record_batch_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
346
- for slice in serialized_slices:
347
- self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
348
- txid=self.tx.txid)
402
+ def insert_in_column_batches(self, rows: pa.RecordBatch):
403
+ """Split the RecordBatch into max_columns that can be inserted in single RPC.
404
+
405
+ Insert first MAX_COLUMN_IN_BATCH columns and get the row_ids. Then loop on the rest of the columns and
406
+ update in groups of MAX_COLUMN_IN_BATCH.
407
+ """
408
+ column_record_batch = pa.RecordBatch.from_arrays([_combine_chunks(rows.column(i)) for i in range(0, MAX_COLUMN_IN_BATCH)],
409
+ schema=pa.schema([rows.schema.field(i) for i in range(0, MAX_COLUMN_IN_BATCH)]))
410
+ row_ids = self.insert(rows=column_record_batch) # type: ignore
411
+
412
+ columns_names = [field.name for field in rows.schema]
413
+ columns = list(rows.schema)
414
+ arrays = [_combine_chunks(rows.column(i)) for i in range(len(rows.schema))]
415
+ for start in range(MAX_COLUMN_IN_BATCH, len(rows.schema), MAX_COLUMN_IN_BATCH):
416
+ end = start + MAX_COLUMN_IN_BATCH if start + MAX_COLUMN_IN_BATCH < len(rows.schema) else len(rows.schema)
417
+ columns_name_chunk = columns_names[start:end]
418
+ columns_chunks = columns[start:end]
419
+ arrays_chunks = arrays[start:end]
420
+ columns_chunks.append(INTERNAL_ROW_ID_FIELD)
421
+ arrays_chunks.append(row_ids.to_pylist())
422
+ column_record_batch = pa.RecordBatch.from_arrays(arrays_chunks, schema=pa.schema(columns_chunks))
423
+ self.update(rows=column_record_batch, columns=columns_name_chunk)
424
+ return row_ids
425
+
426
+ def insert(self, rows: pa.RecordBatch):
427
+ """Insert a RecordBatch into this table."""
428
+ if self._imports_table:
429
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
430
+ try:
431
+ row_ids = []
432
+ serialized_slices = util.iter_serialized_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
433
+ for slice in serialized_slices:
434
+ res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
435
+ txid=self.tx.txid)
436
+ (batch,) = pa.RecordBatchStreamReader(res.raw)
437
+ row_ids.append(batch[INTERNAL_ROW_ID])
438
+ try:
439
+ self.tx._rpc.features.check_return_row_ids()
440
+ except errors.NotSupportedVersion:
441
+ return # type: ignore
442
+ return pa.chunked_array(row_ids)
443
+ except errors.TooWideRow:
444
+ self.tx._rpc.features.check_return_row_ids()
445
+ return self.insert_in_column_batches(rows)
349
446
 
350
447
  def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: Optional[List[str]] = None) -> None:
448
+ """Update a subset of cells in this table.
449
+
450
+ Row IDs are specified using a special field (named "$row_id" of uint64 type) - this function assume that this
451
+ special field is part of arguments.
452
+
453
+ A subset of columns to be updated can be specified via the `columns` argument.
454
+ """
455
+ if self._imports_table:
456
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
457
+ try:
458
+ rows_chunk = rows[INTERNAL_ROW_ID]
459
+ except KeyError:
460
+ raise errors.MissingRowIdColumn
351
461
  if columns is not None:
352
462
  update_fields = [(INTERNAL_ROW_ID, pa.uint64())]
353
- update_values = [self._combine_chunks(rows[INTERNAL_ROW_ID])]
463
+ update_values = [_combine_chunks(rows_chunk)]
354
464
  for col in columns:
355
465
  update_fields.append(rows.field(col))
356
- update_values.append(self._combine_chunks(rows[col]))
466
+ update_values.append(_combine_chunks(rows[col]))
357
467
 
358
468
  update_rows_rb = pa.record_batch(schema=pa.schema(update_fields), data=update_values)
359
469
  else:
360
470
  update_rows_rb = rows
361
471
 
362
- serialized_slices = self.tx._rpc.api._record_batch_slices(update_rows_rb, MAX_ROWS_PER_BATCH)
472
+ serialized_slices = util.iter_serialized_slices(update_rows_rb, MAX_ROWS_PER_BATCH)
363
473
  for slice in serialized_slices:
364
474
  self.tx._rpc.api.update_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
365
475
  txid=self.tx.txid)
366
476
 
367
477
  def delete(self, rows: Union[pa.RecordBatch, pa.Table]) -> None:
478
+ """Delete a subset of rows in this table.
479
+
480
+ Row IDs are specified using a special field (named "$row_id" of uint64 type).
481
+ """
482
+ if self._imports_table:
483
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
484
+ try:
485
+ rows_chunk = rows[INTERNAL_ROW_ID]
486
+ except KeyError:
487
+ raise errors.MissingRowIdColumn
368
488
  delete_rows_rb = pa.record_batch(schema=pa.schema([(INTERNAL_ROW_ID, pa.uint64())]),
369
- data=[self._combine_chunks(rows[INTERNAL_ROW_ID])])
489
+ data=[_combine_chunks(rows_chunk)])
370
490
 
371
- serialized_slices = self.tx._rpc.api._record_batch_slices(delete_rows_rb, MAX_ROWS_PER_BATCH)
491
+ serialized_slices = util.iter_serialized_slices(delete_rows_rb, MAX_ROWS_PER_BATCH)
372
492
  for slice in serialized_slices:
373
493
  self.tx._rpc.api.delete_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
374
- txid=self.tx.txid)
494
+ txid=self.tx.txid, delete_from_imports_table=self._imports_table)
375
495
 
376
496
  def drop(self) -> None:
377
- self.tx._rpc.api.drop_table(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid)
497
+ """Drop this table."""
498
+ self.tx._rpc.api.drop_table(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, remove_imports_table=self._imports_table)
378
499
  log.info("Dropped table: %s", self.name)
379
500
 
380
501
  def rename(self, new_name) -> None:
502
+ """Rename this table."""
503
+ if self._imports_table:
504
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
381
505
  self.tx._rpc.api.alter_table(
382
506
  self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, new_name=new_name)
383
507
  log.info("Renamed table from %s to %s ", self.name, new_name)
384
508
  self.name = new_name
385
509
 
386
510
  def add_column(self, new_column: pa.Schema) -> None:
511
+ """Add a new column."""
512
+ if self._imports_table:
513
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
387
514
  self.tx._rpc.api.add_columns(self.bucket.name, self.schema.name, self.name, new_column, txid=self.tx.txid)
388
515
  log.info("Added column(s): %s", new_column)
389
516
  self.arrow_schema = self.columns()
390
517
 
391
518
  def drop_column(self, column_to_drop: pa.Schema) -> None:
519
+ """Drop an existing column."""
520
+ if self._imports_table:
521
+ raise errors.NotSupported(self.bucket.name, self.schema.name, self.name)
522
+ if self._imports_table:
523
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
392
524
  self.tx._rpc.api.drop_columns(self.bucket.name, self.schema.name, self.name, column_to_drop, txid=self.tx.txid)
393
525
  log.info("Dropped column(s): %s", column_to_drop)
394
526
  self.arrow_schema = self.columns()
395
527
 
396
528
  def rename_column(self, current_column_name: str, new_column_name: str) -> None:
529
+ """Rename an existing column."""
530
+ if self._imports_table:
531
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
397
532
  self.tx._rpc.api.alter_column(self.bucket.name, self.schema.name, self.name, name=current_column_name,
398
533
  new_name=new_column_name, txid=self.tx.txid)
399
534
  log.info("Renamed column: %s to %s", current_column_name, new_column_name)
400
535
  self.arrow_schema = self.columns()
401
536
 
402
537
  def create_projection(self, projection_name: str, sorted_columns: List[str], unsorted_columns: List[str]) -> "Projection":
538
+ """Create a new semi-sorted projection."""
539
+ if self._imports_table:
540
+ raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
403
541
  columns = [(sorted_column, "Sorted") for sorted_column in sorted_columns] + [(unsorted_column, "Unorted") for unsorted_column in unsorted_columns]
404
542
  self.tx._rpc.api.create_projection(self.bucket.name, self.schema.name, self.name, projection_name, columns=columns, txid=self.tx.txid)
405
543
  log.info("Created projection: %s", projection_name)
406
544
  return self.projection(projection_name)
407
545
 
546
+ def create_imports_table(self, fail_if_exists=True) -> "Table":
547
+ """Create imports table."""
548
+ self.tx._rpc.features.check_imports_table()
549
+ empty_schema = pa.schema([])
550
+ self.tx._rpc.api.create_table(self.bucket.name, self.schema.name, self.name, empty_schema, txid=self.tx.txid,
551
+ create_imports_table=True)
552
+ log.info("Created imports table for table: %s", self.name)
553
+ return self.imports_table() # type: ignore[return-value]
554
+
555
+ def imports_table(self) -> Optional["Table"]:
556
+ """Get the imports table under of this table."""
557
+ self.tx._rpc.features.check_imports_table()
558
+ return Table(name=self.name, schema=self.schema, handle=int(self.handle), stats=self.stats, _imports_table=True)
559
+
408
560
  def __getitem__(self, col_name):
561
+ """Allow constructing ibis-like column expressions from this table.
562
+
563
+ It is useful for constructing expressions for predicate pushdown in `Table.select()` method.
564
+ """
409
565
  return self._ibis_table[col_name]
410
566
 
411
567
 
412
568
  @dataclass
413
569
  class Projection:
570
+ """VAST semi-sorted projection."""
571
+
414
572
  name: str
415
573
  table: Table
416
574
  handle: int
@@ -418,20 +576,21 @@ class Projection:
418
576
 
419
577
  @property
420
578
  def bucket(self):
579
+ """Return bucket."""
421
580
  return self.table.schema.bucket
422
581
 
423
582
  @property
424
583
  def schema(self):
584
+ """Return schema."""
425
585
  return self.table.schema
426
586
 
427
587
  @property
428
588
  def tx(self):
589
+ """Return transaction."""
429
590
  return self.table.schema.tx
430
591
 
431
- def __repr__(self):
432
- return f"{type(self).__name__}(name={self.name})"
433
-
434
592
  def columns(self) -> pa.Schema:
593
+ """Return this projections' columns as an Arrow schema."""
435
594
  columns = []
436
595
  next_key = 0
437
596
  while True:
@@ -447,12 +606,14 @@ class Projection:
447
606
  return self.arrow_schema
448
607
 
449
608
  def rename(self, new_name) -> None:
609
+ """Rename this projection."""
450
610
  self.tx._rpc.api.alter_projection(self.bucket.name, self.schema.name,
451
611
  self.table.name, self.name, txid=self.tx.txid, new_name=new_name)
452
612
  log.info("Renamed projection from %s to %s ", self.name, new_name)
453
613
  self.name = new_name
454
614
 
455
615
  def drop(self) -> None:
616
+ """Drop this projection."""
456
617
  self.tx._rpc.api.drop_projection(self.bucket.name, self.schema.name, self.table.name,
457
618
  self.name, txid=self.tx.txid)
458
619
  log.info("Dropped projection: %s", self.name)
@@ -478,3 +639,10 @@ def _serialize_record_batch(record_batch: pa.RecordBatch) -> pa.lib.Buffer:
478
639
  with pa.ipc.new_stream(sink, record_batch.schema) as writer:
479
640
  writer.write(record_batch)
480
641
  return sink.getvalue()
642
+
643
+
644
+ def _combine_chunks(col):
645
+ if hasattr(col, "combine_chunks"):
646
+ return col.combine_chunks()
647
+ else:
648
+ return col
@@ -6,7 +6,7 @@ import pyarrow.parquet as pq
6
6
  import pytest
7
7
 
8
8
  from vastdb import util
9
- from vastdb.errors import ImportFilesError, InvalidArgument
9
+ from vastdb.errors import ImportFilesError, InternalServerError, InvalidArgument
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
@@ -34,12 +34,24 @@ def test_parallel_imports(session, clean_bucket_name, s3):
34
34
  b = tx.bucket(clean_bucket_name)
35
35
  s = b.create_schema('s1')
36
36
  t = s.create_table('t1', pa.schema([('num', pa.int64())]))
37
+ with pytest.raises(InternalServerError):
38
+ t.create_imports_table()
37
39
  log.info("Starting import of %d files", num_files)
38
40
  t.import_files(files)
39
41
  arrow_table = pa.Table.from_batches(t.select(columns=['num']))
40
42
  assert arrow_table.num_rows == num_rows * num_files
41
43
  arrow_table = pa.Table.from_batches(t.select(columns=['num'], predicate=t['num'] == 100))
42
44
  assert arrow_table.num_rows == num_files
45
+ import_table = t.imports_table()
46
+ # checking all imports are on the imports table:
47
+ objects_name = pa.Table.from_batches(import_table.select(columns=["ObjectName"]))
48
+ objects_name = objects_name.to_pydict()
49
+ object_names = set(objects_name['ObjectName'])
50
+ prefix = 'prq'
51
+ numbers = set(range(53))
52
+ assert all(name.startswith(prefix) for name in object_names)
53
+ numbers.issubset(int(name.replace(prefix, '')) for name in object_names)
54
+ assert len(object_names) == len(objects_name['ObjectName'])
43
55
 
44
56
 
45
57
  def test_create_table_from_files(session, clean_bucket_name, s3):
@@ -60,5 +60,4 @@ def test_commits_and_rollbacks(session, clean_bucket_name):
60
60
  def test_list_snapshots(session, clean_bucket_name):
61
61
  with session.transaction() as tx:
62
62
  b = tx.bucket(clean_bucket_name)
63
- s = b.snapshots()
64
- assert s == []
63
+ b.snapshots() # VAST Catalog may create some snapshots
@@ -3,6 +3,7 @@ import decimal
3
3
  import logging
4
4
  import random
5
5
  import threading
6
+ import time
6
7
  from contextlib import closing
7
8
  from tempfile import NamedTemporaryFile
8
9
 
@@ -70,6 +71,16 @@ def test_tables(session, clean_bucket_name):
70
71
  }
71
72
 
72
73
 
74
+ def test_insert_wide_row(session, clean_bucket_name):
75
+ columns = pa.schema([pa.field(f's{i}', pa.utf8()) for i in range(500)])
76
+ data = [['a' * 10**4] for i in range(500)]
77
+ expected = pa.table(schema=columns, data=data)
78
+
79
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
80
+ actual = pa.Table.from_batches(t.select())
81
+ assert actual == expected
82
+
83
+
73
84
  def test_exists(session, clean_bucket_name):
74
85
  with session.transaction() as tx:
75
86
  s = tx.bucket(clean_bucket_name).create_schema('s1')
@@ -261,9 +272,14 @@ def test_filters(session, clean_bucket_name):
261
272
 
262
273
  with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
263
274
  def select(predicate):
264
- return pa.Table.from_batches(t.select(predicate=predicate))
275
+ return pa.Table.from_batches(t.select(predicate=predicate), t.arrow_schema)
265
276
 
266
277
  assert select(None) == expected
278
+ assert select(True) == expected
279
+ assert select(False) == pa.Table.from_batches([], schema=columns)
280
+
281
+ assert select(t['a'].between(222, 444)) == expected.filter((pc.field('a') >= 222) & (pc.field('a') <= 444))
282
+ assert select((t['a'].between(222, 444)) & (t['b'] > 2.5)) == expected.filter((pc.field('a') >= 222) & (pc.field('a') <= 444) & (pc.field('b') > 2.5))
267
283
 
268
284
  assert select(t['a'] > 222) == expected.filter(pc.field('a') > 222)
269
285
  assert select(t['a'] < 222) == expected.filter(pc.field('a') < 222)
@@ -304,6 +320,13 @@ def test_filters(session, clean_bucket_name):
304
320
  assert select(t['s'].contains('b')) == expected.filter(pc.field('s') == 'bb')
305
321
  assert select(t['s'].contains('y')) == expected.filter(pc.field('s') == 'xyz')
306
322
 
323
+ assert select(t['a'].isin([555])) == expected.filter(pc.field('a').isin([555]))
324
+ assert select(t['a'].isin([111, 222, 999])) == expected.filter(pc.field('a').isin([111, 222, 999]))
325
+ assert select((t['a'] == 111) | t['a'].isin([333, 444]) | (t['a'] > 600)) == expected.filter((pc.field('a') == 111) | pc.field('a').isin([333, 444]) | (pc.field('a') > 600))
326
+
327
+ with pytest.raises(NotImplementedError):
328
+ select(t['a'].isin([]))
329
+
307
330
 
308
331
  def test_parquet_export(session, clean_bucket_name):
309
332
  with session.transaction() as tx:
@@ -323,7 +346,8 @@ def test_parquet_export(session, clean_bucket_name):
323
346
  ['a', 'b'],
324
347
  ])
325
348
  expected = pa.Table.from_batches([rb])
326
- t.insert(rb)
349
+ rb = t.insert(rb)
350
+ assert rb.to_pylist() == [0, 1]
327
351
  actual = pa.Table.from_batches(t.select())
328
352
  assert actual == expected
329
353
 
@@ -638,3 +662,20 @@ def test_select_stop(session, clean_bucket_name):
638
662
 
639
663
  # validate that all query threads were killed.
640
664
  assert active_threads() == 0
665
+
666
+
667
+ def test_big_catalog_select(session, clean_bucket_name):
668
+ with session.transaction() as tx:
669
+ bc = tx.catalog()
670
+ actual = pa.Table.from_batches(bc.select(['name']))
671
+ assert actual
672
+ log.info("actual=%s", actual)
673
+
674
+
675
+ def test_audit_log_select(session, clean_bucket_name):
676
+ with session.transaction() as tx:
677
+ a = tx.audit_log()
678
+ a.columns()
679
+ time.sleep(1)
680
+ actual = pa.Table.from_batches(a.select(), a.arrow_schema)
681
+ log.info("actual=%s", actual)
@@ -0,0 +1,39 @@
1
+ import pyarrow as pa
2
+ import pytest
3
+
4
+ from .. import errors, util
5
+
6
+
7
+ def test_slices():
8
+ ROWS = 1 << 20
9
+ t = pa.table({"x": range(ROWS), "y": [i / 1000 for i in range(ROWS)]})
10
+
11
+ chunks = list(util.iter_serialized_slices(t))
12
+ assert len(chunks) > 1
13
+ sizes = [len(c) for c in chunks]
14
+
15
+ assert max(sizes) < util.MAX_RECORD_BATCH_SLICE_SIZE
16
+ assert t == pa.Table.from_batches(_parse(chunks))
17
+
18
+ chunks = list(util.iter_serialized_slices(t, 1000))
19
+ assert len(chunks) > 1
20
+ sizes = [len(c) for c in chunks]
21
+
22
+ assert max(sizes) < util.MAX_RECORD_BATCH_SLICE_SIZE
23
+ assert t == pa.Table.from_batches(_parse(chunks))
24
+
25
+
26
+ def test_wide_row():
27
+ cols = [pa.field(f"x{i}", pa.utf8()) for i in range(1000)]
28
+ values = [['a' * 10000]] * len(cols)
29
+ t = pa.table(values, schema=pa.schema(cols))
30
+ assert len(t) == 1
31
+
32
+ with pytest.raises(errors.TooWideRow):
33
+ list(util.iter_serialized_slices(t))
34
+
35
+
36
+ def _parse(bufs):
37
+ for buf in bufs:
38
+ with pa.ipc.open_stream(buf) as reader:
39
+ yield from reader
vastdb/tests/util.py CHANGED
@@ -9,7 +9,9 @@ def prepare_data(session, clean_bucket_name, schema_name, table_name, arrow_tabl
9
9
  with session.transaction() as tx:
10
10
  s = tx.bucket(clean_bucket_name).create_schema(schema_name)
11
11
  t = s.create_table(table_name, arrow_table.schema)
12
- t.insert(arrow_table)
12
+ row_ids_array = t.insert(arrow_table)
13
+ row_ids = row_ids_array.to_pylist()
14
+ assert row_ids == list(range(arrow_table.num_rows))
13
15
  yield t
14
16
  t.drop()
15
17
  s.drop()