vastdb 0.0.5.3__py3-none-any.whl → 0.0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vastdb/util.py DELETED
@@ -1,77 +0,0 @@
1
- import logging
2
- from typing import Callable
3
-
4
- import pyarrow as pa
5
- import pyarrow.parquet as pq
6
-
7
- from vastdb.v2 import InvalidArgumentError, Table, Schema
8
-
9
-
10
- log = logging.getLogger(__name__)
11
-
12
-
13
- def create_table_from_files(
14
- schema: Schema, table_name: str, parquet_files: [str], schema_merge_func: Callable = None) -> Table:
15
- if not schema_merge_func:
16
- schema_merge_func = default_schema_merge
17
- else:
18
- assert schema_merge_func in [default_schema_merge, strict_schema_merge, union_schema_merge]
19
- tx = schema.tx
20
- current_schema = pa.schema([])
21
- s3fs = pa.fs.S3FileSystem(
22
- access_key=tx._rpc.api.access_key, secret_key=tx._rpc.api.secret_key, endpoint_override=tx._rpc.api.url)
23
- for prq_file in parquet_files:
24
- if not prq_file.startswith('/'):
25
- raise InvalidArgumentError(f"Path {prq_file} must start with a '/'")
26
- parquet_ds = pq.ParquetDataset(prq_file.lstrip('/'), filesystem=s3fs)
27
- current_schema = schema_merge_func(current_schema, parquet_ds.schema)
28
-
29
-
30
- log.info("Creating table %s from %d Parquet files, with columns: %s",
31
- table_name, len(parquet_files), list(current_schema))
32
- table = schema.create_table(table_name, current_schema)
33
-
34
- log.info("Starting import of %d files to table: %s", len(parquet_files), table)
35
- table.import_files(parquet_files)
36
- log.info("Finished import of %d files to table: %s", len(parquet_files), table)
37
- return table
38
-
39
-
40
- def default_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema:
41
- """
42
- This function validates a schema is contained in another schema
43
- Raises an InvalidArgumentError if a certain field does not exist in the target schema
44
- """
45
- if not current_schema.names:
46
- return new_schema
47
- s1 = set(current_schema)
48
- s2 = set(new_schema)
49
-
50
- if len(s1) > len(s2):
51
- s1, s2 = s2, s1
52
- result = current_schema # We need this variable in order to preserve the original fields order
53
- else:
54
- result = new_schema
55
-
56
- if not s1.issubset(s2):
57
- log.error("Schema mismatch. schema: %s isn't contained in schema: %s.", s1, s2)
58
- raise InvalidArgumentError("Found mismatch in parquet files schemas.")
59
- return result
60
-
61
-
62
- def strict_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema:
63
- """
64
- This function validates two Schemas are identical.
65
- Raises an InvalidArgumentError if schemas aren't identical.
66
- """
67
- if current_schema.names and current_schema != new_schema:
68
- raise InvalidArgumentError(f"Schemas are not identical. \n {current_schema} \n vs \n {new_schema}")
69
-
70
- return new_schema
71
-
72
-
73
- def union_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema:
74
- """
75
- This function returns a unified schema from potentially two different schemas.
76
- """
77
- return pa.unify_schemas([current_schema, new_schema])
vastdb/v2.py DELETED
@@ -1,360 +0,0 @@
1
- from dataclasses import dataclass, field
2
- import logging
3
- import os
4
-
5
- import boto3
6
- import botocore
7
- import ibis
8
- import pyarrow as pa
9
- import requests
10
-
11
- from vastdb.api import VastdbApi, serialize_record_batch, build_query_data_request, parse_query_data_response, TABULAR_INVALID_ROW_ID
12
-
13
-
14
- log = logging.getLogger(__name__)
15
-
16
-
17
- class VastException(Exception):
18
- pass
19
-
20
-
21
- class NotFoundError(VastException):
22
- pass
23
-
24
-
25
- class AccessDeniedError(VastException):
26
- pass
27
-
28
-
29
- class ImportFilesError(VastException):
30
- pass
31
-
32
-
33
- class InvalidArgumentError(VastException):
34
- pass
35
-
36
-
37
- class RPC:
38
- def __init__(self, access=None, secret=None, endpoint=None):
39
- if access is None:
40
- access = os.environ['AWS_ACCESS_KEY_ID']
41
- if secret is None:
42
- secret = os.environ['AWS_SECRET_ACCESS_KEY']
43
- if endpoint is None:
44
- endpoint = os.environ['AWS_S3_ENDPOINT_URL']
45
-
46
- self.api = VastdbApi(endpoint, access, secret)
47
- self.s3 = boto3.client('s3',
48
- aws_access_key_id=access,
49
- aws_secret_access_key=secret,
50
- endpoint_url=endpoint)
51
-
52
- def __repr__(self):
53
- return f'RPC(endpoint={self.api.url}, access={self.api.access_key})'
54
-
55
- def transaction(self):
56
- return Transaction(self)
57
-
58
-
59
- def connect(*args, **kw):
60
- return RPC(*args, **kw)
61
-
62
-
63
- @dataclass
64
- class Transaction:
65
- _rpc: RPC
66
- txid: int = None
67
-
68
- def __enter__(self):
69
- response = self._rpc.api.begin_transaction()
70
- self.txid = int(response.headers['tabular-txid'])
71
- log.debug("opened txid=%016x", self.txid)
72
- return self
73
-
74
- def __exit__(self, *args):
75
- if args == (None, None, None):
76
- log.debug("committing txid=%016x", self.txid)
77
- self._rpc.api.commit_transaction(self.txid)
78
- else:
79
- log.exception("rolling back txid=%016x", self.txid)
80
- self._rpc.api.rollback_transaction(self.txid)
81
-
82
- def __repr__(self):
83
- return f'Transaction(id=0x{self.txid:016x})'
84
-
85
- def bucket(self, name: str) -> "Bucket":
86
- try:
87
- self._rpc.s3.head_bucket(Bucket=name)
88
- return Bucket(name, self)
89
- except botocore.exceptions.ClientError as e:
90
- if e.response['Error']['Code'] == 403:
91
- raise AccessDeniedError(f"Access is denied to bucket: {name}") from e
92
- else:
93
- raise NotFoundError(f"Bucket {name} does not exist") from e
94
-
95
-
96
- @dataclass
97
- class Bucket:
98
- name: str
99
- tx: Transaction
100
-
101
- def create_schema(self, path: str) -> "Schema":
102
- self.tx._rpc.api.create_schema(self.name, path, txid=self.tx.txid)
103
- log.info("Created schema: %s", path)
104
- return self.schema(path)
105
-
106
- def schema(self, path: str) -> "Schema":
107
- schema = self.schemas(path)
108
- log.debug("schema: %s", schema)
109
- if not schema:
110
- raise NotFoundError(f"Schema '{path}' was not found in bucket: {self.name}")
111
- assert len(schema) == 1, f"Expected to receive only a single schema, but got: {len(schema)}. ({schema})"
112
- log.debug("Found schema: %s", schema[0].name)
113
- return schema[0]
114
-
115
- def schemas(self, schema: str = None) -> ["Schema"]:
116
- schemas = []
117
- next_key = 0
118
- exact_match = bool(schema)
119
- log.debug("list schemas param: schema=%s, exact_match=%s", schema, exact_match)
120
- while True:
121
- bucket_name, curr_schemas, next_key, is_truncated, _ = \
122
- self.tx._rpc.api.list_schemas(bucket=self.name, next_key=next_key, txid=self.tx.txid,
123
- name_prefix=schema, exact_match=exact_match)
124
- if not curr_schemas:
125
- break
126
- schemas.extend(curr_schemas)
127
- if not is_truncated:
128
- break
129
-
130
- return [Schema(name=name, bucket=self) for name, *_ in schemas]
131
-
132
-
133
- @dataclass
134
- class Schema:
135
- name: str
136
- bucket: Bucket
137
-
138
- @property
139
- def tx(self):
140
- return self.bucket.tx
141
-
142
- def create_table(self, table_name: str, columns: pa.Schema) -> "Table":
143
- self.tx._rpc.api.create_table(self.bucket.name, self.name, table_name, columns, txid=self.tx.txid)
144
- log.info("Created table: %s", table_name)
145
- return self.table(table_name)
146
-
147
- def table(self, name: str) -> "Table":
148
- t = self.tables(table_name=name)
149
- if not t:
150
- raise NotFoundError(f"Table '{name}' was not found under schema: {self.name}")
151
- assert len(t) == 1, f"Expected to receive only a single table, but got: {len(t)}. tables: {t}"
152
- log.debug("Found table: %s", t[0])
153
- return t[0]
154
-
155
- def tables(self, table_name=None) -> ["Table"]:
156
- tables = []
157
- next_key = 0
158
- name_prefix = table_name if table_name else ""
159
- exact_match = bool(table_name)
160
- while True:
161
- bucket_name, schema_name, curr_tables, next_key, is_truncated, _ = \
162
- self.tx._rpc.api.list_tables(
163
- bucket=self.bucket.name, schema=self.name, next_key=next_key, txid=self.tx.txid,
164
- exact_match=exact_match, name_prefix=name_prefix)
165
- if not curr_tables:
166
- break
167
- tables.extend(curr_tables)
168
- if not is_truncated:
169
- break
170
-
171
- return [_parse_table_info(table, self) for table in tables]
172
-
173
- def drop(self) -> None:
174
- self.tx._rpc.api.drop_schema(self.bucket.name, self.name, txid=self.tx.txid)
175
- log.info("Dropped schema: %s", self.name)
176
-
177
- def rename(self, new_name) -> None:
178
- self.tx._rpc.api.alter_schema(self.bucket.name, self.name, txid=self.tx.txid, new_name=new_name)
179
- log.info("Renamed schema: %s to %s", self.name, new_name)
180
- self.name = new_name
181
-
182
-
183
- @dataclass
184
- class TableStats:
185
- num_rows: int
186
- size: int
187
-
188
-
189
- @dataclass
190
- class QueryConfig:
191
- num_sub_splits: int = 4
192
- num_splits: int = 1
193
- data_endpoints: [str] = None
194
- limit_per_sub_split: int = 128 * 1024
195
- num_row_groups_per_sub_split: int = 8
196
-
197
-
198
- @dataclass
199
- class Table:
200
- name: str
201
- schema: pa.Schema
202
- handle: int
203
- stats: TableStats
204
- properties: dict = None
205
- arrow_schema: pa.Schema = field(init=False, compare=False)
206
- _ibis_table: ibis.Schema = field(init=False, compare=False)
207
-
208
- def __post_init__(self):
209
- self.properties = self.properties or {}
210
- self.arrow_schema = self.columns()
211
- self._ibis_table = ibis.Schema.from_pyarrow(self.arrow_schema)
212
-
213
- @property
214
- def tx(self):
215
- return self.schema.tx
216
-
217
- @property
218
- def bucket(self):
219
- return self.schema.bucket
220
-
221
- def __repr__(self):
222
- return f"{type(self).__name__}(name={self.name})"
223
-
224
- def columns(self) -> pa.Schema:
225
- cols = self.tx._rpc.api._list_table_columns(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid)
226
- self.arrow_schema = pa.schema([(col[0], col[1]) for col in cols])
227
- return self.arrow_schema
228
-
229
- def import_files(self, files_to_import: [str]) -> None:
230
- source_files = {}
231
- for f in files_to_import:
232
- bucket_name, object_path = _parse_bucket_and_object_names(f)
233
- source_files[(bucket_name, object_path)] = b''
234
-
235
- self._execute_import(source_files)
236
-
237
- def import_partitioned_files(self, files_and_partitions: {str: pa.RecordBatch}) -> None:
238
- source_files = {}
239
- for f, record_batch in files_and_partitions.items():
240
- bucket_name, object_path = _parse_bucket_and_object_names(f)
241
- serialized_batch = _serialize_record_batch(record_batch)
242
- source_files = {(bucket_name, object_path): serialized_batch.to_pybytes()}
243
-
244
- self._execute_import(source_files)
245
-
246
- def _execute_import(self, source_files):
247
- try:
248
- self.tx._rpc.api.import_data(
249
- self.bucket.name, self.schema.name, self.name, source_files, txid=self.tx.txid)
250
- except requests.HTTPError as e:
251
- raise ImportFilesError(f"import_files failed with status: {e.response.status_code}, reason: {e.response.reason}")
252
- except Exception as e:
253
- # TODO: investigate and raise proper error in case of failure mid import.
254
- raise ImportFilesError("import_files failed") from e
255
-
256
- def select(self, columns: [str], predicate: ibis.expr.types.BooleanColumn = None,
257
- config: "QueryConfig" = None):
258
- if config is None:
259
- config = QueryConfig()
260
-
261
- api = self.tx._rpc.api
262
- field_names = columns
263
- filters = []
264
- bucket = self.bucket.name
265
- schema = self.schema.name
266
- table = self.name
267
- query_data_request = build_query_data_request(
268
- schema=self.arrow_schema, filters=filters, field_names=field_names)
269
-
270
- start_row_ids = {i: 0 for i in range(config.num_sub_splits)}
271
- assert config.num_splits == 1 # TODO()
272
- split = (0, 1, config.num_row_groups_per_sub_split)
273
- response_row_id = False
274
-
275
- while not all(row_id == TABULAR_INVALID_ROW_ID for row_id in start_row_ids.values()):
276
- response = api.query_data(
277
- bucket=bucket,
278
- schema=schema,
279
- table=table,
280
- params=query_data_request.serialized,
281
- split=split,
282
- num_sub_splits=config.num_sub_splits,
283
- response_row_id=response_row_id,
284
- txid=self.tx.txid,
285
- limit_rows=config.limit_per_sub_split,
286
- sub_split_start_row_ids=start_row_ids.items())
287
-
288
- pages_iter = parse_query_data_response(
289
- conn=response.raw,
290
- schema=query_data_request.response_schema,
291
- start_row_ids=start_row_ids)
292
-
293
- for page in pages_iter:
294
- for batch in page.to_batches():
295
- if len(batch) > 0:
296
- yield batch
297
-
298
- def insert(self, rows: pa.RecordBatch) -> None:
299
- blob = serialize_record_batch(rows)
300
- self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=blob, txid=self.tx.txid)
301
-
302
- def drop(self) -> None:
303
- self.tx._rpc.api.drop_table(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid)
304
- log.info("Dropped table: %s", self.name)
305
-
306
- def rename(self, new_name) -> None:
307
- self.tx._rpc.api.alter_table(
308
- self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, new_name=new_name)
309
- log.info("Renamed table from %s to %s ", self.name, new_name)
310
- self.name = new_name
311
-
312
- def add_column(self, new_column: pa.Schema) -> None:
313
- self.tx._rpc.api.add_columns(self.bucket.name, self.schema.name, self.name, new_column, txid=self.tx.txid)
314
- log.info("Added column(s): %s", new_column)
315
- self.arrow_schema = self.columns()
316
-
317
- def drop_column(self, column_to_drop: pa.Schema) -> None:
318
- self.tx._rpc.api.drop_columns(self.bucket.name, self.schema.name, self.name, column_to_drop, txid=self.tx.txid)
319
- log.info("Dropped column(s): %s", column_to_drop)
320
- self.arrow_schema = self.columns()
321
-
322
- def rename_column(self, current_column_name: str, new_column_name: str) -> None:
323
- self.tx._rpc.api.alter_column(self.bucket.name, self.schema.name, self.name, name=current_column_name,
324
- new_name=new_column_name, txid=self.tx.txid)
325
- log.info("Renamed column: %s to %s", current_column_name, new_column_name)
326
- self.arrow_schema = self.columns()
327
-
328
- def __getitem__(self, col_name):
329
- return self._ibis_table[col_name]
330
-
331
-
332
- def _parse_table_info(table_info, schema: "Schema"):
333
- stats = TableStats(num_rows=table_info.num_rows, size=table_info.size_in_bytes)
334
- return Table(name=table_info.name, schema=schema, handle=int(table_info.handle), stats=stats)
335
-
336
-
337
- def _parse_bucket_and_object_names(path: str) -> (str, str):
338
- if not path.startswith('/'):
339
- raise InvalidArgumentError(f"Path {path} must start with a '/'")
340
- components = path.split(os.path.sep)
341
- bucket_name = components[1]
342
- object_path = os.path.sep.join(components[2:])
343
- return bucket_name, object_path
344
-
345
-
346
- def _serialize_record_batch(record_batch: pa.RecordBatch) -> pa.lib.Buffer:
347
- sink = pa.BufferOutputStream()
348
- with pa.ipc.new_stream(sink, record_batch.schema) as writer:
349
- writer.write(record_batch)
350
- return sink.getvalue()
351
-
352
-
353
- def _parse_endpoint(endpoint):
354
- if ":" in endpoint:
355
- endpoint, port = endpoint.split(":")
356
- port = int(port)
357
- else:
358
- port = 80
359
- log.debug("endpoint: %s, port: %d", endpoint, port)
360
- return endpoint, port