databricks-sql-connector 4.0.0b2__tar.gz → 4.0.0b3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/CHANGELOG.md +15 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/PKG-INFO +1 -1
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/pyproject.toml +1 -1
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/__init__.py +1 -1
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/thrift_http_client.py +25 -16
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/client.py +104 -9
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/cloudfetch/download_manager.py +5 -4
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/cloudfetch/downloader.py +6 -7
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_backend.py +21 -54
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/types.py +48 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/utils.py +154 -18
- databricks_sql_connector-4.0.0b3/src/databricks/sqlalchemy/__init__.py +6 -0
- databricks_sql_connector-4.0.0b2/src/databricks/sqlalchemy/__init__.py +0 -6
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/LICENSE +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/README.md +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/auth.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/authenticators.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/endpoint.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/oauth.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/retry.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/exc.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/experimental/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/parameters/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/parameters/native.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/parameters/py.typed +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/py.typed +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/__init__.py +0 -0
|
@@ -1,5 +1,20 @@
|
|
|
1
1
|
# Release History
|
|
2
2
|
|
|
3
|
+
# 3.6.0 (2024-10-25)
|
|
4
|
+
|
|
5
|
+
- Support encryption headers in the cloud fetch request (https://github.com/databricks/databricks-sql-python/pull/460 by @jackyhu-db)
|
|
6
|
+
|
|
7
|
+
# 3.5.0 (2024-10-18)
|
|
8
|
+
|
|
9
|
+
- Create a non pyarrow flow to handle small results for the column set (databricks/databricks-sql-python#440 by @jprakash-db)
|
|
10
|
+
- Fix: On non-retryable error, ensure PySQL includes useful information in error (databricks/databricks-sql-python#447 by @shivam2680)
|
|
11
|
+
|
|
12
|
+
# 3.4.0 (2024-08-27)
|
|
13
|
+
|
|
14
|
+
- Unpin pandas to support v2.2.2 (databricks/databricks-sql-python#416 by @kfollesdal)
|
|
15
|
+
- Make OAuth as the default authenticator if no authentication setting is provided (databricks/databricks-sql-python#419 by @jackyhu-db)
|
|
16
|
+
- Fix (regression): use SSL options with HTTPS connection pool (databricks/databricks-sql-python#425 by @kravets-levko)
|
|
17
|
+
|
|
3
18
|
# 3.3.0 (2024-07-18)
|
|
4
19
|
|
|
5
20
|
- Don't retry requests that fail with HTTP code 401 (databricks/databricks-sql-python#408 by @Hodnebo)
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import logging
|
|
3
3
|
import urllib.parse
|
|
4
|
-
from typing import Dict, Union
|
|
4
|
+
from typing import Dict, Union, Optional
|
|
5
5
|
|
|
6
6
|
import six
|
|
7
7
|
import thrift
|
|
8
8
|
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
9
|
import ssl
|
|
12
10
|
import warnings
|
|
13
11
|
from http.client import HTTPResponse
|
|
@@ -16,6 +14,9 @@ from io import BytesIO
|
|
|
16
14
|
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
|
|
17
15
|
from urllib3.util import make_headers
|
|
18
16
|
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
|
|
17
|
+
from databricks.sql.types import SSLOptions
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
@@ -25,13 +26,12 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
|
25
26
|
uri_or_host,
|
|
26
27
|
port=None,
|
|
27
28
|
path=None,
|
|
28
|
-
|
|
29
|
-
cert_file=None,
|
|
30
|
-
key_file=None,
|
|
31
|
-
ssl_context=None,
|
|
29
|
+
ssl_options: Optional[SSLOptions] = None,
|
|
32
30
|
max_connections: int = 1,
|
|
33
31
|
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
|
|
34
32
|
):
|
|
33
|
+
self._ssl_options = ssl_options
|
|
34
|
+
|
|
35
35
|
if port is not None:
|
|
36
36
|
warnings.warn(
|
|
37
37
|
"Please use the THttpClient('http{s}://host:port/path') constructor",
|
|
@@ -48,13 +48,11 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
|
48
48
|
self.scheme = parsed.scheme
|
|
49
49
|
assert self.scheme in ("http", "https")
|
|
50
50
|
if self.scheme == "https":
|
|
51
|
-
self.
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
else ssl_context
|
|
57
|
-
)
|
|
51
|
+
if self._ssl_options is not None:
|
|
52
|
+
# TODO: Not sure if those options are used anywhere - need to double-check
|
|
53
|
+
self.certfile = self._ssl_options.tls_client_cert_file
|
|
54
|
+
self.keyfile = self._ssl_options.tls_client_cert_key_file
|
|
55
|
+
self.context = self._ssl_options.create_ssl_context()
|
|
58
56
|
self.port = parsed.port
|
|
59
57
|
self.host = parsed.hostname
|
|
60
58
|
self.path = parsed.path
|
|
@@ -109,12 +107,23 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
|
109
107
|
def open(self):
|
|
110
108
|
|
|
111
109
|
# self.__pool replaces the self.__http used by the original THttpClient
|
|
110
|
+
_pool_kwargs = {"maxsize": self.max_connections}
|
|
111
|
+
|
|
112
112
|
if self.scheme == "http":
|
|
113
113
|
pool_class = HTTPConnectionPool
|
|
114
114
|
elif self.scheme == "https":
|
|
115
115
|
pool_class = HTTPSConnectionPool
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
_pool_kwargs.update(
|
|
117
|
+
{
|
|
118
|
+
"cert_reqs": ssl.CERT_REQUIRED
|
|
119
|
+
if self._ssl_options.tls_verify
|
|
120
|
+
else ssl.CERT_NONE,
|
|
121
|
+
"ca_certs": self._ssl_options.tls_trusted_ca_file,
|
|
122
|
+
"cert_file": self._ssl_options.tls_client_cert_file,
|
|
123
|
+
"key_file": self._ssl_options.tls_client_cert_key_file,
|
|
124
|
+
"key_password": self._ssl_options.tls_client_cert_key_password,
|
|
125
|
+
}
|
|
126
|
+
)
|
|
118
127
|
|
|
119
128
|
if self.using_proxy():
|
|
120
129
|
proxy_manager = ProxyManager(
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/client.py
RENAMED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
|
|
2
2
|
|
|
3
3
|
import pandas
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import pyarrow
|
|
7
|
+
except ImportError:
|
|
8
|
+
pyarrow = None
|
|
4
9
|
import requests
|
|
5
10
|
import json
|
|
6
11
|
import os
|
|
@@ -21,6 +26,8 @@ from databricks.sql.utils import (
|
|
|
21
26
|
ParamEscaper,
|
|
22
27
|
inject_parameters,
|
|
23
28
|
transform_paramstyle,
|
|
29
|
+
ColumnTable,
|
|
30
|
+
ColumnQueue,
|
|
24
31
|
)
|
|
25
32
|
from databricks.sql.parameters.native import (
|
|
26
33
|
DbsqlParameterBase,
|
|
@@ -34,7 +41,7 @@ from databricks.sql.parameters.native import (
|
|
|
34
41
|
)
|
|
35
42
|
|
|
36
43
|
|
|
37
|
-
from databricks.sql.types import Row
|
|
44
|
+
from databricks.sql.types import Row, SSLOptions
|
|
38
45
|
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
|
|
39
46
|
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
|
|
40
47
|
|
|
@@ -42,10 +49,6 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
|
|
|
42
49
|
TSparkParameter,
|
|
43
50
|
)
|
|
44
51
|
|
|
45
|
-
try:
|
|
46
|
-
import pyarrow
|
|
47
|
-
except ImportError:
|
|
48
|
-
pyarrow = None
|
|
49
52
|
|
|
50
53
|
logger = logging.getLogger(__name__)
|
|
51
54
|
|
|
@@ -181,8 +184,9 @@ class Connection:
|
|
|
181
184
|
# _tls_trusted_ca_file
|
|
182
185
|
# Set to the path of the file containing trusted CA certificates for server certificate
|
|
183
186
|
# verification. If not provide, uses system truststore.
|
|
184
|
-
# _tls_client_cert_file, _tls_client_cert_key_file
|
|
187
|
+
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
|
|
185
188
|
# Set client SSL certificate.
|
|
189
|
+
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
|
|
186
190
|
# _retry_stop_after_attempts_count
|
|
187
191
|
# The maximum number of attempts during a request retry sequence (defaults to 24)
|
|
188
192
|
# _socket_timeout
|
|
@@ -223,12 +227,25 @@ class Connection:
|
|
|
223
227
|
|
|
224
228
|
base_headers = [("User-Agent", useragent_header)]
|
|
225
229
|
|
|
230
|
+
self._ssl_options = SSLOptions(
|
|
231
|
+
# Double negation is generally a bad thing, but we have to keep backward compatibility
|
|
232
|
+
tls_verify=not kwargs.get(
|
|
233
|
+
"_tls_no_verify", False
|
|
234
|
+
), # by default - verify cert and host
|
|
235
|
+
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
|
|
236
|
+
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
|
|
237
|
+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
|
|
238
|
+
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
|
|
239
|
+
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
|
|
240
|
+
)
|
|
241
|
+
|
|
226
242
|
self.thrift_backend = ThriftBackend(
|
|
227
243
|
self.host,
|
|
228
244
|
self.port,
|
|
229
245
|
http_path,
|
|
230
246
|
(http_headers or []) + base_headers,
|
|
231
247
|
auth_provider,
|
|
248
|
+
ssl_options=self._ssl_options,
|
|
232
249
|
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
|
|
233
250
|
**kwargs,
|
|
234
251
|
)
|
|
@@ -1132,6 +1149,18 @@ class ResultSet:
|
|
|
1132
1149
|
self.results = results
|
|
1133
1150
|
self.has_more_rows = has_more_rows
|
|
1134
1151
|
|
|
1152
|
+
def _convert_columnar_table(self, table):
|
|
1153
|
+
column_names = [c[0] for c in self.description]
|
|
1154
|
+
ResultRow = Row(*column_names)
|
|
1155
|
+
result = []
|
|
1156
|
+
for row_index in range(table.num_rows):
|
|
1157
|
+
curr_row = []
|
|
1158
|
+
for col_index in range(table.num_columns):
|
|
1159
|
+
curr_row.append(table.get_item(col_index, row_index))
|
|
1160
|
+
result.append(ResultRow(*curr_row))
|
|
1161
|
+
|
|
1162
|
+
return result
|
|
1163
|
+
|
|
1135
1164
|
def _convert_arrow_table(self, table):
|
|
1136
1165
|
column_names = [c[0] for c in self.description]
|
|
1137
1166
|
ResultRow = Row(*column_names)
|
|
@@ -1199,6 +1228,48 @@ class ResultSet:
|
|
|
1199
1228
|
|
|
1200
1229
|
return results
|
|
1201
1230
|
|
|
1231
|
+
def merge_columnar(self, result1, result2):
|
|
1232
|
+
"""
|
|
1233
|
+
Function to merge / combining the columnar results into a single result
|
|
1234
|
+
:param result1:
|
|
1235
|
+
:param result2:
|
|
1236
|
+
:return:
|
|
1237
|
+
"""
|
|
1238
|
+
|
|
1239
|
+
if result1.column_names != result2.column_names:
|
|
1240
|
+
raise ValueError("The columns in the results don't match")
|
|
1241
|
+
|
|
1242
|
+
merged_result = [
|
|
1243
|
+
result1.column_table[i] + result2.column_table[i]
|
|
1244
|
+
for i in range(result1.num_columns)
|
|
1245
|
+
]
|
|
1246
|
+
return ColumnTable(merged_result, result1.column_names)
|
|
1247
|
+
|
|
1248
|
+
def fetchmany_columnar(self, size: int):
|
|
1249
|
+
"""
|
|
1250
|
+
Fetch the next set of rows of a query result, returning a Columnar Table.
|
|
1251
|
+
An empty sequence is returned when no more rows are available.
|
|
1252
|
+
"""
|
|
1253
|
+
if size < 0:
|
|
1254
|
+
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
|
|
1255
|
+
|
|
1256
|
+
results = self.results.next_n_rows(size)
|
|
1257
|
+
n_remaining_rows = size - results.num_rows
|
|
1258
|
+
self._next_row_index += results.num_rows
|
|
1259
|
+
|
|
1260
|
+
while (
|
|
1261
|
+
n_remaining_rows > 0
|
|
1262
|
+
and not self.has_been_closed_server_side
|
|
1263
|
+
and self.has_more_rows
|
|
1264
|
+
):
|
|
1265
|
+
self._fill_results_buffer()
|
|
1266
|
+
partial_results = self.results.next_n_rows(n_remaining_rows)
|
|
1267
|
+
results = self.merge_columnar(results, partial_results)
|
|
1268
|
+
n_remaining_rows -= partial_results.num_rows
|
|
1269
|
+
self._next_row_index += partial_results.num_rows
|
|
1270
|
+
|
|
1271
|
+
return results
|
|
1272
|
+
|
|
1202
1273
|
def fetchall_arrow(self) -> "pyarrow.Table":
|
|
1203
1274
|
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
|
|
1204
1275
|
results = self.results.remaining_rows()
|
|
@@ -1212,12 +1283,30 @@ class ResultSet:
|
|
|
1212
1283
|
|
|
1213
1284
|
return results
|
|
1214
1285
|
|
|
1286
|
+
def fetchall_columnar(self):
|
|
1287
|
+
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
|
|
1288
|
+
results = self.results.remaining_rows()
|
|
1289
|
+
self._next_row_index += results.num_rows
|
|
1290
|
+
|
|
1291
|
+
while not self.has_been_closed_server_side and self.has_more_rows:
|
|
1292
|
+
self._fill_results_buffer()
|
|
1293
|
+
partial_results = self.results.remaining_rows()
|
|
1294
|
+
results = self.merge_columnar(results, partial_results)
|
|
1295
|
+
self._next_row_index += partial_results.num_rows
|
|
1296
|
+
|
|
1297
|
+
return results
|
|
1298
|
+
|
|
1215
1299
|
def fetchone(self) -> Optional[Row]:
|
|
1216
1300
|
"""
|
|
1217
1301
|
Fetch the next row of a query result set, returning a single sequence,
|
|
1218
1302
|
or None when no more data is available.
|
|
1219
1303
|
"""
|
|
1220
|
-
|
|
1304
|
+
|
|
1305
|
+
if isinstance(self.results, ColumnQueue):
|
|
1306
|
+
res = self._convert_columnar_table(self.fetchmany_columnar(1))
|
|
1307
|
+
else:
|
|
1308
|
+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
|
|
1309
|
+
|
|
1221
1310
|
if len(res) > 0:
|
|
1222
1311
|
return res[0]
|
|
1223
1312
|
else:
|
|
@@ -1227,7 +1316,10 @@ class ResultSet:
|
|
|
1227
1316
|
"""
|
|
1228
1317
|
Fetch all (remaining) rows of a query result, returning them as a list of rows.
|
|
1229
1318
|
"""
|
|
1230
|
-
|
|
1319
|
+
if isinstance(self.results, ColumnQueue):
|
|
1320
|
+
return self._convert_columnar_table(self.fetchall_columnar())
|
|
1321
|
+
else:
|
|
1322
|
+
return self._convert_arrow_table(self.fetchall_arrow())
|
|
1231
1323
|
|
|
1232
1324
|
def fetchmany(self, size: int) -> List[Row]:
|
|
1233
1325
|
"""
|
|
@@ -1235,7 +1327,10 @@ class ResultSet:
|
|
|
1235
1327
|
|
|
1236
1328
|
An empty sequence is returned when no more rows are available.
|
|
1237
1329
|
"""
|
|
1238
|
-
|
|
1330
|
+
if isinstance(self.results, ColumnQueue):
|
|
1331
|
+
return self._convert_columnar_table(self.fetchmany_columnar(size))
|
|
1332
|
+
else:
|
|
1333
|
+
return self._convert_arrow_table(self.fetchmany_arrow(size))
|
|
1239
1334
|
|
|
1240
1335
|
def close(self) -> None:
|
|
1241
1336
|
"""
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
from ssl import SSLContext
|
|
4
3
|
from concurrent.futures import ThreadPoolExecutor, Future
|
|
5
4
|
from typing import List, Union
|
|
6
5
|
|
|
@@ -9,6 +8,8 @@ from databricks.sql.cloudfetch.downloader import (
|
|
|
9
8
|
DownloadableResultSettings,
|
|
10
9
|
DownloadedFile,
|
|
11
10
|
)
|
|
11
|
+
from databricks.sql.types import SSLOptions
|
|
12
|
+
|
|
12
13
|
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
|
|
13
14
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
@@ -20,7 +21,7 @@ class ResultFileDownloadManager:
|
|
|
20
21
|
links: List[TSparkArrowResultLink],
|
|
21
22
|
max_download_threads: int,
|
|
22
23
|
lz4_compressed: bool,
|
|
23
|
-
|
|
24
|
+
ssl_options: SSLOptions,
|
|
24
25
|
):
|
|
25
26
|
self._pending_links: List[TSparkArrowResultLink] = []
|
|
26
27
|
for link in links:
|
|
@@ -38,7 +39,7 @@ class ResultFileDownloadManager:
|
|
|
38
39
|
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
|
|
39
40
|
|
|
40
41
|
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
|
|
41
|
-
self.
|
|
42
|
+
self._ssl_options = ssl_options
|
|
42
43
|
|
|
43
44
|
def get_next_downloaded_file(
|
|
44
45
|
self, next_row_offset: int
|
|
@@ -95,7 +96,7 @@ class ResultFileDownloadManager:
|
|
|
95
96
|
handler = ResultSetDownloadHandler(
|
|
96
97
|
settings=self._downloadable_result_settings,
|
|
97
98
|
link=link,
|
|
98
|
-
|
|
99
|
+
ssl_options=self._ssl_options,
|
|
99
100
|
)
|
|
100
101
|
task = self._thread_pool.submit(handler.run)
|
|
101
102
|
self._download_tasks.append(task)
|
|
@@ -3,13 +3,12 @@ from dataclasses import dataclass
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
from requests.adapters import HTTPAdapter, Retry
|
|
6
|
-
from ssl import SSLContext, CERT_NONE
|
|
7
6
|
import lz4.frame
|
|
8
7
|
import time
|
|
9
8
|
|
|
10
9
|
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
|
|
11
|
-
|
|
12
10
|
from databricks.sql.exc import Error
|
|
11
|
+
from databricks.sql.types import SSLOptions
|
|
13
12
|
|
|
14
13
|
logger = logging.getLogger(__name__)
|
|
15
14
|
|
|
@@ -66,11 +65,11 @@ class ResultSetDownloadHandler:
|
|
|
66
65
|
self,
|
|
67
66
|
settings: DownloadableResultSettings,
|
|
68
67
|
link: TSparkArrowResultLink,
|
|
69
|
-
|
|
68
|
+
ssl_options: SSLOptions,
|
|
70
69
|
):
|
|
71
70
|
self.settings = settings
|
|
72
71
|
self.link = link
|
|
73
|
-
self.
|
|
72
|
+
self._ssl_options = ssl_options
|
|
74
73
|
|
|
75
74
|
def run(self) -> DownloadedFile:
|
|
76
75
|
"""
|
|
@@ -95,14 +94,14 @@ class ResultSetDownloadHandler:
|
|
|
95
94
|
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
|
|
96
95
|
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
|
|
97
96
|
|
|
98
|
-
ssl_verify = self._ssl_context.verify_mode != CERT_NONE
|
|
99
|
-
|
|
100
97
|
try:
|
|
101
98
|
# Get the file via HTTP request
|
|
102
99
|
response = session.get(
|
|
103
100
|
self.link.fileLink,
|
|
104
101
|
timeout=self.settings.download_timeout,
|
|
105
|
-
verify=
|
|
102
|
+
verify=self._ssl_options.tls_verify,
|
|
103
|
+
headers=self.link.httpHeaders
|
|
104
|
+
# TODO: Pass cert from `self._ssl_options`
|
|
106
105
|
)
|
|
107
106
|
response.raise_for_status()
|
|
108
107
|
|
|
@@ -5,9 +5,12 @@ import math
|
|
|
5
5
|
import time
|
|
6
6
|
import uuid
|
|
7
7
|
import threading
|
|
8
|
-
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
|
|
9
8
|
from typing import List, Union
|
|
10
9
|
|
|
10
|
+
try:
|
|
11
|
+
import pyarrow
|
|
12
|
+
except ImportError:
|
|
13
|
+
pyarrow = None
|
|
11
14
|
import thrift.transport.THttpClient
|
|
12
15
|
import thrift.protocol.TBinaryProtocol
|
|
13
16
|
import thrift.transport.TSocket
|
|
@@ -35,11 +38,7 @@ from databricks.sql.utils import (
|
|
|
35
38
|
convert_decimals_in_arrow_table,
|
|
36
39
|
convert_column_based_set_to_arrow_table,
|
|
37
40
|
)
|
|
38
|
-
|
|
39
|
-
try:
|
|
40
|
-
import pyarrow
|
|
41
|
-
except ImportError:
|
|
42
|
-
pyarrow = None
|
|
41
|
+
from databricks.sql.types import SSLOptions
|
|
43
42
|
|
|
44
43
|
logger = logging.getLogger(__name__)
|
|
45
44
|
|
|
@@ -89,6 +88,7 @@ class ThriftBackend:
|
|
|
89
88
|
http_path: str,
|
|
90
89
|
http_headers,
|
|
91
90
|
auth_provider: AuthProvider,
|
|
91
|
+
ssl_options: SSLOptions,
|
|
92
92
|
staging_allowed_local_path: Union[None, str, List[str]] = None,
|
|
93
93
|
**kwargs,
|
|
94
94
|
):
|
|
@@ -97,16 +97,6 @@ class ThriftBackend:
|
|
|
97
97
|
# Tag to add to User-Agent header. For use by partners.
|
|
98
98
|
# _username, _password
|
|
99
99
|
# Username and password Basic authentication (no official support)
|
|
100
|
-
# _tls_no_verify
|
|
101
|
-
# Set to True (Boolean) to completely disable SSL verification.
|
|
102
|
-
# _tls_verify_hostname
|
|
103
|
-
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
|
|
104
|
-
# _tls_trusted_ca_file
|
|
105
|
-
# Set to the path of the file containing trusted CA certificates for server certificate
|
|
106
|
-
# verification. If not provide, uses system truststore.
|
|
107
|
-
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
|
|
108
|
-
# Set client SSL certificate.
|
|
109
|
-
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
|
|
110
100
|
# _connection_uri
|
|
111
101
|
# Overrides server_hostname and http_path.
|
|
112
102
|
# RETRY/ATTEMPT POLICY
|
|
@@ -166,29 +156,7 @@ class ThriftBackend:
|
|
|
166
156
|
# Cloud fetch
|
|
167
157
|
self.max_download_threads = kwargs.get("max_download_threads", 10)
|
|
168
158
|
|
|
169
|
-
|
|
170
|
-
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
|
|
171
|
-
if kwargs.get("_tls_no_verify") is True:
|
|
172
|
-
ssl_context.check_hostname = False
|
|
173
|
-
ssl_context.verify_mode = CERT_NONE
|
|
174
|
-
elif kwargs.get("_tls_verify_hostname") is False:
|
|
175
|
-
ssl_context.check_hostname = False
|
|
176
|
-
ssl_context.verify_mode = CERT_REQUIRED
|
|
177
|
-
else:
|
|
178
|
-
ssl_context.check_hostname = True
|
|
179
|
-
ssl_context.verify_mode = CERT_REQUIRED
|
|
180
|
-
|
|
181
|
-
tls_client_cert_file = kwargs.get("_tls_client_cert_file")
|
|
182
|
-
tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
|
|
183
|
-
tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
|
|
184
|
-
if tls_client_cert_file:
|
|
185
|
-
ssl_context.load_cert_chain(
|
|
186
|
-
certfile=tls_client_cert_file,
|
|
187
|
-
keyfile=tls_client_cert_key_file,
|
|
188
|
-
password=tls_client_cert_key_password,
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
self._ssl_context = ssl_context
|
|
159
|
+
self._ssl_options = ssl_options
|
|
192
160
|
|
|
193
161
|
self._auth_provider = auth_provider
|
|
194
162
|
|
|
@@ -229,7 +197,7 @@ class ThriftBackend:
|
|
|
229
197
|
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
|
|
230
198
|
auth_provider=self._auth_provider,
|
|
231
199
|
uri_or_host=uri,
|
|
232
|
-
|
|
200
|
+
ssl_options=self._ssl_options,
|
|
233
201
|
**additional_transport_args, # type: ignore
|
|
234
202
|
)
|
|
235
203
|
|
|
@@ -656,12 +624,6 @@ class ThriftBackend:
|
|
|
656
624
|
|
|
657
625
|
@staticmethod
|
|
658
626
|
def _hive_schema_to_arrow_schema(t_table_schema):
|
|
659
|
-
|
|
660
|
-
if pyarrow is None:
|
|
661
|
-
raise ImportError(
|
|
662
|
-
"pyarrow is required to convert Hive schema to Arrow schema"
|
|
663
|
-
)
|
|
664
|
-
|
|
665
627
|
def map_type(t_type_entry):
|
|
666
628
|
if t_type_entry.primitiveEntry:
|
|
667
629
|
return {
|
|
@@ -767,12 +729,17 @@ class ThriftBackend:
|
|
|
767
729
|
description = self._hive_schema_to_description(
|
|
768
730
|
t_result_set_metadata_resp.schema
|
|
769
731
|
)
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
732
|
+
|
|
733
|
+
if pyarrow:
|
|
734
|
+
schema_bytes = (
|
|
735
|
+
t_result_set_metadata_resp.arrowSchema
|
|
736
|
+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
|
|
737
|
+
.serialize()
|
|
738
|
+
.to_pybytes()
|
|
739
|
+
)
|
|
740
|
+
else:
|
|
741
|
+
schema_bytes = None
|
|
742
|
+
|
|
776
743
|
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
|
|
777
744
|
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
|
|
778
745
|
if direct_results and direct_results.resultSet:
|
|
@@ -786,7 +753,7 @@ class ThriftBackend:
|
|
|
786
753
|
max_download_threads=self.max_download_threads,
|
|
787
754
|
lz4_compressed=lz4_compressed,
|
|
788
755
|
description=description,
|
|
789
|
-
|
|
756
|
+
ssl_options=self._ssl_options,
|
|
790
757
|
)
|
|
791
758
|
else:
|
|
792
759
|
arrow_queue_opt = None
|
|
@@ -1018,7 +985,7 @@ class ThriftBackend:
|
|
|
1018
985
|
max_download_threads=self.max_download_threads,
|
|
1019
986
|
lz4_compressed=lz4_compressed,
|
|
1020
987
|
description=description,
|
|
1021
|
-
|
|
988
|
+
ssl_options=self._ssl_options,
|
|
1022
989
|
)
|
|
1023
990
|
|
|
1024
991
|
return queue, resp.hasMoreRows
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/types.py
RENAMED
|
@@ -19,6 +19,54 @@
|
|
|
19
19
|
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
|
|
20
20
|
import datetime
|
|
21
21
|
import decimal
|
|
22
|
+
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SSLOptions:
|
|
26
|
+
tls_verify: bool
|
|
27
|
+
tls_verify_hostname: bool
|
|
28
|
+
tls_trusted_ca_file: Optional[str]
|
|
29
|
+
tls_client_cert_file: Optional[str]
|
|
30
|
+
tls_client_cert_key_file: Optional[str]
|
|
31
|
+
tls_client_cert_key_password: Optional[str]
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
tls_verify: bool = True,
|
|
36
|
+
tls_verify_hostname: bool = True,
|
|
37
|
+
tls_trusted_ca_file: Optional[str] = None,
|
|
38
|
+
tls_client_cert_file: Optional[str] = None,
|
|
39
|
+
tls_client_cert_key_file: Optional[str] = None,
|
|
40
|
+
tls_client_cert_key_password: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
self.tls_verify = tls_verify
|
|
43
|
+
self.tls_verify_hostname = tls_verify_hostname
|
|
44
|
+
self.tls_trusted_ca_file = tls_trusted_ca_file
|
|
45
|
+
self.tls_client_cert_file = tls_client_cert_file
|
|
46
|
+
self.tls_client_cert_key_file = tls_client_cert_key_file
|
|
47
|
+
self.tls_client_cert_key_password = tls_client_cert_key_password
|
|
48
|
+
|
|
49
|
+
def create_ssl_context(self) -> SSLContext:
|
|
50
|
+
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
|
|
51
|
+
|
|
52
|
+
if self.tls_verify is False:
|
|
53
|
+
ssl_context.check_hostname = False
|
|
54
|
+
ssl_context.verify_mode = CERT_NONE
|
|
55
|
+
elif self.tls_verify_hostname is False:
|
|
56
|
+
ssl_context.check_hostname = False
|
|
57
|
+
ssl_context.verify_mode = CERT_REQUIRED
|
|
58
|
+
else:
|
|
59
|
+
ssl_context.check_hostname = True
|
|
60
|
+
ssl_context.verify_mode = CERT_REQUIRED
|
|
61
|
+
|
|
62
|
+
if self.tls_client_cert_file:
|
|
63
|
+
ssl_context.load_cert_chain(
|
|
64
|
+
certfile=self.tls_client_cert_file,
|
|
65
|
+
keyfile=self.tls_client_cert_key_file,
|
|
66
|
+
password=self.tls_client_cert_key_password,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return ssl_context
|
|
22
70
|
|
|
23
71
|
|
|
24
72
|
class Row(tuple):
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/utils.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import pytz
|
|
3
4
|
import datetime
|
|
4
5
|
import decimal
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -9,10 +10,14 @@ from decimal import Decimal
|
|
|
9
10
|
from enum import Enum
|
|
10
11
|
from typing import Any, Dict, List, Optional, Union
|
|
11
12
|
import re
|
|
12
|
-
from ssl import SSLContext
|
|
13
13
|
|
|
14
14
|
import lz4.frame
|
|
15
15
|
|
|
16
|
+
try:
|
|
17
|
+
import pyarrow
|
|
18
|
+
except ImportError:
|
|
19
|
+
pyarrow = None
|
|
20
|
+
|
|
16
21
|
from databricks.sql import OperationalError, exc
|
|
17
22
|
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
|
|
18
23
|
from databricks.sql.thrift_api.TCLIService.ttypes import (
|
|
@@ -20,17 +25,14 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
|
|
|
20
25
|
TSparkArrowResultLink,
|
|
21
26
|
TSparkRowSetType,
|
|
22
27
|
)
|
|
28
|
+
from databricks.sql.types import SSLOptions
|
|
23
29
|
|
|
24
30
|
from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
|
|
25
31
|
|
|
26
|
-
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
|
|
27
|
-
|
|
28
32
|
import logging
|
|
29
33
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
except ImportError:
|
|
33
|
-
pyarrow = None
|
|
34
|
+
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
|
|
35
|
+
DEFAULT_ERROR_CONTEXT = "Unknown error"
|
|
34
36
|
|
|
35
37
|
logger = logging.getLogger(__name__)
|
|
36
38
|
|
|
@@ -52,7 +54,7 @@ class ResultSetQueueFactory(ABC):
|
|
|
52
54
|
t_row_set: TRowSet,
|
|
53
55
|
arrow_schema_bytes: bytes,
|
|
54
56
|
max_download_threads: int,
|
|
55
|
-
|
|
57
|
+
ssl_options: SSLOptions,
|
|
56
58
|
lz4_compressed: bool = True,
|
|
57
59
|
description: Optional[List[List[Any]]] = None,
|
|
58
60
|
) -> ResultSetQueue:
|
|
@@ -66,7 +68,7 @@ class ResultSetQueueFactory(ABC):
|
|
|
66
68
|
lz4_compressed (bool): Whether result data has been lz4 compressed.
|
|
67
69
|
description (List[List[Any]]): Hive table schema description.
|
|
68
70
|
max_download_threads (int): Maximum number of downloader thread pool threads.
|
|
69
|
-
|
|
71
|
+
ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
|
|
70
72
|
|
|
71
73
|
Returns:
|
|
72
74
|
ResultSetQueue
|
|
@@ -80,13 +82,15 @@ class ResultSetQueueFactory(ABC):
|
|
|
80
82
|
)
|
|
81
83
|
return ArrowQueue(converted_arrow_table, n_valid_rows)
|
|
82
84
|
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
|
|
83
|
-
|
|
85
|
+
column_table, column_names = convert_column_based_set_to_column_table(
|
|
84
86
|
t_row_set.columns, description
|
|
85
87
|
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
|
|
89
|
+
converted_column_table = convert_to_assigned_datatypes_in_column_table(
|
|
90
|
+
column_table, description
|
|
88
91
|
)
|
|
89
|
-
|
|
92
|
+
|
|
93
|
+
return ColumnQueue(ColumnTable(converted_column_table, column_names))
|
|
90
94
|
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
|
|
91
95
|
return CloudFetchQueue(
|
|
92
96
|
schema_bytes=arrow_schema_bytes,
|
|
@@ -95,12 +99,65 @@ class ResultSetQueueFactory(ABC):
|
|
|
95
99
|
lz4_compressed=lz4_compressed,
|
|
96
100
|
description=description,
|
|
97
101
|
max_download_threads=max_download_threads,
|
|
98
|
-
|
|
102
|
+
ssl_options=ssl_options,
|
|
99
103
|
)
|
|
100
104
|
else:
|
|
101
105
|
raise AssertionError("Row set type is not valid")
|
|
102
106
|
|
|
103
107
|
|
|
108
|
+
class ColumnTable:
|
|
109
|
+
def __init__(self, column_table, column_names):
|
|
110
|
+
self.column_table = column_table
|
|
111
|
+
self.column_names = column_names
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def num_rows(self):
|
|
115
|
+
if len(self.column_table) == 0:
|
|
116
|
+
return 0
|
|
117
|
+
else:
|
|
118
|
+
return len(self.column_table[0])
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def num_columns(self):
|
|
122
|
+
return len(self.column_names)
|
|
123
|
+
|
|
124
|
+
def get_item(self, col_index, row_index):
|
|
125
|
+
return self.column_table[col_index][row_index]
|
|
126
|
+
|
|
127
|
+
def slice(self, curr_index, length):
|
|
128
|
+
sliced_column_table = [
|
|
129
|
+
column[curr_index : curr_index + length] for column in self.column_table
|
|
130
|
+
]
|
|
131
|
+
return ColumnTable(sliced_column_table, self.column_names)
|
|
132
|
+
|
|
133
|
+
def __eq__(self, other):
|
|
134
|
+
return (
|
|
135
|
+
self.column_table == other.column_table
|
|
136
|
+
and self.column_names == other.column_names
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ColumnQueue(ResultSetQueue):
|
|
141
|
+
def __init__(self, column_table: ColumnTable):
|
|
142
|
+
self.column_table = column_table
|
|
143
|
+
self.cur_row_index = 0
|
|
144
|
+
self.n_valid_rows = column_table.num_rows
|
|
145
|
+
|
|
146
|
+
def next_n_rows(self, num_rows):
|
|
147
|
+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
|
|
148
|
+
|
|
149
|
+
slice = self.column_table.slice(self.cur_row_index, length)
|
|
150
|
+
self.cur_row_index += slice.num_rows
|
|
151
|
+
return slice
|
|
152
|
+
|
|
153
|
+
def remaining_rows(self):
|
|
154
|
+
slice = self.column_table.slice(
|
|
155
|
+
self.cur_row_index, self.n_valid_rows - self.cur_row_index
|
|
156
|
+
)
|
|
157
|
+
self.cur_row_index += slice.num_rows
|
|
158
|
+
return slice
|
|
159
|
+
|
|
160
|
+
|
|
104
161
|
class ArrowQueue(ResultSetQueue):
|
|
105
162
|
def __init__(
|
|
106
163
|
self,
|
|
@@ -141,7 +198,7 @@ class CloudFetchQueue(ResultSetQueue):
|
|
|
141
198
|
self,
|
|
142
199
|
schema_bytes,
|
|
143
200
|
max_download_threads: int,
|
|
144
|
-
|
|
201
|
+
ssl_options: SSLOptions,
|
|
145
202
|
start_row_offset: int = 0,
|
|
146
203
|
result_links: Optional[List[TSparkArrowResultLink]] = None,
|
|
147
204
|
lz4_compressed: bool = True,
|
|
@@ -164,7 +221,7 @@ class CloudFetchQueue(ResultSetQueue):
|
|
|
164
221
|
self.result_links = result_links
|
|
165
222
|
self.lz4_compressed = lz4_compressed
|
|
166
223
|
self.description = description
|
|
167
|
-
self.
|
|
224
|
+
self._ssl_options = ssl_options
|
|
168
225
|
|
|
169
226
|
logger.debug(
|
|
170
227
|
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
|
|
@@ -182,7 +239,7 @@ class CloudFetchQueue(ResultSetQueue):
|
|
|
182
239
|
links=result_links or [],
|
|
183
240
|
max_download_threads=self.max_download_threads,
|
|
184
241
|
lz4_compressed=self.lz4_compressed,
|
|
185
|
-
|
|
242
|
+
ssl_options=self._ssl_options,
|
|
186
243
|
)
|
|
187
244
|
|
|
188
245
|
self.table = self._create_next_table()
|
|
@@ -361,7 +418,12 @@ class RequestErrorInfo(
|
|
|
361
418
|
user_friendly_error_message = "{}: {}".format(
|
|
362
419
|
user_friendly_error_message, self.error_message
|
|
363
420
|
)
|
|
364
|
-
|
|
421
|
+
try:
|
|
422
|
+
error_context = str(self.error)
|
|
423
|
+
except:
|
|
424
|
+
error_context = DEFAULT_ERROR_CONTEXT
|
|
425
|
+
|
|
426
|
+
return user_friendly_error_message + ". " + error_context
|
|
365
427
|
|
|
366
428
|
|
|
367
429
|
# Taken from PyHive
|
|
@@ -566,6 +628,37 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
|
|
|
566
628
|
return table
|
|
567
629
|
|
|
568
630
|
|
|
631
|
+
def convert_to_assigned_datatypes_in_column_table(column_table, description):
|
|
632
|
+
|
|
633
|
+
converted_column_table = []
|
|
634
|
+
for i, col in enumerate(column_table):
|
|
635
|
+
if description[i][1] == "decimal":
|
|
636
|
+
converted_column_table.append(
|
|
637
|
+
tuple(v if v is None else Decimal(v) for v in col)
|
|
638
|
+
)
|
|
639
|
+
elif description[i][1] == "date":
|
|
640
|
+
converted_column_table.append(
|
|
641
|
+
tuple(v if v is None else datetime.date.fromisoformat(v) for v in col)
|
|
642
|
+
)
|
|
643
|
+
elif description[i][1] == "timestamp":
|
|
644
|
+
converted_column_table.append(
|
|
645
|
+
tuple(
|
|
646
|
+
(
|
|
647
|
+
v
|
|
648
|
+
if v is None
|
|
649
|
+
else datetime.datetime.strptime(
|
|
650
|
+
v, "%Y-%m-%d %H:%M:%S.%f"
|
|
651
|
+
).replace(tzinfo=pytz.UTC)
|
|
652
|
+
)
|
|
653
|
+
for v in col
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
else:
|
|
657
|
+
converted_column_table.append(col)
|
|
658
|
+
|
|
659
|
+
return converted_column_table
|
|
660
|
+
|
|
661
|
+
|
|
569
662
|
def convert_column_based_set_to_arrow_table(columns, description):
|
|
570
663
|
arrow_table = pyarrow.Table.from_arrays(
|
|
571
664
|
[_convert_column_to_arrow_array(c) for c in columns],
|
|
@@ -577,6 +670,13 @@ def convert_column_based_set_to_arrow_table(columns, description):
|
|
|
577
670
|
return arrow_table, arrow_table.num_rows
|
|
578
671
|
|
|
579
672
|
|
|
673
|
+
def convert_column_based_set_to_column_table(columns, description):
|
|
674
|
+
column_names = [c[0] for c in description]
|
|
675
|
+
column_table = [_convert_column_to_list(c) for c in columns]
|
|
676
|
+
|
|
677
|
+
return column_table, column_names
|
|
678
|
+
|
|
679
|
+
|
|
580
680
|
def _convert_column_to_arrow_array(t_col):
|
|
581
681
|
"""
|
|
582
682
|
Return a pyarrow array from the values in a TColumn instance.
|
|
@@ -601,6 +701,26 @@ def _convert_column_to_arrow_array(t_col):
|
|
|
601
701
|
raise OperationalError("Empty TColumn instance {}".format(t_col))
|
|
602
702
|
|
|
603
703
|
|
|
704
|
+
def _convert_column_to_list(t_col):
|
|
705
|
+
SUPPORTED_FIELD_TYPES = (
|
|
706
|
+
"boolVal",
|
|
707
|
+
"byteVal",
|
|
708
|
+
"i16Val",
|
|
709
|
+
"i32Val",
|
|
710
|
+
"i64Val",
|
|
711
|
+
"doubleVal",
|
|
712
|
+
"stringVal",
|
|
713
|
+
"binaryVal",
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
for field in SUPPORTED_FIELD_TYPES:
|
|
717
|
+
wrapper = getattr(t_col, field)
|
|
718
|
+
if wrapper:
|
|
719
|
+
return _create_python_tuple(wrapper)
|
|
720
|
+
|
|
721
|
+
raise OperationalError("Empty TColumn instance {}".format(t_col))
|
|
722
|
+
|
|
723
|
+
|
|
604
724
|
def _create_arrow_array(t_col_value_wrapper, arrow_type):
|
|
605
725
|
result = t_col_value_wrapper.values
|
|
606
726
|
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
|
|
@@ -615,3 +735,19 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
|
|
|
615
735
|
result[i] = None
|
|
616
736
|
|
|
617
737
|
return pyarrow.array(result, type=arrow_type)
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def _create_python_tuple(t_col_value_wrapper):
|
|
741
|
+
result = t_col_value_wrapper.values
|
|
742
|
+
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
|
|
743
|
+
assert isinstance(nulls, bytes)
|
|
744
|
+
|
|
745
|
+
# The number of bits in nulls can be both larger or smaller than the number of
|
|
746
|
+
# elements in result, so take the minimum of both to iterate over.
|
|
747
|
+
length = min(len(result), len(nulls) * 8)
|
|
748
|
+
|
|
749
|
+
for i in range(length):
|
|
750
|
+
if nulls[i >> 3] & BIT_MASKS[i & 0x7]:
|
|
751
|
+
result[i] = None
|
|
752
|
+
|
|
753
|
+
return tuple(result)
|
|
File without changes
|
|
File without changes
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/exc.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/py.typed
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|