databricks-sql-connector 4.0.4__tar.gz → 4.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/CHANGELOG.md +3 -0
  2. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/PKG-INFO +4 -2
  3. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/pyproject.toml +25 -3
  4. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/__init__.py +1 -1
  5. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/auth.py +27 -38
  6. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/authenticators.py +96 -4
  7. databricks_sql_connector-4.0.6/src/databricks/sql/auth/common.py +127 -0
  8. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/oauth.py +116 -23
  9. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/retry.py +9 -3
  10. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/thrift_http_client.py +25 -24
  11. databricks_sql_connector-4.0.6/src/databricks/sql/backend/databricks_client.py +347 -0
  12. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/backend.py +825 -0
  13. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/__init__.py +52 -0
  14. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/base.py +82 -0
  15. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/requests.py +133 -0
  16. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/models/responses.py +196 -0
  17. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/queue.py +391 -0
  18. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/result_set.py +266 -0
  19. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/constants.py +67 -0
  20. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/conversion.py +173 -0
  21. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/filters.py +289 -0
  22. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/http_client.py +294 -0
  23. databricks_sql_connector-4.0.6/src/databricks/sql/backend/sea/utils/normalize.py +50 -0
  24. {databricks_sql_connector-4.0.4/src/databricks/sql → databricks_sql_connector-4.0.6/src/databricks/sql/backend}/thrift_backend.py +395 -203
  25. databricks_sql_connector-4.0.6/src/databricks/sql/backend/types.py +427 -0
  26. databricks_sql_connector-4.0.6/src/databricks/sql/backend/utils/__init__.py +3 -0
  27. databricks_sql_connector-4.0.6/src/databricks/sql/backend/utils/guid_utils.py +23 -0
  28. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/client.py +283 -549
  29. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/cloudfetch/download_manager.py +43 -9
  30. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/cloudfetch/downloader.py +83 -58
  31. databricks_sql_connector-4.0.6/src/databricks/sql/common/feature_flag.py +187 -0
  32. databricks_sql_connector-4.0.6/src/databricks/sql/common/http.py +40 -0
  33. databricks_sql_connector-4.0.6/src/databricks/sql/common/http_utils.py +100 -0
  34. databricks_sql_connector-4.0.6/src/databricks/sql/common/unified_http_client.py +309 -0
  35. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/exc.py +13 -2
  36. databricks_sql_connector-4.0.6/src/databricks/sql/result_set.py +439 -0
  37. databricks_sql_connector-4.0.6/src/databricks/sql/session.py +195 -0
  38. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/latency_logger.py +216 -0
  39. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/endpoint_models.py +39 -0
  40. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/enums.py +44 -0
  41. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/event.py +162 -0
  42. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/models/frontend_logs.py +65 -0
  43. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/telemetry_client.py +561 -0
  44. databricks_sql_connector-4.0.6/src/databricks/sql/telemetry/utils.py +69 -0
  45. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/types.py +1 -0
  46. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/utils.py +231 -71
  47. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/LICENSE +0 -0
  48. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/README.md +0 -0
  49. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/__init__.py +0 -0
  50. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/__init__.py +0 -0
  51. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/endpoint.py +0 -0
  52. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
  53. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/experimental/__init__.py +0 -0
  54. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
  55. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/parameters/__init__.py +0 -0
  56. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/parameters/native.py +0 -0
  57. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/parameters/py.typed +0 -0
  58. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/py.typed +0 -0
  59. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
  60. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
  61. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
  62. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
  63. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
  64. {databricks_sql_connector-4.0.4 → databricks_sql_connector-4.0.6}/src/databricks/sql/thrift_api/__init__.py +0 -0
@@ -1,5 +1,8 @@
1
1
  # Release History
2
2
 
3
+ # 4.0.5 (2025-06-24)
4
+ - Fix: Reverted change in cursor close handling which led to errors impacting users (databricks/databricks-sql-python#613 by @madhav-db)
5
+
3
6
  # 4.0.4 (2025-06-16)
4
7
 
5
8
  - Update thrift client library after cleaning up unused fields and structs (databricks/databricks-sql-python#553 by @vikrantpuppala)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: databricks-sql-connector
3
- Version: 4.0.4
3
+ Version: 4.0.6
4
4
  Summary: Databricks SQL Connector for Python
5
5
  License: Apache-2.0
6
6
  Author: Databricks
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.9
13
13
  Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
15
  Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
16
17
  Provides-Extra: pyarrow
17
18
  Requires-Dist: lz4 (>=4.0.2,<5.0.0)
18
19
  Requires-Dist: oauthlib (>=3.1.0,<4.0.0)
@@ -21,6 +22,7 @@ Requires-Dist: pandas (>=1.2.5,<2.3.0) ; python_version >= "3.8" and python_vers
21
22
  Requires-Dist: pandas (>=2.2.3,<2.3.0) ; python_version >= "3.13"
22
23
  Requires-Dist: pyarrow (>=14.0.1) ; (python_version >= "3.8" and python_version < "3.13") and (extra == "pyarrow")
23
24
  Requires-Dist: pyarrow (>=18.0.0) ; (python_version >= "3.13") and (extra == "pyarrow")
25
+ Requires-Dist: pyjwt (>=2.0.0,<3.0.0)
24
26
  Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
25
27
  Requires-Dist: requests (>=2.18.1,<3.0.0)
26
28
  Requires-Dist: thrift (>=0.16.0,<0.21.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "databricks-sql-connector"
3
- version = "4.0.4"
3
+ version = "4.0.6"
4
4
  description = "Databricks SQL Connector for Python"
5
5
  authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
6
6
  license = "Apache-2.0"
@@ -20,21 +20,25 @@ requests = "^2.18.1"
20
20
  oauthlib = "^3.1.0"
21
21
  openpyxl = "^3.0.10"
22
22
  urllib3 = ">=1.26"
23
+ python-dateutil = "^2.8.0"
23
24
  pyarrow = [
24
25
  { version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
25
26
  { version = ">=18.0.0", python = ">=3.13", optional=true }
26
27
  ]
27
- python-dateutil = "^2.8.0"
28
+ pyjwt = "^2.0.0"
29
+ requests-kerberos = {version = "^0.15.0", optional = true}
30
+
28
31
 
29
32
  [tool.poetry.extras]
30
33
  pyarrow = ["pyarrow"]
31
34
 
32
- [tool.poetry.dev-dependencies]
35
+ [tool.poetry.group.dev.dependencies]
33
36
  pytest = "^7.1.2"
34
37
  mypy = "^1.10.1"
35
38
  pylint = ">=2.12.0"
36
39
  black = "^22.3.0"
37
40
  pytest-dotenv = "^0.5.2"
41
+ pytest-cov = "^4.0.0"
38
42
  numpy = [
39
43
  { version = ">=1.16.6", python = ">=3.8,<3.11" },
40
44
  { version = ">=1.23.4", python = ">=3.11" },
@@ -62,3 +66,21 @@ log_cli = "false"
62
66
  log_cli_level = "INFO"
63
67
  testpaths = ["tests"]
64
68
  env_files = ["test.env"]
69
+
70
+ [tool.coverage.run]
71
+ source = ["src"]
72
+ branch = true
73
+ omit = [
74
+ "*/tests/*",
75
+ "*/test_*",
76
+ "*/__pycache__/*",
77
+ "*/thrift_api/*",
78
+ ]
79
+
80
+ [tool.coverage.report]
81
+ precision = 2
82
+ show_missing = true
83
+ skip_covered = false
84
+
85
+ [tool.coverage.xml]
86
+ output = "coverage.xml"
@@ -68,7 +68,7 @@ DATETIME = DBAPITypeObject("timestamp")
68
68
  DATE = DBAPITypeObject("date")
69
69
  ROWID = DBAPITypeObject()
70
70
 
71
- __version__ = "4.0.4"
71
+ __version__ = "4.0.5"
72
72
  USER_AGENT_NAME = "PyDatabricksSqlConnector"
73
73
 
74
74
  # These two functions are pyhive legacy
@@ -1,4 +1,3 @@
1
- from enum import Enum
2
1
  from typing import Optional, List
3
2
 
4
3
  from databricks.sql.auth.authenticators import (
@@ -6,46 +5,26 @@ from databricks.sql.auth.authenticators import (
6
5
  AccessTokenAuthProvider,
7
6
  ExternalAuthProvider,
8
7
  DatabricksOAuthProvider,
8
+ AzureServicePrincipalCredentialProvider,
9
9
  )
10
+ from databricks.sql.auth.common import AuthType, ClientContext
10
11
 
11
12
 
12
- class AuthType(Enum):
13
- DATABRICKS_OAUTH = "databricks-oauth"
14
- AZURE_OAUTH = "azure-oauth"
15
- # other supported types (access_token) can be inferred
16
- # we can add more types as needed later
17
-
18
-
19
- class ClientContext:
20
- def __init__(
21
- self,
22
- hostname: str,
23
- access_token: Optional[str] = None,
24
- auth_type: Optional[str] = None,
25
- oauth_scopes: Optional[List[str]] = None,
26
- oauth_client_id: Optional[str] = None,
27
- oauth_redirect_port_range: Optional[List[int]] = None,
28
- use_cert_as_auth: Optional[str] = None,
29
- tls_client_cert_file: Optional[str] = None,
30
- oauth_persistence=None,
31
- credentials_provider=None,
32
- ):
33
- self.hostname = hostname
34
- self.access_token = access_token
35
- self.auth_type = auth_type
36
- self.oauth_scopes = oauth_scopes
37
- self.oauth_client_id = oauth_client_id
38
- self.oauth_redirect_port_range = oauth_redirect_port_range
39
- self.use_cert_as_auth = use_cert_as_auth
40
- self.tls_client_cert_file = tls_client_cert_file
41
- self.oauth_persistence = oauth_persistence
42
- self.credentials_provider = credentials_provider
43
-
44
-
45
- def get_auth_provider(cfg: ClientContext):
13
+ def get_auth_provider(cfg: ClientContext, http_client):
46
14
  if cfg.credentials_provider:
47
15
  return ExternalAuthProvider(cfg.credentials_provider)
48
- if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
16
+ elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
17
+ return ExternalAuthProvider(
18
+ AzureServicePrincipalCredentialProvider(
19
+ cfg.hostname,
20
+ cfg.azure_client_id,
21
+ cfg.azure_client_secret,
22
+ http_client,
23
+ cfg.azure_tenant_id,
24
+ cfg.azure_workspace_resource_id,
25
+ )
26
+ )
27
+ elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
49
28
  assert cfg.oauth_redirect_port_range is not None
50
29
  assert cfg.oauth_client_id is not None
51
30
  assert cfg.oauth_scopes is not None
@@ -56,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
56
35
  cfg.oauth_redirect_port_range,
57
36
  cfg.oauth_client_id,
58
37
  cfg.oauth_scopes,
38
+ http_client,
59
39
  cfg.auth_type,
60
40
  )
61
41
  elif cfg.access_token is not None:
@@ -75,6 +55,8 @@ def get_auth_provider(cfg: ClientContext):
75
55
  cfg.oauth_redirect_port_range,
76
56
  cfg.oauth_client_id,
77
57
  cfg.oauth_scopes,
58
+ http_client,
59
+ cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
78
60
  )
79
61
  else:
80
62
  raise RuntimeError("No valid authentication settings!")
@@ -101,11 +83,14 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
101
83
  )
102
84
 
103
85
 
104
- def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
86
+ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
87
+ # TODO : unify all the auth mechanisms with the Python SDK
88
+
105
89
  auth_type = kwargs.get("auth_type")
106
90
  (client_id, redirect_port_range) = get_client_id_and_redirect_port(
107
91
  auth_type == AuthType.AZURE_OAUTH.value
108
92
  )
93
+
109
94
  if kwargs.get("username") or kwargs.get("password"):
110
95
  raise ValueError(
111
96
  "Username/password authentication is no longer supported. "
@@ -120,10 +105,14 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
120
105
  tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
121
106
  oauth_scopes=PYSQL_OAUTH_SCOPES,
122
107
  oauth_client_id=kwargs.get("oauth_client_id") or client_id,
108
+ azure_client_id=kwargs.get("azure_client_id"),
109
+ azure_client_secret=kwargs.get("azure_client_secret"),
110
+ azure_tenant_id=kwargs.get("azure_tenant_id"),
111
+ azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
123
112
  oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
124
113
  if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
125
114
  else redirect_port_range,
126
115
  oauth_persistence=kwargs.get("experimental_oauth_persistence"),
127
116
  credentials_provider=kwargs.get("credentials_provider"),
128
117
  )
129
- return get_auth_provider(cfg)
118
+ return get_auth_provider(cfg, http_client)
@@ -1,10 +1,18 @@
1
1
  import abc
2
- import base64
3
2
  import logging
4
3
  from typing import Callable, Dict, List
5
-
6
- from databricks.sql.auth.oauth import OAuthManager
7
- from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
4
+ from databricks.sql.common.http import HttpHeader
5
+ from databricks.sql.auth.oauth import (
6
+ OAuthManager,
7
+ RefreshableTokenSource,
8
+ ClientCredentialsTokenSource,
9
+ )
10
+ from databricks.sql.auth.endpoint import get_oauth_endpoints
11
+ from databricks.sql.auth.common import (
12
+ AuthType,
13
+ get_effective_azure_login_app_id,
14
+ get_azure_tenant_id_from_host,
15
+ )
8
16
 
9
17
  # Private API: this is an evolving interface and it will change in the future.
10
18
  # Please must not depend on it in your applications.
@@ -55,6 +63,7 @@ class DatabricksOAuthProvider(AuthProvider):
55
63
  redirect_port_range: List[int],
56
64
  client_id: str,
57
65
  scopes: List[str],
66
+ http_client,
58
67
  auth_type: str = "databricks-oauth",
59
68
  ):
60
69
  try:
@@ -71,6 +80,7 @@ class DatabricksOAuthProvider(AuthProvider):
71
80
  port_range=redirect_port_range,
72
81
  client_id=client_id,
73
82
  idp_endpoint=idp_endpoint,
83
+ http_client=http_client,
74
84
  )
75
85
  self._hostname = hostname
76
86
  self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
@@ -146,3 +156,85 @@ class ExternalAuthProvider(AuthProvider):
146
156
  headers = self._header_factory()
147
157
  for k, v in headers.items():
148
158
  request_headers[k] = v
159
+
160
+
161
+ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
162
+ """
163
+ A credential provider for Azure Service Principal authentication with Databricks.
164
+
165
+ This class implements the CredentialsProvider protocol to authenticate requests
166
+ to Databricks REST APIs using Azure Active Directory (AAD) service principal
167
+ credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
168
+ from Azure AD and automatically refreshes them when they expire.
169
+
170
+ Attributes:
171
+ hostname (str): The Databricks workspace hostname.
172
+ azure_client_id (str): The Azure service principal's client ID.
173
+ azure_client_secret (str): The Azure service principal's client secret.
174
+ azure_tenant_id (str): The Azure AD tenant ID.
175
+ azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
176
+ """
177
+
178
+ AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
179
+ AZURE_TOKEN_ENDPOINT = "oauth2/token"
180
+
181
+ AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
182
+
183
+ DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
184
+ DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
185
+ "X-Databricks-Azure-Workspace-Resource-Id"
186
+ )
187
+
188
+ def __init__(
189
+ self,
190
+ hostname,
191
+ azure_client_id,
192
+ azure_client_secret,
193
+ http_client,
194
+ azure_tenant_id=None,
195
+ azure_workspace_resource_id=None,
196
+ ):
197
+ self.hostname = hostname
198
+ self.azure_client_id = azure_client_id
199
+ self.azure_client_secret = azure_client_secret
200
+ self.azure_workspace_resource_id = azure_workspace_resource_id
201
+ self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
202
+ hostname, http_client
203
+ )
204
+ self._http_client = http_client
205
+
206
+ def auth_type(self) -> str:
207
+ return AuthType.AZURE_SP_M2M.value
208
+
209
+ def get_token_source(self, resource: str) -> RefreshableTokenSource:
210
+ return ClientCredentialsTokenSource(
211
+ token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
212
+ client_id=self.azure_client_id,
213
+ client_secret=self.azure_client_secret,
214
+ http_client=self._http_client,
215
+ extra_params={"resource": resource},
216
+ )
217
+
218
+ def __call__(self, *args, **kwargs) -> HeaderFactory:
219
+ inner = self.get_token_source(
220
+ resource=get_effective_azure_login_app_id(self.hostname)
221
+ )
222
+ cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
223
+
224
+ def header_factory() -> Dict[str, str]:
225
+ inner_token = inner.get_token()
226
+ cloud_token = cloud.get_token()
227
+
228
+ headers = {
229
+ HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
230
+ self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
231
+ }
232
+
233
+ if self.azure_workspace_resource_id:
234
+ headers[
235
+ self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
236
+ ] = self.azure_workspace_resource_id
237
+
238
+ return headers
239
+
240
+ return header_factory
@@ -0,0 +1,127 @@
1
+ from enum import Enum
2
+ import logging
3
+ from typing import Optional, List
4
+ from urllib.parse import urlparse
5
+ from databricks.sql.auth.retry import DatabricksRetryPolicy
6
+ from databricks.sql.common.http import HttpMethod
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class AuthType(Enum):
12
+ DATABRICKS_OAUTH = "databricks-oauth"
13
+ AZURE_OAUTH = "azure-oauth"
14
+ AZURE_SP_M2M = "azure-sp-m2m"
15
+
16
+
17
+ class AzureAppId(Enum):
18
+ DEV = (".dev.azuredatabricks.net", "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc")
19
+ STAGING = (".staging.azuredatabricks.net", "4a67d088-db5c-48f1-9ff2-0aace800ae68")
20
+ PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
21
+
22
+
23
+ class ClientContext:
24
+ def __init__(
25
+ self,
26
+ hostname: str,
27
+ access_token: Optional[str] = None,
28
+ auth_type: Optional[str] = None,
29
+ oauth_scopes: Optional[List[str]] = None,
30
+ oauth_client_id: Optional[str] = None,
31
+ azure_client_id: Optional[str] = None,
32
+ azure_client_secret: Optional[str] = None,
33
+ azure_tenant_id: Optional[str] = None,
34
+ azure_workspace_resource_id: Optional[str] = None,
35
+ oauth_redirect_port_range: Optional[List[int]] = None,
36
+ use_cert_as_auth: Optional[str] = None,
37
+ tls_client_cert_file: Optional[str] = None,
38
+ oauth_persistence=None,
39
+ credentials_provider=None,
40
+ # HTTP client configuration parameters
41
+ ssl_options=None, # SSLOptions type
42
+ socket_timeout: Optional[float] = None,
43
+ retry_stop_after_attempts_count: Optional[int] = None,
44
+ retry_delay_min: Optional[float] = None,
45
+ retry_delay_max: Optional[float] = None,
46
+ retry_stop_after_attempts_duration: Optional[float] = None,
47
+ retry_delay_default: Optional[float] = None,
48
+ retry_dangerous_codes: Optional[List[int]] = None,
49
+ proxy_auth_method: Optional[str] = None,
50
+ pool_connections: Optional[int] = None,
51
+ pool_maxsize: Optional[int] = None,
52
+ user_agent: Optional[str] = None,
53
+ ):
54
+ self.hostname = hostname
55
+ self.access_token = access_token
56
+ self.auth_type = auth_type
57
+ self.oauth_scopes = oauth_scopes
58
+ self.oauth_client_id = oauth_client_id
59
+ self.azure_client_id = azure_client_id
60
+ self.azure_client_secret = azure_client_secret
61
+ self.azure_tenant_id = azure_tenant_id
62
+ self.azure_workspace_resource_id = azure_workspace_resource_id
63
+ self.oauth_redirect_port_range = oauth_redirect_port_range
64
+ self.use_cert_as_auth = use_cert_as_auth
65
+ self.tls_client_cert_file = tls_client_cert_file
66
+ self.oauth_persistence = oauth_persistence
67
+ self.credentials_provider = credentials_provider
68
+
69
+ # HTTP client configuration
70
+ self.ssl_options = ssl_options
71
+ self.socket_timeout = socket_timeout
72
+ self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
73
+ self.retry_delay_min = retry_delay_min or 1.0
74
+ self.retry_delay_max = retry_delay_max or 10.0
75
+ self.retry_stop_after_attempts_duration = (
76
+ retry_stop_after_attempts_duration or 300.0
77
+ )
78
+ self.retry_delay_default = retry_delay_default or 5.0
79
+ self.retry_dangerous_codes = retry_dangerous_codes or []
80
+ self.proxy_auth_method = proxy_auth_method
81
+ self.pool_connections = pool_connections or 10
82
+ self.pool_maxsize = pool_maxsize or 20
83
+ self.user_agent = user_agent
84
+
85
+
86
+ def get_effective_azure_login_app_id(hostname) -> str:
87
+ """
88
+ Get the effective Azure login app ID for a given hostname.
89
+ This function determines the appropriate Azure login app ID based on the hostname.
90
+ If the hostname does not match any of these domains, it returns the default Databricks resource ID.
91
+
92
+ """
93
+ for azure_app_id in AzureAppId:
94
+ domain, app_id = azure_app_id.value
95
+ if domain in hostname:
96
+ return app_id
97
+
98
+ # default databricks resource id
99
+ return AzureAppId.PROD.value[1]
100
+
101
+
102
+ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
103
+ """
104
+ Load the Azure tenant ID from the Azure Databricks login page.
105
+
106
+ This function retrieves the Azure tenant ID by making a request to the Databricks
107
+ Azure Active Directory (AAD) authentication endpoint. The endpoint redirects to
108
+ the Azure login page, and the tenant ID is extracted from the redirect URL.
109
+ """
110
+
111
+ login_url = f"{host}/aad/auth"
112
+ logger.debug("Loading tenant ID from %s", login_url)
113
+
114
+ with http_client.request_context(HttpMethod.GET, login_url) as resp:
115
+ entra_id_endpoint = resp.retries.history[-1].redirect_location
116
+ if entra_id_endpoint is None:
117
+ raise ValueError(
118
+ f"No Location header in response from {login_url}: {entra_id_endpoint}"
119
+ )
120
+
121
+ # The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
122
+ # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
123
+ url = urlparse(entra_id_endpoint)
124
+ path_segments = url.path.split("/")
125
+ if len(path_segments) < 2:
126
+ raise ValueError(f"Invalid path in Location header: {url.path}")
127
+ return path_segments[1]
@@ -6,33 +6,59 @@ import secrets
6
6
  import webbrowser
7
7
  from datetime import datetime, timezone
8
8
  from http.server import HTTPServer
9
- from typing import List
9
+ from typing import List, Optional
10
10
 
11
11
  import oauthlib.oauth2
12
- import requests
13
12
  from oauthlib.oauth2.rfc6749.errors import OAuth2Error
14
- from requests.exceptions import RequestException
15
-
13
+ from databricks.sql.common.http import HttpMethod, HttpHeader
14
+ from databricks.sql.common.http import OAuthResponse
16
15
  from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler
17
16
  from databricks.sql.auth.endpoint import OAuthEndpointCollection
17
+ from abc import abstractmethod, ABC
18
+ from urllib.parse import urlencode
19
+ import jwt
20
+ import time
18
21
 
19
22
  logger = logging.getLogger(__name__)
20
23
 
21
24
 
22
- class IgnoreNetrcAuth(requests.auth.AuthBase):
23
- """This auth method is a no-op.
25
+ class Token:
26
+ """
27
+ A class to represent a token.
24
28
 
25
- We use it to force requestslib to not use .netrc to write auth headers
26
- when making .post() requests to the oauth token endpoints, since these
27
- don't require authentication.
29
+ Attributes:
30
+ access_token (str): The access token string.
31
+ token_type (str): The type of token (e.g., "Bearer").
32
+ refresh_token (str): The refresh token string.
33
+ """
28
34
 
29
- In cases where .netrc is outdated or corrupt, these requests will fail.
35
+ def __init__(self, access_token: str, token_type: str, refresh_token: str):
36
+ self.access_token = access_token
37
+ self.token_type = token_type
38
+ self.refresh_token = refresh_token
30
39
 
31
- See issue #121
32
- """
40
+ def is_expired(self) -> bool:
41
+ try:
42
+ decoded_token = jwt.decode(
43
+ self.access_token, options={"verify_signature": False}
44
+ )
45
+ exp_time = decoded_token.get("exp")
46
+ current_time = time.time()
47
+ buffer_time = 30 # 30 seconds buffer
48
+ return exp_time and (exp_time - buffer_time) <= current_time
49
+ except Exception as e:
50
+ logger.error("Failed to decode token: %s", e)
51
+ raise e
33
52
 
34
- def __call__(self, r):
35
- return r
53
+
54
+ class RefreshableTokenSource(ABC):
55
+ @abstractmethod
56
+ def get_token(self) -> Token:
57
+ pass
58
+
59
+ @abstractmethod
60
+ def refresh(self) -> Token:
61
+ pass
36
62
 
37
63
 
38
64
  class OAuthManager:
@@ -41,11 +67,13 @@ class OAuthManager:
41
67
  port_range: List[int],
42
68
  client_id: str,
43
69
  idp_endpoint: OAuthEndpointCollection,
70
+ http_client,
44
71
  ):
45
72
  self.port_range = port_range
46
73
  self.client_id = client_id
47
74
  self.redirect_port = None
48
75
  self.idp_endpoint = idp_endpoint
76
+ self.http_client = http_client
49
77
 
50
78
  @staticmethod
51
79
  def __token_urlsafe(nbytes=32):
@@ -59,8 +87,11 @@ class OAuthManager:
59
87
  known_config_url = self.idp_endpoint.get_openid_config_url(hostname)
60
88
 
61
89
  try:
62
- response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth())
63
- except RequestException as e:
90
+ response = self.http_client.request(HttpMethod.GET, url=known_config_url)
91
+ # Convert urllib3 response to requests-like response for compatibility
92
+ response.status_code = response.status
93
+ response.json = lambda: json.loads(response.data.decode())
94
+ except Exception as e:
64
95
  logger.error(
65
96
  f"Unable to fetch OAuth configuration from {known_config_url}.\n"
66
97
  "Verify it is a valid workspace URL and that OAuth is "
@@ -78,7 +109,7 @@ class OAuthManager:
78
109
  raise RuntimeError(msg)
79
110
  try:
80
111
  return response.json()
81
- except requests.exceptions.JSONDecodeError as e:
112
+ except Exception as e:
82
113
  logger.error(
83
114
  f"Unable to decode OAuth configuration from {known_config_url}.\n"
84
115
  "Verify it is a valid workspace URL and that OAuth is "
@@ -159,16 +190,17 @@ class OAuthManager:
159
190
  data = f"{token_request_body}&code_verifier={verifier}"
160
191
  return self.__send_token_request(token_request_url, data)
161
192
 
162
- @staticmethod
163
- def __send_token_request(token_request_url, data):
193
+ def __send_token_request(self, token_request_url, data):
164
194
  headers = {
165
195
  "Accept": "application/json",
166
196
  "Content-Type": "application/x-www-form-urlencoded",
167
197
  }
168
- response = requests.post(
169
- url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth()
198
+ # Use unified HTTP client
199
+ response = self.http_client.request(
200
+ HttpMethod.POST, url=token_request_url, body=data, headers=headers
170
201
  )
171
- return response.json()
202
+ # Convert urllib3 response to dict for compatibility
203
+ return json.loads(response.data.decode())
172
204
 
173
205
  def __send_refresh_token_request(self, hostname, refresh_token):
174
206
  oauth_config = self.__fetch_well_known_config(hostname)
@@ -177,7 +209,7 @@ class OAuthManager:
177
209
  token_request_body = client.prepare_refresh_body(
178
210
  refresh_token=refresh_token, client_id=client.client_id
179
211
  )
180
- return OAuthManager.__send_token_request(token_request_url, token_request_body)
212
+ return self.__send_token_request(token_request_url, token_request_body)
181
213
 
182
214
  @staticmethod
183
215
  def __get_tokens_from_response(oauth_response):
@@ -258,3 +290,64 @@ class OAuthManager:
258
290
  client, token_request_url, redirect_url, code, verifier
259
291
  )
260
292
  return self.__get_tokens_from_response(oauth_response)
293
+
294
+
295
+ class ClientCredentialsTokenSource(RefreshableTokenSource):
296
+ """
297
+ A token source that uses client credentials to get a token from the token endpoint.
298
+ It will refresh the token if it is expired.
299
+
300
+ Attributes:
301
+ token_url (str): The URL of the token endpoint.
302
+ client_id (str): The client ID.
303
+ client_secret (str): The client secret.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ token_url,
309
+ client_id,
310
+ client_secret,
311
+ http_client,
312
+ extra_params: dict = {},
313
+ ):
314
+ self.client_id = client_id
315
+ self.client_secret = client_secret
316
+ self.token_url = token_url
317
+ self.extra_params = extra_params
318
+ self.token: Optional[Token] = None
319
+ self._http_client = http_client
320
+
321
+ def get_token(self) -> Token:
322
+ if self.token is None or self.token.is_expired():
323
+ self.token = self.refresh()
324
+ return self.token
325
+
326
+ def refresh(self) -> Token:
327
+ logger.info("Refreshing OAuth token using client credentials flow")
328
+ headers = {
329
+ HttpHeader.CONTENT_TYPE.value: "application/x-www-form-urlencoded",
330
+ }
331
+ data = urlencode(
332
+ {
333
+ "grant_type": "client_credentials",
334
+ "client_id": self.client_id,
335
+ "client_secret": self.client_secret,
336
+ **self.extra_params,
337
+ }
338
+ )
339
+
340
+ response = self._http_client.request(
341
+ method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
342
+ )
343
+ if response.status == 200:
344
+ oauth_response = OAuthResponse(**json.loads(response.data.decode("utf-8")))
345
+ return Token(
346
+ oauth_response.access_token,
347
+ oauth_response.token_type,
348
+ oauth_response.refresh_token,
349
+ )
350
+ else:
351
+ raise Exception(
352
+ f"Failed to get token: {response.status} {response.data.decode('utf-8')}"
353
+ )