databricks-sql-connector 3.3.0__tar.gz → 3.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/CHANGELOG.md +6 -0
  2. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/PKG-INFO +7 -9
  3. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/README.md +5 -7
  4. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/pyproject.toml +2 -2
  5. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/__init__.py +1 -1
  6. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/auth.py +14 -1
  7. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/thrift_http_client.py +25 -16
  8. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/client.py +18 -4
  9. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/cloudfetch/download_manager.py +5 -4
  10. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/cloudfetch/downloader.py +5 -7
  11. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_backend.py +6 -37
  12. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/types.py +48 -0
  13. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/utils.py +9 -9
  14. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/LICENSE +0 -0
  15. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/__init__.py +0 -0
  16. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/__init__.py +0 -0
  17. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/authenticators.py +0 -0
  18. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/endpoint.py +0 -0
  19. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/oauth.py +0 -0
  20. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
  21. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/auth/retry.py +0 -0
  22. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/exc.py +0 -0
  23. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/experimental/__init__.py +0 -0
  24. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
  25. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/parameters/__init__.py +0 -0
  26. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/parameters/native.py +0 -0
  27. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/parameters/py.typed +0 -0
  28. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/py.typed +0 -0
  29. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
  30. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
  31. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
  32. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
  33. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
  34. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sql/thrift_api/__init__.py +0 -0
  35. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/README.sqlalchemy.md +0 -0
  36. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/README.tests.md +0 -0
  37. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/__init__.py +0 -0
  38. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/_ddl.py +0 -0
  39. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/_parse.py +0 -0
  40. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/_types.py +0 -0
  41. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/base.py +0 -0
  42. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/py.typed +0 -0
  43. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/requirements.py +0 -0
  44. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/setup.cfg +0 -0
  45. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/_extra.py +0 -0
  46. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/_future.py +0 -0
  47. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/_regression.py +0 -0
  48. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/_unsupported.py +0 -0
  49. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/conftest.py +0 -0
  50. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/overrides/_componentreflectiontest.py +0 -0
  51. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/overrides/_ctetest.py +0 -0
  52. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test/test_suite.py +0 -0
  53. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/__init__.py +0 -0
  54. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/conftest.py +0 -0
  55. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/e2e/MOCK_DATA.xlsx +0 -0
  56. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +0 -0
  57. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/test_ddl.py +0 -0
  58. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/test_parsing.py +0 -0
  59. {databricks_sql_connector-3.3.0 → databricks_sql_connector-3.4.0}/src/databricks/sqlalchemy/test_local/test_types.py +0 -0
@@ -1,5 +1,11 @@
1
1
  # Release History
2
2
 
3
+ # 3.4.0 (2024-08-27)
4
+
5
+ - Unpin pandas to support v2.2.2 (databricks/databricks-sql-python#416 by @kfollesdal)
6
+ - Make OAuth as the default authenticator if no authentication setting is provided (databricks/databricks-sql-python#419 by @jackyhu-db)
7
+ - Fix (regression): use SSL options with HTTPS connection pool (databricks/databricks-sql-python#425 by @kravets-levko)
8
+
3
9
  # 3.3.0 (2024-07-18)
4
10
 
5
11
  - Don't retry requests that fail with HTTP code 401 (databricks/databricks-sql-python#408 by @Hodnebo)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: databricks-sql-connector
3
- Version: 3.3.0
3
+ Version: 3.4.0
4
4
  Summary: Databricks SQL Connector for Python
5
5
  License: Apache-2.0
6
6
  Author: Databricks
@@ -21,7 +21,7 @@ Requires-Dist: numpy (>=1.16.6,<2.0.0) ; python_version >= "3.8" and python_vers
21
21
  Requires-Dist: numpy (>=1.23.4,<2.0.0) ; python_version >= "3.11"
22
22
  Requires-Dist: oauthlib (>=3.1.0,<4.0.0)
23
23
  Requires-Dist: openpyxl (>=3.0.10,<4.0.0)
24
- Requires-Dist: pandas (>=1.2.5,<2.2.0) ; python_version >= "3.8"
24
+ Requires-Dist: pandas (>=1.2.5,<2.3.0) ; python_version >= "3.8"
25
25
  Requires-Dist: pyarrow (>=14.0.1,<17)
26
26
  Requires-Dist: requests (>=2.18.1,<3.0.0)
27
27
  Requires-Dist: sqlalchemy (>=2.0.21) ; extra == "sqlalchemy" or extra == "alembic"
@@ -57,12 +57,9 @@ For the latest documentation, see
57
57
 
58
58
  Install the library with `pip install databricks-sql-connector`
59
59
 
60
- Note: Don't hard-code authentication secrets into your Python. Use environment variables
61
-
62
60
  ```bash
63
61
  export DATABRICKS_HOST=********.databricks.com
64
62
  export DATABRICKS_HTTP_PATH=/sql/1.0/endpoints/****************
65
- export DATABRICKS_TOKEN=dapi********************************
66
63
  ```
67
64
 
68
65
  Example usage:
@@ -72,12 +69,10 @@ from databricks import sql
72
69
 
73
70
  host = os.getenv("DATABRICKS_HOST")
74
71
  http_path = os.getenv("DATABRICKS_HTTP_PATH")
75
- access_token = os.getenv("DATABRICKS_TOKEN")
76
72
 
77
73
  connection = sql.connect(
78
74
  server_hostname=host,
79
- http_path=http_path,
80
- access_token=access_token)
75
+ http_path=http_path)
81
76
 
82
77
  cursor = connection.cursor()
83
78
  cursor.execute('SELECT :param `p`, * FROM RANGE(10)', {"param": "foo"})
@@ -93,7 +88,10 @@ In the above example:
93
88
  - `server-hostname` is the Databricks instance host name.
94
89
  - `http-path` is the HTTP Path either to a Databricks SQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef),
95
90
  or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
96
- - `personal-access-token` is the Databricks Personal Access Token for the account that will execute commands and queries
91
+
92
+ > Note: This example uses [Databricks OAuth U2M](https://docs.databricks.com/en/dev-tools/auth/oauth-u2m.html)
93
+ > to authenticate the target Databricks user account and needs to open the browser for authentication. So it
94
+ > can only run on the user's machine.
97
95
 
98
96
 
99
97
  ## Contributing
@@ -24,12 +24,9 @@ For the latest documentation, see
24
24
 
25
25
  Install the library with `pip install databricks-sql-connector`
26
26
 
27
- Note: Don't hard-code authentication secrets into your Python. Use environment variables
28
-
29
27
  ```bash
30
28
  export DATABRICKS_HOST=********.databricks.com
31
29
  export DATABRICKS_HTTP_PATH=/sql/1.0/endpoints/****************
32
- export DATABRICKS_TOKEN=dapi********************************
33
30
  ```
34
31
 
35
32
  Example usage:
@@ -39,12 +36,10 @@ from databricks import sql
39
36
 
40
37
  host = os.getenv("DATABRICKS_HOST")
41
38
  http_path = os.getenv("DATABRICKS_HTTP_PATH")
42
- access_token = os.getenv("DATABRICKS_TOKEN")
43
39
 
44
40
  connection = sql.connect(
45
41
  server_hostname=host,
46
- http_path=http_path,
47
- access_token=access_token)
42
+ http_path=http_path)
48
43
 
49
44
  cursor = connection.cursor()
50
45
  cursor.execute('SELECT :param `p`, * FROM RANGE(10)', {"param": "foo"})
@@ -60,7 +55,10 @@ In the above example:
60
55
  - `server-hostname` is the Databricks instance host name.
61
56
  - `http-path` is the HTTP Path either to a Databricks SQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef),
62
57
  or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
63
- - `personal-access-token` is the Databricks Personal Access Token for the account that will execute commands and queries
58
+
59
+ > Note: This example uses [Databricks OAuth U2M](https://docs.databricks.com/en/dev-tools/auth/oauth-u2m.html)
60
+ > to authenticate the target Databricks user account and needs to open the browser for authentication. So it
61
+ > can only run on the user's machine.
64
62
 
65
63
 
66
64
  ## Contributing
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "databricks-sql-connector"
3
- version = "3.3.0"
3
+ version = "3.4.0"
4
4
  description = "Databricks SQL Connector for Python"
5
5
  authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
6
6
  license = "Apache-2.0"
@@ -12,7 +12,7 @@ include = ["CHANGELOG.md"]
12
12
  python = "^3.8.0"
13
13
  thrift = ">=0.16.0,<0.21.0"
14
14
  pandas = [
15
- { version = ">=1.2.5,<2.2.0", python = ">=3.8" }
15
+ { version = ">=1.2.5,<2.3.0", python = ">=3.8" }
16
16
  ]
17
17
  pyarrow = ">=14.0.1,<17"
18
18
 
@@ -68,7 +68,7 @@ DATETIME = DBAPITypeObject("timestamp")
68
68
  DATE = DBAPITypeObject("date")
69
69
  ROWID = DBAPITypeObject()
70
70
 
71
- __version__ = "3.3.0"
71
+ __version__ = "3.4.0"
72
72
  USER_AGENT_NAME = "PyDatabricksSqlConnector"
73
73
 
74
74
  # These two functions are pyhive legacy
@@ -64,7 +64,20 @@ def get_auth_provider(cfg: ClientContext):
64
64
  # no op authenticator. authentication is performed using ssl certificate outside of headers
65
65
  return AuthProvider()
66
66
  else:
67
- raise RuntimeError("No valid authentication settings!")
67
+ if (
68
+ cfg.oauth_redirect_port_range is not None
69
+ and cfg.oauth_client_id is not None
70
+ and cfg.oauth_scopes is not None
71
+ ):
72
+ return DatabricksOAuthProvider(
73
+ cfg.hostname,
74
+ cfg.oauth_persistence,
75
+ cfg.oauth_redirect_port_range,
76
+ cfg.oauth_client_id,
77
+ cfg.oauth_scopes,
78
+ )
79
+ else:
80
+ raise RuntimeError("No valid authentication settings!")
68
81
 
69
82
 
70
83
  PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
@@ -1,13 +1,11 @@
1
1
  import base64
2
2
  import logging
3
3
  import urllib.parse
4
- from typing import Dict, Union
4
+ from typing import Dict, Union, Optional
5
5
 
6
6
  import six
7
7
  import thrift
8
8
 
9
- logger = logging.getLogger(__name__)
10
-
11
9
  import ssl
12
10
  import warnings
13
11
  from http.client import HTTPResponse
@@ -16,6 +14,9 @@ from io import BytesIO
16
14
  from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
17
15
  from urllib3.util import make_headers
18
16
  from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
17
+ from databricks.sql.types import SSLOptions
18
+
19
+ logger = logging.getLogger(__name__)
19
20
 
20
21
 
21
22
  class THttpClient(thrift.transport.THttpClient.THttpClient):
@@ -25,13 +26,12 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
25
26
  uri_or_host,
26
27
  port=None,
27
28
  path=None,
28
- cafile=None,
29
- cert_file=None,
30
- key_file=None,
31
- ssl_context=None,
29
+ ssl_options: Optional[SSLOptions] = None,
32
30
  max_connections: int = 1,
33
31
  retry_policy: Union[DatabricksRetryPolicy, int] = 0,
34
32
  ):
33
+ self._ssl_options = ssl_options
34
+
35
35
  if port is not None:
36
36
  warnings.warn(
37
37
  "Please use the THttpClient('http{s}://host:port/path') constructor",
@@ -48,13 +48,11 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
48
48
  self.scheme = parsed.scheme
49
49
  assert self.scheme in ("http", "https")
50
50
  if self.scheme == "https":
51
- self.certfile = cert_file
52
- self.keyfile = key_file
53
- self.context = (
54
- ssl.create_default_context(cafile=cafile)
55
- if (cafile and not ssl_context)
56
- else ssl_context
57
- )
51
+ if self._ssl_options is not None:
52
+ # TODO: Not sure if those options are used anywhere - need to double-check
53
+ self.certfile = self._ssl_options.tls_client_cert_file
54
+ self.keyfile = self._ssl_options.tls_client_cert_key_file
55
+ self.context = self._ssl_options.create_ssl_context()
58
56
  self.port = parsed.port
59
57
  self.host = parsed.hostname
60
58
  self.path = parsed.path
@@ -109,12 +107,23 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
109
107
  def open(self):
110
108
 
111
109
  # self.__pool replaces the self.__http used by the original THttpClient
110
+ _pool_kwargs = {"maxsize": self.max_connections}
111
+
112
112
  if self.scheme == "http":
113
113
  pool_class = HTTPConnectionPool
114
114
  elif self.scheme == "https":
115
115
  pool_class = HTTPSConnectionPool
116
-
117
- _pool_kwargs = {"maxsize": self.max_connections}
116
+ _pool_kwargs.update(
117
+ {
118
+ "cert_reqs": ssl.CERT_REQUIRED
119
+ if self._ssl_options.tls_verify
120
+ else ssl.CERT_NONE,
121
+ "ca_certs": self._ssl_options.tls_trusted_ca_file,
122
+ "cert_file": self._ssl_options.tls_client_cert_file,
123
+ "key_file": self._ssl_options.tls_client_cert_key_file,
124
+ "key_password": self._ssl_options.tls_client_cert_key_password,
125
+ }
126
+ )
118
127
 
119
128
  if self.using_proxy():
120
129
  proxy_manager = ProxyManager(
@@ -35,7 +35,7 @@ from databricks.sql.parameters.native import (
35
35
  )
36
36
 
37
37
 
38
- from databricks.sql.types import Row
38
+ from databricks.sql.types import Row, SSLOptions
39
39
  from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
40
40
  from databricks.sql.experimental.oauth_persistence import OAuthPersistence
41
41
 
@@ -96,7 +96,7 @@ class Connection:
96
96
  sanitise parameterized inputs to prevent SQL injection. The inline parameter approach is maintained for
97
97
  legacy purposes and will be deprecated in a future release. When this parameter is `True` you will see
98
98
  a warning log message. To suppress this log message, set `use_inline_params="silent"`.
99
- auth_type: `str`, optional
99
+ auth_type: `str`, optional (default is databricks-oauth if neither `access_token` nor `tls_client_cert_file` is set)
100
100
  `databricks-oauth` : to use Databricks OAuth with fine-grained permission scopes, set to `databricks-oauth`.
101
101
  `azure-oauth` : to use Microsoft Entra ID OAuth flow, set to `azure-oauth`.
102
102
 
@@ -178,8 +178,9 @@ class Connection:
178
178
  # _tls_trusted_ca_file
179
179
  # Set to the path of the file containing trusted CA certificates for server certificate
180
180
  # verification. If not provide, uses system truststore.
181
- # _tls_client_cert_file, _tls_client_cert_key_file
181
+ # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
182
182
  # Set client SSL certificate.
183
+ # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
183
184
  # _retry_stop_after_attempts_count
184
185
  # The maximum number of attempts during a request retry sequence (defaults to 24)
185
186
  # _socket_timeout
@@ -220,12 +221,25 @@ class Connection:
220
221
 
221
222
  base_headers = [("User-Agent", useragent_header)]
222
223
 
224
+ self._ssl_options = SSLOptions(
225
+ # Double negation is generally a bad thing, but we have to keep backward compatibility
226
+ tls_verify=not kwargs.get(
227
+ "_tls_no_verify", False
228
+ ), # by default - verify cert and host
229
+ tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
230
+ tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
231
+ tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
232
+ tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
233
+ tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
234
+ )
235
+
223
236
  self.thrift_backend = ThriftBackend(
224
237
  self.host,
225
238
  self.port,
226
239
  http_path,
227
240
  (http_headers or []) + base_headers,
228
241
  auth_provider,
242
+ ssl_options=self._ssl_options,
229
243
  _use_arrow_native_complex_types=_use_arrow_native_complex_types,
230
244
  **kwargs,
231
245
  )
@@ -1164,7 +1178,7 @@ class ResultSet:
1164
1178
  timestamp_as_object=True,
1165
1179
  )
1166
1180
 
1167
- res = df.to_numpy(na_value=None)
1181
+ res = df.to_numpy(na_value=None, dtype="object")
1168
1182
  return [ResultRow(*v) for v in res]
1169
1183
 
1170
1184
  @property
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
 
3
- from ssl import SSLContext
4
3
  from concurrent.futures import ThreadPoolExecutor, Future
5
4
  from typing import List, Union
6
5
 
@@ -9,6 +8,8 @@ from databricks.sql.cloudfetch.downloader import (
9
8
  DownloadableResultSettings,
10
9
  DownloadedFile,
11
10
  )
11
+ from databricks.sql.types import SSLOptions
12
+
12
13
  from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
13
14
 
14
15
  logger = logging.getLogger(__name__)
@@ -20,7 +21,7 @@ class ResultFileDownloadManager:
20
21
  links: List[TSparkArrowResultLink],
21
22
  max_download_threads: int,
22
23
  lz4_compressed: bool,
23
- ssl_context: SSLContext,
24
+ ssl_options: SSLOptions,
24
25
  ):
25
26
  self._pending_links: List[TSparkArrowResultLink] = []
26
27
  for link in links:
@@ -38,7 +39,7 @@ class ResultFileDownloadManager:
38
39
  self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
39
40
 
40
41
  self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
41
- self._ssl_context = ssl_context
42
+ self._ssl_options = ssl_options
42
43
 
43
44
  def get_next_downloaded_file(
44
45
  self, next_row_offset: int
@@ -95,7 +96,7 @@ class ResultFileDownloadManager:
95
96
  handler = ResultSetDownloadHandler(
96
97
  settings=self._downloadable_result_settings,
97
98
  link=link,
98
- ssl_context=self._ssl_context,
99
+ ssl_options=self._ssl_options,
99
100
  )
100
101
  task = self._thread_pool.submit(handler.run)
101
102
  self._download_tasks.append(task)
@@ -3,13 +3,12 @@ from dataclasses import dataclass
3
3
 
4
4
  import requests
5
5
  from requests.adapters import HTTPAdapter, Retry
6
- from ssl import SSLContext, CERT_NONE
7
6
  import lz4.frame
8
7
  import time
9
8
 
10
9
  from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
11
-
12
10
  from databricks.sql.exc import Error
11
+ from databricks.sql.types import SSLOptions
13
12
 
14
13
  logger = logging.getLogger(__name__)
15
14
 
@@ -66,11 +65,11 @@ class ResultSetDownloadHandler:
66
65
  self,
67
66
  settings: DownloadableResultSettings,
68
67
  link: TSparkArrowResultLink,
69
- ssl_context: SSLContext,
68
+ ssl_options: SSLOptions,
70
69
  ):
71
70
  self.settings = settings
72
71
  self.link = link
73
- self._ssl_context = ssl_context
72
+ self._ssl_options = ssl_options
74
73
 
75
74
  def run(self) -> DownloadedFile:
76
75
  """
@@ -95,14 +94,13 @@ class ResultSetDownloadHandler:
95
94
  session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
96
95
  session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
97
96
 
98
- ssl_verify = self._ssl_context.verify_mode != CERT_NONE
99
-
100
97
  try:
101
98
  # Get the file via HTTP request
102
99
  response = session.get(
103
100
  self.link.fileLink,
104
101
  timeout=self.settings.download_timeout,
105
- verify=ssl_verify,
102
+ verify=self._ssl_options.tls_verify,
103
+ # TODO: Pass cert from `self._ssl_options`
106
104
  )
107
105
  response.raise_for_status()
108
106
 
@@ -5,7 +5,6 @@ import math
5
5
  import time
6
6
  import uuid
7
7
  import threading
8
- from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
9
8
  from typing import List, Union
10
9
 
11
10
  import pyarrow
@@ -36,6 +35,7 @@ from databricks.sql.utils import (
36
35
  convert_decimals_in_arrow_table,
37
36
  convert_column_based_set_to_arrow_table,
38
37
  )
38
+ from databricks.sql.types import SSLOptions
39
39
 
40
40
  logger = logging.getLogger(__name__)
41
41
 
@@ -85,6 +85,7 @@ class ThriftBackend:
85
85
  http_path: str,
86
86
  http_headers,
87
87
  auth_provider: AuthProvider,
88
+ ssl_options: SSLOptions,
88
89
  staging_allowed_local_path: Union[None, str, List[str]] = None,
89
90
  **kwargs,
90
91
  ):
@@ -93,16 +94,6 @@ class ThriftBackend:
93
94
  # Tag to add to User-Agent header. For use by partners.
94
95
  # _username, _password
95
96
  # Username and password Basic authentication (no official support)
96
- # _tls_no_verify
97
- # Set to True (Boolean) to completely disable SSL verification.
98
- # _tls_verify_hostname
99
- # Set to False (Boolean) to disable SSL hostname verification, but check certificate.
100
- # _tls_trusted_ca_file
101
- # Set to the path of the file containing trusted CA certificates for server certificate
102
- # verification. If not provide, uses system truststore.
103
- # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
104
- # Set client SSL certificate.
105
- # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
106
97
  # _connection_uri
107
98
  # Overrides server_hostname and http_path.
108
99
  # RETRY/ATTEMPT POLICY
@@ -162,29 +153,7 @@ class ThriftBackend:
162
153
  # Cloud fetch
163
154
  self.max_download_threads = kwargs.get("max_download_threads", 10)
164
155
 
165
- # Configure tls context
166
- ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
167
- if kwargs.get("_tls_no_verify") is True:
168
- ssl_context.check_hostname = False
169
- ssl_context.verify_mode = CERT_NONE
170
- elif kwargs.get("_tls_verify_hostname") is False:
171
- ssl_context.check_hostname = False
172
- ssl_context.verify_mode = CERT_REQUIRED
173
- else:
174
- ssl_context.check_hostname = True
175
- ssl_context.verify_mode = CERT_REQUIRED
176
-
177
- tls_client_cert_file = kwargs.get("_tls_client_cert_file")
178
- tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
179
- tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
180
- if tls_client_cert_file:
181
- ssl_context.load_cert_chain(
182
- certfile=tls_client_cert_file,
183
- keyfile=tls_client_cert_key_file,
184
- password=tls_client_cert_key_password,
185
- )
186
-
187
- self._ssl_context = ssl_context
156
+ self._ssl_options = ssl_options
188
157
 
189
158
  self._auth_provider = auth_provider
190
159
 
@@ -225,7 +194,7 @@ class ThriftBackend:
225
194
  self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
226
195
  auth_provider=self._auth_provider,
227
196
  uri_or_host=uri,
228
- ssl_context=self._ssl_context,
197
+ ssl_options=self._ssl_options,
229
198
  **additional_transport_args, # type: ignore
230
199
  )
231
200
 
@@ -776,7 +745,7 @@ class ThriftBackend:
776
745
  max_download_threads=self.max_download_threads,
777
746
  lz4_compressed=lz4_compressed,
778
747
  description=description,
779
- ssl_context=self._ssl_context,
748
+ ssl_options=self._ssl_options,
780
749
  )
781
750
  else:
782
751
  arrow_queue_opt = None
@@ -1008,7 +977,7 @@ class ThriftBackend:
1008
977
  max_download_threads=self.max_download_threads,
1009
978
  lz4_compressed=lz4_compressed,
1010
979
  description=description,
1011
- ssl_context=self._ssl_context,
980
+ ssl_options=self._ssl_options,
1012
981
  )
1013
982
 
1014
983
  return queue, resp.hasMoreRows
@@ -19,6 +19,54 @@
19
19
  from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
20
20
  import datetime
21
21
  import decimal
22
+ from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
23
+
24
+
25
+ class SSLOptions:
26
+ tls_verify: bool
27
+ tls_verify_hostname: bool
28
+ tls_trusted_ca_file: Optional[str]
29
+ tls_client_cert_file: Optional[str]
30
+ tls_client_cert_key_file: Optional[str]
31
+ tls_client_cert_key_password: Optional[str]
32
+
33
+ def __init__(
34
+ self,
35
+ tls_verify: bool = True,
36
+ tls_verify_hostname: bool = True,
37
+ tls_trusted_ca_file: Optional[str] = None,
38
+ tls_client_cert_file: Optional[str] = None,
39
+ tls_client_cert_key_file: Optional[str] = None,
40
+ tls_client_cert_key_password: Optional[str] = None,
41
+ ):
42
+ self.tls_verify = tls_verify
43
+ self.tls_verify_hostname = tls_verify_hostname
44
+ self.tls_trusted_ca_file = tls_trusted_ca_file
45
+ self.tls_client_cert_file = tls_client_cert_file
46
+ self.tls_client_cert_key_file = tls_client_cert_key_file
47
+ self.tls_client_cert_key_password = tls_client_cert_key_password
48
+
49
+ def create_ssl_context(self) -> SSLContext:
50
+ ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
51
+
52
+ if self.tls_verify is False:
53
+ ssl_context.check_hostname = False
54
+ ssl_context.verify_mode = CERT_NONE
55
+ elif self.tls_verify_hostname is False:
56
+ ssl_context.check_hostname = False
57
+ ssl_context.verify_mode = CERT_REQUIRED
58
+ else:
59
+ ssl_context.check_hostname = True
60
+ ssl_context.verify_mode = CERT_REQUIRED
61
+
62
+ if self.tls_client_cert_file:
63
+ ssl_context.load_cert_chain(
64
+ certfile=self.tls_client_cert_file,
65
+ keyfile=self.tls_client_cert_key_file,
66
+ password=self.tls_client_cert_key_password,
67
+ )
68
+
69
+ return ssl_context
22
70
 
23
71
 
24
72
  class Row(tuple):
@@ -9,7 +9,6 @@ from decimal import Decimal
9
9
  from enum import Enum
10
10
  from typing import Any, Dict, List, Optional, Union
11
11
  import re
12
- from ssl import SSLContext
13
12
 
14
13
  import lz4.frame
15
14
  import pyarrow
@@ -21,13 +20,14 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
21
20
  TSparkArrowResultLink,
22
21
  TSparkRowSetType,
23
22
  )
23
+ from databricks.sql.types import SSLOptions
24
24
 
25
25
  from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
26
26
 
27
- BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
28
-
29
27
  import logging
30
28
 
29
+ BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
30
+
31
31
  logger = logging.getLogger(__name__)
32
32
 
33
33
 
@@ -48,7 +48,7 @@ class ResultSetQueueFactory(ABC):
48
48
  t_row_set: TRowSet,
49
49
  arrow_schema_bytes: bytes,
50
50
  max_download_threads: int,
51
- ssl_context: SSLContext,
51
+ ssl_options: SSLOptions,
52
52
  lz4_compressed: bool = True,
53
53
  description: Optional[List[List[Any]]] = None,
54
54
  ) -> ResultSetQueue:
@@ -62,7 +62,7 @@ class ResultSetQueueFactory(ABC):
62
62
  lz4_compressed (bool): Whether result data has been lz4 compressed.
63
63
  description (List[List[Any]]): Hive table schema description.
64
64
  max_download_threads (int): Maximum number of downloader thread pool threads.
65
- ssl_context (SSLContext): SSLContext object for CloudFetchQueue
65
+ ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
66
66
 
67
67
  Returns:
68
68
  ResultSetQueue
@@ -91,7 +91,7 @@ class ResultSetQueueFactory(ABC):
91
91
  lz4_compressed=lz4_compressed,
92
92
  description=description,
93
93
  max_download_threads=max_download_threads,
94
- ssl_context=ssl_context,
94
+ ssl_options=ssl_options,
95
95
  )
96
96
  else:
97
97
  raise AssertionError("Row set type is not valid")
@@ -137,7 +137,7 @@ class CloudFetchQueue(ResultSetQueue):
137
137
  self,
138
138
  schema_bytes,
139
139
  max_download_threads: int,
140
- ssl_context: SSLContext,
140
+ ssl_options: SSLOptions,
141
141
  start_row_offset: int = 0,
142
142
  result_links: Optional[List[TSparkArrowResultLink]] = None,
143
143
  lz4_compressed: bool = True,
@@ -160,7 +160,7 @@ class CloudFetchQueue(ResultSetQueue):
160
160
  self.result_links = result_links
161
161
  self.lz4_compressed = lz4_compressed
162
162
  self.description = description
163
- self._ssl_context = ssl_context
163
+ self._ssl_options = ssl_options
164
164
 
165
165
  logger.debug(
166
166
  "Initialize CloudFetch loader, row set start offset: {}, file list:".format(
@@ -178,7 +178,7 @@ class CloudFetchQueue(ResultSetQueue):
178
178
  links=result_links or [],
179
179
  max_download_threads=self.max_download_threads,
180
180
  lz4_compressed=self.lz4_compressed,
181
- ssl_context=self._ssl_context,
181
+ ssl_options=self._ssl_options,
182
182
  )
183
183
 
184
184
  self.table = self._create_next_table()