vastdb 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/schema.py CHANGED
@@ -4,12 +4,17 @@ VAST S3 buckets can be used to create Database schemas and tables.
4
4
  It is possible to list and access VAST snapshots generated over a bucket.
5
5
  """
6
6
 
7
- from . import bucket, errors, schema, table
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, List, Optional
8
10
 
9
11
  import pyarrow as pa
10
12
 
11
- from dataclasses import dataclass
12
- import logging
13
+ from . import bucket, errors, schema, table
14
+
15
+ if TYPE_CHECKING:
16
+ from .table import Table
17
+
13
18
 
14
19
  log = logging.getLogger(__name__)
15
20
 
@@ -26,29 +31,37 @@ class Schema:
26
31
  """VAST transaction used for this schema."""
27
32
  return self.bucket.tx
28
33
 
29
- def create_table(self, table_name: str, columns: pa.Schema) -> "table.Table":
34
+ def create_table(self, table_name: str, columns: pa.Schema, fail_if_exists=True) -> "Table":
30
35
  """Create a new table under this schema."""
36
+ if current := self.table(table_name, fail_if_missing=False):
37
+ if fail_if_exists:
38
+ raise errors.TableExists(self.bucket.name, self.name, table_name)
39
+ else:
40
+ return current
31
41
  self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid)
32
42
  log.info("Created table: %s", table_name)
33
- return self.table(table_name)
43
+ return self.table(table_name) # type: ignore[return-value]
34
44
 
35
- def table(self, name: str) -> "table.Table":
45
+ def table(self, name: str, fail_if_missing=True) -> Optional["table.Table"]:
36
46
  """Get a specific table under this schema."""
37
47
  t = self.tables(table_name=name)
38
48
  if not t:
39
- raise errors.MissingTable(self.bucket.name, self.name, name)
49
+ if fail_if_missing:
50
+ raise errors.MissingTable(self.bucket.name, self.name, name)
51
+ else:
52
+ return None
40
53
  assert len(t) == 1, f"Expected to receive only a single table, but got: {len(t)}. tables: {t}"
41
54
  log.debug("Found table: %s", t[0])
42
55
  return t[0]
43
56
 
44
- def tables(self, table_name=None) -> ["table.Table"]:
57
+ def tables(self, table_name=None) -> List["Table"]:
45
58
  """List all tables under this schema."""
46
59
  tables = []
47
60
  next_key = 0
48
61
  name_prefix = table_name if table_name else ""
49
62
  exact_match = bool(table_name)
50
63
  while True:
51
- bucket_name, schema_name, curr_tables, next_key, is_truncated, _ = \
64
+ _bucket_name, _schema_name, curr_tables, next_key, is_truncated, _ = \
52
65
  self.tx._rpc.api.list_tables(
53
66
  bucket=self.bucket.name, schema=self.name, next_key=next_key, txid=self.tx.txid,
54
67
  exact_match=exact_match, name_prefix=name_prefix, include_list_stats=exact_match)
vastdb/session.py CHANGED
@@ -7,12 +7,11 @@ For more details see:
7
7
  - [Tabular identity policy with the proper permissions](https://support.vastdata.com/s/article/UUID-14322b60-d6a2-89ac-3df0-3dfbb6974182)
8
8
  """
9
9
 
10
- from . import internal_commands
11
- from . import transaction
10
+ import os
12
11
 
13
12
  import boto3
14
13
 
15
- import os
14
+ from . import internal_commands, transaction
16
15
 
17
16
 
18
17
  class Session:
vastdb/table.py CHANGED
@@ -1,19 +1,16 @@
1
- from . import errors, schema
2
- from .internal_commands import build_query_data_request, parse_query_data_response, \
3
- TABULAR_INVALID_ROW_ID, VastdbApi
4
-
5
- import pyarrow as pa
6
- import ibis
7
-
8
1
  import concurrent.futures
2
+ import logging
3
+ import os
9
4
  import queue
10
- from threading import Event
5
+ from dataclasses import dataclass, field
11
6
  from math import ceil
7
+ from threading import Event
8
+ from typing import Dict, List, Optional, Tuple, Union
12
9
 
13
- from dataclasses import dataclass, field
14
- from typing import List, Union
15
- import logging
16
- import os
10
+ import ibis
11
+ import pyarrow as pa
12
+
13
+ from . import errors, internal_commands, schema
17
14
 
18
15
  log = logging.getLogger(__name__)
19
16
 
@@ -24,18 +21,20 @@ MAX_ROWS_PER_BATCH = 512 * 1024
24
21
  # for example insert of 512k uint8 result in 512k*8bytes response since row_ids are uint64
25
22
  MAX_INSERT_ROWS_PER_PATCH = 512 * 1024
26
23
 
24
+
27
25
  @dataclass
28
26
  class TableStats:
29
27
  num_rows: int
30
28
  size_in_bytes: int
31
29
  is_external_rowid_alloc: bool = False
32
- endpoints: List[str] = None
30
+ endpoints: Tuple[str, ...] = ()
31
+
33
32
 
34
33
  @dataclass
35
34
  class QueryConfig:
36
35
  num_sub_splits: int = 4
37
36
  num_splits: int = 1
38
- data_endpoints: [str] = None
37
+ data_endpoints: Optional[List[str]] = None
39
38
  limit_rows_per_sub_split: int = 128 * 1024
40
39
  num_row_groups_per_sub_split: int = 8
41
40
  use_semi_sorted_projections: bool = True
@@ -47,15 +46,16 @@ class QueryConfig:
47
46
  class ImportConfig:
48
47
  import_concurrency: int = 2
49
48
 
49
+
50
50
  class SelectSplitState():
51
- def __init__(self, query_data_request, table : "Table", split_id : int, config: QueryConfig) -> None:
51
+ def __init__(self, query_data_request, table: "Table", split_id: int, config: QueryConfig) -> None:
52
52
  self.split_id = split_id
53
53
  self.subsplits_state = {i: 0 for i in range(config.num_sub_splits)}
54
54
  self.config = config
55
55
  self.query_data_request = query_data_request
56
56
  self.table = table
57
57
 
58
- def batches(self, api : VastdbApi):
58
+ def batches(self, api: internal_commands.VastdbApi):
59
59
  while not self.done:
60
60
  response = api.query_data(
61
61
  bucket=self.table.bucket.name,
@@ -69,20 +69,21 @@ class SelectSplitState():
69
69
  limit_rows=self.config.limit_rows_per_sub_split,
70
70
  sub_split_start_row_ids=self.subsplits_state.items(),
71
71
  enable_sorted_projections=self.config.use_semi_sorted_projections)
72
- pages_iter = parse_query_data_response(
72
+ pages_iter = internal_commands.parse_query_data_response(
73
73
  conn=response.raw,
74
74
  schema=self.query_data_request.response_schema,
75
- start_row_ids=self.subsplits_state)
75
+ start_row_ids=self.subsplits_state,
76
+ parser=self.query_data_request.response_parser)
76
77
 
77
78
  for page in pages_iter:
78
79
  for batch in page.to_batches():
79
80
  if len(batch) > 0:
80
81
  yield batch
81
82
 
82
-
83
83
  @property
84
84
  def done(self):
85
- return all(row_id == TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
85
+ return all(row_id == internal_commands.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
86
+
86
87
 
87
88
  @dataclass
88
89
  class Table:
@@ -90,12 +91,10 @@ class Table:
90
91
  schema: "schema.Schema"
91
92
  handle: int
92
93
  stats: TableStats
93
- properties: dict = None
94
94
  arrow_schema: pa.Schema = field(init=False, compare=False)
95
95
  _ibis_table: ibis.Schema = field(init=False, compare=False)
96
96
 
97
97
  def __post_init__(self):
98
- self.properties = self.properties or {}
99
98
  self.arrow_schema = self.columns()
100
99
 
101
100
  table_path = f'{self.schema.bucket.name}/{self.schema.name}/{self.name}'
@@ -133,13 +132,13 @@ class Table:
133
132
  log.debug("Found projection: %s", projs[0])
134
133
  return projs[0]
135
134
 
136
- def projections(self, projection_name=None) -> ["Projection"]:
135
+ def projections(self, projection_name=None) -> List["Projection"]:
137
136
  projections = []
138
137
  next_key = 0
139
138
  name_prefix = projection_name if projection_name else ""
140
139
  exact_match = bool(projection_name)
141
140
  while True:
142
- bucket_name, schema_name, table_name, curr_projections, next_key, is_truncated, _ = \
141
+ _bucket_name, _schema_name, _table_name, curr_projections, next_key, is_truncated, _ = \
143
142
  self.tx._rpc.api.list_projections(
144
143
  bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid,
145
144
  exact_match=exact_match, name_prefix=name_prefix)
@@ -150,7 +149,7 @@ class Table:
150
149
  break
151
150
  return [_parse_projection_info(projection, self) for projection in projections]
152
151
 
153
- def import_files(self, files_to_import: [str], config: ImportConfig = None) -> None:
152
+ def import_files(self, files_to_import: List[str], config: Optional[ImportConfig] = None) -> None:
154
153
  source_files = {}
155
154
  for f in files_to_import:
156
155
  bucket_name, object_path = _parse_bucket_and_object_names(f)
@@ -158,7 +157,7 @@ class Table:
158
157
 
159
158
  self._execute_import(source_files, config=config)
160
159
 
161
- def import_partitioned_files(self, files_and_partitions: {str: pa.RecordBatch}, config: ImportConfig = None) -> None:
160
+ def import_partitioned_files(self, files_and_partitions: Dict[str, pa.RecordBatch], config: Optional[ImportConfig] = None) -> None:
162
161
  source_files = {}
163
162
  for f, record_batch in files_and_partitions.items():
164
163
  bucket_name, object_path = _parse_bucket_and_object_names(f)
@@ -206,7 +205,7 @@ class Table:
206
205
  max_workers=config.import_concurrency, thread_name_prefix='import_thread') as pool:
207
206
  try:
208
207
  for endpoint in endpoints:
209
- session = VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
208
+ session = internal_commands.VastdbApi(endpoint, self.tx._rpc.api.access_key, self.tx._rpc.api.secret_key)
210
209
  futures.append(pool.submit(import_worker, files_queue, session))
211
210
 
212
211
  log.debug("Waiting for import workers to finish")
@@ -215,24 +214,30 @@ class Table:
215
214
  finally:
216
215
  stop_event.set()
217
216
  # ThreadPoolExecutor will be joined at the end of the context
218
- def refresh_stats(self):
217
+
218
+ def get_stats(self) -> TableStats:
219
219
  stats_tuple = self.tx._rpc.api.get_table_stats(
220
220
  bucket=self.bucket.name, schema=self.schema.name, name=self.name, txid=self.tx.txid)
221
- self.stats = TableStats(**stats_tuple._asdict())
221
+ return TableStats(**stats_tuple._asdict())
222
222
 
223
- def select(self, columns: [str] = None,
223
+ def select(self, columns: Optional[List[str]] = None,
224
224
  predicate: ibis.expr.types.BooleanColumn = None,
225
- config: QueryConfig = None,
225
+ config: Optional[QueryConfig] = None,
226
226
  *,
227
227
  internal_row_id: bool = False) -> pa.RecordBatchReader:
228
228
  if config is None:
229
229
  config = QueryConfig()
230
230
 
231
- self.refresh_stats()
231
+ # Take a snapshot of enpoints
232
+ stats = self.get_stats()
233
+ endpoints = stats.endpoints if config.data_endpoints is None else config.data_endpoints
234
+
235
+ if stats.num_rows > config.rows_per_split and config.num_splits is None:
236
+ config.num_splits = stats.num_rows // config.rows_per_split
237
+ log.debug(f"num_rows={stats.num_rows} rows_per_splits={config.rows_per_split} num_splits={config.num_splits} ")
232
238
 
233
- if self.stats.num_rows > config.rows_per_split and config.num_splits is None:
234
- config.num_splits = self.stats.num_rows // config.rows_per_split
235
- log.debug(f"num_rows={self.stats.num_rows} rows_per_splits={config.rows_per_split} num_splits={config.num_splits} ")
239
+ if columns is None:
240
+ columns = [f.name for f in self.arrow_schema]
236
241
 
237
242
  query_schema = self.arrow_schema
238
243
  if internal_row_id:
@@ -241,12 +246,12 @@ class Table:
241
246
  query_schema = pa.schema(queried_fields)
242
247
  columns.append(INTERNAL_ROW_ID)
243
248
 
244
- query_data_request = build_query_data_request(
249
+ query_data_request = internal_commands.build_query_data_request(
245
250
  schema=query_schema,
246
251
  predicate=predicate,
247
252
  field_names=columns)
248
253
 
249
- splits_queue = queue.Queue()
254
+ splits_queue: queue.Queue[int] = queue.Queue()
250
255
 
251
256
  for split in range(config.num_splits):
252
257
  splits_queue.put(split)
@@ -254,8 +259,10 @@ class Table:
254
259
  # this queue shouldn't be large it is marely a pipe through which the results
255
260
  # are sent to the main thread. Most of the pages actually held in the
256
261
  # threads that fetch the pages.
257
- record_batches_queue = queue.Queue(maxsize=2)
262
+ record_batches_queue: queue.Queue[pa.RecordBatch] = queue.Queue(maxsize=2)
263
+
258
264
  stop_event = Event()
265
+
259
266
  class StoppedException(Exception):
260
267
  pass
261
268
 
@@ -263,9 +270,9 @@ class Table:
263
270
  if stop_event.is_set():
264
271
  raise StoppedException
265
272
 
266
- def single_endpoint_worker(endpoint : str):
273
+ def single_endpoint_worker(endpoint: str):
267
274
  try:
268
- host_api = VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
275
+ host_api = internal_commands.VastdbApi(endpoint=endpoint, access_key=self.tx._rpc.api.access_key, secret_key=self.tx._rpc.api.secret_key)
269
276
  while True:
270
277
  check_stop()
271
278
  try:
@@ -290,12 +297,11 @@ class Table:
290
297
  log.debug("exiting")
291
298
  record_batches_queue.put(None)
292
299
 
293
- # Take a snapshot of enpoints
294
- endpoints = list(self.stats.endpoints) if config.data_endpoints is None else list(config.data_endpoints)
295
-
296
300
  def batches_iterator():
297
- def propagate_first_exception(futures : List[concurrent.futures.Future], block = False):
301
+ def propagate_first_exception(futures: List[concurrent.futures.Future], block=False):
298
302
  done, not_done = concurrent.futures.wait(futures, None if block else 0, concurrent.futures.FIRST_EXCEPTION)
303
+ if self.tx.txid is None:
304
+ raise errors.MissingTransaction()
299
305
  for future in done:
300
306
  future.result()
301
307
  return not_done
@@ -305,7 +311,7 @@ class Table:
305
311
  if config.query_id:
306
312
  threads_prefix = threads_prefix + "-" + config.query_id
307
313
 
308
- with concurrent.futures.ThreadPoolExecutor(max_workers=len(endpoints), thread_name_prefix=threads_prefix) as tp: # TODO: concurrency == enpoints is just a heuristic
314
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(endpoints), thread_name_prefix=threads_prefix) as tp: # TODO: concurrency == enpoints is just a heuristic
309
315
  futures = [tp.submit(single_endpoint_worker, endpoint) for endpoint in endpoints]
310
316
  tasks_running = len(futures)
311
317
  try:
@@ -327,7 +333,7 @@ class Table:
327
333
  if record_batches_queue.get() is None:
328
334
  tasks_running -= 1
329
335
 
330
- return pa.RecordBatchReader.from_batches(query_data_request.response_schema.arrow_schema, batches_iterator())
336
+ return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
331
337
 
332
338
  def _combine_chunks(self, col):
333
339
  if hasattr(col, "combine_chunks"):
@@ -337,16 +343,11 @@ class Table:
337
343
 
338
344
  def insert(self, rows: pa.RecordBatch) -> pa.RecordBatch:
339
345
  serialized_slices = self.tx._rpc.api._record_batch_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
340
- row_ids = []
341
346
  for slice in serialized_slices:
342
- res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
347
+ self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
343
348
  txid=self.tx.txid)
344
- (batch,) = pa.RecordBatchStreamReader(res.raw)
345
- row_ids.append(batch[INTERNAL_ROW_ID])
346
-
347
- return pa.chunked_array(row_ids)
348
349
 
349
- def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: list = None) -> None:
350
+ def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: Optional[List[str]] = None) -> None:
350
351
  if columns is not None:
351
352
  update_fields = [(INTERNAL_ROW_ID, pa.uint64())]
352
353
  update_values = [self._combine_chunks(rows[INTERNAL_ROW_ID])]
@@ -414,7 +415,6 @@ class Projection:
414
415
  table: Table
415
416
  handle: int
416
417
  stats: TableStats
417
- properties: dict = None
418
418
 
419
419
  @property
420
420
  def bucket(self):
@@ -435,7 +435,7 @@ class Projection:
435
435
  columns = []
436
436
  next_key = 0
437
437
  while True:
438
- curr_columns, next_key, is_truncated, count, _ = \
438
+ curr_columns, next_key, is_truncated, _count, _ = \
439
439
  self.tx._rpc.api.list_projection_columns(
440
440
  self.bucket.name, self.schema.name, self.table.name, self.name, txid=self.table.tx.txid, next_key=next_key)
441
441
  if not curr_columns:
@@ -464,9 +464,9 @@ def _parse_projection_info(projection_info, table: "Table"):
464
464
  return Projection(name=projection_info.name, table=table, stats=stats, handle=int(projection_info.handle))
465
465
 
466
466
 
467
- def _parse_bucket_and_object_names(path: str) -> (str, str):
467
+ def _parse_bucket_and_object_names(path: str) -> Tuple[str, str]:
468
468
  if not path.startswith('/'):
469
- raise errors.InvalidArgumentError(f"Path {path} must start with a '/'")
469
+ raise errors.InvalidArgument(f"Path {path} must start with a '/'")
470
470
  components = path.split(os.path.sep)
471
471
  bucket_name = components[1]
472
472
  object_path = os.path.sep.join(components[2:])
@@ -0,0 +1,61 @@
1
+ import logging
2
+
3
+ import duckdb
4
+ import pyarrow as pa
5
+ import pyarrow.compute as pc
6
+ import pytest
7
+
8
+ from ..table import QueryConfig
9
+ from .util import prepare_data
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def test_duckdb(session, clean_bucket_name):
15
+ columns = pa.schema([
16
+ ('a', pa.int32()),
17
+ ('b', pa.float64()),
18
+ ])
19
+ data = pa.table(schema=columns, data=[
20
+ [111, 222, 333],
21
+ [0.5, 1.5, 2.5],
22
+ ])
23
+ with prepare_data(session, clean_bucket_name, 's', 't', data) as t:
24
+ conn = duckdb.connect()
25
+ batches = t.select(columns=['a'], predicate=(t['b'] < 2)) # noqa: F841
26
+ actual = conn.execute('SELECT max(a) as "a_max" FROM batches').arrow()
27
+ expected = (data
28
+ .filter(pc.field('b') < 2)
29
+ .group_by([])
30
+ .aggregate([('a', 'max')]))
31
+ assert actual == expected
32
+
33
+
34
+ def test_closed_tx(session, clean_bucket_name):
35
+ columns = pa.schema([
36
+ ('a', pa.int64()),
37
+ ])
38
+ data = pa.table(schema=columns, data=[
39
+ list(range(10000)),
40
+ ])
41
+
42
+ with session.transaction() as tx:
43
+ t = tx.bucket(clean_bucket_name).create_schema("s1").create_table("t1", columns)
44
+ t.insert(data)
45
+
46
+ config = QueryConfig(
47
+ num_sub_splits=1,
48
+ num_splits=1,
49
+ num_row_groups_per_sub_split=1,
50
+ limit_rows_per_sub_split=100)
51
+ batches = t.select(config=config) # noqa: F841
52
+ first = next(batches) # make sure that HTTP response processing has started
53
+ assert first['a'].to_pylist() == list(range(100))
54
+
55
+ conn = duckdb.connect()
56
+ res = conn.execute('SELECT a FROM batches')
57
+ log.debug("closing tx=%s after first batch=%s", t.tx, first)
58
+
59
+ # transaction is closed, collecting the result should fail
60
+ with pytest.raises(duckdb.InvalidInputException, match="Detail: Python exception: MissingTransaction"):
61
+ res.arrow()
@@ -1,14 +1,12 @@
1
- import pytest
2
-
3
- from tempfile import NamedTemporaryFile
4
1
  import logging
2
+ from tempfile import NamedTemporaryFile
5
3
 
6
4
  import pyarrow as pa
7
5
  import pyarrow.parquet as pq
6
+ import pytest
8
7
 
9
- from vastdb.errors import InvalidArgument, ImportFilesError
10
8
  from vastdb import util
11
-
9
+ from vastdb.errors import ImportFilesError, InvalidArgument
12
10
 
13
11
  log = logging.getLogger(__name__)
14
12
 
@@ -0,0 +1,28 @@
1
+ import itertools
2
+
3
+ import pyarrow as pa
4
+
5
+ from .util import prepare_data
6
+
7
+
8
+ def test_nested(session, clean_bucket_name):
9
+ columns = pa.schema([
10
+ ('l', pa.list_(pa.int8())),
11
+ ('m', pa.map_(pa.utf8(), pa.float64())),
12
+ ('s', pa.struct([('x', pa.int16()), ('y', pa.int32())])),
13
+ ])
14
+ expected = pa.table(schema=columns, data=[
15
+ [[1], [], [2, 3], None],
16
+ [None, {'a': 2.5}, {'b': 0.25, 'c': 0.025}, {}],
17
+ [{'x': 1, 'y': None}, None, {'x': 2, 'y': 3}, {'x': None, 'y': 4}],
18
+ ])
19
+
20
+ with prepare_data(session, clean_bucket_name, 's', 't', expected) as t:
21
+ actual = pa.Table.from_batches(t.select())
22
+ assert actual == expected
23
+
24
+ names = [f.name for f in columns]
25
+ for n in range(len(names) + 1):
26
+ for cols in itertools.permutations(names, n):
27
+ actual = pa.Table.from_batches(t.select(columns=cols))
28
+ assert actual == expected.select(cols)
@@ -1,8 +1,10 @@
1
- import pyarrow as pa
2
1
  import logging
3
2
 
3
+ import pyarrow as pa
4
+
4
5
  log = logging.getLogger(__name__)
5
6
 
7
+
6
8
  def test_basic_projections(session, clean_bucket_name):
7
9
  with session.transaction() as tx:
8
10
  s = tx.bucket(clean_bucket_name).create_schema('s1')
@@ -1,15 +1,14 @@
1
- from http.server import HTTPServer, BaseHTTPRequestHandler
2
- from itertools import cycle
1
+ import contextlib
3
2
  import logging
4
3
  import threading
5
- import contextlib
4
+ from http.server import BaseHTTPRequestHandler, HTTPServer
5
+ from itertools import cycle
6
6
 
7
7
  import pytest
8
8
  import requests
9
9
 
10
10
  import vastdb
11
11
 
12
-
13
12
  log = logging.getLogger(__name__)
14
13
 
15
14
 
@@ -58,10 +57,10 @@ def test_version_extraction():
58
57
  return f"vast {version}" if version else "vast"
59
58
 
60
59
  def log_message(self, format, *args):
61
- log.debug(format,*args)
60
+ log.debug(format, *args)
62
61
 
63
62
  # start the server on localhost on some available port port
64
- server_address =('localhost', 0)
63
+ server_address = ('localhost', 0)
65
64
  httpd = HTTPServer(server_address, MockOptionsHandler)
66
65
 
67
66
  def start_http_server_in_thread():
@@ -1,5 +1,7 @@
1
1
  import pytest
2
2
 
3
+ from .. import errors
4
+
3
5
 
4
6
  def test_schemas(session, clean_bucket_name):
5
7
  with session.transaction() as tx:
@@ -19,6 +21,22 @@ def test_schemas(session, clean_bucket_name):
19
21
  assert b.schemas() == []
20
22
 
21
23
 
24
+ def test_exists(session, clean_bucket_name):
25
+ with session.transaction() as tx:
26
+ b = tx.bucket(clean_bucket_name)
27
+ assert b.schemas() == []
28
+
29
+ s = b.create_schema('s1')
30
+
31
+ assert b.schemas() == [s]
32
+ with pytest.raises(errors.SchemaExists):
33
+ b.create_schema('s1')
34
+
35
+ assert b.schemas() == [s]
36
+ assert b.create_schema('s1', fail_if_exists=False) == s
37
+ assert b.schemas() == [s]
38
+
39
+
22
40
  def test_commits_and_rollbacks(session, clean_bucket_name):
23
41
  with session.transaction() as tx:
24
42
  b = tx.bucket(clean_bucket_name)
@@ -32,12 +50,13 @@ def test_commits_and_rollbacks(session, clean_bucket_name):
32
50
  b = tx.bucket(clean_bucket_name)
33
51
  b.schema("s3").drop()
34
52
  assert b.schemas() == []
35
- 1/0 # rollback schema dropping
53
+ 1 / 0 # rollback schema dropping
36
54
 
37
55
  with session.transaction() as tx:
38
56
  b = tx.bucket(clean_bucket_name)
39
57
  assert b.schemas() != []
40
58
 
59
+
41
60
  def test_list_snapshots(session, clean_bucket_name):
42
61
  with session.transaction() as tx:
43
62
  b = tx.bucket(clean_bucket_name)