databricks-sql-connector 4.0.4__tar.gz → 4.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/CHANGELOG.md +3 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/PKG-INFO +4 -2
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/pyproject.toml +25 -3
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/__init__.py +1 -1
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/auth.py +27 -38
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/authenticators.py +96 -4
- databricks_sql_connector-4.0.6/src/databricks/sql/auth/common.py +127 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/oauth.py +116 -23
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/retry.py +9 -3
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/thrift_http_client.py +25 -24
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/databricks_client.py +347 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/backend.py +825 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/__init__.py +52 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/base.py +82 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/requests.py +133 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/responses.py +196 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/queue.py +391 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/result_set.py +266 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/constants.py +67 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/conversion.py +173 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/filters.py +289 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/http_client.py +294 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/normalize.py +50 -0
- {databricks_sql_connector-4.0.4/src/databricks/sql → databricks_sql_connector-4.0.6/src/databricks/sql/backend}/thrift_backend.py +395 -203
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/types.py +427 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/utils/__init__.py +3 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/backend/utils/guid_utils.py +23 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/client.py +283 -549
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/cloudfetch/download_manager.py +43 -9
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/cloudfetch/downloader.py +83 -58
- databricks_sql_connector-4.0.6/src/databricks/sql/common/feature_flag.py +187 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/common/http.py +40 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/common/http_utils.py +100 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/common/unified_http_client.py +309 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/exc.py +13 -2
- databricks_sql_connector-4.0.6/src/databricks/sql/result_set.py +439 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/session.py +195 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/latency_logger.py +216 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/endpoint_models.py +39 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/enums.py +44 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/event.py +162 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/frontend_logs.py +65 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/telemetry_client.py +561 -0
- databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/utils.py +69 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/types.py +1 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/utils.py +231 -71
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/LICENSE +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/README.md +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/__init__.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/__init__.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/endpoint.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/experimental/__init__.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/parameters/__init__.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/parameters/native.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/parameters/py.typed +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/py.typed +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
- {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/__init__.py +0 -0
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
# Release History
|
|
2
2
|
|
|
3
|
+
# 4.0.5 (2025-06-24)
|
|
4
|
+
- Fix: Reverted change in cursor close handling which led to errors impacting users (databricks/databricks-sql-python#613 by @madhav-db)
|
|
5
|
+
|
|
3
6
|
# 4.0.4 (2025-06-16)
|
|
4
7
|
|
|
5
8
|
- Update thrift client library after cleaning up unused fields and structs (databricks/databricks-sql-python#553 by @vikrantpuppala)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: databricks-sql-connector
|
|
3
|
-
Version: 4.0.
|
|
3
|
+
Version: 4.0.6
|
|
4
4
|
Summary: Databricks SQL Connector for Python
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Databricks
|
|
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
17
|
Provides-Extra: pyarrow
|
|
17
18
|
Requires-Dist: lz4 (>=4.0.2,<5.0.0)
|
|
18
19
|
Requires-Dist: oauthlib (>=3.1.0,<4.0.0)
|
|
@@ -21,6 +22,7 @@ Requires-Dist: pandas (>=1.2.5,<2.3.0) ; python_version >= "3.8" and python_vers
|
|
|
21
22
|
Requires-Dist: pandas (>=2.2.3,<2.3.0) ; python_version >= "3.13"
|
|
22
23
|
Requires-Dist: pyarrow (>=14.0.1) ; (python_version >= "3.8" and python_version < "3.13") and (extra == "pyarrow")
|
|
23
24
|
Requires-Dist: pyarrow (>=18.0.0) ; (python_version >= "3.13") and (extra == "pyarrow")
|
|
25
|
+
Requires-Dist: pyjwt (>=2.0.0,<3.0.0)
|
|
24
26
|
Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
|
|
25
27
|
Requires-Dist: requests (>=2.18.1,<3.0.0)
|
|
26
28
|
Requires-Dist: thrift (>=0.16.0,<0.21.0)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "databricks-sql-connector"
|
|
3
|
-
version = "4.0.
|
|
3
|
+
version = "4.0.6"
|
|
4
4
|
description = "Databricks SQL Connector for Python"
|
|
5
5
|
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -20,21 +20,25 @@ requests = "^2.18.1"
|
|
|
20
20
|
oauthlib = "^3.1.0"
|
|
21
21
|
openpyxl = "^3.0.10"
|
|
22
22
|
urllib3 = ">=1.26"
|
|
23
|
+
python-dateutil = "^2.8.0"
|
|
23
24
|
pyarrow = [
|
|
24
25
|
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
|
|
25
26
|
{ version = ">=18.0.0", python = ">=3.13", optional=true }
|
|
26
27
|
]
|
|
27
|
-
|
|
28
|
+
pyjwt = "^2.0.0"
|
|
29
|
+
requests-kerberos = {version = "^0.15.0", optional = true}
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
[tool.poetry.extras]
|
|
30
33
|
pyarrow = ["pyarrow"]
|
|
31
34
|
|
|
32
|
-
[tool.poetry.dev
|
|
35
|
+
[tool.poetry.group.dev.dependencies]
|
|
33
36
|
pytest = "^7.1.2"
|
|
34
37
|
mypy = "^1.10.1"
|
|
35
38
|
pylint = ">=2.12.0"
|
|
36
39
|
black = "^22.3.0"
|
|
37
40
|
pytest-dotenv = "^0.5.2"
|
|
41
|
+
pytest-cov = "^4.0.0"
|
|
38
42
|
numpy = [
|
|
39
43
|
{ version = ">=1.16.6", python = ">=3.8,<3.11" },
|
|
40
44
|
{ version = ">=1.23.4", python = ">=3.11" },
|
|
@@ -62,3 +66,21 @@ log_cli = "false"
|
|
|
62
66
|
log_cli_level = "INFO"
|
|
63
67
|
testpaths = ["tests"]
|
|
64
68
|
env_files = ["test.env"]
|
|
69
|
+
|
|
70
|
+
[tool.coverage.run]
|
|
71
|
+
source = ["src"]
|
|
72
|
+
branch = true
|
|
73
|
+
omit = [
|
|
74
|
+
"*/tests/*",
|
|
75
|
+
"*/test_*",
|
|
76
|
+
"*/__pycache__/*",
|
|
77
|
+
"*/thrift_api/*",
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
[tool.coverage.report]
|
|
81
|
+
precision = 2
|
|
82
|
+
show_missing = true
|
|
83
|
+
skip_covered = false
|
|
84
|
+
|
|
85
|
+
[tool.coverage.xml]
|
|
86
|
+
output = "coverage.xml"
|
{databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/auth.py
RENAMED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
1
|
from typing import Optional, List
|
|
3
2
|
|
|
4
3
|
from databricks.sql.auth.authenticators import (
|
|
@@ -6,46 +5,26 @@ from databricks.sql.auth.authenticators import (
|
|
|
6
5
|
AccessTokenAuthProvider,
|
|
7
6
|
ExternalAuthProvider,
|
|
8
7
|
DatabricksOAuthProvider,
|
|
8
|
+
AzureServicePrincipalCredentialProvider,
|
|
9
9
|
)
|
|
10
|
+
from databricks.sql.auth.common import AuthType, ClientContext
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
|
|
13
|
-
DATABRICKS_OAUTH = "databricks-oauth"
|
|
14
|
-
AZURE_OAUTH = "azure-oauth"
|
|
15
|
-
# other supported types (access_token) can be inferred
|
|
16
|
-
# we can add more types as needed later
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class ClientContext:
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
hostname: str,
|
|
23
|
-
access_token: Optional[str] = None,
|
|
24
|
-
auth_type: Optional[str] = None,
|
|
25
|
-
oauth_scopes: Optional[List[str]] = None,
|
|
26
|
-
oauth_client_id: Optional[str] = None,
|
|
27
|
-
oauth_redirect_port_range: Optional[List[int]] = None,
|
|
28
|
-
use_cert_as_auth: Optional[str] = None,
|
|
29
|
-
tls_client_cert_file: Optional[str] = None,
|
|
30
|
-
oauth_persistence=None,
|
|
31
|
-
credentials_provider=None,
|
|
32
|
-
):
|
|
33
|
-
self.hostname = hostname
|
|
34
|
-
self.access_token = access_token
|
|
35
|
-
self.auth_type = auth_type
|
|
36
|
-
self.oauth_scopes = oauth_scopes
|
|
37
|
-
self.oauth_client_id = oauth_client_id
|
|
38
|
-
self.oauth_redirect_port_range = oauth_redirect_port_range
|
|
39
|
-
self.use_cert_as_auth = use_cert_as_auth
|
|
40
|
-
self.tls_client_cert_file = tls_client_cert_file
|
|
41
|
-
self.oauth_persistence = oauth_persistence
|
|
42
|
-
self.credentials_provider = credentials_provider
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def get_auth_provider(cfg: ClientContext):
|
|
13
|
+
def get_auth_provider(cfg: ClientContext, http_client):
|
|
46
14
|
if cfg.credentials_provider:
|
|
47
15
|
return ExternalAuthProvider(cfg.credentials_provider)
|
|
48
|
-
|
|
16
|
+
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
|
|
17
|
+
return ExternalAuthProvider(
|
|
18
|
+
AzureServicePrincipalCredentialProvider(
|
|
19
|
+
cfg.hostname,
|
|
20
|
+
cfg.azure_client_id,
|
|
21
|
+
cfg.azure_client_secret,
|
|
22
|
+
http_client,
|
|
23
|
+
cfg.azure_tenant_id,
|
|
24
|
+
cfg.azure_workspace_resource_id,
|
|
25
|
+
)
|
|
26
|
+
)
|
|
27
|
+
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
|
|
49
28
|
assert cfg.oauth_redirect_port_range is not None
|
|
50
29
|
assert cfg.oauth_client_id is not None
|
|
51
30
|
assert cfg.oauth_scopes is not None
|
|
@@ -56,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
|
|
|
56
35
|
cfg.oauth_redirect_port_range,
|
|
57
36
|
cfg.oauth_client_id,
|
|
58
37
|
cfg.oauth_scopes,
|
|
38
|
+
http_client,
|
|
59
39
|
cfg.auth_type,
|
|
60
40
|
)
|
|
61
41
|
elif cfg.access_token is not None:
|
|
@@ -75,6 +55,8 @@ def get_auth_provider(cfg: ClientContext):
|
|
|
75
55
|
cfg.oauth_redirect_port_range,
|
|
76
56
|
cfg.oauth_client_id,
|
|
77
57
|
cfg.oauth_scopes,
|
|
58
|
+
http_client,
|
|
59
|
+
cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
|
|
78
60
|
)
|
|
79
61
|
else:
|
|
80
62
|
raise RuntimeError("No valid authentication settings!")
|
|
@@ -101,11 +83,14 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
|
|
|
101
83
|
)
|
|
102
84
|
|
|
103
85
|
|
|
104
|
-
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
|
|
86
|
+
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
|
|
87
|
+
# TODO : unify all the auth mechanisms with the Python SDK
|
|
88
|
+
|
|
105
89
|
auth_type = kwargs.get("auth_type")
|
|
106
90
|
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
|
|
107
91
|
auth_type == AuthType.AZURE_OAUTH.value
|
|
108
92
|
)
|
|
93
|
+
|
|
109
94
|
if kwargs.get("username") or kwargs.get("password"):
|
|
110
95
|
raise ValueError(
|
|
111
96
|
"Username/password authentication is no longer supported. "
|
|
@@ -120,10 +105,14 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
|
|
|
120
105
|
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
|
|
121
106
|
oauth_scopes=PYSQL_OAUTH_SCOPES,
|
|
122
107
|
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
|
|
108
|
+
azure_client_id=kwargs.get("azure_client_id"),
|
|
109
|
+
azure_client_secret=kwargs.get("azure_client_secret"),
|
|
110
|
+
azure_tenant_id=kwargs.get("azure_tenant_id"),
|
|
111
|
+
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
|
|
123
112
|
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
|
|
124
113
|
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
|
|
125
114
|
else redirect_port_range,
|
|
126
115
|
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
|
|
127
116
|
credentials_provider=kwargs.get("credentials_provider"),
|
|
128
117
|
)
|
|
129
|
-
return get_auth_provider(cfg)
|
|
118
|
+
return get_auth_provider(cfg, http_client)
|
|
@@ -1,10 +1,18 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
import base64
|
|
3
2
|
import logging
|
|
4
3
|
from typing import Callable, Dict, List
|
|
5
|
-
|
|
6
|
-
from databricks.sql.auth.oauth import
|
|
7
|
-
|
|
4
|
+
from databricks.sql.common.http import HttpHeader
|
|
5
|
+
from databricks.sql.auth.oauth import (
|
|
6
|
+
OAuthManager,
|
|
7
|
+
RefreshableTokenSource,
|
|
8
|
+
ClientCredentialsTokenSource,
|
|
9
|
+
)
|
|
10
|
+
from databricks.sql.auth.endpoint import get_oauth_endpoints
|
|
11
|
+
from databricks.sql.auth.common import (
|
|
12
|
+
AuthType,
|
|
13
|
+
get_effective_azure_login_app_id,
|
|
14
|
+
get_azure_tenant_id_from_host,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
# Private API: this is an evolving interface and it will change in the future.
|
|
10
18
|
# Please must not depend on it in your applications.
|
|
@@ -55,6 +63,7 @@ class DatabricksOAuthProvider(AuthProvider):
|
|
|
55
63
|
redirect_port_range: List[int],
|
|
56
64
|
client_id: str,
|
|
57
65
|
scopes: List[str],
|
|
66
|
+
http_client,
|
|
58
67
|
auth_type: str = "databricks-oauth",
|
|
59
68
|
):
|
|
60
69
|
try:
|
|
@@ -71,6 +80,7 @@ class DatabricksOAuthProvider(AuthProvider):
|
|
|
71
80
|
port_range=redirect_port_range,
|
|
72
81
|
client_id=client_id,
|
|
73
82
|
idp_endpoint=idp_endpoint,
|
|
83
|
+
http_client=http_client,
|
|
74
84
|
)
|
|
75
85
|
self._hostname = hostname
|
|
76
86
|
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
|
|
@@ -146,3 +156,85 @@ class ExternalAuthProvider(AuthProvider):
|
|
|
146
156
|
headers = self._header_factory()
|
|
147
157
|
for k, v in headers.items():
|
|
148
158
|
request_headers[k] = v
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class AzureServicePrincipalCredentialProvider(CredentialsProvider):
|
|
162
|
+
"""
|
|
163
|
+
A credential provider for Azure Service Principal authentication with Databricks.
|
|
164
|
+
|
|
165
|
+
This class implements the CredentialsProvider protocol to authenticate requests
|
|
166
|
+
to Databricks REST APIs using Azure Active Directory (AAD) service principal
|
|
167
|
+
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
|
|
168
|
+
from Azure AD and automatically refreshes them when they expire.
|
|
169
|
+
|
|
170
|
+
Attributes:
|
|
171
|
+
hostname (str): The Databricks workspace hostname.
|
|
172
|
+
azure_client_id (str): The Azure service principal's client ID.
|
|
173
|
+
azure_client_secret (str): The Azure service principal's client secret.
|
|
174
|
+
azure_tenant_id (str): The Azure AD tenant ID.
|
|
175
|
+
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
|
|
179
|
+
AZURE_TOKEN_ENDPOINT = "oauth2/token"
|
|
180
|
+
|
|
181
|
+
AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
|
|
182
|
+
|
|
183
|
+
DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
|
|
184
|
+
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
|
|
185
|
+
"X-Databricks-Azure-Workspace-Resource-Id"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
hostname,
|
|
191
|
+
azure_client_id,
|
|
192
|
+
azure_client_secret,
|
|
193
|
+
http_client,
|
|
194
|
+
azure_tenant_id=None,
|
|
195
|
+
azure_workspace_resource_id=None,
|
|
196
|
+
):
|
|
197
|
+
self.hostname = hostname
|
|
198
|
+
self.azure_client_id = azure_client_id
|
|
199
|
+
self.azure_client_secret = azure_client_secret
|
|
200
|
+
self.azure_workspace_resource_id = azure_workspace_resource_id
|
|
201
|
+
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
|
|
202
|
+
hostname, http_client
|
|
203
|
+
)
|
|
204
|
+
self._http_client = http_client
|
|
205
|
+
|
|
206
|
+
def auth_type(self) -> str:
|
|
207
|
+
return AuthType.AZURE_SP_M2M.value
|
|
208
|
+
|
|
209
|
+
def get_token_source(self, resource: str) -> RefreshableTokenSource:
|
|
210
|
+
return ClientCredentialsTokenSource(
|
|
211
|
+
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
|
|
212
|
+
client_id=self.azure_client_id,
|
|
213
|
+
client_secret=self.azure_client_secret,
|
|
214
|
+
http_client=self._http_client,
|
|
215
|
+
extra_params={"resource": resource},
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def __call__(self, *args, **kwargs) -> HeaderFactory:
|
|
219
|
+
inner = self.get_token_source(
|
|
220
|
+
resource=get_effective_azure_login_app_id(self.hostname)
|
|
221
|
+
)
|
|
222
|
+
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
|
|
223
|
+
|
|
224
|
+
def header_factory() -> Dict[str, str]:
|
|
225
|
+
inner_token = inner.get_token()
|
|
226
|
+
cloud_token = cloud.get_token()
|
|
227
|
+
|
|
228
|
+
headers = {
|
|
229
|
+
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
|
|
230
|
+
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
if self.azure_workspace_resource_id:
|
|
234
|
+
headers[
|
|
235
|
+
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
|
|
236
|
+
] = self.azure_workspace_resource_id
|
|
237
|
+
|
|
238
|
+
return headers
|
|
239
|
+
|
|
240
|
+
return header_factory
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
from databricks.sql.auth.retry import DatabricksRetryPolicy
|
|
6
|
+
from databricks.sql.common.http import HttpMethod
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AuthType(Enum):
|
|
12
|
+
DATABRICKS_OAUTH = "databricks-oauth"
|
|
13
|
+
AZURE_OAUTH = "azure-oauth"
|
|
14
|
+
AZURE_SP_M2M = "azure-sp-m2m"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AzureAppId(Enum):
|
|
18
|
+
DEV = (".dev.azuredatabricks.net", "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc")
|
|
19
|
+
STAGING = (".staging.azuredatabricks.net", "4a67d088-db5c-48f1-9ff2-0aace800ae68")
|
|
20
|
+
PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClientContext:
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
hostname: str,
|
|
27
|
+
access_token: Optional[str] = None,
|
|
28
|
+
auth_type: Optional[str] = None,
|
|
29
|
+
oauth_scopes: Optional[List[str]] = None,
|
|
30
|
+
oauth_client_id: Optional[str] = None,
|
|
31
|
+
azure_client_id: Optional[str] = None,
|
|
32
|
+
azure_client_secret: Optional[str] = None,
|
|
33
|
+
azure_tenant_id: Optional[str] = None,
|
|
34
|
+
azure_workspace_resource_id: Optional[str] = None,
|
|
35
|
+
oauth_redirect_port_range: Optional[List[int]] = None,
|
|
36
|
+
use_cert_as_auth: Optional[str] = None,
|
|
37
|
+
tls_client_cert_file: Optional[str] = None,
|
|
38
|
+
oauth_persistence=None,
|
|
39
|
+
credentials_provider=None,
|
|
40
|
+
# HTTP client configuration parameters
|
|
41
|
+
ssl_options=None, # SSLOptions type
|
|
42
|
+
socket_timeout: Optional[float] = None,
|
|
43
|
+
retry_stop_after_attempts_count: Optional[int] = None,
|
|
44
|
+
retry_delay_min: Optional[float] = None,
|
|
45
|
+
retry_delay_max: Optional[float] = None,
|
|
46
|
+
retry_stop_after_attempts_duration: Optional[float] = None,
|
|
47
|
+
retry_delay_default: Optional[float] = None,
|
|
48
|
+
retry_dangerous_codes: Optional[List[int]] = None,
|
|
49
|
+
proxy_auth_method: Optional[str] = None,
|
|
50
|
+
pool_connections: Optional[int] = None,
|
|
51
|
+
pool_maxsize: Optional[int] = None,
|
|
52
|
+
user_agent: Optional[str] = None,
|
|
53
|
+
):
|
|
54
|
+
self.hostname = hostname
|
|
55
|
+
self.access_token = access_token
|
|
56
|
+
self.auth_type = auth_type
|
|
57
|
+
self.oauth_scopes = oauth_scopes
|
|
58
|
+
self.oauth_client_id = oauth_client_id
|
|
59
|
+
self.azure_client_id = azure_client_id
|
|
60
|
+
self.azure_client_secret = azure_client_secret
|
|
61
|
+
self.azure_tenant_id = azure_tenant_id
|
|
62
|
+
self.azure_workspace_resource_id = azure_workspace_resource_id
|
|
63
|
+
self.oauth_redirect_port_range = oauth_redirect_port_range
|
|
64
|
+
self.use_cert_as_auth = use_cert_as_auth
|
|
65
|
+
self.tls_client_cert_file = tls_client_cert_file
|
|
66
|
+
self.oauth_persistence = oauth_persistence
|
|
67
|
+
self.credentials_provider = credentials_provider
|
|
68
|
+
|
|
69
|
+
# HTTP client configuration
|
|
70
|
+
self.ssl_options = ssl_options
|
|
71
|
+
self.socket_timeout = socket_timeout
|
|
72
|
+
self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
|
|
73
|
+
self.retry_delay_min = retry_delay_min or 1.0
|
|
74
|
+
self.retry_delay_max = retry_delay_max or 10.0
|
|
75
|
+
self.retry_stop_after_attempts_duration = (
|
|
76
|
+
retry_stop_after_attempts_duration or 300.0
|
|
77
|
+
)
|
|
78
|
+
self.retry_delay_default = retry_delay_default or 5.0
|
|
79
|
+
self.retry_dangerous_codes = retry_dangerous_codes or []
|
|
80
|
+
self.proxy_auth_method = proxy_auth_method
|
|
81
|
+
self.pool_connections = pool_connections or 10
|
|
82
|
+
self.pool_maxsize = pool_maxsize or 20
|
|
83
|
+
self.user_agent = user_agent
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_effective_azure_login_app_id(hostname) -> str:
|
|
87
|
+
"""
|
|
88
|
+
Get the effective Azure login app ID for a given hostname.
|
|
89
|
+
This function determines the appropriate Azure login app ID based on the hostname.
|
|
90
|
+
If the hostname does not match any of these domains, it returns the default Databricks resource ID.
|
|
91
|
+
|
|
92
|
+
"""
|
|
93
|
+
for azure_app_id in AzureAppId:
|
|
94
|
+
domain, app_id = azure_app_id.value
|
|
95
|
+
if domain in hostname:
|
|
96
|
+
return app_id
|
|
97
|
+
|
|
98
|
+
# default databricks resource id
|
|
99
|
+
return AzureAppId.PROD.value[1]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_azure_tenant_id_from_host(host: str, http_client) -> str:
|
|
103
|
+
"""
|
|
104
|
+
Load the Azure tenant ID from the Azure Databricks login page.
|
|
105
|
+
|
|
106
|
+
This function retrieves the Azure tenant ID by making a request to the Databricks
|
|
107
|
+
Azure Active Directory (AAD) authentication endpoint. The endpoint redirects to
|
|
108
|
+
the Azure login page, and the tenant ID is extracted from the redirect URL.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
login_url = f"{host}/aad/auth"
|
|
112
|
+
logger.debug("Loading tenant ID from %s", login_url)
|
|
113
|
+
|
|
114
|
+
with http_client.request_context(HttpMethod.GET, login_url) as resp:
|
|
115
|
+
entra_id_endpoint = resp.retries.history[-1].redirect_location
|
|
116
|
+
if entra_id_endpoint is None:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"No Location header in response from {login_url}: {entra_id_endpoint}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
|
|
122
|
+
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
|
|
123
|
+
url = urlparse(entra_id_endpoint)
|
|
124
|
+
path_segments = url.path.split("/")
|
|
125
|
+
if len(path_segments) < 2:
|
|
126
|
+
raise ValueError(f"Invalid path in Location header: {url.path}")
|
|
127
|
+
return path_segments[1]
|
{databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/oauth.py
RENAMED
|
@@ -6,33 +6,59 @@ import secrets
|
|
|
6
6
|
import webbrowser
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
8
|
from http.server import HTTPServer
|
|
9
|
-
from typing import List
|
|
9
|
+
from typing import List, Optional
|
|
10
10
|
|
|
11
11
|
import oauthlib.oauth2
|
|
12
|
-
import requests
|
|
13
12
|
from oauthlib.oauth2.rfc6749.errors import OAuth2Error
|
|
14
|
-
from
|
|
15
|
-
|
|
13
|
+
from databricks.sql.common.http import HttpMethod, HttpHeader
|
|
14
|
+
from databricks.sql.common.http import OAuthResponse
|
|
16
15
|
from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
|
|
17
16
|
from databricks.sql.auth.endpoint import OAuthEndpointCollection
|
|
17
|
+
from abc import abstractmethod, ABC
|
|
18
|
+
from urllib.parse import urlencode
|
|
19
|
+
import jwt
|
|
20
|
+
import time
|
|
18
21
|
|
|
19
22
|
logger = logging.getLogger(__name__)
|
|
20
23
|
|
|
21
24
|
|
|
22
|
-
class
|
|
23
|
-
"""
|
|
25
|
+
class Token:
|
|
26
|
+
"""
|
|
27
|
+
A class to represent a token.
|
|
24
28
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
29
|
+
Attributes:
|
|
30
|
+
access_token (str): The access token string.
|
|
31
|
+
token_type (str): The type of token (e.g., "Bearer").
|
|
32
|
+
refresh_token (str): The refresh token string.
|
|
33
|
+
"""
|
|
28
34
|
|
|
29
|
-
|
|
35
|
+
def __init__(self, access_token: str, token_type: str, refresh_token: str):
|
|
36
|
+
self.access_token = access_token
|
|
37
|
+
self.token_type = token_type
|
|
38
|
+
self.refresh_token = refresh_token
|
|
30
39
|
|
|
31
|
-
|
|
32
|
-
|
|
40
|
+
def is_expired(self) -> bool:
|
|
41
|
+
try:
|
|
42
|
+
decoded_token = jwt.decode(
|
|
43
|
+
self.access_token, options={"verify_signature": False}
|
|
44
|
+
)
|
|
45
|
+
exp_time = decoded_token.get("exp")
|
|
46
|
+
current_time = time.time()
|
|
47
|
+
buffer_time = 30 # 30 seconds buffer
|
|
48
|
+
return exp_time and (exp_time - buffer_time) <= current_time
|
|
49
|
+
except Exception as e:
|
|
50
|
+
logger.error("Failed to decode token: %s", e)
|
|
51
|
+
raise e
|
|
33
52
|
|
|
34
|
-
|
|
35
|
-
|
|
53
|
+
|
|
54
|
+
class RefreshableTokenSource(ABC):
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def get_token(self) -> Token:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def refresh(self) -> Token:
|
|
61
|
+
pass
|
|
36
62
|
|
|
37
63
|
|
|
38
64
|
class OAuthManager:
|
|
@@ -41,11 +67,13 @@ class OAuthManager:
|
|
|
41
67
|
port_range: List[int],
|
|
42
68
|
client_id: str,
|
|
43
69
|
idp_endpoint: OAuthEndpointCollection,
|
|
70
|
+
http_client,
|
|
44
71
|
):
|
|
45
72
|
self.port_range = port_range
|
|
46
73
|
self.client_id = client_id
|
|
47
74
|
self.redirect_port = None
|
|
48
75
|
self.idp_endpoint = idp_endpoint
|
|
76
|
+
self.http_client = http_client
|
|
49
77
|
|
|
50
78
|
@staticmethod
|
|
51
79
|
def __token_urlsafe(nbytes=32):
|
|
@@ -59,8 +87,11 @@ class OAuthManager:
|
|
|
59
87
|
known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
|
|
60
88
|
|
|
61
89
|
try:
|
|
62
|
-
response =
|
|
63
|
-
|
|
90
|
+
response = self.http_client.request(HttpMethod.GET, url=known_config_url)
|
|
91
|
+
# Convert urllib3 response to requests-like response for compatibility
|
|
92
|
+
response.status_code = response.status
|
|
93
|
+
response.json = lambda: json.loads(response.data.decode())
|
|
94
|
+
except Exception as e:
|
|
64
95
|
logger.error(
|
|
65
96
|
f"Unable to fetch OAuth configuration from {known_config_url}.\n"
|
|
66
97
|
"Verify it is a valid workspace URL and that OAuth is "
|
|
@@ -78,7 +109,7 @@ class OAuthManager:
|
|
|
78
109
|
raise RuntimeError(msg)
|
|
79
110
|
try:
|
|
80
111
|
return response.json()
|
|
81
|
-
except
|
|
112
|
+
except Exception as e:
|
|
82
113
|
logger.error(
|
|
83
114
|
f"Unable to decode OAuth configuration from {known_config_url}.\n"
|
|
84
115
|
"Verify it is a valid workspace URL and that OAuth is "
|
|
@@ -159,16 +190,17 @@ class OAuthManager:
|
|
|
159
190
|
data = f"{token_request_body}&code_verifier={verifier}"
|
|
160
191
|
return self.__send_token_request(token_request_url, data)
|
|
161
192
|
|
|
162
|
-
|
|
163
|
-
def __send_token_request(token_request_url, data):
|
|
193
|
+
def __send_token_request(self, token_request_url, data):
|
|
164
194
|
headers = {
|
|
165
195
|
"Accept": "application/json",
|
|
166
196
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
167
197
|
}
|
|
168
|
-
|
|
169
|
-
|
|
198
|
+
# Use unified HTTP client
|
|
199
|
+
response = self.http_client.request(
|
|
200
|
+
HttpMethod.POST, url=token_request_url, body=data, headers=headers
|
|
170
201
|
)
|
|
171
|
-
|
|
202
|
+
# Convert urllib3 response to dict for compatibility
|
|
203
|
+
return json.loads(response.data.decode())
|
|
172
204
|
|
|
173
205
|
def __send_refresh_token_request(self, hostname, refresh_token):
|
|
174
206
|
oauth_config = self.__fetch_well_known_config(hostname)
|
|
@@ -177,7 +209,7 @@ class OAuthManager:
|
|
|
177
209
|
token_request_body = client.prepare_refresh_body(
|
|
178
210
|
refresh_token=refresh_token, client_id=client.client_id
|
|
179
211
|
)
|
|
180
|
-
return
|
|
212
|
+
return self.__send_token_request(token_request_url, token_request_body)
|
|
181
213
|
|
|
182
214
|
@staticmethod
|
|
183
215
|
def __get_tokens_from_response(oauth_response):
|
|
@@ -258,3 +290,64 @@ class OAuthManager:
|
|
|
258
290
|
client, token_request_url, redirect_url, code, verifier
|
|
259
291
|
)
|
|
260
292
|
return self.__get_tokens_from_response(oauth_response)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class ClientCredentialsTokenSource(RefreshableTokenSource):
|
|
296
|
+
"""
|
|
297
|
+
A token source that uses client credentials to get a token from the token endpoint.
|
|
298
|
+
It will refresh the token if it is expired.
|
|
299
|
+
|
|
300
|
+
Attributes:
|
|
301
|
+
token_url (str): The URL of the token endpoint.
|
|
302
|
+
client_id (str): The client ID.
|
|
303
|
+
client_secret (str): The client secret.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
token_url,
|
|
309
|
+
client_id,
|
|
310
|
+
client_secret,
|
|
311
|
+
http_client,
|
|
312
|
+
extra_params: dict = {},
|
|
313
|
+
):
|
|
314
|
+
self.client_id = client_id
|
|
315
|
+
self.client_secret = client_secret
|
|
316
|
+
self.token_url = token_url
|
|
317
|
+
self.extra_params = extra_params
|
|
318
|
+
self.token: Optional[Token] = None
|
|
319
|
+
self._http_client = http_client
|
|
320
|
+
|
|
321
|
+
def get_token(self) -> Token:
|
|
322
|
+
if self.token is None or self.token.is_expired():
|
|
323
|
+
self.token = self.refresh()
|
|
324
|
+
return self.token
|
|
325
|
+
|
|
326
|
+
def refresh(self) -> Token:
|
|
327
|
+
logger.info("Refreshing OAuth token using client credentials flow")
|
|
328
|
+
headers = {
|
|
329
|
+
HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded",
|
|
330
|
+
}
|
|
331
|
+
data = urlencode(
|
|
332
|
+
{
|
|
333
|
+
"grant_type": "client_credentials",
|
|
334
|
+
"client_id": self.client_id,
|
|
335
|
+
"client_secret": self.client_secret,
|
|
336
|
+
**self.extra_params,
|
|
337
|
+
}
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
response = self._http_client.request(
|
|
341
|
+
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
|
|
342
|
+
)
|
|
343
|
+
if response.status == 200:
|
|
344
|
+
oauth_response = OAuthResponse(**json.loads(response.data.decode("utf-8")))
|
|
345
|
+
return Token(
|
|
346
|
+
oauth_response.access_token,
|
|
347
|
+
oauth_response.token_type,
|
|
348
|
+
oauth_response.refresh_token,
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
raise Exception(
|
|
352
|
+
f"Failed to get token: {response.status} {response.data.decode('utf-8')}"
|
|
353
|
+
)
|