databricks-sql-connector 4.0.0b2__tar.gz → 4.0.0b4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/CHANGELOG.md +20 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/PKG-INFO +21 -8
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/README.md +20 -3
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/pyproject.toml +2 -11
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/__init__.py +1 -1
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/thrift_http_client.py +25 -16
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/client.py +111 -9
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/cloudfetch/download_manager.py +5 -4
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/cloudfetch/downloader.py +6 -7
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_backend.py +21 -54
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/types.py +48 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/utils.py +154 -18
- databricks_sql_connector-4.0.0b2/src/databricks/sqlalchemy/__init__.py +0 -6
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/LICENSE +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/auth.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/authenticators.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/endpoint.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/oauth.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/auth/retry.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/exc.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/experimental/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/parameters/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/parameters/native.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/parameters/py.typed +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/py.typed +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
- {databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/thrift_api/__init__.py +0 -0
|
@@ -1,5 +1,25 @@
|
|
|
1
1
|
# Release History
|
|
2
2
|
|
|
3
|
+
# 4.0.0
|
|
4
|
+
|
|
5
|
+
- Split the connector into two separate packages: `databricks-sql-connector` and `databricks-sqlalchemy`. The `databricks-sql-connector` package contains the core functionality of the connector, while the `databricks-sqlalchemy` package contains the SQLAlchemy dialect for the connector.
|
|
6
|
+
- Pyarrow dependency is now optional in `databricks-sql-connector`. Users needing arrow are supposed to explicitly install pyarrow
|
|
7
|
+
|
|
8
|
+
# 3.6.0 (2024-10-25)
|
|
9
|
+
|
|
10
|
+
- Support encryption headers in the cloud fetch request (https://github.com/databricks/databricks-sql-python/pull/460 by @jackyhu-db)
|
|
11
|
+
|
|
12
|
+
# 3.5.0 (2024-10-18)
|
|
13
|
+
|
|
14
|
+
- Create a non pyarrow flow to handle small results for the column set (databricks/databricks-sql-python#440 by @jprakash-db)
|
|
15
|
+
- Fix: On non-retryable error, ensure PySQL includes useful information in error (databricks/databricks-sql-python#447 by @shivam2680)
|
|
16
|
+
|
|
17
|
+
# 3.4.0 (2024-08-27)
|
|
18
|
+
|
|
19
|
+
- Unpin pandas to support v2.2.2 (databricks/databricks-sql-python#416 by @kfollesdal)
|
|
20
|
+
- Make OAuth as the default authenticator if no authentication setting is provided (databricks/databricks-sql-python#419 by @jackyhu-db)
|
|
21
|
+
- Fix (regression): use SSL options with HTTPS connection pool (databricks/databricks-sql-python#425 by @kravets-levko)
|
|
22
|
+
|
|
3
23
|
# 3.3.0 (2024-07-18)
|
|
4
24
|
|
|
5
25
|
- Don't retry requests that fail with HTTP code 401 (databricks/databricks-sql-python#408 by @Hodnebo)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: databricks-sql-connector
|
|
3
|
-
Version: 4.0.
|
|
3
|
+
Version: 4.0.0b4
|
|
4
4
|
Summary: Databricks SQL Connector for Python
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Databricks
|
|
@@ -13,11 +13,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
-
Provides-Extra: alembic
|
|
17
|
-
Provides-Extra: databricks-sqlalchemy
|
|
18
16
|
Provides-Extra: pyarrow
|
|
19
|
-
Requires-Dist: alembic (>=1.0.11,<2.0.0) ; extra == "alembic"
|
|
20
|
-
Requires-Dist: databricks-sqlalchemy (>=2.0.0) ; extra == "databricks-sqlalchemy" or extra == "alembic"
|
|
21
17
|
Requires-Dist: lz4 (>=4.0.2,<5.0.0)
|
|
22
18
|
Requires-Dist: numpy (>=1.16.6,<2.0.0) ; python_version >= "3.8" and python_version < "3.11"
|
|
23
19
|
Requires-Dist: numpy (>=1.23.4,<2.0.0) ; python_version >= "3.11"
|
|
@@ -37,9 +33,9 @@ Description-Content-Type: text/markdown
|
|
|
37
33
|
[](https://pypi.org/project/databricks-sql-connector/)
|
|
38
34
|
[](https://pepy.tech/project/databricks-sql-connector)
|
|
39
35
|
|
|
40
|
-
The Databricks SQL Connector for Python allows you to develop Python applications that connect to Databricks clusters and SQL warehouses. It is a Thrift-based client with no dependencies on ODBC or JDBC. It conforms to the [Python DB API 2.0 specification](https://www.python.org/dev/peps/pep-0249/)
|
|
36
|
+
The Databricks SQL Connector for Python allows you to develop Python applications that connect to Databricks clusters and SQL warehouses. It is a Thrift-based client with no dependencies on ODBC or JDBC. It conforms to the [Python DB API 2.0 specification](https://www.python.org/dev/peps/pep-0249/).
|
|
41
37
|
|
|
42
|
-
This connector uses Arrow as the data-exchange format, and supports APIs to directly fetch Arrow tables. Arrow tables are wrapped in the `ArrowQueue` class to provide a natural API to get several rows at a time.
|
|
38
|
+
This connector uses Arrow as the data-exchange format, and supports APIs (e.g. `fetchmany_arrow`) to directly fetch Arrow tables. Arrow tables are wrapped in the `ArrowQueue` class to provide a natural API to get several rows at a time. [PyArrow](https://arrow.apache.org/docs/python/index.html) is required to enable this and use these APIs, you can install it via `pip install pyarrow` or `pip install databricks-sql-connector[pyarrow]`.
|
|
43
39
|
|
|
44
40
|
You are welcome to file an issue here for general use cases. You can also contact Databricks Support [here](help.databricks.com).
|
|
45
41
|
|
|
@@ -56,7 +52,12 @@ For the latest documentation, see
|
|
|
56
52
|
|
|
57
53
|
## Quickstart
|
|
58
54
|
|
|
59
|
-
|
|
55
|
+
### Installing the core library
|
|
56
|
+
Install using `pip install databricks-sql-connector`
|
|
57
|
+
|
|
58
|
+
### Installing the core library with PyArrow
|
|
59
|
+
Install using `pip install databricks-sql-connector[pyarrow]`
|
|
60
|
+
|
|
60
61
|
|
|
61
62
|
```bash
|
|
62
63
|
export DATABRICKS_HOST=********.databricks.com
|
|
@@ -94,6 +95,18 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789
|
|
|
94
95
|
> to authenticate the target Databricks user account and needs to open the browser for authentication. So it
|
|
95
96
|
> can only run on the user's machine.
|
|
96
97
|
|
|
98
|
+
## SQLAlchemy
|
|
99
|
+
Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`.
|
|
100
|
+
|
|
101
|
+
- Github repository [databricks-sqlalchemy github](https://github.com/databricks/databricks-sqlalchemy)
|
|
102
|
+
- PyPI [databricks-sqlalchemy pypi](https://pypi.org/project/databricks-sqlalchemy/)
|
|
103
|
+
|
|
104
|
+
### Quick SQLAlchemy guide
|
|
105
|
+
Users can now choose between using the SQLAlchemy v1 or SQLAlchemy v2 dialects with the connector core
|
|
106
|
+
|
|
107
|
+
- Install the latest SQLAlchemy v1 using `pip install databricks-sqlalchemy~=1.0`
|
|
108
|
+
- Install SQLAlchemy v2 using `pip install databricks-sqlalchemy`
|
|
109
|
+
|
|
97
110
|
|
|
98
111
|
## Contributing
|
|
99
112
|
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
[](https://pypi.org/project/databricks-sql-connector/)
|
|
4
4
|
[](https://pepy.tech/project/databricks-sql-connector)
|
|
5
5
|
|
|
6
|
-
The Databricks SQL Connector for Python allows you to develop Python applications that connect to Databricks clusters and SQL warehouses. It is a Thrift-based client with no dependencies on ODBC or JDBC. It conforms to the [Python DB API 2.0 specification](https://www.python.org/dev/peps/pep-0249/)
|
|
6
|
+
The Databricks SQL Connector for Python allows you to develop Python applications that connect to Databricks clusters and SQL warehouses. It is a Thrift-based client with no dependencies on ODBC or JDBC. It conforms to the [Python DB API 2.0 specification](https://www.python.org/dev/peps/pep-0249/).
|
|
7
7
|
|
|
8
|
-
This connector uses Arrow as the data-exchange format, and supports APIs to directly fetch Arrow tables. Arrow tables are wrapped in the `ArrowQueue` class to provide a natural API to get several rows at a time.
|
|
8
|
+
This connector uses Arrow as the data-exchange format, and supports APIs (e.g. `fetchmany_arrow`) to directly fetch Arrow tables. Arrow tables are wrapped in the `ArrowQueue` class to provide a natural API to get several rows at a time. [PyArrow](https://arrow.apache.org/docs/python/index.html) is required to enable this and use these APIs, you can install it via `pip install pyarrow` or `pip install databricks-sql-connector[pyarrow]`.
|
|
9
9
|
|
|
10
10
|
You are welcome to file an issue here for general use cases. You can also contact Databricks Support [here](help.databricks.com).
|
|
11
11
|
|
|
@@ -22,7 +22,12 @@ For the latest documentation, see
|
|
|
22
22
|
|
|
23
23
|
## Quickstart
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
### Installing the core library
|
|
26
|
+
Install using `pip install databricks-sql-connector`
|
|
27
|
+
|
|
28
|
+
### Installing the core library with PyArrow
|
|
29
|
+
Install using `pip install databricks-sql-connector[pyarrow]`
|
|
30
|
+
|
|
26
31
|
|
|
27
32
|
```bash
|
|
28
33
|
export DATABRICKS_HOST=********.databricks.com
|
|
@@ -60,6 +65,18 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789
|
|
|
60
65
|
> to authenticate the target Databricks user account and needs to open the browser for authentication. So it
|
|
61
66
|
> can only run on the user's machine.
|
|
62
67
|
|
|
68
|
+
## SQLAlchemy
|
|
69
|
+
Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`.
|
|
70
|
+
|
|
71
|
+
- Github repository [databricks-sqlalchemy github](https://github.com/databricks/databricks-sqlalchemy)
|
|
72
|
+
- PyPI [databricks-sqlalchemy pypi](https://pypi.org/project/databricks-sqlalchemy/)
|
|
73
|
+
|
|
74
|
+
### Quick SQLAlchemy guide
|
|
75
|
+
Users can now choose between using the SQLAlchemy v1 or SQLAlchemy v2 dialects with the connector core
|
|
76
|
+
|
|
77
|
+
- Install the latest SQLAlchemy v1 using `pip install databricks-sqlalchemy~=1.0`
|
|
78
|
+
- Install SQLAlchemy v2 using `pip install databricks-sqlalchemy`
|
|
79
|
+
|
|
63
80
|
|
|
64
81
|
## Contributing
|
|
65
82
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "databricks-sql-connector"
|
|
3
|
-
version = "4.0.0.
|
|
3
|
+
version = "4.0.0.b4"
|
|
4
4
|
description = "Databricks SQL Connector for Python"
|
|
5
5
|
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -23,15 +23,9 @@ numpy = [
|
|
|
23
23
|
]
|
|
24
24
|
openpyxl = "^3.0.10"
|
|
25
25
|
urllib3 = ">=1.26"
|
|
26
|
-
|
|
27
|
-
databricks-sqlalchemy = { version = ">=2.0.0", optional = true }
|
|
28
26
|
pyarrow = { version = ">=14.0.1,<17", optional=true }
|
|
29
|
-
alembic = { version = "^1.0.11", optional = true }
|
|
30
|
-
|
|
31
27
|
|
|
32
28
|
[tool.poetry.extras]
|
|
33
|
-
databricks-sqlalchemy = ["databricks-sqlalchemy"]
|
|
34
|
-
alembic = ["databricks-sqlalchemy", "alembic"]
|
|
35
29
|
pyarrow = ["pyarrow"]
|
|
36
30
|
|
|
37
31
|
[tool.poetry.dev-dependencies]
|
|
@@ -45,9 +39,6 @@ pytest-dotenv = "^0.5.2"
|
|
|
45
39
|
"Homepage" = "https://github.com/databricks/databricks-sql-python"
|
|
46
40
|
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"
|
|
47
41
|
|
|
48
|
-
[tool.poetry.plugins."sqlalchemy.dialects"]
|
|
49
|
-
"databricks" = "databricks.sqlalchemy:DatabricksDialect"
|
|
50
|
-
|
|
51
42
|
[build-system]
|
|
52
43
|
requires = ["poetry-core>=1.0.0"]
|
|
53
44
|
build-backend = "poetry.core.masonry.api"
|
|
@@ -64,5 +55,5 @@ markers = {"reviewed" = "Test case has been reviewed by Databricks"}
|
|
|
64
55
|
minversion = "6.0"
|
|
65
56
|
log_cli = "false"
|
|
66
57
|
log_cli_level = "INFO"
|
|
67
|
-
testpaths = ["tests"
|
|
58
|
+
testpaths = ["tests"]
|
|
68
59
|
env_files = ["test.env"]
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import logging
|
|
3
3
|
import urllib.parse
|
|
4
|
-
from typing import Dict, Union
|
|
4
|
+
from typing import Dict, Union, Optional
|
|
5
5
|
|
|
6
6
|
import six
|
|
7
7
|
import thrift
|
|
8
8
|
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
9
|
import ssl
|
|
12
10
|
import warnings
|
|
13
11
|
from http.client import HTTPResponse
|
|
@@ -16,6 +14,9 @@ from io import BytesIO
|
|
|
16
14
|
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
|
|
17
15
|
from urllib3.util import make_headers
|
|
18
16
|
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
|
|
17
|
+
from databricks.sql.types import SSLOptions
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
@@ -25,13 +26,12 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
|
25
26
|
uri_or_host,
|
|
26
27
|
port=None,
|
|
27
28
|
path=None,
|
|
28
|
-
|
|
29
|
-
cert_file=None,
|
|
30
|
-
key_file=None,
|
|
31
|
-
ssl_context=None,
|
|
29
|
+
ssl_options: Optional[SSLOptions] = None,
|
|
32
30
|
max_connections: int = 1,
|
|
33
31
|
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
|
|
34
32
|
):
|
|
33
|
+
self._ssl_options = ssl_options
|
|
34
|
+
|
|
35
35
|
if port is not None:
|
|
36
36
|
warnings.warn(
|
|
37
37
|
"Please use the THttpClient('http{s}://host:port/path') constructor",
|
|
@@ -48,13 +48,11 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
|
48
48
|
self.scheme = parsed.scheme
|
|
49
49
|
assert self.scheme in ("http", "https")
|
|
50
50
|
if self.scheme == "https":
|
|
51
|
-
self.
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
else ssl_context
|
|
57
|
-
)
|
|
51
|
+
if self._ssl_options is not None:
|
|
52
|
+
# TODO: Not sure if those options are used anywhere - need to double-check
|
|
53
|
+
self.certfile = self._ssl_options.tls_client_cert_file
|
|
54
|
+
self.keyfile = self._ssl_options.tls_client_cert_key_file
|
|
55
|
+
self.context = self._ssl_options.create_ssl_context()
|
|
58
56
|
self.port = parsed.port
|
|
59
57
|
self.host = parsed.hostname
|
|
60
58
|
self.path = parsed.path
|
|
@@ -109,12 +107,23 @@ class THttpClient(thrift.transport.THttpClient.THttpClient):
|
|
|
109
107
|
def open(self):
|
|
110
108
|
|
|
111
109
|
# self.__pool replaces the self.__http used by the original THttpClient
|
|
110
|
+
_pool_kwargs = {"maxsize": self.max_connections}
|
|
111
|
+
|
|
112
112
|
if self.scheme == "http":
|
|
113
113
|
pool_class = HTTPConnectionPool
|
|
114
114
|
elif self.scheme == "https":
|
|
115
115
|
pool_class = HTTPSConnectionPool
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
_pool_kwargs.update(
|
|
117
|
+
{
|
|
118
|
+
"cert_reqs": ssl.CERT_REQUIRED
|
|
119
|
+
if self._ssl_options.tls_verify
|
|
120
|
+
else ssl.CERT_NONE,
|
|
121
|
+
"ca_certs": self._ssl_options.tls_trusted_ca_file,
|
|
122
|
+
"cert_file": self._ssl_options.tls_client_cert_file,
|
|
123
|
+
"key_file": self._ssl_options.tls_client_cert_key_file,
|
|
124
|
+
"key_password": self._ssl_options.tls_client_cert_key_password,
|
|
125
|
+
}
|
|
126
|
+
)
|
|
118
127
|
|
|
119
128
|
if self.using_proxy():
|
|
120
129
|
proxy_manager = ProxyManager(
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/client.py
RENAMED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
|
|
2
2
|
|
|
3
3
|
import pandas
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import pyarrow
|
|
7
|
+
except ImportError:
|
|
8
|
+
pyarrow = None
|
|
4
9
|
import requests
|
|
5
10
|
import json
|
|
6
11
|
import os
|
|
@@ -21,6 +26,8 @@ from databricks.sql.utils import (
|
|
|
21
26
|
ParamEscaper,
|
|
22
27
|
inject_parameters,
|
|
23
28
|
transform_paramstyle,
|
|
29
|
+
ColumnTable,
|
|
30
|
+
ColumnQueue,
|
|
24
31
|
)
|
|
25
32
|
from databricks.sql.parameters.native import (
|
|
26
33
|
DbsqlParameterBase,
|
|
@@ -34,7 +41,7 @@ from databricks.sql.parameters.native import (
|
|
|
34
41
|
)
|
|
35
42
|
|
|
36
43
|
|
|
37
|
-
from databricks.sql.types import Row
|
|
44
|
+
from databricks.sql.types import Row, SSLOptions
|
|
38
45
|
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
|
|
39
46
|
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
|
|
40
47
|
|
|
@@ -42,13 +49,16 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
|
|
|
42
49
|
TSparkParameter,
|
|
43
50
|
)
|
|
44
51
|
|
|
45
|
-
try:
|
|
46
|
-
import pyarrow
|
|
47
|
-
except ImportError:
|
|
48
|
-
pyarrow = None
|
|
49
52
|
|
|
50
53
|
logger = logging.getLogger(__name__)
|
|
51
54
|
|
|
55
|
+
if pyarrow is None:
|
|
56
|
+
logger.warning(
|
|
57
|
+
"[WARN] pyarrow is not installed by default since databricks-sql-connector 4.0.0,"
|
|
58
|
+
"any arrow specific api (e.g. fetchmany_arrow) and cloud fetch will be disabled."
|
|
59
|
+
"If you need these features, please run pip install pyarrow or pip install databricks-sql-connector[pyarrow] to install"
|
|
60
|
+
)
|
|
61
|
+
|
|
52
62
|
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
|
|
53
63
|
DEFAULT_ARRAY_SIZE = 100000
|
|
54
64
|
|
|
@@ -181,8 +191,9 @@ class Connection:
|
|
|
181
191
|
# _tls_trusted_ca_file
|
|
182
192
|
# Set to the path of the file containing trusted CA certificates for server certificate
|
|
183
193
|
# verification. If not provide, uses system truststore.
|
|
184
|
-
# _tls_client_cert_file, _tls_client_cert_key_file
|
|
194
|
+
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
|
|
185
195
|
# Set client SSL certificate.
|
|
196
|
+
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
|
|
186
197
|
# _retry_stop_after_attempts_count
|
|
187
198
|
# The maximum number of attempts during a request retry sequence (defaults to 24)
|
|
188
199
|
# _socket_timeout
|
|
@@ -223,12 +234,25 @@ class Connection:
|
|
|
223
234
|
|
|
224
235
|
base_headers = [("User-Agent", useragent_header)]
|
|
225
236
|
|
|
237
|
+
self._ssl_options = SSLOptions(
|
|
238
|
+
# Double negation is generally a bad thing, but we have to keep backward compatibility
|
|
239
|
+
tls_verify=not kwargs.get(
|
|
240
|
+
"_tls_no_verify", False
|
|
241
|
+
), # by default - verify cert and host
|
|
242
|
+
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
|
|
243
|
+
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
|
|
244
|
+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
|
|
245
|
+
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
|
|
246
|
+
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
|
|
247
|
+
)
|
|
248
|
+
|
|
226
249
|
self.thrift_backend = ThriftBackend(
|
|
227
250
|
self.host,
|
|
228
251
|
self.port,
|
|
229
252
|
http_path,
|
|
230
253
|
(http_headers or []) + base_headers,
|
|
231
254
|
auth_provider,
|
|
255
|
+
ssl_options=self._ssl_options,
|
|
232
256
|
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
|
|
233
257
|
**kwargs,
|
|
234
258
|
)
|
|
@@ -1132,6 +1156,18 @@ class ResultSet:
|
|
|
1132
1156
|
self.results = results
|
|
1133
1157
|
self.has_more_rows = has_more_rows
|
|
1134
1158
|
|
|
1159
|
+
def _convert_columnar_table(self, table):
|
|
1160
|
+
column_names = [c[0] for c in self.description]
|
|
1161
|
+
ResultRow = Row(*column_names)
|
|
1162
|
+
result = []
|
|
1163
|
+
for row_index in range(table.num_rows):
|
|
1164
|
+
curr_row = []
|
|
1165
|
+
for col_index in range(table.num_columns):
|
|
1166
|
+
curr_row.append(table.get_item(col_index, row_index))
|
|
1167
|
+
result.append(ResultRow(*curr_row))
|
|
1168
|
+
|
|
1169
|
+
return result
|
|
1170
|
+
|
|
1135
1171
|
def _convert_arrow_table(self, table):
|
|
1136
1172
|
column_names = [c[0] for c in self.description]
|
|
1137
1173
|
ResultRow = Row(*column_names)
|
|
@@ -1199,6 +1235,48 @@ class ResultSet:
|
|
|
1199
1235
|
|
|
1200
1236
|
return results
|
|
1201
1237
|
|
|
1238
|
+
def merge_columnar(self, result1, result2):
|
|
1239
|
+
"""
|
|
1240
|
+
Function to merge / combining the columnar results into a single result
|
|
1241
|
+
:param result1:
|
|
1242
|
+
:param result2:
|
|
1243
|
+
:return:
|
|
1244
|
+
"""
|
|
1245
|
+
|
|
1246
|
+
if result1.column_names != result2.column_names:
|
|
1247
|
+
raise ValueError("The columns in the results don't match")
|
|
1248
|
+
|
|
1249
|
+
merged_result = [
|
|
1250
|
+
result1.column_table[i] + result2.column_table[i]
|
|
1251
|
+
for i in range(result1.num_columns)
|
|
1252
|
+
]
|
|
1253
|
+
return ColumnTable(merged_result, result1.column_names)
|
|
1254
|
+
|
|
1255
|
+
def fetchmany_columnar(self, size: int):
|
|
1256
|
+
"""
|
|
1257
|
+
Fetch the next set of rows of a query result, returning a Columnar Table.
|
|
1258
|
+
An empty sequence is returned when no more rows are available.
|
|
1259
|
+
"""
|
|
1260
|
+
if size < 0:
|
|
1261
|
+
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
|
|
1262
|
+
|
|
1263
|
+
results = self.results.next_n_rows(size)
|
|
1264
|
+
n_remaining_rows = size - results.num_rows
|
|
1265
|
+
self._next_row_index += results.num_rows
|
|
1266
|
+
|
|
1267
|
+
while (
|
|
1268
|
+
n_remaining_rows > 0
|
|
1269
|
+
and not self.has_been_closed_server_side
|
|
1270
|
+
and self.has_more_rows
|
|
1271
|
+
):
|
|
1272
|
+
self._fill_results_buffer()
|
|
1273
|
+
partial_results = self.results.next_n_rows(n_remaining_rows)
|
|
1274
|
+
results = self.merge_columnar(results, partial_results)
|
|
1275
|
+
n_remaining_rows -= partial_results.num_rows
|
|
1276
|
+
self._next_row_index += partial_results.num_rows
|
|
1277
|
+
|
|
1278
|
+
return results
|
|
1279
|
+
|
|
1202
1280
|
def fetchall_arrow(self) -> "pyarrow.Table":
|
|
1203
1281
|
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
|
|
1204
1282
|
results = self.results.remaining_rows()
|
|
@@ -1212,12 +1290,30 @@ class ResultSet:
|
|
|
1212
1290
|
|
|
1213
1291
|
return results
|
|
1214
1292
|
|
|
1293
|
+
def fetchall_columnar(self):
|
|
1294
|
+
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
|
|
1295
|
+
results = self.results.remaining_rows()
|
|
1296
|
+
self._next_row_index += results.num_rows
|
|
1297
|
+
|
|
1298
|
+
while not self.has_been_closed_server_side and self.has_more_rows:
|
|
1299
|
+
self._fill_results_buffer()
|
|
1300
|
+
partial_results = self.results.remaining_rows()
|
|
1301
|
+
results = self.merge_columnar(results, partial_results)
|
|
1302
|
+
self._next_row_index += partial_results.num_rows
|
|
1303
|
+
|
|
1304
|
+
return results
|
|
1305
|
+
|
|
1215
1306
|
def fetchone(self) -> Optional[Row]:
|
|
1216
1307
|
"""
|
|
1217
1308
|
Fetch the next row of a query result set, returning a single sequence,
|
|
1218
1309
|
or None when no more data is available.
|
|
1219
1310
|
"""
|
|
1220
|
-
|
|
1311
|
+
|
|
1312
|
+
if isinstance(self.results, ColumnQueue):
|
|
1313
|
+
res = self._convert_columnar_table(self.fetchmany_columnar(1))
|
|
1314
|
+
else:
|
|
1315
|
+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
|
|
1316
|
+
|
|
1221
1317
|
if len(res) > 0:
|
|
1222
1318
|
return res[0]
|
|
1223
1319
|
else:
|
|
@@ -1227,7 +1323,10 @@ class ResultSet:
|
|
|
1227
1323
|
"""
|
|
1228
1324
|
Fetch all (remaining) rows of a query result, returning them as a list of rows.
|
|
1229
1325
|
"""
|
|
1230
|
-
|
|
1326
|
+
if isinstance(self.results, ColumnQueue):
|
|
1327
|
+
return self._convert_columnar_table(self.fetchall_columnar())
|
|
1328
|
+
else:
|
|
1329
|
+
return self._convert_arrow_table(self.fetchall_arrow())
|
|
1231
1330
|
|
|
1232
1331
|
def fetchmany(self, size: int) -> List[Row]:
|
|
1233
1332
|
"""
|
|
@@ -1235,7 +1334,10 @@ class ResultSet:
|
|
|
1235
1334
|
|
|
1236
1335
|
An empty sequence is returned when no more rows are available.
|
|
1237
1336
|
"""
|
|
1238
|
-
|
|
1337
|
+
if isinstance(self.results, ColumnQueue):
|
|
1338
|
+
return self._convert_columnar_table(self.fetchmany_columnar(size))
|
|
1339
|
+
else:
|
|
1340
|
+
return self._convert_arrow_table(self.fetchmany_arrow(size))
|
|
1239
1341
|
|
|
1240
1342
|
def close(self) -> None:
|
|
1241
1343
|
"""
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
from ssl import SSLContext
|
|
4
3
|
from concurrent.futures import ThreadPoolExecutor, Future
|
|
5
4
|
from typing import List, Union
|
|
6
5
|
|
|
@@ -9,6 +8,8 @@ from databricks.sql.cloudfetch.downloader import (
|
|
|
9
8
|
DownloadableResultSettings,
|
|
10
9
|
DownloadedFile,
|
|
11
10
|
)
|
|
11
|
+
from databricks.sql.types import SSLOptions
|
|
12
|
+
|
|
12
13
|
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
|
|
13
14
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
@@ -20,7 +21,7 @@ class ResultFileDownloadManager:
|
|
|
20
21
|
links: List[TSparkArrowResultLink],
|
|
21
22
|
max_download_threads: int,
|
|
22
23
|
lz4_compressed: bool,
|
|
23
|
-
|
|
24
|
+
ssl_options: SSLOptions,
|
|
24
25
|
):
|
|
25
26
|
self._pending_links: List[TSparkArrowResultLink] = []
|
|
26
27
|
for link in links:
|
|
@@ -38,7 +39,7 @@ class ResultFileDownloadManager:
|
|
|
38
39
|
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
|
|
39
40
|
|
|
40
41
|
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
|
|
41
|
-
self.
|
|
42
|
+
self._ssl_options = ssl_options
|
|
42
43
|
|
|
43
44
|
def get_next_downloaded_file(
|
|
44
45
|
self, next_row_offset: int
|
|
@@ -95,7 +96,7 @@ class ResultFileDownloadManager:
|
|
|
95
96
|
handler = ResultSetDownloadHandler(
|
|
96
97
|
settings=self._downloadable_result_settings,
|
|
97
98
|
link=link,
|
|
98
|
-
|
|
99
|
+
ssl_options=self._ssl_options,
|
|
99
100
|
)
|
|
100
101
|
task = self._thread_pool.submit(handler.run)
|
|
101
102
|
self._download_tasks.append(task)
|
|
@@ -3,13 +3,12 @@ from dataclasses import dataclass
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
from requests.adapters import HTTPAdapter, Retry
|
|
6
|
-
from ssl import SSLContext, CERT_NONE
|
|
7
6
|
import lz4.frame
|
|
8
7
|
import time
|
|
9
8
|
|
|
10
9
|
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
|
|
11
|
-
|
|
12
10
|
from databricks.sql.exc import Error
|
|
11
|
+
from databricks.sql.types import SSLOptions
|
|
13
12
|
|
|
14
13
|
logger = logging.getLogger(__name__)
|
|
15
14
|
|
|
@@ -66,11 +65,11 @@ class ResultSetDownloadHandler:
|
|
|
66
65
|
self,
|
|
67
66
|
settings: DownloadableResultSettings,
|
|
68
67
|
link: TSparkArrowResultLink,
|
|
69
|
-
|
|
68
|
+
ssl_options: SSLOptions,
|
|
70
69
|
):
|
|
71
70
|
self.settings = settings
|
|
72
71
|
self.link = link
|
|
73
|
-
self.
|
|
72
|
+
self._ssl_options = ssl_options
|
|
74
73
|
|
|
75
74
|
def run(self) -> DownloadedFile:
|
|
76
75
|
"""
|
|
@@ -95,14 +94,14 @@ class ResultSetDownloadHandler:
|
|
|
95
94
|
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
|
|
96
95
|
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))
|
|
97
96
|
|
|
98
|
-
ssl_verify = self._ssl_context.verify_mode != CERT_NONE
|
|
99
|
-
|
|
100
97
|
try:
|
|
101
98
|
# Get the file via HTTP request
|
|
102
99
|
response = session.get(
|
|
103
100
|
self.link.fileLink,
|
|
104
101
|
timeout=self.settings.download_timeout,
|
|
105
|
-
verify=
|
|
102
|
+
verify=self._ssl_options.tls_verify,
|
|
103
|
+
headers=self.link.httpHeaders
|
|
104
|
+
# TODO: Pass cert from `self._ssl_options`
|
|
106
105
|
)
|
|
107
106
|
response.raise_for_status()
|
|
108
107
|
|
|
@@ -5,9 +5,12 @@ import math
|
|
|
5
5
|
import time
|
|
6
6
|
import uuid
|
|
7
7
|
import threading
|
|
8
|
-
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
|
|
9
8
|
from typing import List, Union
|
|
10
9
|
|
|
10
|
+
try:
|
|
11
|
+
import pyarrow
|
|
12
|
+
except ImportError:
|
|
13
|
+
pyarrow = None
|
|
11
14
|
import thrift.transport.THttpClient
|
|
12
15
|
import thrift.protocol.TBinaryProtocol
|
|
13
16
|
import thrift.transport.TSocket
|
|
@@ -35,11 +38,7 @@ from databricks.sql.utils import (
|
|
|
35
38
|
convert_decimals_in_arrow_table,
|
|
36
39
|
convert_column_based_set_to_arrow_table,
|
|
37
40
|
)
|
|
38
|
-
|
|
39
|
-
try:
|
|
40
|
-
import pyarrow
|
|
41
|
-
except ImportError:
|
|
42
|
-
pyarrow = None
|
|
41
|
+
from databricks.sql.types import SSLOptions
|
|
43
42
|
|
|
44
43
|
logger = logging.getLogger(__name__)
|
|
45
44
|
|
|
@@ -89,6 +88,7 @@ class ThriftBackend:
|
|
|
89
88
|
http_path: str,
|
|
90
89
|
http_headers,
|
|
91
90
|
auth_provider: AuthProvider,
|
|
91
|
+
ssl_options: SSLOptions,
|
|
92
92
|
staging_allowed_local_path: Union[None, str, List[str]] = None,
|
|
93
93
|
**kwargs,
|
|
94
94
|
):
|
|
@@ -97,16 +97,6 @@ class ThriftBackend:
|
|
|
97
97
|
# Tag to add to User-Agent header. For use by partners.
|
|
98
98
|
# _username, _password
|
|
99
99
|
# Username and password Basic authentication (no official support)
|
|
100
|
-
# _tls_no_verify
|
|
101
|
-
# Set to True (Boolean) to completely disable SSL verification.
|
|
102
|
-
# _tls_verify_hostname
|
|
103
|
-
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
|
|
104
|
-
# _tls_trusted_ca_file
|
|
105
|
-
# Set to the path of the file containing trusted CA certificates for server certificate
|
|
106
|
-
# verification. If not provide, uses system truststore.
|
|
107
|
-
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
|
|
108
|
-
# Set client SSL certificate.
|
|
109
|
-
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
|
|
110
100
|
# _connection_uri
|
|
111
101
|
# Overrides server_hostname and http_path.
|
|
112
102
|
# RETRY/ATTEMPT POLICY
|
|
@@ -166,29 +156,7 @@ class ThriftBackend:
|
|
|
166
156
|
# Cloud fetch
|
|
167
157
|
self.max_download_threads = kwargs.get("max_download_threads", 10)
|
|
168
158
|
|
|
169
|
-
|
|
170
|
-
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
|
|
171
|
-
if kwargs.get("_tls_no_verify") is True:
|
|
172
|
-
ssl_context.check_hostname = False
|
|
173
|
-
ssl_context.verify_mode = CERT_NONE
|
|
174
|
-
elif kwargs.get("_tls_verify_hostname") is False:
|
|
175
|
-
ssl_context.check_hostname = False
|
|
176
|
-
ssl_context.verify_mode = CERT_REQUIRED
|
|
177
|
-
else:
|
|
178
|
-
ssl_context.check_hostname = True
|
|
179
|
-
ssl_context.verify_mode = CERT_REQUIRED
|
|
180
|
-
|
|
181
|
-
tls_client_cert_file = kwargs.get("_tls_client_cert_file")
|
|
182
|
-
tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
|
|
183
|
-
tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
|
|
184
|
-
if tls_client_cert_file:
|
|
185
|
-
ssl_context.load_cert_chain(
|
|
186
|
-
certfile=tls_client_cert_file,
|
|
187
|
-
keyfile=tls_client_cert_key_file,
|
|
188
|
-
password=tls_client_cert_key_password,
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
self._ssl_context = ssl_context
|
|
159
|
+
self._ssl_options = ssl_options
|
|
192
160
|
|
|
193
161
|
self._auth_provider = auth_provider
|
|
194
162
|
|
|
@@ -229,7 +197,7 @@ class ThriftBackend:
|
|
|
229
197
|
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
|
|
230
198
|
auth_provider=self._auth_provider,
|
|
231
199
|
uri_or_host=uri,
|
|
232
|
-
|
|
200
|
+
ssl_options=self._ssl_options,
|
|
233
201
|
**additional_transport_args, # type: ignore
|
|
234
202
|
)
|
|
235
203
|
|
|
@@ -656,12 +624,6 @@ class ThriftBackend:
|
|
|
656
624
|
|
|
657
625
|
@staticmethod
|
|
658
626
|
def _hive_schema_to_arrow_schema(t_table_schema):
|
|
659
|
-
|
|
660
|
-
if pyarrow is None:
|
|
661
|
-
raise ImportError(
|
|
662
|
-
"pyarrow is required to convert Hive schema to Arrow schema"
|
|
663
|
-
)
|
|
664
|
-
|
|
665
627
|
def map_type(t_type_entry):
|
|
666
628
|
if t_type_entry.primitiveEntry:
|
|
667
629
|
return {
|
|
@@ -767,12 +729,17 @@ class ThriftBackend:
|
|
|
767
729
|
description = self._hive_schema_to_description(
|
|
768
730
|
t_result_set_metadata_resp.schema
|
|
769
731
|
)
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
732
|
+
|
|
733
|
+
if pyarrow:
|
|
734
|
+
schema_bytes = (
|
|
735
|
+
t_result_set_metadata_resp.arrowSchema
|
|
736
|
+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
|
|
737
|
+
.serialize()
|
|
738
|
+
.to_pybytes()
|
|
739
|
+
)
|
|
740
|
+
else:
|
|
741
|
+
schema_bytes = None
|
|
742
|
+
|
|
776
743
|
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
|
|
777
744
|
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
|
|
778
745
|
if direct_results and direct_results.resultSet:
|
|
@@ -786,7 +753,7 @@ class ThriftBackend:
|
|
|
786
753
|
max_download_threads=self.max_download_threads,
|
|
787
754
|
lz4_compressed=lz4_compressed,
|
|
788
755
|
description=description,
|
|
789
|
-
|
|
756
|
+
ssl_options=self._ssl_options,
|
|
790
757
|
)
|
|
791
758
|
else:
|
|
792
759
|
arrow_queue_opt = None
|
|
@@ -1018,7 +985,7 @@ class ThriftBackend:
|
|
|
1018
985
|
max_download_threads=self.max_download_threads,
|
|
1019
986
|
lz4_compressed=lz4_compressed,
|
|
1020
987
|
description=description,
|
|
1021
|
-
|
|
988
|
+
ssl_options=self._ssl_options,
|
|
1022
989
|
)
|
|
1023
990
|
|
|
1024
991
|
return queue, resp.hasMoreRows
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/types.py
RENAMED
|
@@ -19,6 +19,54 @@
|
|
|
19
19
|
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
|
|
20
20
|
import datetime
|
|
21
21
|
import decimal
|
|
22
|
+
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SSLOptions:
|
|
26
|
+
tls_verify: bool
|
|
27
|
+
tls_verify_hostname: bool
|
|
28
|
+
tls_trusted_ca_file: Optional[str]
|
|
29
|
+
tls_client_cert_file: Optional[str]
|
|
30
|
+
tls_client_cert_key_file: Optional[str]
|
|
31
|
+
tls_client_cert_key_password: Optional[str]
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
tls_verify: bool = True,
|
|
36
|
+
tls_verify_hostname: bool = True,
|
|
37
|
+
tls_trusted_ca_file: Optional[str] = None,
|
|
38
|
+
tls_client_cert_file: Optional[str] = None,
|
|
39
|
+
tls_client_cert_key_file: Optional[str] = None,
|
|
40
|
+
tls_client_cert_key_password: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
self.tls_verify = tls_verify
|
|
43
|
+
self.tls_verify_hostname = tls_verify_hostname
|
|
44
|
+
self.tls_trusted_ca_file = tls_trusted_ca_file
|
|
45
|
+
self.tls_client_cert_file = tls_client_cert_file
|
|
46
|
+
self.tls_client_cert_key_file = tls_client_cert_key_file
|
|
47
|
+
self.tls_client_cert_key_password = tls_client_cert_key_password
|
|
48
|
+
|
|
49
|
+
def create_ssl_context(self) -> SSLContext:
|
|
50
|
+
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)
|
|
51
|
+
|
|
52
|
+
if self.tls_verify is False:
|
|
53
|
+
ssl_context.check_hostname = False
|
|
54
|
+
ssl_context.verify_mode = CERT_NONE
|
|
55
|
+
elif self.tls_verify_hostname is False:
|
|
56
|
+
ssl_context.check_hostname = False
|
|
57
|
+
ssl_context.verify_mode = CERT_REQUIRED
|
|
58
|
+
else:
|
|
59
|
+
ssl_context.check_hostname = True
|
|
60
|
+
ssl_context.verify_mode = CERT_REQUIRED
|
|
61
|
+
|
|
62
|
+
if self.tls_client_cert_file:
|
|
63
|
+
ssl_context.load_cert_chain(
|
|
64
|
+
certfile=self.tls_client_cert_file,
|
|
65
|
+
keyfile=self.tls_client_cert_key_file,
|
|
66
|
+
password=self.tls_client_cert_key_password,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return ssl_context
|
|
22
70
|
|
|
23
71
|
|
|
24
72
|
class Row(tuple):
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/utils.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import pytz
|
|
3
4
|
import datetime
|
|
4
5
|
import decimal
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -9,10 +10,14 @@ from decimal import Decimal
|
|
|
9
10
|
from enum import Enum
|
|
10
11
|
from typing import Any, Dict, List, Optional, Union
|
|
11
12
|
import re
|
|
12
|
-
from ssl import SSLContext
|
|
13
13
|
|
|
14
14
|
import lz4.frame
|
|
15
15
|
|
|
16
|
+
try:
|
|
17
|
+
import pyarrow
|
|
18
|
+
except ImportError:
|
|
19
|
+
pyarrow = None
|
|
20
|
+
|
|
16
21
|
from databricks.sql import OperationalError, exc
|
|
17
22
|
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
|
|
18
23
|
from databricks.sql.thrift_api.TCLIService.ttypes import (
|
|
@@ -20,17 +25,14 @@ from databricks.sql.thrift_api.TCLIService.ttypes import (
|
|
|
20
25
|
TSparkArrowResultLink,
|
|
21
26
|
TSparkRowSetType,
|
|
22
27
|
)
|
|
28
|
+
from databricks.sql.types import SSLOptions
|
|
23
29
|
|
|
24
30
|
from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter
|
|
25
31
|
|
|
26
|
-
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
|
|
27
|
-
|
|
28
32
|
import logging
|
|
29
33
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
except ImportError:
|
|
33
|
-
pyarrow = None
|
|
34
|
+
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
|
|
35
|
+
DEFAULT_ERROR_CONTEXT = "Unknown error"
|
|
34
36
|
|
|
35
37
|
logger = logging.getLogger(__name__)
|
|
36
38
|
|
|
@@ -52,7 +54,7 @@ class ResultSetQueueFactory(ABC):
|
|
|
52
54
|
t_row_set: TRowSet,
|
|
53
55
|
arrow_schema_bytes: bytes,
|
|
54
56
|
max_download_threads: int,
|
|
55
|
-
|
|
57
|
+
ssl_options: SSLOptions,
|
|
56
58
|
lz4_compressed: bool = True,
|
|
57
59
|
description: Optional[List[List[Any]]] = None,
|
|
58
60
|
) -> ResultSetQueue:
|
|
@@ -66,7 +68,7 @@ class ResultSetQueueFactory(ABC):
|
|
|
66
68
|
lz4_compressed (bool): Whether result data has been lz4 compressed.
|
|
67
69
|
description (List[List[Any]]): Hive table schema description.
|
|
68
70
|
max_download_threads (int): Maximum number of downloader thread pool threads.
|
|
69
|
-
|
|
71
|
+
ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
|
|
70
72
|
|
|
71
73
|
Returns:
|
|
72
74
|
ResultSetQueue
|
|
@@ -80,13 +82,15 @@ class ResultSetQueueFactory(ABC):
|
|
|
80
82
|
)
|
|
81
83
|
return ArrowQueue(converted_arrow_table, n_valid_rows)
|
|
82
84
|
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
|
|
83
|
-
|
|
85
|
+
column_table, column_names = convert_column_based_set_to_column_table(
|
|
84
86
|
t_row_set.columns, description
|
|
85
87
|
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
|
|
89
|
+
converted_column_table = convert_to_assigned_datatypes_in_column_table(
|
|
90
|
+
column_table, description
|
|
88
91
|
)
|
|
89
|
-
|
|
92
|
+
|
|
93
|
+
return ColumnQueue(ColumnTable(converted_column_table, column_names))
|
|
90
94
|
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
|
|
91
95
|
return CloudFetchQueue(
|
|
92
96
|
schema_bytes=arrow_schema_bytes,
|
|
@@ -95,12 +99,65 @@ class ResultSetQueueFactory(ABC):
|
|
|
95
99
|
lz4_compressed=lz4_compressed,
|
|
96
100
|
description=description,
|
|
97
101
|
max_download_threads=max_download_threads,
|
|
98
|
-
|
|
102
|
+
ssl_options=ssl_options,
|
|
99
103
|
)
|
|
100
104
|
else:
|
|
101
105
|
raise AssertionError("Row set type is not valid")
|
|
102
106
|
|
|
103
107
|
|
|
108
|
+
class ColumnTable:
|
|
109
|
+
def __init__(self, column_table, column_names):
|
|
110
|
+
self.column_table = column_table
|
|
111
|
+
self.column_names = column_names
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def num_rows(self):
|
|
115
|
+
if len(self.column_table) == 0:
|
|
116
|
+
return 0
|
|
117
|
+
else:
|
|
118
|
+
return len(self.column_table[0])
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def num_columns(self):
|
|
122
|
+
return len(self.column_names)
|
|
123
|
+
|
|
124
|
+
def get_item(self, col_index, row_index):
|
|
125
|
+
return self.column_table[col_index][row_index]
|
|
126
|
+
|
|
127
|
+
def slice(self, curr_index, length):
|
|
128
|
+
sliced_column_table = [
|
|
129
|
+
column[curr_index : curr_index + length] for column in self.column_table
|
|
130
|
+
]
|
|
131
|
+
return ColumnTable(sliced_column_table, self.column_names)
|
|
132
|
+
|
|
133
|
+
def __eq__(self, other):
|
|
134
|
+
return (
|
|
135
|
+
self.column_table == other.column_table
|
|
136
|
+
and self.column_names == other.column_names
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ColumnQueue(ResultSetQueue):
|
|
141
|
+
def __init__(self, column_table: ColumnTable):
|
|
142
|
+
self.column_table = column_table
|
|
143
|
+
self.cur_row_index = 0
|
|
144
|
+
self.n_valid_rows = column_table.num_rows
|
|
145
|
+
|
|
146
|
+
def next_n_rows(self, num_rows):
|
|
147
|
+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
|
|
148
|
+
|
|
149
|
+
slice = self.column_table.slice(self.cur_row_index, length)
|
|
150
|
+
self.cur_row_index += slice.num_rows
|
|
151
|
+
return slice
|
|
152
|
+
|
|
153
|
+
def remaining_rows(self):
|
|
154
|
+
slice = self.column_table.slice(
|
|
155
|
+
self.cur_row_index, self.n_valid_rows - self.cur_row_index
|
|
156
|
+
)
|
|
157
|
+
self.cur_row_index += slice.num_rows
|
|
158
|
+
return slice
|
|
159
|
+
|
|
160
|
+
|
|
104
161
|
class ArrowQueue(ResultSetQueue):
|
|
105
162
|
def __init__(
|
|
106
163
|
self,
|
|
@@ -141,7 +198,7 @@ class CloudFetchQueue(ResultSetQueue):
|
|
|
141
198
|
self,
|
|
142
199
|
schema_bytes,
|
|
143
200
|
max_download_threads: int,
|
|
144
|
-
|
|
201
|
+
ssl_options: SSLOptions,
|
|
145
202
|
start_row_offset: int = 0,
|
|
146
203
|
result_links: Optional[List[TSparkArrowResultLink]] = None,
|
|
147
204
|
lz4_compressed: bool = True,
|
|
@@ -164,7 +221,7 @@ class CloudFetchQueue(ResultSetQueue):
|
|
|
164
221
|
self.result_links = result_links
|
|
165
222
|
self.lz4_compressed = lz4_compressed
|
|
166
223
|
self.description = description
|
|
167
|
-
self.
|
|
224
|
+
self._ssl_options = ssl_options
|
|
168
225
|
|
|
169
226
|
logger.debug(
|
|
170
227
|
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
|
|
@@ -182,7 +239,7 @@ class CloudFetchQueue(ResultSetQueue):
|
|
|
182
239
|
links=result_links or [],
|
|
183
240
|
max_download_threads=self.max_download_threads,
|
|
184
241
|
lz4_compressed=self.lz4_compressed,
|
|
185
|
-
|
|
242
|
+
ssl_options=self._ssl_options,
|
|
186
243
|
)
|
|
187
244
|
|
|
188
245
|
self.table = self._create_next_table()
|
|
@@ -361,7 +418,12 @@ class RequestErrorInfo(
|
|
|
361
418
|
user_friendly_error_message = "{}: {}".format(
|
|
362
419
|
user_friendly_error_message, self.error_message
|
|
363
420
|
)
|
|
364
|
-
|
|
421
|
+
try:
|
|
422
|
+
error_context = str(self.error)
|
|
423
|
+
except:
|
|
424
|
+
error_context = DEFAULT_ERROR_CONTEXT
|
|
425
|
+
|
|
426
|
+
return user_friendly_error_message + ". " + error_context
|
|
365
427
|
|
|
366
428
|
|
|
367
429
|
# Taken from PyHive
|
|
@@ -566,6 +628,37 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
|
|
|
566
628
|
return table
|
|
567
629
|
|
|
568
630
|
|
|
631
|
+
def convert_to_assigned_datatypes_in_column_table(column_table, description):
|
|
632
|
+
|
|
633
|
+
converted_column_table = []
|
|
634
|
+
for i, col in enumerate(column_table):
|
|
635
|
+
if description[i][1] == "decimal":
|
|
636
|
+
converted_column_table.append(
|
|
637
|
+
tuple(v if v is None else Decimal(v) for v in col)
|
|
638
|
+
)
|
|
639
|
+
elif description[i][1] == "date":
|
|
640
|
+
converted_column_table.append(
|
|
641
|
+
tuple(v if v is None else datetime.date.fromisoformat(v) for v in col)
|
|
642
|
+
)
|
|
643
|
+
elif description[i][1] == "timestamp":
|
|
644
|
+
converted_column_table.append(
|
|
645
|
+
tuple(
|
|
646
|
+
(
|
|
647
|
+
v
|
|
648
|
+
if v is None
|
|
649
|
+
else datetime.datetime.strptime(
|
|
650
|
+
v, "%Y-%m-%d %H:%M:%S.%f"
|
|
651
|
+
).replace(tzinfo=pytz.UTC)
|
|
652
|
+
)
|
|
653
|
+
for v in col
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
else:
|
|
657
|
+
converted_column_table.append(col)
|
|
658
|
+
|
|
659
|
+
return converted_column_table
|
|
660
|
+
|
|
661
|
+
|
|
569
662
|
def convert_column_based_set_to_arrow_table(columns, description):
|
|
570
663
|
arrow_table = pyarrow.Table.from_arrays(
|
|
571
664
|
[_convert_column_to_arrow_array(c) for c in columns],
|
|
@@ -577,6 +670,13 @@ def convert_column_based_set_to_arrow_table(columns, description):
|
|
|
577
670
|
return arrow_table, arrow_table.num_rows
|
|
578
671
|
|
|
579
672
|
|
|
673
|
+
def convert_column_based_set_to_column_table(columns, description):
|
|
674
|
+
column_names = [c[0] for c in description]
|
|
675
|
+
column_table = [_convert_column_to_list(c) for c in columns]
|
|
676
|
+
|
|
677
|
+
return column_table, column_names
|
|
678
|
+
|
|
679
|
+
|
|
580
680
|
def _convert_column_to_arrow_array(t_col):
|
|
581
681
|
"""
|
|
582
682
|
Return a pyarrow array from the values in a TColumn instance.
|
|
@@ -601,6 +701,26 @@ def _convert_column_to_arrow_array(t_col):
|
|
|
601
701
|
raise OperationalError("Empty TColumn instance {}".format(t_col))
|
|
602
702
|
|
|
603
703
|
|
|
704
|
+
def _convert_column_to_list(t_col):
|
|
705
|
+
SUPPORTED_FIELD_TYPES = (
|
|
706
|
+
"boolVal",
|
|
707
|
+
"byteVal",
|
|
708
|
+
"i16Val",
|
|
709
|
+
"i32Val",
|
|
710
|
+
"i64Val",
|
|
711
|
+
"doubleVal",
|
|
712
|
+
"stringVal",
|
|
713
|
+
"binaryVal",
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
for field in SUPPORTED_FIELD_TYPES:
|
|
717
|
+
wrapper = getattr(t_col, field)
|
|
718
|
+
if wrapper:
|
|
719
|
+
return _create_python_tuple(wrapper)
|
|
720
|
+
|
|
721
|
+
raise OperationalError("Empty TColumn instance {}".format(t_col))
|
|
722
|
+
|
|
723
|
+
|
|
604
724
|
def _create_arrow_array(t_col_value_wrapper, arrow_type):
|
|
605
725
|
result = t_col_value_wrapper.values
|
|
606
726
|
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
|
|
@@ -615,3 +735,19 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
|
|
|
615
735
|
result[i] = None
|
|
616
736
|
|
|
617
737
|
return pyarrow.array(result, type=arrow_type)
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def _create_python_tuple(t_col_value_wrapper):
|
|
741
|
+
result = t_col_value_wrapper.values
|
|
742
|
+
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
|
|
743
|
+
assert isinstance(nulls, bytes)
|
|
744
|
+
|
|
745
|
+
# The number of bits in nulls can be both larger or smaller than the number of
|
|
746
|
+
# elements in result, so take the minimum of both to iterate over.
|
|
747
|
+
length = min(len(result), len(nulls) * 8)
|
|
748
|
+
|
|
749
|
+
for i in range(length):
|
|
750
|
+
if nulls[i >> 3] & BIT_MASKS[i & 0x7]:
|
|
751
|
+
result[i] = None
|
|
752
|
+
|
|
753
|
+
return tuple(result)
|
|
File without changes
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/exc.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{databricks_sql_connector-4.0.0b2 → databricks_sql_connector-4.0.0b4}/src/databricks/sql/py.typed
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|