vastdb 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,17 +5,37 @@ import re
5
5
  import struct
6
6
  import urllib.parse
7
7
  from collections import defaultdict, namedtuple
8
+ from dataclasses import dataclass, field
8
9
  from enum import Enum
9
- from typing import Any, Dict, Iterator, List, Optional, Union
10
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
10
11
 
12
+ import backoff
11
13
  import flatbuffers
12
14
  import ibis
13
15
  import pyarrow as pa
14
- import pyarrow.parquet as pq
15
16
  import requests
16
17
  import urllib3
17
18
  import xmltodict
18
19
  from aws_requests_auth.aws_auth import AWSRequestsAuth
20
+ from ibis.expr.operations.generic import (
21
+ IsNull,
22
+ Literal,
23
+ )
24
+ from ibis.expr.operations.logical import (
25
+ And,
26
+ Between,
27
+ Equals,
28
+ Greater,
29
+ GreaterEqual,
30
+ InValues,
31
+ Less,
32
+ LessEqual,
33
+ Not,
34
+ NotEquals,
35
+ Or,
36
+ )
37
+ from ibis.expr.operations.relations import Field
38
+ from ibis.expr.operations.strings import StringContains
19
39
 
20
40
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.BinaryLiteral as fb_binary_lit
21
41
  import vast_flatbuf.org.apache.arrow.computeir.flatbuf.BooleanLiteral as fb_bool_lit
@@ -137,26 +157,6 @@ class Predicate:
137
157
  self.expr = expr
138
158
 
139
159
  def serialize(self, builder: 'flatbuffers.builder.Builder'):
140
- from ibis.expr.operations.generic import (
141
- IsNull,
142
- Literal,
143
- TableColumn,
144
- )
145
- from ibis.expr.operations.logical import (
146
- And,
147
- Between,
148
- Equals,
149
- Greater,
150
- GreaterEqual,
151
- InValues,
152
- Less,
153
- LessEqual,
154
- Not,
155
- NotEquals,
156
- Or,
157
- )
158
- from ibis.expr.operations.strings import StringContains
159
-
160
160
  builder_map = {
161
161
  Greater: self.build_greater,
162
162
  GreaterEqual: self.build_greater_equal,
@@ -216,7 +216,7 @@ class Predicate:
216
216
  if not isinstance(literal, Literal):
217
217
  raise NotImplementedError(self.expr)
218
218
 
219
- if not isinstance(column, TableColumn):
219
+ if not isinstance(column, Field):
220
220
  raise NotImplementedError(self.expr)
221
221
 
222
222
  field_name = column.name
@@ -722,19 +722,59 @@ def _parse_table_info(obj):
722
722
  TableStatsResult = namedtuple("TableStatsResult", ["num_rows", "size_in_bytes", "is_external_rowid_alloc", "endpoints"])
723
723
 
724
724
 
725
+ def _backoff_giveup(exc: Exception) -> bool:
726
+
727
+ if isinstance(exc, errors.Slowdown):
728
+ # the server is overloaded, retry later
729
+ return False
730
+
731
+ if isinstance(exc, requests.exceptions.ConnectionError):
732
+ if exc.request.method == "GET":
733
+ # low-level connection issue, it is safe to retry only read-only requests
734
+ return False
735
+
736
+ return True # giveup in case of other exceptions
737
+
738
+
739
+ @dataclass
740
+ class BackoffConfig:
741
+ wait_gen: Callable = field(default=backoff.expo)
742
+ max_tries: int = 10
743
+ max_time: float = 60.0 # in seconds
744
+ backoff_log_level: int = logging.DEBUG
745
+
746
+
725
747
  class VastdbApi:
726
748
  # we expect the vast version to be <major>.<minor>.<patch>.<protocol>
727
749
  VAST_VERSION_REGEX = re.compile(r'^vast (\d+\.\d+\.\d+\.\d+)$')
728
750
 
729
- def __init__(self, endpoint, access_key, secret_key, auth_type=AuthType.SIGV4, ssl_verify=True):
751
+ def __init__(self, endpoint, access_key, secret_key,
752
+ *,
753
+ auth_type=AuthType.SIGV4,
754
+ ssl_verify=True,
755
+ backoff_config: Optional[BackoffConfig] = None):
756
+
757
+ from . import __version__ # import lazily here (to avoid circular dependencies)
758
+ self.client_sdk_version = f"VAST Database Python SDK {__version__} - 2024 (c)"
759
+
730
760
  url = urllib3.util.parse_url(endpoint)
731
761
  self.access_key = access_key
732
762
  self.secret_key = secret_key
733
763
 
734
764
  self.default_max_list_columns_page_size = 1000
735
- self.session = requests.Session()
736
- self.session.verify = ssl_verify
737
- self.session.headers['user-agent'] = "VastData Tabular API 1.0 - 2022 (c)"
765
+ self._session = requests.Session()
766
+ self._session.verify = ssl_verify
767
+ self._session.headers['user-agent'] = self.client_sdk_version
768
+
769
+ backoff_config = backoff_config or BackoffConfig()
770
+ backoff_decorator = backoff.on_exception(
771
+ wait_gen=backoff_config.wait_gen,
772
+ exception=(requests.exceptions.ConnectionError, errors.Slowdown),
773
+ giveup=_backoff_giveup,
774
+ max_tries=backoff_config.max_tries,
775
+ max_time=backoff_config.max_time,
776
+ backoff_log_level=backoff_config.backoff_log_level)
777
+ self._request = backoff_decorator(self._single_request)
738
778
 
739
779
  if url.port in {80, 443, None}:
740
780
  self.aws_host = f'{url.host}'
@@ -744,22 +784,21 @@ class VastdbApi:
744
784
  self.url = str(url)
745
785
  _logger.debug('url=%s aws_host=%s', self.url, self.aws_host)
746
786
 
747
- self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
787
+ self._session.auth = AWSRequestsAuth(aws_access_key=access_key,
748
788
  aws_secret_access_key=secret_key,
749
789
  aws_host=self.aws_host,
750
- aws_region='us-east-1',
790
+ aws_region='',
751
791
  aws_service='s3')
752
792
 
753
793
  # probe the cluster for its version
754
- self.vast_version = None
755
- res = self.session.get(self.url)
794
+ res = self._request(method="GET", url=self._url(command="transaction"), skip_status_check=True) # used only for the response headers
795
+ _logger.debug("headers=%s code=%s content=%s", res.headers, res.status_code, res.content)
756
796
  server_header = res.headers.get("Server")
757
797
  if server_header is None:
758
798
  _logger.error("Response doesn't contain 'Server' header")
759
799
  else:
760
- _logger.debug("Server header is '%s'", server_header)
761
800
  if m := self.VAST_VERSION_REGEX.match(server_header):
762
- self.vast_version, = m.groups()
801
+ self.vast_version: Tuple[int, ...] = tuple(int(v) for v in m.group(1).split("."))
763
802
  return
764
803
  else:
765
804
  _logger.error("'Server' header '%s' doesn't match the expected pattern", server_header)
@@ -772,15 +811,14 @@ class VastdbApi:
772
811
  _logger.critical(msg)
773
812
  raise NotImplementedError(msg)
774
813
 
775
- def update_mgmt_session(self, access_key: str, secret_key: str, auth_type=AuthType.SIGV4):
776
- if auth_type != AuthType.BASIC:
777
- self.session.auth = AWSRequestsAuth(aws_access_key=access_key,
778
- aws_secret_access_key=secret_key,
779
- aws_host=self.aws_host,
780
- aws_region='us-east-1',
781
- aws_service='s3')
814
+ def _single_request(self, *, method, url, skip_status_check=False, **kwargs):
815
+ res = self._session.request(method=method, url=url, **kwargs)
816
+ if not skip_status_check:
817
+ if exc := errors.from_response(res):
818
+ raise exc # application-level error
819
+ return res # successful response
782
820
 
783
- def _api_prefix(self, bucket="", schema="", table="", command="", url_params={}):
821
+ def _url(self, bucket="", schema="", table="", command="", url_params={}):
784
822
  prefix_list = [self.url]
785
823
  if len(bucket):
786
824
  prefix_list.append(bucket)
@@ -815,11 +853,6 @@ class VastdbApi:
815
853
 
816
854
  return common_headers | {f'tabular-client-tags-{index}': tag for index, tag in enumerate(client_tags)}
817
855
 
818
- def _check_res(self, res, cmd="", expected_retvals=[]):
819
- if exc := errors.from_response(res):
820
- raise exc
821
- return res
822
-
823
856
  def create_schema(self, bucket, name, txid=0, client_tags=[], schema_properties="", expected_retvals=[]):
824
857
  """
825
858
  Create a collection of tables, use the following request
@@ -841,10 +874,10 @@ class VastdbApi:
841
874
 
842
875
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
843
876
  headers['Content-Length'] = str(len(create_schema_req))
844
- res = self.session.post(self._api_prefix(bucket=bucket, schema=name, command="schema"),
845
- data=create_schema_req, headers=headers, stream=True)
846
-
847
- return self._check_res(res, "create_schema", expected_retvals)
877
+ self._request(
878
+ method="POST",
879
+ url=self._url(bucket=bucket, schema=name, command="schema"),
880
+ data=create_schema_req, headers=headers)
848
881
 
849
882
  def alter_schema(self, bucket, name, txid=0, client_tags=[], schema_properties="", new_name="", expected_retvals=[]):
850
883
  """
@@ -870,10 +903,10 @@ class VastdbApi:
870
903
  headers['Content-Length'] = str(len(alter_schema_req))
871
904
  url_params = {'tabular-new-schema-name': new_name} if len(new_name) else {}
872
905
 
873
- res = self.session.put(self._api_prefix(bucket=bucket, schema=name, command="schema", url_params=url_params),
874
- data=alter_schema_req, headers=headers)
875
-
876
- return self._check_res(res, "alter_schema", expected_retvals)
906
+ self._request(
907
+ method="PUT",
908
+ url=self._url(bucket=bucket, schema=name, command="schema", url_params=url_params),
909
+ data=alter_schema_req, headers=headers)
877
910
 
878
911
  def drop_schema(self, bucket, name, txid=0, client_tags=[], expected_retvals=[]):
879
912
  """
@@ -884,9 +917,10 @@ class VastdbApi:
884
917
  """
885
918
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
886
919
 
887
- res = self.session.delete(self._api_prefix(bucket=bucket, schema=name, command="schema"), headers=headers)
888
-
889
- return self._check_res(res, "drop_schema", expected_retvals)
920
+ self._request(
921
+ method="DELETE",
922
+ url=self._url(bucket=bucket, schema=name, command="schema"),
923
+ headers=headers)
890
924
 
891
925
  def list_schemas(self, bucket, schema="", txid=0, client_tags=[], max_keys=1000, next_key=0, name_prefix="",
892
926
  exact_match=False, expected_retvals=[], count_only=False):
@@ -915,25 +949,27 @@ class VastdbApi:
915
949
 
916
950
  schemas = []
917
951
  schema = schema or ""
918
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, command="schema"), headers=headers, stream=True)
919
- self._check_res(res, "list_schemas", expected_retvals)
920
- if res.status_code == 200:
921
- res_headers = res.headers
922
- next_key = int(res_headers['tabular-next-key'])
923
- is_truncated = res_headers['tabular-is-truncated'] == 'true'
924
- lists = list_schemas.GetRootAs(res.content)
925
- bucket_name = lists.BucketName().decode()
926
- if not bucket.startswith(bucket_name):
927
- raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
928
- schemas_length = lists.SchemasLength()
929
- count = int(res_headers['tabular-list-count']) if 'tabular-list-count' in res_headers else schemas_length
930
- for i in range(schemas_length):
931
- schema_obj = lists.Schemas(i)
932
- name = schema_obj.Name().decode()
933
- properties = schema_obj.Properties().decode()
934
- schemas.append([name, properties])
935
-
936
- return bucket_name, schemas, next_key, is_truncated, count
952
+ res = self._request(
953
+ method="GET",
954
+ url=self._url(bucket=bucket, schema=schema, command="schema"),
955
+ headers=headers)
956
+
957
+ res_headers = res.headers
958
+ next_key = int(res_headers['tabular-next-key'])
959
+ is_truncated = res_headers['tabular-is-truncated'] == 'true'
960
+ lists = list_schemas.GetRootAs(res.content)
961
+ bucket_name = lists.BucketName().decode()
962
+ if not bucket.startswith(bucket_name):
963
+ raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
964
+ schemas_length = lists.SchemasLength()
965
+ count = int(res_headers['tabular-list-count']) if 'tabular-list-count' in res_headers else schemas_length
966
+ for i in range(schemas_length):
967
+ schema_obj = lists.Schemas(i)
968
+ name = schema_obj.Name().decode()
969
+ properties = schema_obj.Properties().decode()
970
+ schemas.append([name, properties])
971
+
972
+ return bucket_name, schemas, next_key, is_truncated, count
937
973
 
938
974
  def list_snapshots(self, bucket, max_keys=1000, next_token=None, name_prefix=''):
939
975
  next_token = next_token or ''
@@ -941,8 +977,9 @@ class VastdbApi:
941
977
  if next_token:
942
978
  url_params['continuation-token'] = next_token
943
979
 
944
- res = self.session.get(self._api_prefix(bucket=bucket, command="list", url_params=url_params), headers={}, stream=True)
945
- self._check_res(res, "list_snapshots")
980
+ res = self._request(
981
+ method="GET",
982
+ url=self._url(bucket=bucket, command="list", url_params=url_params))
946
983
 
947
984
  xml_str = res.content.decode()
948
985
  xml_dict = xmltodict.parse(xml_str)
@@ -985,33 +1022,10 @@ class VastdbApi:
985
1022
  if create_imports_table:
986
1023
  url_params['sub-table'] = IMPORTED_OBJECTS_TABLE_NAME
987
1024
 
988
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
989
- data=serialized_schema, headers=headers)
990
- return self._check_res(res, "create_table", expected_retvals)
991
-
992
- def create_table_from_parquet_schema(self, bucket, schema, name, parquet_path=None,
993
- parquet_bucket_name=None, parquet_object_name=None,
994
- txid=0, client_tags=[], expected_retvals=[]):
995
-
996
- # Use pyarrow.parquet.ParquetDataset to open the Parquet file
997
- if parquet_path:
998
- parquet_ds = pq.ParquetDataset(parquet_path)
999
- elif parquet_bucket_name and parquet_object_name:
1000
- s3fs = pa.fs.S3FileSystem(access_key=self.access_key, secret_key=self.secret_key, endpoint_override=self.url)
1001
- parquet_ds = pq.ParquetDataset('/'.join([parquet_bucket_name, parquet_object_name]), filesystem=s3fs)
1002
- else:
1003
- raise RuntimeError(f'invalid params parquet_path={parquet_path} parquet_bucket_name={parquet_bucket_name} parquet_object_name={parquet_object_name}')
1004
-
1005
- # Get the schema of the Parquet file
1006
- if isinstance(parquet_ds.schema, pq.ParquetSchema):
1007
- arrow_schema = parquet_ds.schema.to_arrow_schema()
1008
- elif isinstance(parquet_ds.schema, pa.Schema):
1009
- arrow_schema = parquet_ds.schema
1010
- else:
1011
- raise RuntimeError(f'invalid type(parquet_ds.schema) = {type(parquet_ds.schema)}')
1012
-
1013
- # create the table
1014
- return self.create_table(bucket, schema, name, arrow_schema, txid, client_tags, expected_retvals)
1025
+ self._request(
1026
+ method="POST",
1027
+ url=self._url(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
1028
+ data=serialized_schema, headers=headers)
1015
1029
 
1016
1030
  def get_table_stats(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[], imports_table_stats=False):
1017
1031
  """
@@ -1023,8 +1037,10 @@ class VastdbApi:
1023
1037
  """
1024
1038
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1025
1039
  url_params = {'sub-table': IMPORTED_OBJECTS_TABLE_NAME} if imports_table_stats else {}
1026
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=name, command="stats", url_params=url_params), headers=headers)
1027
- self._check_res(res, "get_table_stats", expected_retvals)
1040
+ res = self._request(
1041
+ method="GET",
1042
+ url=self._url(bucket=bucket, schema=schema, table=name, command="stats", url_params=url_params),
1043
+ headers=headers)
1028
1044
 
1029
1045
  stats = get_table_stats.GetRootAs(res.content)
1030
1046
  num_rows = stats.NumRows()
@@ -1059,10 +1075,10 @@ class VastdbApi:
1059
1075
  headers['Content-Length'] = str(len(alter_table_req))
1060
1076
  url_params = {'tabular-new-table-name': schema + "/" + new_name} if len(new_name) else {}
1061
1077
 
1062
- res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
1063
- data=alter_table_req, headers=headers)
1064
-
1065
- return self._check_res(res, "alter_table", expected_retvals)
1078
+ self._request(
1079
+ method="PUT",
1080
+ url=self._url(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
1081
+ data=alter_table_req, headers=headers)
1066
1082
 
1067
1083
  def drop_table(self, bucket, schema, name, txid=0, client_tags=[], expected_retvals=[], remove_imports_table=False):
1068
1084
  """
@@ -1075,9 +1091,10 @@ class VastdbApi:
1075
1091
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1076
1092
  url_params = {'sub-table': IMPORTED_OBJECTS_TABLE_NAME} if remove_imports_table else {}
1077
1093
 
1078
- res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
1079
- headers=headers)
1080
- return self._check_res(res, "drop_table", expected_retvals)
1094
+ self._request(
1095
+ method="DELETE",
1096
+ url=self._url(bucket=bucket, schema=schema, table=name, command="table", url_params=url_params),
1097
+ headers=headers)
1081
1098
 
1082
1099
  def list_tables(self, bucket, schema, txid=0, client_tags=[], max_keys=1000, next_key=0, name_prefix="",
1083
1100
  exact_match=False, expected_retvals=[], include_list_stats=False, count_only=False):
@@ -1101,23 +1118,25 @@ class VastdbApi:
1101
1118
  headers['tabular-include-list-stats'] = str(include_list_stats)
1102
1119
 
1103
1120
  tables = []
1104
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, command="table"), headers=headers)
1105
- self._check_res(res, "list_table", expected_retvals)
1106
- if res.status_code == 200:
1107
- res_headers = res.headers
1108
- next_key = int(res_headers['tabular-next-key'])
1109
- is_truncated = res_headers['tabular-is-truncated'] == 'true'
1110
- lists = list_tables.GetRootAs(res.content)
1111
- bucket_name = lists.BucketName().decode()
1112
- schema_name = lists.SchemaName().decode()
1113
- if not bucket.startswith(bucket_name): # ignore snapshot name
1114
- raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
1115
- tables_length = lists.TablesLength()
1116
- count = int(res_headers['tabular-list-count']) if 'tabular-list-count' in res_headers else tables_length
1117
- for i in range(tables_length):
1118
- tables.append(_parse_table_info(lists.Tables(i)))
1119
-
1120
- return bucket_name, schema_name, tables, next_key, is_truncated, count
1121
+ res = self._request(
1122
+ method="GET",
1123
+ url=self._url(bucket=bucket, schema=schema, command="table"),
1124
+ headers=headers)
1125
+
1126
+ res_headers = res.headers
1127
+ next_key = int(res_headers['tabular-next-key'])
1128
+ is_truncated = res_headers['tabular-is-truncated'] == 'true'
1129
+ lists = list_tables.GetRootAs(res.content)
1130
+ bucket_name = lists.BucketName().decode()
1131
+ schema_name = lists.SchemaName().decode()
1132
+ if not bucket.startswith(bucket_name): # ignore snapshot name
1133
+ raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
1134
+ tables_length = lists.TablesLength()
1135
+ count = int(res_headers['tabular-list-count']) if 'tabular-list-count' in res_headers else tables_length
1136
+ for i in range(tables_length):
1137
+ tables.append(_parse_table_info(lists.Tables(i)))
1138
+
1139
+ return bucket_name, schema_name, tables, next_key, is_truncated, count
1121
1140
 
1122
1141
  def add_columns(self, bucket, schema, name, arrow_schema, txid=0, client_tags=[], expected_retvals=[]):
1123
1142
  """
@@ -1139,9 +1158,10 @@ class VastdbApi:
1139
1158
  serialized_schema = arrow_schema.serialize()
1140
1159
  headers['Content-Length'] = str(len(serialized_schema))
1141
1160
 
1142
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=name, command="column"),
1143
- data=serialized_schema, headers=headers)
1144
- return self._check_res(res, "add_columns", expected_retvals)
1161
+ self._request(
1162
+ method="POST",
1163
+ url=self._url(bucket=bucket, schema=schema, table=name, command="column"),
1164
+ data=serialized_schema, headers=headers)
1145
1165
 
1146
1166
  def alter_column(self, bucket, schema, table, name, txid=0, client_tags=[], column_properties="",
1147
1167
  new_name="", column_sep=".", column_stats="", expected_retvals=[]):
@@ -1177,9 +1197,10 @@ class VastdbApi:
1177
1197
  if len(new_name):
1178
1198
  url_params['tabular-new-column-name'] = new_name
1179
1199
 
1180
- res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=table, command="column", url_params=url_params),
1181
- data=alter_column_req, headers=headers)
1182
- return self._check_res(res, "alter_column", expected_retvals)
1200
+ self._request(
1201
+ method="PUT",
1202
+ url=self._url(bucket=bucket, schema=schema, table=table, command="column", url_params=url_params),
1203
+ data=alter_column_req, headers=headers)
1183
1204
 
1184
1205
  def drop_columns(self, bucket, schema, table, arrow_schema, txid=0, client_tags=[], expected_retvals=[]):
1185
1206
  """
@@ -1192,9 +1213,10 @@ class VastdbApi:
1192
1213
  serialized_schema = arrow_schema.serialize()
1193
1214
  headers['Content-Length'] = str(len(serialized_schema))
1194
1215
 
1195
- res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="column"),
1196
- data=serialized_schema, headers=headers)
1197
- return self._check_res(res, "drop_columns", expected_retvals)
1216
+ self._request(
1217
+ method="DELETE",
1218
+ url=self._url(bucket=bucket, schema=schema, table=table, command="column"),
1219
+ data=serialized_schema, headers=headers)
1198
1220
 
1199
1221
  def list_columns(self, bucket, schema, table, *, txid=0, client_tags=None, max_keys=None, next_key=0,
1200
1222
  count_only=False, name_prefix="", exact_match=False,
@@ -1226,18 +1248,18 @@ class VastdbApi:
1226
1248
  headers['tabular-name-prefix'] = name_prefix
1227
1249
 
1228
1250
  url_params = {'sub-table': IMPORTED_OBJECTS_TABLE_NAME} if list_imports_table else {}
1229
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=table, command="column",
1230
- url_params=url_params),
1231
- headers=headers, stream=True)
1232
- self._check_res(res, "list_columns", expected_retvals)
1233
- if res.status_code == 200:
1234
- res_headers = res.headers
1235
- next_key = int(res_headers['tabular-next-key'])
1236
- is_truncated = res_headers['tabular-is-truncated'] == 'true'
1237
- count = int(res_headers['tabular-list-count'])
1238
- columns = [] if count_only else pa.ipc.open_stream(res.content).schema
1239
-
1240
- return columns, next_key, is_truncated, count
1251
+ res = self._request(
1252
+ method="GET",
1253
+ url=self._url(bucket=bucket, schema=schema, table=table, command="column", url_params=url_params),
1254
+ headers=headers)
1255
+
1256
+ res_headers = res.headers
1257
+ next_key = int(res_headers['tabular-next-key'])
1258
+ is_truncated = res_headers['tabular-is-truncated'] == 'true'
1259
+ count = int(res_headers['tabular-list-count'])
1260
+ columns = [] if count_only else pa.ipc.open_stream(res.content).schema
1261
+
1262
+ return columns, next_key, is_truncated, count
1241
1263
 
1242
1264
  def begin_transaction(self, client_tags=[], expected_retvals=[]):
1243
1265
  """
@@ -1248,8 +1270,10 @@ class VastdbApi:
1248
1270
  tabular-txid: TransactionId
1249
1271
  """
1250
1272
  headers = self._fill_common_headers(client_tags=client_tags)
1251
- res = self.session.post(self._api_prefix(command="transaction"), headers=headers)
1252
- return self._check_res(res, "begin_transaction", expected_retvals)
1273
+ return self._request(
1274
+ method="POST",
1275
+ url=self._url(command="transaction"),
1276
+ headers=headers)
1253
1277
 
1254
1278
  def commit_transaction(self, txid, client_tags=[], expected_retvals=[]):
1255
1279
  """
@@ -1258,8 +1282,10 @@ class VastdbApi:
1258
1282
  tabular-client-tag: ClientTag
1259
1283
  """
1260
1284
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1261
- res = self.session.put(self._api_prefix(command="transaction"), headers=headers)
1262
- return self._check_res(res, "commit_transaction", expected_retvals)
1285
+ self._request(
1286
+ method="PUT",
1287
+ url=self._url(command="transaction"),
1288
+ headers=headers)
1263
1289
 
1264
1290
  def rollback_transaction(self, txid, client_tags=[], expected_retvals=[]):
1265
1291
  """
@@ -1268,8 +1294,10 @@ class VastdbApi:
1268
1294
  tabular-client-tag: ClientTag
1269
1295
  """
1270
1296
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1271
- res = self.session.delete(self._api_prefix(command="transaction"), headers=headers)
1272
- return self._check_res(res, "rollback_transaction", expected_retvals)
1297
+ self._request(
1298
+ method="DELETE",
1299
+ url=self._url(command="transaction"),
1300
+ headers=headers)
1273
1301
 
1274
1302
  def get_transaction(self, txid, client_tags=[], expected_retvals=[]):
1275
1303
  """
@@ -1278,56 +1306,10 @@ class VastdbApi:
1278
1306
  tabular-client-tag: ClientTag
1279
1307
  """
1280
1308
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1281
- res = self.session.get(self._api_prefix(command="transaction"), headers=headers)
1282
- return self._check_res(res, "get_transaction", expected_retvals)
1283
-
1284
- def select_row_ids(self, bucket, schema, table, params, txid=0, client_tags=[], expected_retvals=[],
1285
- retry_count=0, enable_sorted_projections=True):
1286
- """
1287
- POST /mybucket/myschema/mytable?query-data=SelectRowIds HTTP/1.1
1288
- """
1289
-
1290
- # add query option select-only and read-only
1291
- headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1292
- headers['Content-Length'] = str(len(params))
1293
- headers['tabular-enable-sorted-projections'] = str(enable_sorted_projections)
1294
- if retry_count > 0:
1295
- headers['tabular-retry-count'] = str(retry_count)
1296
-
1297
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="query-data=SelectRowIds",),
1298
- data=params, headers=headers, stream=True)
1299
- return self._check_res(res, "query_data", expected_retvals)
1300
-
1301
- def read_columns_data(self, bucket, schema, table, params, txid=0, client_tags=[], expected_retvals=[], tenant_guid=None,
1302
- retry_count=0, enable_sorted_projections=True):
1303
- """
1304
- POST /mybucket/myschema/mytable?query-data=ReadColumns HTTP/1.1
1305
- """
1306
-
1307
- headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1308
- headers['Content-Length'] = str(len(params))
1309
- headers['tabular-enable-sorted-projections'] = str(enable_sorted_projections)
1310
- if retry_count > 0:
1311
- headers['tabular-retry-count'] = str(retry_count)
1312
-
1313
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="query-data=ReadColumns",),
1314
- data=params, headers=headers, stream=True)
1315
- return self._check_res(res, "query_data", expected_retvals)
1316
-
1317
- def count_rows(self, bucket, schema, table, params, txid=0, client_tags=[], expected_retvals=[], tenant_guid=None,
1318
- retry_count=0, enable_sorted_projections=True):
1319
- """
1320
- POST /mybucket/myschema/mytable?query-data=CountRows HTTP/1.1
1321
- """
1322
- headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1323
- headers['Content-Length'] = str(len(params))
1324
- headers['tabular-enable-sorted-projections'] = str(enable_sorted_projections)
1325
- if retry_count > 0:
1326
- headers['tabular-retry-count'] = str(retry_count)
1327
-
1328
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="query-data=CountRows",),
1329
- data=params, headers=headers, stream=True)
1330
- return self._check_res(res, "query_data", expected_retvals)
1309
+ self._request(
1310
+ method="GET",
1311
+ url=self._url(command="transaction"),
1312
+ headers=headers)
1331
1313
 
1332
1314
  def _build_query_data_headers(self, txid, client_tags, params, split, num_sub_splits, request_format, response_format,
1333
1315
  enable_sorted_projections, limit_rows, schedule_id, retry_count, search_path, tenant_guid,
@@ -1369,35 +1351,6 @@ class VastdbApi:
1369
1351
  url_params['name'] = projection
1370
1352
  return url_params
1371
1353
 
1372
- def legacy_query_data(self, bucket, schema, table, params, split=(0, 1, 8), num_sub_splits=1, response_row_id=False,
1373
- txid=0, client_tags=[], expected_retvals=[], limit_rows=0, schedule_id=None, retry_count=0,
1374
- search_path=None, sub_split_start_row_ids=[], tenant_guid=None, projection='', enable_sorted_projections=True,
1375
- request_format='string', response_format='string', query_imports_table=False):
1376
- """
1377
- POST /mybucket/myschema/mytable?query-data=LegacyQueryData HTTP/1.1
1378
- Content-Length: ContentLength
1379
- tabular-txid: TransactionId
1380
- tabular-client-tag: ClientTag
1381
- tabular-split: "split_id,total_splits,num_row_groups_per_split"
1382
- tabular-num-of-subsplits: "total"
1383
- tabular-request-format: "string"
1384
- tabular-response-format: "string" #arrow/trino
1385
- tabular-schedule-id: "schedule-id"
1386
-
1387
- Request Body (flatbuf)
1388
- projections_chunk [expressions]
1389
- predicate_chunk "formatted_data", (required)
1390
-
1391
- """
1392
- headers = self._build_query_data_headers(txid, client_tags, params, split, num_sub_splits, request_format, response_format,
1393
- enable_sorted_projections, limit_rows, schedule_id, retry_count, search_path, tenant_guid,
1394
- sub_split_start_row_ids)
1395
- url_params = self._build_query_data_url_params(projection, query_imports_table)
1396
-
1397
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="query-data=LegacyQueryData",
1398
- url_params=url_params), data=params, headers=headers, stream=True)
1399
- return self._check_res(res, "legacy_query_data", expected_retvals)
1400
-
1401
1354
  def query_data(self, bucket, schema, table, params, split=(0, 1, 8), num_sub_splits=1, response_row_id=False,
1402
1355
  txid=0, client_tags=[], expected_retvals=[], limit_rows=0, schedule_id=None, retry_count=0,
1403
1356
  search_path=None, sub_split_start_row_ids=[], tenant_guid=None, projection='', enable_sorted_projections=True,
@@ -1427,9 +1380,10 @@ class VastdbApi:
1427
1380
 
1428
1381
  url_params = self._build_query_data_url_params(projection, query_imports_table)
1429
1382
 
1430
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=table, command="data", url_params=url_params),
1431
- data=params, headers=headers, stream=True)
1432
- return self._check_res(res, "query_data", expected_retvals)
1383
+ return self._request(
1384
+ method="GET",
1385
+ url=self._url(bucket=bucket, schema=schema, table=table, command="data", url_params=url_params),
1386
+ data=params, headers=headers, stream=True)
1433
1387
 
1434
1388
  """
1435
1389
  source_files: list of (bucket_name, file_name)
@@ -1515,12 +1469,14 @@ class VastdbApi:
1515
1469
  headers['tabular-schedule-id'] = str(schedule_id)
1516
1470
  if retry_count > 0:
1517
1471
  headers['tabular-retry-count'] = str(retry_count)
1518
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="data"),
1519
- data=import_req, headers=headers, stream=True)
1472
+ res = self._request(
1473
+ method="POST",
1474
+ url=self._url(bucket=bucket, schema=schema, table=table, command="data"),
1475
+ data=import_req, headers=headers, stream=True)
1520
1476
  if blocking:
1521
1477
  res = iterate_over_import_data_response(res)
1522
1478
 
1523
- return self._check_res(res, "import_data", expected_retvals)
1479
+ return res
1524
1480
 
1525
1481
  def insert_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1526
1482
  """
@@ -1534,9 +1490,10 @@ class VastdbApi:
1534
1490
  """
1535
1491
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1536
1492
  headers['Content-Length'] = str(len(record_batch))
1537
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1538
- data=record_batch, headers=headers, stream=True)
1539
- return self._check_res(res, "insert_rows", expected_retvals)
1493
+ return self._request(
1494
+ method="POST",
1495
+ url=self._url(bucket=bucket, schema=schema, table=table, command="rows"),
1496
+ data=record_batch, headers=headers)
1540
1497
 
1541
1498
  def update_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[]):
1542
1499
  """
@@ -1550,9 +1507,10 @@ class VastdbApi:
1550
1507
  """
1551
1508
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1552
1509
  headers['Content-Length'] = str(len(record_batch))
1553
- res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows"),
1554
- data=record_batch, headers=headers)
1555
- self._check_res(res, "update_rows", expected_retvals)
1510
+ self._request(
1511
+ method="PUT",
1512
+ url=self._url(bucket=bucket, schema=schema, table=table, command="rows"),
1513
+ data=record_batch, headers=headers)
1556
1514
 
1557
1515
  def delete_rows(self, bucket, schema, table, record_batch, txid=0, client_tags=[], expected_retvals=[],
1558
1516
  delete_from_imports_table=False):
@@ -1569,9 +1527,10 @@ class VastdbApi:
1569
1527
  headers['Content-Length'] = str(len(record_batch))
1570
1528
  url_params = {'sub-table': IMPORTED_OBJECTS_TABLE_NAME} if delete_from_imports_table else {}
1571
1529
 
1572
- res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="rows", url_params=url_params),
1573
- data=record_batch, headers=headers)
1574
- self._check_res(res, "delete_rows", expected_retvals)
1530
+ self._request(
1531
+ method="DELETE",
1532
+ url=self._url(bucket=bucket, schema=schema, table=table, command="rows", url_params=url_params),
1533
+ data=record_batch, headers=headers)
1575
1534
 
1576
1535
  def create_projection(self, bucket, schema, table, name, columns, txid=0, client_tags=[], expected_retvals=[]):
1577
1536
  """
@@ -1618,9 +1577,10 @@ class VastdbApi:
1618
1577
  headers['Content-Length'] = str(len(create_projection_req))
1619
1578
  url_params = {'name': name}
1620
1579
 
1621
- res = self.session.post(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection", url_params=url_params),
1622
- data=create_projection_req, headers=headers)
1623
- return self._check_res(res, "create_projection", expected_retvals)
1580
+ self._request(
1581
+ method="POST",
1582
+ url=self._url(bucket=bucket, schema=schema, table=table, command="projection", url_params=url_params),
1583
+ data=create_projection_req, headers=headers)
1624
1584
 
1625
1585
  def get_projection_stats(self, bucket, schema, table, name, txid=0, client_tags=[], expected_retvals=[]):
1626
1586
  """
@@ -1632,17 +1592,17 @@ class VastdbApi:
1632
1592
  """
1633
1593
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1634
1594
  url_params = {'name': name}
1635
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection-stats", url_params=url_params),
1636
- headers=headers)
1637
- if res.status_code == 200:
1638
- stats = get_projection_table_stats.GetRootAs(res.content)
1639
- num_rows = stats.NumRows()
1640
- size_in_bytes = stats.SizeInBytes()
1641
- dirty_blocks_percentage = stats.DirtyBlocksPercentage()
1642
- initial_sync_progress = stats.InitialSyncProgress()
1643
- return num_rows, size_in_bytes, dirty_blocks_percentage, initial_sync_progress
1644
-
1645
- return self._check_res(res, "get_projection_stats", expected_retvals)
1595
+ res = self._request(
1596
+ method="GET",
1597
+ url=self._url(bucket=bucket, schema=schema, table=table, command="projection-stats", url_params=url_params),
1598
+ headers=headers)
1599
+
1600
+ stats = get_projection_table_stats.GetRootAs(res.content)
1601
+ num_rows = stats.NumRows()
1602
+ size_in_bytes = stats.SizeInBytes()
1603
+ dirty_blocks_percentage = stats.DirtyBlocksPercentage()
1604
+ initial_sync_progress = stats.InitialSyncProgress()
1605
+ return num_rows, size_in_bytes, dirty_blocks_percentage, initial_sync_progress
1646
1606
 
1647
1607
  def alter_projection(self, bucket, schema, table, name, txid=0, client_tags=[], table_properties="",
1648
1608
  new_name="", expected_retvals=[]):
@@ -1674,10 +1634,10 @@ class VastdbApi:
1674
1634
  headers['Content-Length'] = str(len(alter_projection_req))
1675
1635
  url_params = {'name': name}
1676
1636
 
1677
- res = self.session.put(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection", url_params=url_params),
1678
- data=alter_projection_req, headers=headers)
1679
-
1680
- return self._check_res(res, "alter_projection", expected_retvals)
1637
+ self._request(
1638
+ method="PUT",
1639
+ url=self._url(bucket=bucket, schema=schema, table=table, command="projection", url_params=url_params),
1640
+ data=alter_projection_req, headers=headers)
1681
1641
 
1682
1642
  def drop_projection(self, bucket, schema, table, name, txid=0, client_tags=[], expected_retvals=[]):
1683
1643
  """
@@ -1688,9 +1648,10 @@ class VastdbApi:
1688
1648
  headers = self._fill_common_headers(txid=txid, client_tags=client_tags)
1689
1649
  url_params = {'name': name}
1690
1650
 
1691
- res = self.session.delete(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection", url_params=url_params),
1692
- headers=headers)
1693
- return self._check_res(res, "drop_projection", expected_retvals)
1651
+ self._request(
1652
+ method="DELETE",
1653
+ url=self._url(bucket=bucket, schema=schema, table=table, command="projection", url_params=url_params),
1654
+ headers=headers)
1694
1655
 
1695
1656
  def list_projections(self, bucket, schema, table, txid=0, client_tags=[], max_keys=1000, next_key=0, name_prefix="",
1696
1657
  exact_match=False, expected_retvals=[], include_list_stats=False, count_only=False):
@@ -1714,24 +1675,26 @@ class VastdbApi:
1714
1675
  headers['tabular-include-list-stats'] = str(include_list_stats)
1715
1676
 
1716
1677
  projections = []
1717
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection"), headers=headers)
1718
- self._check_res(res, "list_projections", expected_retvals)
1719
- if res.status_code == 200:
1720
- res_headers = res.headers
1721
- next_key = int(res_headers['tabular-next-key'])
1722
- is_truncated = res_headers['tabular-is-truncated'] == 'true'
1723
- count = int(res_headers['tabular-list-count'])
1724
- lists = list_projections.GetRootAs(res.content)
1725
- bucket_name = lists.BucketName().decode()
1726
- schema_name = lists.SchemaName().decode()
1727
- table_name = lists.TableName().decode()
1728
- if not bucket.startswith(bucket_name): # ignore snapshot name
1729
- raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
1730
- projections_length = lists.ProjectionsLength()
1731
- for i in range(projections_length):
1732
- projections.append(_parse_table_info(lists.Projections(i)))
1733
-
1734
- return bucket_name, schema_name, table_name, projections, next_key, is_truncated, count
1678
+ res = self._request(
1679
+ method="GET",
1680
+ url=self._url(bucket=bucket, schema=schema, table=table, command="projection"),
1681
+ headers=headers)
1682
+
1683
+ res_headers = res.headers
1684
+ next_key = int(res_headers['tabular-next-key'])
1685
+ is_truncated = res_headers['tabular-is-truncated'] == 'true'
1686
+ count = int(res_headers['tabular-list-count'])
1687
+ lists = list_projections.GetRootAs(res.content)
1688
+ bucket_name = lists.BucketName().decode()
1689
+ schema_name = lists.SchemaName().decode()
1690
+ table_name = lists.TableName().decode()
1691
+ if not bucket.startswith(bucket_name): # ignore snapshot name
1692
+ raise ValueError(f'bucket: {bucket} did not start from {bucket_name}')
1693
+ projections_length = lists.ProjectionsLength()
1694
+ for i in range(projections_length):
1695
+ projections.append(_parse_table_info(lists.Projections(i)))
1696
+
1697
+ return bucket_name, schema_name, table_name, projections, next_key, is_truncated, count
1735
1698
 
1736
1699
  def list_projection_columns(self, bucket, schema, table, projection, txid=0, client_tags=[], max_keys=1000,
1737
1700
  next_key=0, count_only=False, name_prefix="", exact_match=False,
@@ -1759,19 +1722,20 @@ class VastdbApi:
1759
1722
 
1760
1723
  url_params = {'name': projection}
1761
1724
 
1762
- res = self.session.get(self._api_prefix(bucket=bucket, schema=schema, table=table, command="projection-columns", url_params=url_params),
1763
- headers=headers, stream=True)
1764
- self._check_res(res, "list_projection_columns", expected_retvals)
1725
+ res = self._request(
1726
+ method="GET",
1727
+ url=self._url(bucket=bucket, schema=schema, table=table, command="projection-columns", url_params=url_params),
1728
+ headers=headers)
1729
+
1765
1730
  # list projection columns response will also show column type Sorted/UnSorted
1766
- if res.status_code == 200:
1767
- res_headers = res.headers
1768
- next_key = int(res_headers['tabular-next-key'])
1769
- is_truncated = res_headers['tabular-is-truncated'] == 'true'
1770
- count = int(res_headers['tabular-list-count'])
1771
- columns = [] if count_only else [[f.name, f.type, f.metadata] for f in
1772
- pa.ipc.open_stream(res.content).schema]
1731
+ res_headers = res.headers
1732
+ next_key = int(res_headers['tabular-next-key'])
1733
+ is_truncated = res_headers['tabular-is-truncated'] == 'true'
1734
+ count = int(res_headers['tabular-list-count'])
1735
+ columns = [] if count_only else [[f.name, f.type, f.metadata] for f in
1736
+ pa.ipc.open_stream(res.content).schema]
1773
1737
 
1774
- return columns, next_key, is_truncated, count
1738
+ return columns, next_key, is_truncated, count
1775
1739
 
1776
1740
 
1777
1741
  class QueryDataInternalError(Exception):
@@ -2118,40 +2082,3 @@ def build_query_data_request(schema: 'pa.Schema' = pa.schema([]), predicate: ibi
2118
2082
  builder.Finish(relation)
2119
2083
 
2120
2084
  return QueryDataRequest(serialized=builder.Output(), response_schema=response_schema, response_parser=QueryDataParser(response_schema))
2121
-
2122
-
2123
- def convert_column_types(table: 'pa.Table') -> 'pa.Table':
2124
- """
2125
- Adjusting table values
2126
-
2127
- 1. Because the timestamp resolution is too high it is necessary to trim it. ORION-96961
2128
- 2. Since the values of nfs_mode_bits are returned in decimal, need to convert them to octal,
2129
- as in all representations, so that the mode of 448 turn into 700
2130
- 3. for owner_name and group_owner_name 0 -> root, and 65534 -> nobody
2131
- """
2132
- ts_indexes = []
2133
- indexes_of_fields_to_change = {}
2134
- sid_to_name = {
2135
- '0': 'root',
2136
- '65534': 'nobody' # NFSNOBODY_UID_16_BIT
2137
- }
2138
- column_matcher = { # column_name: custom converting rule
2139
- 'nfs_mode_bits': lambda val: int(oct(val).replace('0o', '')) if val is not None else val,
2140
- 'owner_name': lambda val: sid_to_name.get(val, val),
2141
- 'group_owner_name': lambda val: sid_to_name.get(val, val),
2142
- }
2143
- for index, field in enumerate(table.schema):
2144
- if isinstance(field.type, pa.TimestampType) and field.type.unit == 'ns':
2145
- ts_indexes.append(index)
2146
- if field.name in column_matcher:
2147
- indexes_of_fields_to_change[field.name] = index
2148
- for changing_index in ts_indexes:
2149
- field_name = table.schema[changing_index].name
2150
- new_column = table[field_name].cast(pa.timestamp('us'), safe=False)
2151
- table = table.set_column(changing_index, field_name, new_column)
2152
- for field_name, changing_index in indexes_of_fields_to_change.items():
2153
- new_column = table[field_name].to_pylist()
2154
- new_column = list(map(column_matcher[field_name], new_column))
2155
- new_column = pa.array(new_column, table[field_name].type)
2156
- table = table.set_column(changing_index, field_name, new_column)
2157
- return table