databricks-sql-connector 4.0.0b2__tar.gz → 4.0.0b3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/CHANGELOG.md +15 -0
  2. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/PKG-INFO +1 -1
  3. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/pyproject.toml +1 -1
  4. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/__init__.py +1 -1
  5. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/thrift_http_client.py +25 -16
  6. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/client.py +104 -9
  7. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/cloudfetch/download_manager.py +5 -4
  8. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/cloudfetch/downloader.py +6 -7
  9. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_backend.py +21 -54
  10. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/types.py +48 -0
  11. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/utils.py +154 -18
  12. databricks_sql_connector-4.0.0b3/src/databricks/sqlalchemy/__init__.py +6 -0
  13. databricks_sql_connector-4.0.0b2/src/databricks/sqlalchemy/__init__.py +0 -6
  14. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/LICENSE +0 -0
  15. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/README.md +0 -0
  16. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/__init__.py +0 -0
  17. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/__init__.py +0 -0
  18. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/auth.py +0 -0
  19. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/authenticators.py +0 -0
  20. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/endpoint.py +0 -0
  21. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/oauth.py +0 -0
  22. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
  23. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/auth/retry.py +0 -0
  24. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/exc.py +0 -0
  25. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/experimental/__init__.py +0 -0
  26. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
  27. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/parameters/__init__.py +0 -0
  28. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/parameters/native.py +0 -0
  29. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/parameters/py.typed +0 -0
  30. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/py.typed +0 -0
  31. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
  32. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
  33. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
  34. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
  35. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
  36. {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b3}/src/databricks/sql/thrift_api/__init__.py +0 -0
@@ -1,5 +1,20 @@
1
1
  # Release History
2
2
 
3
+ # 3.6.0 (2024-10-25)
4
+
5
+ - Support encryption headers in the cloud fetch request (https://github.com/databricks/databricks-sql-python/pull/460 by @jackyhu-db)
6
+
7
+ # 3.5.0 (2024-10-18)
8
+
9
+ - Create a non pyarrow flow to handle small results for the column set (databricks/databricks-sql-python#440 by @jprakash-db)
10
+ - Fix: On non-retryable error, ensure PySQL includes useful information in error (databricks/databricks-sql-python#447 by @shivam2680)
11
+
12
+ # 3.4.0 (2024-08-27)
13
+
14
+ - Unpin pandas to support v2.2.2 (databricks/databricks-sql-python#416 by @kfollesdal)
15
+ - Make OAuth as the default authenticator if no authentication setting is provided (databricks/databricks-sql-python#419 by @jackyhu-db)
16
+ - Fix (regression): use SSL options with HTTPS connection pool (databricks/databricks-sql-python#425 by @kravets-levko)
17
+
3
18
  # 3.3.0 (2024-07-18)
4
19
 
5
20
  - Don't retry requests that fail with HTTP code 401 (databricks/databricks-sql-python#408 by @Hodnebo)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: databricks-sql-connector
3
- Version: 4.0.0b2
3
+ Version: 4.0.0b3
4
4
  Summary: Databricks SQL Connector for Python
5
5
  License: Apache-2.0
6
6
  Author: Databricks
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "databricks-sql-connector"
3
- version = "4.0.0.b2"
3
+ version = "4.0.0.b3"
4
4
  description = "Databricks SQL Connector for Python"
5
5
  authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
6
6
  license = "Apache-2.0"
@@ -68,7 +68,7 @@ DATETIME = DBAPITypeObject("timestamp")
68
68
  DATE = DBAPITypeObject("date")
69
69
  ROWID = DBAPITypeObject()
70
70
 
71
- __version__ = "3.3.0"
71
+ __version__ = "3.6.0"
72
72
  USER_AGENT_NAME = "PyDatabricksSqlConnector"
73
73
 
74
74
  # These two functions are pyhive legacy
@@ -1,13 +1,11 @@
1
1
  import base64
2
2
  import logging
3
3
  import urllib.parse
4
- from typing import Dict, Union
4
+ from typing import Dict, Union, Optional
5
5
 
6
6
  import six
7
7
  import thrift
8
8
 
9
- logger = logging.getLogger(__name__)
10
-
11
9
  import ssl
12
10
  import warnings
13
11
  from http.client import HTTPResponse
@@ -16,6 +14,9 @@ from io import BytesIO
16
14
  from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
17
15
  from urllib3.util import make_headers
18
16
  from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
17
+ from databricks.sql.types import SSLOptions
18
+
19
+ logger = logging.getLogger(__name__)
19
20
 
20
21
 
21
22
  class THttpClient(thrift.transport.THttpClient.THttpClient):
@@ -25,13 +26,12 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
25
26
  uri_or_host,
26
27
  port=None,
27
28
  path=None,
28
- cafile=None,
29
- cert_file=None,
30
- key_file=None,
31
- ssl_context=None,
29
+ ssl_options: Optional[SSLOptions] = None,
32
30
  max_connections: int = 1,
33
31
  retry_policy: Union[DatabricksRetryPolicy, int] = 0,
34
32
  ):
33
+ self._ssl_options = ssl_options
34
+
35
35
  if port is not None:
36
36
  warnings.warn(
37
37
  "Please use the THttpClient('http{s}://host:port/path') constructor",
@@ -48,13 +48,11 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
48
48
  self.scheme = parsed.scheme
49
49
  assert self.scheme in ("http", "https")
50
50
  if self.scheme == "https":
51
- self.certfile = cert_file
52
- self.keyfile = key_file
53
- self.context = (
54
- ssl.create_default_context(cafile=cafile)
55
- if (cafile and not ssl_context)
56
- else ssl_context
57
- )
51
+ if self._ssl_options is not None:
52
+ # TODO: Not sure if those options are used anywhere - need to double-check
53
+ self.certfile = self._ssl_options.tls_client_cert_file
54
+ self.keyfile = self._ssl_options.tls_client_cert_key_file
55
+ self.context = self._ssl_options.create_ssl_context()
58
56
  self.port = parsed.port
59
57
  self.host = parsed.hostname
60
58
  self.path = parsed.path
@@ -109,12 +107,23 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
109
107
  def open(self):
110
108
 
111
109
  # self.__pool replaces the self.__http used by the original THttpClient
110
+ _pool_kwargs = {"maxsize": self.max_connections}
111
+
112
112
  if self.scheme == "http":
113
113
  pool_class = HTTPConnectionPool
114
114
  elif self.scheme == "https":
115
115
  pool_class = HTTPSConnectionPool
116
-
117
- _pool_kwargs = {"maxsize": self.max_connections}
116
+ _pool_kwargs.update(
117
+ {
118
+ "cert_reqs": ssl.CERT_REQUIRED
119
+ if self._ssl_options.tls_verify
120
+ else ssl.CERT_NONE,
121
+ "ca_certs": self._ssl_options.tls_trusted_ca_file,
122
+ "cert_file": self._ssl_options.tls_client_cert_file,
123
+ "key_file": self._ssl_options.tls_client_cert_key_file,
124
+ "key_password": self._ssl_options.tls_client_cert_key_password,
125
+ }
126
+ )
118
127
 
119
128
  if self.using_proxy():
120
129
  proxy_manager = ProxyManager(
@@ -1,6 +1,11 @@
1
1
  from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
2
2
 
3
3
  import pandas
4
+
5
+ try:
6
+ import pyarrow
7
+ except ImportError:
8
+ pyarrow = None
4
9
  import requests
5
10
  import json
6
11
  import os
@@ -21,6 +26,8 @@ from databricks.sql.utils import (
21
26
  ParamEscaper,
22
27
  inject_parameters,
23
28
  transform_paramstyle,
29
+ ColumnTable,
30
+ ColumnQueue,
24
31
  )
25
32
  from databricks.sql.parameters.native import (
26
33
  DbsqlParameterBase,
@@ -34,7 +41,7 @@ from databricks.sql.parameters.native import (
34
41
  )
35
42
 
36
43
 
37
- from databricks.sql.types import Row
44
+ from databricks.sql.types import Row, SSLOptions
38
45
  from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
39
46
  from databricks.sql.experimental.oauth_persistence import OAuthPersistence
40
47
 
@@ -42,10 +49,6 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
42
49
  TSparkParameter,
43
50
  )
44
51
 
45
- try:
46
- import pyarrow
47
- except ImportError:
48
- pyarrow = None
49
52
 
50
53
  logger = logging.getLogger(__name__)
51
54
 
@@ -181,8 +184,9 @@ class Connection:
181
184
  # _tls_trusted_ca_file
182
185
  # Set to the path of the file containing trusted CA certificates for server certificate
183
186
  # verification. If not provide, uses system truststore.
184
- # _tls_client_cert_file, _tls_client_cert_key_file
187
+ # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
185
188
  # Set client SSL certificate.
189
+ # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
186
190
  # _retry_stop_after_attempts_count
187
191
  # The maximum number of attempts during a request retry sequence (defaults to 24)
188
192
  # _socket_timeout
@@ -223,12 +227,25 @@ class Connection:
223
227
 
224
228
  base_headers = [("User-Agent", useragent_header)]
225
229
 
230
+ self._ssl_options = SSLOptions(
231
+ # Double negation is generally a bad thing, but we have to keep backward compatibility
232
+ tls_verify=not kwargs.get(
233
+ "_tls_no_verify", False
234
+ ), # by default - verify cert and host
235
+ tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
236
+ tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
237
+ tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
238
+ tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
239
+ tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
240
+ )
241
+
226
242
  self.thrift_backend = ThriftBackend(
227
243
  self.host,
228
244
  self.port,
229
245
  http_path,
230
246
  (http_headers or []) + base_headers,
231
247
  auth_provider,
248
+ ssl_options=self._ssl_options,
232
249
  _use_arrow_native_complex_types=_use_arrow_native_complex_types,
233
250
  **kwargs,
234
251
  )
@@ -1132,6 +1149,18 @@ class ResultSet:
1132
1149
  self.results = results
1133
1150
  self.has_more_rows = has_more_rows
1134
1151
 
1152
+ def _convert_columnar_table(self, table):
1153
+ column_names = [c[0] for c in self.description]
1154
+ ResultRow = Row(*column_names)
1155
+ result = []
1156
+ for row_index in range(table.num_rows):
1157
+ curr_row = []
1158
+ for col_index in range(table.num_columns):
1159
+ curr_row.append(table.get_item(col_index, row_index))
1160
+ result.append(ResultRow(*curr_row))
1161
+
1162
+ return result
1163
+
1135
1164
  def _convert_arrow_table(self, table):
1136
1165
  column_names = [c[0] for c in self.description]
1137
1166
  ResultRow = Row(*column_names)
@@ -1199,6 +1228,48 @@ class ResultSet:
1199
1228
 
1200
1229
  return results
1201
1230
 
1231
+ def merge_columnar(self, result1, result2):
1232
+ """
1233
+ Function to merge / combining the columnar results into a single result
1234
+ :param result1:
1235
+ :param result2:
1236
+ :return:
1237
+ """
1238
+
1239
+ if result1.column_names != result2.column_names:
1240
+ raise ValueError("The columns in the results don't match")
1241
+
1242
+ merged_result = [
1243
+ result1.column_table[i] + result2.column_table[i]
1244
+ for i in range(result1.num_columns)
1245
+ ]
1246
+ return ColumnTable(merged_result, result1.column_names)
1247
+
1248
+ def fetchmany_columnar(self, size: int):
1249
+ """
1250
+ Fetch the next set of rows of a query result, returning a Columnar Table.
1251
+ An empty sequence is returned when no more rows are available.
1252
+ """
1253
+ if size < 0:
1254
+ raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
1255
+
1256
+ results = self.results.next_n_rows(size)
1257
+ n_remaining_rows = size - results.num_rows
1258
+ self._next_row_index += results.num_rows
1259
+
1260
+ while (
1261
+ n_remaining_rows > 0
1262
+ and not self.has_been_closed_server_side
1263
+ and self.has_more_rows
1264
+ ):
1265
+ self._fill_results_buffer()
1266
+ partial_results = self.results.next_n_rows(n_remaining_rows)
1267
+ results = self.merge_columnar(results, partial_results)
1268
+ n_remaining_rows -= partial_results.num_rows
1269
+ self._next_row_index += partial_results.num_rows
1270
+
1271
+ return results
1272
+
1202
1273
  def fetchall_arrow(self) -> "pyarrow.Table":
1203
1274
  """Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
1204
1275
  results = self.results.remaining_rows()
@@ -1212,12 +1283,30 @@ class ResultSet:
1212
1283
 
1213
1284
  return results
1214
1285
 
1286
+ def fetchall_columnar(self):
1287
+ """Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
1288
+ results = self.results.remaining_rows()
1289
+ self._next_row_index += results.num_rows
1290
+
1291
+ while not self.has_been_closed_server_side and self.has_more_rows:
1292
+ self._fill_results_buffer()
1293
+ partial_results = self.results.remaining_rows()
1294
+ results = self.merge_columnar(results, partial_results)
1295
+ self._next_row_index += partial_results.num_rows
1296
+
1297
+ return results
1298
+
1215
1299
  def fetchone(self) -> Optional[Row]:
1216
1300
  """
1217
1301
  Fetch the next row of a query result set, returning a single sequence,
1218
1302
  or None when no more data is available.
1219
1303
  """
1220
- res = self._convert_arrow_table(self.fetchmany_arrow(1))
1304
+
1305
+ if isinstance(self.results, ColumnQueue):
1306
+ res = self._convert_columnar_table(self.fetchmany_columnar(1))
1307
+ else:
1308
+ res = self._convert_arrow_table(self.fetchmany_arrow(1))
1309
+
1221
1310
  if len(res) > 0:
1222
1311
  return res[0]
1223
1312
  else:
@@ -1227,7 +1316,10 @@ class ResultSet:
1227
1316
  """
1228
1317
  Fetch all (remaining) rows of a query result, returning them as a list of rows.
1229
1318
  """
1230
- return self._convert_arrow_table(self.fetchall_arrow())
1319
+ if isinstance(self.results, ColumnQueue):
1320
+ return self._convert_columnar_table(self.fetchall_columnar())
1321
+ else:
1322
+ return self._convert_arrow_table(self.fetchall_arrow())
1231
1323
 
1232
1324
  def fetchmany(self, size: int) -> List[Row]:
1233
1325
  """
@@ -1235,7 +1327,10 @@ class ResultSet:
1235
1327
 
1236
1328
  An empty sequence is returned when no more rows are available.
1237
1329
  """
1238
- return self._convert_arrow_table(self.fetchmany_arrow(size))
1330
+ if isinstance(self.results, ColumnQueue):
1331
+ return self._convert_columnar_table(self.fetchmany_columnar(size))
1332
+ else:
1333
+ return self._convert_arrow_table(self.fetchmany_arrow(size))
1239
1334
 
1240
1335
  def close(self) -> None:
1241
1336
  """
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
 
3
- from ssl import SSLContext
4
3
  from concurrent.futures import ThreadPoolExecutor, Future
5
4
  from typing import List, Union
6
5
 
@@ -9,6 +8,8 @@ from databricks.sql.cloudfetch.downloader import (
9
8
  DownloadableResultSettings,
10
9
  DownloadedFile,
11
10
  )
11
+ from databricks.sql.types import SSLOptions
12
+
12
13
  from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
13
14
 
14
15
  logger = logging.getLogger(__name__)
@@ -20,7 +21,7 @@ class ResultFileDownloadManager:
20
21
  links: List[TSparkArrowResultLink],
21
22
  max_download_threads: int,
22
23
  lz4_compressed: bool,
23
- ssl_context: SSLContext,
24
+ ssl_options: SSLOptions,
24
25
  ):
25
26
  self._pending_links: List[TSparkArrowResultLink] = []
26
27
  for link in links:
@@ -38,7 +39,7 @@ class ResultFileDownloadManager:
38
39
  self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
39
40
 
40
41
  self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
41
- self._ssl_context = ssl_context
42
+ self._ssl_options = ssl_options
42
43
 
43
44
  def get_next_downloaded_file(
44
45
  self, next_row_offset: int
@@ -95,7 +96,7 @@ class ResultFileDownloadManager:
95
96
  handler = ResultSetDownloadHandler(
96
97
  settings=self._downloadable_result_settings,
97
98
  link=link,
98
- ssl_context=self._ssl_context,
99
+ ssl_options=self._ssl_options,
99
100
  )
100
101
  task = self._thread_pool.submit(handler.run)
101
102
  self._download_tasks.append(task)
@@ -3,13 +3,12 @@ from dataclasses import dataclass
3
3
 
4
4
  import requests
5
5
  from requests.adapters import HTTPAdapter, Retry
6
- from ssl import SSLContext, CERT_NONE
7
6
  import lz4.frame
8
7
  import time
9
8
 
10
9
  from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
11
-
12
10
  from databricks.sql.exc import Error
11
+ from databricks.sql.types import SSLOptions
13
12
 
14
13
  logger = logging.getLogger(__name__)
15
14
 
@@ -66,11 +65,11 @@ class ResultSetDownloadHandler:
66
65
  self,
67
66
  settings: DownloadableResultSettings,
68
67
  link: TSparkArrowResultLink,
69
- ssl_context: SSLContext,
68
+ ssl_options: SSLOptions,
70
69
  ):
71
70
  self.settings = settings
72
71
  self.link = link
73
- self._ssl_context = ssl_context
72
+ self._ssl_options = ssl_options
74
73
 
75
74
  def run(self) -> DownloadedFile:
76
75
  """
@@ -95,14 +94,14 @@ class ResultSetDownloadHandler:
95
94
  session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
96
95
  session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
97
96
 
98
- ssl_verify = self._ssl_context.verify_mode != CERT_NONE
99
-
100
97
  try:
101
98
  # Get the file via HTTP request
102
99
  response = session.get(
103
100
  self.link.fileLink,
104
101
  timeout=self.settings.download_timeout,
105
- verify=ssl_verify,
102
+ verify=self._ssl_options.tls_verify,
103
+ headers=self.link.httpHeaders
104
+ # TODO: Pass cert from `self._ssl_options`
106
105
  )
107
106
  response.raise_for_status()
108
107
 
@@ -5,9 +5,12 @@ import math
5
5
  import time
6
6
  import uuid
7
7
  import threading
8
- from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
9
8
  from typing import List, Union
10
9
 
10
+ try:
11
+ import pyarrow
12
+ except ImportError:
13
+ pyarrow = None
11
14
  import thrift.transport.THttpClient
12
15
  import thrift.protocol.TBinaryProtocol
13
16
  import thrift.transport.TSocket
@@ -35,11 +38,7 @@ from databricks.sql.utils import (
35
38
  convert_decimals_in_arrow_table,
36
39
  convert_column_based_set_to_arrow_table,
37
40
  )
38
-
39
- try:
40
- import pyarrow
41
- except ImportError:
42
- pyarrow = None
41
+ from databricks.sql.types import SSLOptions
43
42
 
44
43
  logger = logging.getLogger(__name__)
45
44
 
@@ -89,6 +88,7 @@ class ThriftBackend:
89
88
  http_path: str,
90
89
  http_headers,
91
90
  auth_provider: AuthProvider,
91
+ ssl_options: SSLOptions,
92
92
  staging_allowed_local_path: Union[None, str, List[str]] = None,
93
93
  **kwargs,
94
94
  ):
@@ -97,16 +97,6 @@ class ThriftBackend:
97
97
  # Tag to add to User-Agent header. For use by partners.
98
98
  # _username, _password
99
99
  # Username and password Basic authentication (no official support)
100
- # _tls_no_verify
101
- # Set to True (Boolean) to completely disable SSL verification.
102
- # _tls_verify_hostname
103
- # Set to False (Boolean) to disable SSL hostname verification, but check certificate.
104
- # _tls_trusted_ca_file
105
- # Set to the path of the file containing trusted CA certificates for server certificate
106
- # verification. If not provide, uses system truststore.
107
- # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
108
- # Set client SSL certificate.
109
- # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
110
100
  # _connection_uri
111
101
  # Overrides server_hostname and http_path.
112
102
  # RETRY/ATTEMPT POLICY
@@ -166,29 +156,7 @@ class ThriftBackend:
166
156
  # Cloud fetch
167
157
  self.max_download_threads = kwargs.get("max_download_threads", 10)
168
158
 
169
- # Configure tls context
170
- ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
171
- if kwargs.get("_tls_no_verify") is True:
172
- ssl_context.check_hostname = False
173
- ssl_context.verify_mode = CERT_NONE
174
- elif kwargs.get("_tls_verify_hostname") is False:
175
- ssl_context.check_hostname = False
176
- ssl_context.verify_mode = CERT_REQUIRED
177
- else:
178
- ssl_context.check_hostname = True
179
- ssl_context.verify_mode = CERT_REQUIRED
180
-
181
- tls_client_cert_file = kwargs.get("_tls_client_cert_file")
182
- tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
183
- tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
184
- if tls_client_cert_file:
185
- ssl_context.load_cert_chain(
186
- certfile=tls_client_cert_file,
187
- keyfile=tls_client_cert_key_file,
188
- password=tls_client_cert_key_password,
189
- )
190
-
191
- self._ssl_context = ssl_context
159
+ self._ssl_options = ssl_options
192
160
 
193
161
  self._auth_provider = auth_provider
194
162
 
@@ -229,7 +197,7 @@ class ThriftBackend:
229
197
  self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
230
198
  auth_provider=self._auth_provider,
231
199
  uri_or_host=uri,
232
- ssl_context=self._ssl_context,
200
+ ssl_options=self._ssl_options,
233
201
  **additional_transport_args, # type: ignore
234
202
  )
235
203
 
@@ -656,12 +624,6 @@ class ThriftBackend:
656
624
 
657
625
  @staticmethod
658
626
  def _hive_schema_to_arrow_schema(t_table_schema):
659
-
660
- if pyarrow is None:
661
- raise ImportError(
662
- "pyarrow is required to convert Hive schema to Arrow schema"
663
- )
664
-
665
627
  def map_type(t_type_entry):
666
628
  if t_type_entry.primitiveEntry:
667
629
  return {
@@ -767,12 +729,17 @@ class ThriftBackend:
767
729
  description = self._hive_schema_to_description(
768
730
  t_result_set_metadata_resp.schema
769
731
  )
770
- schema_bytes = (
771
- t_result_set_metadata_resp.arrowSchema
772
- or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
773
- .serialize()
774
- .to_pybytes()
775
- )
732
+
733
+ if pyarrow:
734
+ schema_bytes = (
735
+ t_result_set_metadata_resp.arrowSchema
736
+ or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
737
+ .serialize()
738
+ .to_pybytes()
739
+ )
740
+ else:
741
+ schema_bytes = None
742
+
776
743
  lz4_compressed = t_result_set_metadata_resp.lz4Compressed
777
744
  is_staging_operation = t_result_set_metadata_resp.isStagingOperation
778
745
  if direct_results and direct_results.resultSet:
@@ -786,7 +753,7 @@ class ThriftBackend:
786
753
  max_download_threads=self.max_download_threads,
787
754
  lz4_compressed=lz4_compressed,
788
755
  description=description,
789
- ssl_context=self._ssl_context,
756
+ ssl_options=self._ssl_options,
790
757
  )
791
758
  else:
792
759
  arrow_queue_opt = None
@@ -1018,7 +985,7 @@ class ThriftBackend:
1018
985
  max_download_threads=self.max_download_threads,
1019
986
  lz4_compressed=lz4_compressed,
1020
987
  description=description,
1021
- ssl_context=self._ssl_context,
988
+ ssl_options=self._ssl_options,
1022
989
  )
1023
990
 
1024
991
  return queue, resp.hasMoreRows
@@ -19,6 +19,54 @@
19
19
  from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
20
20
  import datetime
21
21
  import decimal
22
+ from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
23
+
24
+
25
+ class SSLOptions:
26
+ tls_verify: bool
27
+ tls_verify_hostname: bool
28
+ tls_trusted_ca_file: Optional[str]
29
+ tls_client_cert_file: Optional[str]
30
+ tls_client_cert_key_file: Optional[str]
31
+ tls_client_cert_key_password: Optional[str]
32
+
33
+ def __init__(
34
+ self,
35
+ tls_verify: bool = True,
36
+ tls_verify_hostname: bool = True,
37
+ tls_trusted_ca_file: Optional[str] = None,
38
+ tls_client_cert_file: Optional[str] = None,
39
+ tls_client_cert_key_file: Optional[str] = None,
40
+ tls_client_cert_key_password: Optional[str] = None,
41
+ ):
42
+ self.tls_verify = tls_verify
43
+ self.tls_verify_hostname = tls_verify_hostname
44
+ self.tls_trusted_ca_file = tls_trusted_ca_file
45
+ self.tls_client_cert_file = tls_client_cert_file
46
+ self.tls_client_cert_key_file = tls_client_cert_key_file
47
+ self.tls_client_cert_key_password = tls_client_cert_key_password
48
+
49
+ def create_ssl_context(self) -> SSLContext:
50
+ ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
51
+
52
+ if self.tls_verify is False:
53
+ ssl_context.check_hostname = False
54
+ ssl_context.verify_mode = CERT_NONE
55
+ elif self.tls_verify_hostname is False:
56
+ ssl_context.check_hostname = False
57
+ ssl_context.verify_mode = CERT_REQUIRED
58
+ else:
59
+ ssl_context.check_hostname = True
60
+ ssl_context.verify_mode = CERT_REQUIRED
61
+
62
+ if self.tls_client_cert_file:
63
+ ssl_context.load_cert_chain(
64
+ certfile=self.tls_client_cert_file,
65
+ keyfile=self.tls_client_cert_key_file,
66
+ password=self.tls_client_cert_key_password,
67
+ )
68
+
69
+ return ssl_context
22
70
 
23
71
 
24
72
  class Row(tuple):
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import pytz
3
4
  import datetime
4
5
  import decimal
5
6
  from abc import ABC, abstractmethod
@@ -9,10 +10,14 @@ from decimal import Decimal
9
10
  from enum import Enum
10
11
  from typing import Any, Dict, List, Optional, Union
11
12
  import re
12
- from ssl import SSLContext
13
13
 
14
14
  import lz4.frame
15
15
 
16
+ try:
17
+ import pyarrow
18
+ except ImportError:
19
+ pyarrow = None
20
+
16
21
  from databricks.sql import OperationalError, exc
17
22
  from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
18
23
  from databricks.sql.thrift_api.TCLIService.ttypes import (
@@ -20,17 +25,14 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
20
25
  TSparkArrowResultLink,
21
26
  TSparkRowSetType,
22
27
  )
28
+ from databricks.sql.types import SSLOptions
23
29
 
24
30
  from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
25
31
 
26
- BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
27
-
28
32
  import logging
29
33
 
30
- try:
31
- import pyarrow
32
- except ImportError:
33
- pyarrow = None
34
+ BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
35
+ DEFAULT_ERROR_CONTEXT = "Unknown error"
34
36
 
35
37
  logger = logging.getLogger(__name__)
36
38
 
@@ -52,7 +54,7 @@ class ResultSetQueueFactory(ABC):
52
54
  t_row_set: TRowSet,
53
55
  arrow_schema_bytes: bytes,
54
56
  max_download_threads: int,
55
- ssl_context: SSLContext,
57
+ ssl_options: SSLOptions,
56
58
  lz4_compressed: bool = True,
57
59
  description: Optional[List[List[Any]]] = None,
58
60
  ) -> ResultSetQueue:
@@ -66,7 +68,7 @@ class ResultSetQueueFactory(ABC):
66
68
  lz4_compressed (bool): Whether result data has been lz4 compressed.
67
69
  description (List[List[Any]]): Hive table schema description.
68
70
  max_download_threads (int): Maximum number of downloader thread pool threads.
69
- ssl_context (SSLContext): SSLContext object for CloudFetchQueue
71
+ ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
70
72
 
71
73
  Returns:
72
74
  ResultSetQueue
@@ -80,13 +82,15 @@ class ResultSetQueueFactory(ABC):
80
82
  )
81
83
  return ArrowQueue(converted_arrow_table, n_valid_rows)
82
84
  elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
83
- arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
85
+ column_table, column_names = convert_column_based_set_to_column_table(
84
86
  t_row_set.columns, description
85
87
  )
86
- converted_arrow_table = convert_decimals_in_arrow_table(
87
- arrow_table, description
88
+
89
+ converted_column_table = convert_to_assigned_datatypes_in_column_table(
90
+ column_table, description
88
91
  )
89
- return ArrowQueue(converted_arrow_table, n_valid_rows)
92
+
93
+ return ColumnQueue(ColumnTable(converted_column_table, column_names))
90
94
  elif row_set_type == TSparkRowSetType.URL_BASED_SET:
91
95
  return CloudFetchQueue(
92
96
  schema_bytes=arrow_schema_bytes,
@@ -95,12 +99,65 @@ class ResultSetQueueFactory(ABC):
95
99
  lz4_compressed=lz4_compressed,
96
100
  description=description,
97
101
  max_download_threads=max_download_threads,
98
- ssl_context=ssl_context,
102
+ ssl_options=ssl_options,
99
103
  )
100
104
  else:
101
105
  raise AssertionError("Row set type is not valid")
102
106
 
103
107
 
108
+ class ColumnTable:
109
+ def __init__(self, column_table, column_names):
110
+ self.column_table = column_table
111
+ self.column_names = column_names
112
+
113
+ @property
114
+ def num_rows(self):
115
+ if len(self.column_table) == 0:
116
+ return 0
117
+ else:
118
+ return len(self.column_table[0])
119
+
120
+ @property
121
+ def num_columns(self):
122
+ return len(self.column_names)
123
+
124
+ def get_item(self, col_index, row_index):
125
+ return self.column_table[col_index][row_index]
126
+
127
+ def slice(self, curr_index, length):
128
+ sliced_column_table = [
129
+ column[curr_index : curr_index + length] for column in self.column_table
130
+ ]
131
+ return ColumnTable(sliced_column_table, self.column_names)
132
+
133
+ def __eq__(self, other):
134
+ return (
135
+ self.column_table == other.column_table
136
+ and self.column_names == other.column_names
137
+ )
138
+
139
+
140
+ class ColumnQueue(ResultSetQueue):
141
+ def __init__(self, column_table: ColumnTable):
142
+ self.column_table = column_table
143
+ self.cur_row_index = 0
144
+ self.n_valid_rows = column_table.num_rows
145
+
146
+ def next_n_rows(self, num_rows):
147
+ length = min(num_rows, self.n_valid_rows - self.cur_row_index)
148
+
149
+ slice = self.column_table.slice(self.cur_row_index, length)
150
+ self.cur_row_index += slice.num_rows
151
+ return slice
152
+
153
+ def remaining_rows(self):
154
+ slice = self.column_table.slice(
155
+ self.cur_row_index, self.n_valid_rows - self.cur_row_index
156
+ )
157
+ self.cur_row_index += slice.num_rows
158
+ return slice
159
+
160
+
104
161
  class ArrowQueue(ResultSetQueue):
105
162
  def __init__(
106
163
  self,
@@ -141,7 +198,7 @@ class CloudFetchQueue(ResultSetQueue):
141
198
  self,
142
199
  schema_bytes,
143
200
  max_download_threads: int,
144
- ssl_context: SSLContext,
201
+ ssl_options: SSLOptions,
145
202
  start_row_offset: int = 0,
146
203
  result_links: Optional[List[TSparkArrowResultLink]] = None,
147
204
  lz4_compressed: bool = True,
@@ -164,7 +221,7 @@ class CloudFetchQueue(ResultSetQueue):
164
221
  self.result_links = result_links
165
222
  self.lz4_compressed = lz4_compressed
166
223
  self.description = description
167
- self._ssl_context = ssl_context
224
+ self._ssl_options = ssl_options
168
225
 
169
226
  logger.debug(
170
227
  "Initialize CloudFetch loader, row set start offset: {}, file list:".format(
@@ -182,7 +239,7 @@ class CloudFetchQueue(ResultSetQueue):
182
239
  links=result_links or [],
183
240
  max_download_threads=self.max_download_threads,
184
241
  lz4_compressed=self.lz4_compressed,
185
- ssl_context=self._ssl_context,
242
+ ssl_options=self._ssl_options,
186
243
  )
187
244
 
188
245
  self.table = self._create_next_table()
@@ -361,7 +418,12 @@ class RequestErrorInfo(
361
418
  user_friendly_error_message = "{}: {}".format(
362
419
  user_friendly_error_message, self.error_message
363
420
  )
364
- return user_friendly_error_message
421
+ try:
422
+ error_context = str(self.error)
423
+ except:
424
+ error_context = DEFAULT_ERROR_CONTEXT
425
+
426
+ return user_friendly_error_message + ". " + error_context
365
427
 
366
428
 
367
429
  # Taken from PyHive
@@ -566,6 +628,37 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
566
628
  return table
567
629
 
568
630
 
631
+ def convert_to_assigned_datatypes_in_column_table(column_table, description):
632
+
633
+ converted_column_table = []
634
+ for i, col in enumerate(column_table):
635
+ if description[i][1] == "decimal":
636
+ converted_column_table.append(
637
+ tuple(v if v is None else Decimal(v) for v in col)
638
+ )
639
+ elif description[i][1] == "date":
640
+ converted_column_table.append(
641
+ tuple(v if v is None else datetime.date.fromisoformat(v) for v in col)
642
+ )
643
+ elif description[i][1] == "timestamp":
644
+ converted_column_table.append(
645
+ tuple(
646
+ (
647
+ v
648
+ if v is None
649
+ else datetime.datetime.strptime(
650
+ v, "%Y-%m-%d %H:%M:%S.%f"
651
+ ).replace(tzinfo=pytz.UTC)
652
+ )
653
+ for v in col
654
+ )
655
+ )
656
+ else:
657
+ converted_column_table.append(col)
658
+
659
+ return converted_column_table
660
+
661
+
569
662
  def convert_column_based_set_to_arrow_table(columns, description):
570
663
  arrow_table = pyarrow.Table.from_arrays(
571
664
  [_convert_column_to_arrow_array(c) for c in columns],
@@ -577,6 +670,13 @@ def convert_column_based_set_to_arrow_table(columns, description):
577
670
  return arrow_table, arrow_table.num_rows
578
671
 
579
672
 
673
+ def convert_column_based_set_to_column_table(columns, description):
674
+ column_names = [c[0] for c in description]
675
+ column_table = [_convert_column_to_list(c) for c in columns]
676
+
677
+ return column_table, column_names
678
+
679
+
580
680
  def _convert_column_to_arrow_array(t_col):
581
681
  """
582
682
  Return a pyarrow array from the values in a TColumn instance.
@@ -601,6 +701,26 @@ def _convert_column_to_arrow_array(t_col):
601
701
  raise OperationalError("Empty TColumn instance {}".format(t_col))
602
702
 
603
703
 
704
+ def _convert_column_to_list(t_col):
705
+ SUPPORTED_FIELD_TYPES = (
706
+ "boolVal",
707
+ "byteVal",
708
+ "i16Val",
709
+ "i32Val",
710
+ "i64Val",
711
+ "doubleVal",
712
+ "stringVal",
713
+ "binaryVal",
714
+ )
715
+
716
+ for field in SUPPORTED_FIELD_TYPES:
717
+ wrapper = getattr(t_col, field)
718
+ if wrapper:
719
+ return _create_python_tuple(wrapper)
720
+
721
+ raise OperationalError("Empty TColumn instance {}".format(t_col))
722
+
723
+
604
724
  def _create_arrow_array(t_col_value_wrapper, arrow_type):
605
725
  result = t_col_value_wrapper.values
606
726
  nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
@@ -615,3 +735,19 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
615
735
  result[i] = None
616
736
 
617
737
  return pyarrow.array(result, type=arrow_type)
738
+
739
+
740
+ def _create_python_tuple(t_col_value_wrapper):
741
+ result = t_col_value_wrapper.values
742
+ nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
743
+ assert isinstance(nulls, bytes)
744
+
745
+ # The number of bits in nulls can be both larger or smaller than the number of
746
+ # elements in result, so take the minimum of both to iterate over.
747
+ length = min(len(result), len(nulls) * 8)
748
+
749
+ for i in range(length):
750
+ if nulls[i >> 3] & BIT_MASKS[i & 0x7]:
751
+ result[i] = None
752
+
753
+ return tuple(result)
@@ -0,0 +1,6 @@
1
+ try:
2
+ from databricks_sqlalchemy import *
3
+ except:
4
+ import warnings
5
+
6
+ warnings.warn("Install databricks-sqlalchemy plugin before using this")
@@ -1,6 +0,0 @@
1
- try:
2
- from databricks_sqlalchemy import *
3
- except:
4
- import warnings
5
-
6
- warnings.warn("Install databricks-sqlalchemy plugin before using this")