databricks-sql-connector 4.0.5__tar.gz → 4.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/CHANGELOG.md +24 -0
  2. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/PKG-INFO +2 -1
  3. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/pyproject.toml +25 -3
  4. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/__init__.py +1 -1
  5. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/auth.py +27 -38
  6. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/authenticators.py +96 -4
  7. databricks_sql_connector-4.1.0/src/databricks/sql/auth/common.py +127 -0
  8. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/oauth.py +116 -23
  9. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/retry.py +9 -3
  10. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/thrift_http_client.py +25 -24
  11. databricks_sql_connector-4.1.0/src/databricks/sql/backend/databricks_client.py +347 -0
  12. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/backend.py +825 -0
  13. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/models/__init__.py +52 -0
  14. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/models/base.py +82 -0
  15. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/models/requests.py +133 -0
  16. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/models/responses.py +196 -0
  17. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/queue.py +391 -0
  18. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/result_set.py +266 -0
  19. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/utils/constants.py +67 -0
  20. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/utils/conversion.py +173 -0
  21. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/utils/filters.py +289 -0
  22. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/utils/http_client.py +294 -0
  23. databricks_sql_connector-4.1.0/src/databricks/sql/backend/sea/utils/normalize.py +50 -0
  24. {databricks_sql_connector-4.0.5/src/databricks/sql → databricks_sql_connector-4.1.0/src/databricks/sql/backend}/thrift_backend.py +395 -203
  25. databricks_sql_connector-4.1.0/src/databricks/sql/backend/types.py +427 -0
  26. databricks_sql_connector-4.1.0/src/databricks/sql/backend/utils/__init__.py +3 -0
  27. databricks_sql_connector-4.1.0/src/databricks/sql/backend/utils/guid_utils.py +23 -0
  28. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/client.py +281 -520
  29. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/cloudfetch/download_manager.py +43 -9
  30. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/cloudfetch/downloader.py +83 -58
  31. databricks_sql_connector-4.1.0/src/databricks/sql/common/feature_flag.py +187 -0
  32. databricks_sql_connector-4.1.0/src/databricks/sql/common/http.py +40 -0
  33. databricks_sql_connector-4.1.0/src/databricks/sql/common/http_utils.py +100 -0
  34. databricks_sql_connector-4.1.0/src/databricks/sql/common/unified_http_client.py +309 -0
  35. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/exc.py +13 -2
  36. databricks_sql_connector-4.1.0/src/databricks/sql/result_set.py +439 -0
  37. databricks_sql_connector-4.1.0/src/databricks/sql/session.py +195 -0
  38. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/latency_logger.py +216 -0
  39. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/models/endpoint_models.py +39 -0
  40. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/models/enums.py +44 -0
  41. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/models/event.py +162 -0
  42. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/models/frontend_logs.py +65 -0
  43. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/telemetry_client.py +561 -0
  44. databricks_sql_connector-4.1.0/src/databricks/sql/telemetry/utils.py +69 -0
  45. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/types.py +1 -0
  46. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/utils.py +231 -71
  47. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/LICENSE +0 -0
  48. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/README.md +0 -0
  49. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/__init__.py +0 -0
  50. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/__init__.py +0 -0
  51. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/endpoint.py +0 -0
  52. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/auth/oauth_http_handler.py +0 -0
  53. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/experimental/__init__.py +0 -0
  54. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/experimental/oauth_persistence.py +0 -0
  55. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/parameters/__init__.py +0 -0
  56. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/parameters/native.py +0 -0
  57. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/parameters/py.typed +0 -0
  58. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/py.typed +0 -0
  59. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote +0 -0
  60. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/thrift_api/TCLIService/TCLIService.py +0 -0
  61. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/thrift_api/TCLIService/__init__.py +0 -0
  62. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/thrift_api/TCLIService/constants.py +0 -0
  63. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/thrift_api/TCLIService/ttypes.py +0 -0
  64. {databricks_sql_connector-4.0.5 → databricks_sql_connector-4.1.0}/src/databricks/sql/thrift_api/__init__.py +0 -0
@@ -1,5 +1,29 @@
1
1
  # Release History
2
2
 
3
+ # 4.1.0 (2025-08-18)
4
+ - Removed Codeowners (databricks/databricks-sql-python#623 by @jprakash-db)
5
+ - Azure Service Principal Credential Provider (databricks/databricks-sql-python#621 by @jprakash-db)
6
+ - Add optional telemetry support to the python connector (databricks/databricks-sql-python#628 by @saishreeeee)
7
+ - Fix potential resource leak in `CloudFetchQueue` (databricks/databricks-sql-python#624 by @varun-edachali-dbx)
8
+ - Generalise Backend Layer (databricks/databricks-sql-python#604 by @varun-edachali-dbx)
9
+ - Arrow performance optimizations (databricks/databricks-sql-python#638 by @jprakash-db)
10
+ - Connection errors to unauthenticated telemetry endpoint (databricks/databricks-sql-python#619 by @saishreeeee)
11
+ - SEA: Execution Phase (databricks/databricks-sql-python#645 by @varun-edachali-dbx)
12
+ - Add retry mechanism to telemetry requests (databricks/databricks-sql-python#617 by @saishreeeee)
13
+ - SEA: Fetch Phase (databricks/databricks-sql-python#650 by @varun-edachali-dbx)
14
+ - added logs for cloud fetch speed (databricks/databricks-sql-python#654 by @shivam2680)
15
+ - Make telemetry batch size configurable and add time-based flush (databricks/databricks-sql-python#622 by @saishreeeee)
16
+ - Normalise type code (databricks/databricks-sql-python#652 by @varun-edachali-dbx)
17
+ - Testing for telemetry (databricks/databricks-sql-python#616 by @saishreeeee)
18
+ - Bug fixes in telemetry (databricks/databricks-sql-python#659 by @saishreeeee)
19
+ - Telemetry server-side flag integration (databricks/databricks-sql-python#646 by @saishreeeee)
20
+ - Enhance SEA HTTP Client (databricks/databricks-sql-python#618 by @varun-edachali-dbx)
21
+ - SEA: Allow large metadata responses (databricks/databricks-sql-python#653 by @varun-edachali-dbx)
22
+ - Added code coverage workflow to test the code coverage from unit and e2e tests (databricks/databricks-sql-python#657 by @msrathore-db)
23
+ - Concat tables to be backward compatible (databricks/databricks-sql-python#647 by @jprakash-db)
24
+ - Refactor codebase to use a unified http client (databricks/databricks-sql-python#673 by @vikrantpuppala)
25
+ - Add kerberos support for proxy auth (databricks/databricks-sql-python#675 by @vikrantpuppala)
26
+
3
27
  # 4.0.5 (2025-06-24)
4
28
  - Fix: Reverted change in cursor close handling which led to errors impacting users (databricks/databricks-sql-python#613 by @madhav-db)
5
29
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: databricks-sql-connector
3
- Version: 4.0.5
3
+ Version: 4.1.0
4
4
  Summary: Databricks SQL Connector for Python
5
5
  License: Apache-2.0
6
6
  Author: Databricks
@@ -22,6 +22,7 @@ Requires-Dist: pandas (>=1.2.5,<2.3.0) ; python_version >= "3.8" and python_vers
22
22
  Requires-Dist: pandas (>=2.2.3,<2.3.0) ; python_version >= "3.13"
23
23
  Requires-Dist: pyarrow (>=14.0.1) ; (python_version >= "3.8" and python_version < "3.13") and (extra == "pyarrow")
24
24
  Requires-Dist: pyarrow (>=18.0.0) ; (python_version >= "3.13") and (extra == "pyarrow")
25
+ Requires-Dist: pyjwt (>=2.0.0,<3.0.0)
25
26
  Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
26
27
  Requires-Dist: requests (>=2.18.1,<3.0.0)
27
28
  Requires-Dist: thrift (>=0.16.0,<0.21.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "databricks-sql-connector"
3
- version = "4.0.5"
3
+ version = "4.1.0"
4
4
  description = "Databricks SQL Connector for Python"
5
5
  authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
6
6
  license = "Apache-2.0"
@@ -20,21 +20,25 @@ requests = "^2.18.1"
20
20
  oauthlib = "^3.1.0"
21
21
  openpyxl = "^3.0.10"
22
22
  urllib3 = ">=1.26"
23
+ python-dateutil = "^2.8.0"
23
24
  pyarrow = [
24
25
  { version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
25
26
  { version = ">=18.0.0", python = ">=3.13", optional=true }
26
27
  ]
27
- python-dateutil = "^2.8.0"
28
+ pyjwt = "^2.0.0"
29
+ requests-kerberos = {version = "^0.15.0", optional = true}
30
+
28
31
 
29
32
  [tool.poetry.extras]
30
33
  pyarrow = ["pyarrow"]
31
34
 
32
- [tool.poetry.dev-dependencies]
35
+ [tool.poetry.group.dev.dependencies]
33
36
  pytest = "^7.1.2"
34
37
  mypy = "^1.10.1"
35
38
  pylint = ">=2.12.0"
36
39
  black = "^22.3.0"
37
40
  pytest-dotenv = "^0.5.2"
41
+ pytest-cov = "^4.0.0"
38
42
  numpy = [
39
43
  { version = ">=1.16.6", python = ">=3.8,<3.11" },
40
44
  { version = ">=1.23.4", python = ">=3.11" },
@@ -62,3 +66,21 @@ log_cli = "false"
62
66
  log_cli_level = "INFO"
63
67
  testpaths = ["tests"]
64
68
  env_files = ["test.env"]
69
+
70
+ [tool.coverage.run]
71
+ source = ["src"]
72
+ branch = true
73
+ omit = [
74
+ "*/tests/*",
75
+ "*/test_*",
76
+ "*/__pycache__/*",
77
+ "*/thrift_api/*",
78
+ ]
79
+
80
+ [tool.coverage.report]
81
+ precision = 2
82
+ show_missing = true
83
+ skip_covered = false
84
+
85
+ [tool.coverage.xml]
86
+ output = "coverage.xml"
@@ -68,7 +68,7 @@ DATETIME = DBAPITypeObject("timestamp")
68
68
  DATE = DBAPITypeObject("date")
69
69
  ROWID = DBAPITypeObject()
70
70
 
71
- __version__ = "4.0.5"
71
+ __version__ = "4.1.0"
72
72
  USER_AGENT_NAME = "PyDatabricksSqlConnector"
73
73
 
74
74
  # These two functions are pyhive legacy
@@ -1,4 +1,3 @@
1
- from enum import Enum
2
1
  from typing import Optional, List
3
2
 
4
3
  from databricks.sql.auth.authenticators import (
@@ -6,46 +5,26 @@ from databricks.sql.auth.authenticators import (
6
5
  AccessTokenAuthProvider,
7
6
  ExternalAuthProvider,
8
7
  DatabricksOAuthProvider,
8
+ AzureServicePrincipalCredentialProvider,
9
9
  )
10
+ from databricks.sql.auth.common import AuthType, ClientContext
10
11
 
11
12
 
12
- class AuthType(Enum):
13
- DATABRICKS_OAUTH = "databricks-oauth"
14
- AZURE_OAUTH = "azure-oauth"
15
- # other supported types (access_token) can be inferred
16
- # we can add more types as needed later
17
-
18
-
19
- class ClientContext:
20
- def __init__(
21
- self,
22
- hostname: str,
23
- access_token: Optional[str] = None,
24
- auth_type: Optional[str] = None,
25
- oauth_scopes: Optional[List[str]] = None,
26
- oauth_client_id: Optional[str] = None,
27
- oauth_redirect_port_range: Optional[List[int]] = None,
28
- use_cert_as_auth: Optional[str] = None,
29
- tls_client_cert_file: Optional[str] = None,
30
- oauth_persistence=None,
31
- credentials_provider=None,
32
- ):
33
- self.hostname = hostname
34
- self.access_token = access_token
35
- self.auth_type = auth_type
36
- self.oauth_scopes = oauth_scopes
37
- self.oauth_client_id = oauth_client_id
38
- self.oauth_redirect_port_range = oauth_redirect_port_range
39
- self.use_cert_as_auth = use_cert_as_auth
40
- self.tls_client_cert_file = tls_client_cert_file
41
- self.oauth_persistence = oauth_persistence
42
- self.credentials_provider = credentials_provider
43
-
44
-
45
- def get_auth_provider(cfg: ClientContext):
13
+ def get_auth_provider(cfg: ClientContext, http_client):
46
14
  if cfg.credentials_provider:
47
15
  return ExternalAuthProvider(cfg.credentials_provider)
48
- if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
16
+ elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
17
+ return ExternalAuthProvider(
18
+ AzureServicePrincipalCredentialProvider(
19
+ cfg.hostname,
20
+ cfg.azure_client_id,
21
+ cfg.azure_client_secret,
22
+ http_client,
23
+ cfg.azure_tenant_id,
24
+ cfg.azure_workspace_resource_id,
25
+ )
26
+ )
27
+ elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
49
28
  assert cfg.oauth_redirect_port_range is not None
50
29
  assert cfg.oauth_client_id is not None
51
30
  assert cfg.oauth_scopes is not None
@@ -56,6 +35,7 @@ def get_auth_provider(cfg: ClientContext):
56
35
  cfg.oauth_redirect_port_range,
57
36
  cfg.oauth_client_id,
58
37
  cfg.oauth_scopes,
38
+ http_client,
59
39
  cfg.auth_type,
60
40
  )
61
41
  elif cfg.access_token is not None:
@@ -75,6 +55,8 @@ def get_auth_provider(cfg: ClientContext):
75
55
  cfg.oauth_redirect_port_range,
76
56
  cfg.oauth_client_id,
77
57
  cfg.oauth_scopes,
58
+ http_client,
59
+ cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
78
60
  )
79
61
  else:
80
62
  raise RuntimeError("No valid authentication settings!")
@@ -101,11 +83,14 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):
101
83
  )
102
84
 
103
85
 
104
- def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
86
+ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
87
+ # TODO : unify all the auth mechanisms with the Python SDK
88
+
105
89
  auth_type = kwargs.get("auth_type")
106
90
  (client_id, redirect_port_range) = get_client_id_and_redirect_port(
107
91
  auth_type == AuthType.AZURE_OAUTH.value
108
92
  )
93
+
109
94
  if kwargs.get("username") or kwargs.get("password"):
110
95
  raise ValueError(
111
96
  "Username/password authentication is no longer supported. "
@@ -120,10 +105,14 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
120
105
  tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
121
106
  oauth_scopes=PYSQL_OAUTH_SCOPES,
122
107
  oauth_client_id=kwargs.get("oauth_client_id") or client_id,
108
+ azure_client_id=kwargs.get("azure_client_id"),
109
+ azure_client_secret=kwargs.get("azure_client_secret"),
110
+ azure_tenant_id=kwargs.get("azure_tenant_id"),
111
+ azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
123
112
  oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
124
113
  if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
125
114
  else redirect_port_range,
126
115
  oauth_persistence=kwargs.get("experimental_oauth_persistence"),
127
116
  credentials_provider=kwargs.get("credentials_provider"),
128
117
  )
129
- return get_auth_provider(cfg)
118
+ return get_auth_provider(cfg, http_client)
@@ -1,10 +1,18 @@
1
1
  import abc
2
- import base64
3
2
  import logging
4
3
  from typing import Callable, Dict, List
5
-
6
- from databricks.sql.auth.oauth import OAuthManager
7
- from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
4
+ from databricks.sql.common.http import HttpHeader
5
+ from databricks.sql.auth.oauth import (
6
+ OAuthManager,
7
+ RefreshableTokenSource,
8
+ ClientCredentialsTokenSource,
9
+ )
10
+ from databricks.sql.auth.endpoint import get_oauth_endpoints
11
+ from databricks.sql.auth.common import (
12
+ AuthType,
13
+ get_effective_azure_login_app_id,
14
+ get_azure_tenant_id_from_host,
15
+ )
8
16
 
9
17
  # Private API: this is an evolving interface and it will change in the future.
10
18
  # Please must not depend on it in your applications.
@@ -55,6 +63,7 @@ class DatabricksOAuthProvider(AuthProvider):
55
63
  redirect_port_range: List[int],
56
64
  client_id: str,
57
65
  scopes: List[str],
66
+ http_client,
58
67
  auth_type: str = "databricks-oauth",
59
68
  ):
60
69
  try:
@@ -71,6 +80,7 @@ class DatabricksOAuthProvider(AuthProvider):
71
80
  port_range=redirect_port_range,
72
81
  client_id=client_id,
73
82
  idp_endpoint=idp_endpoint,
83
+ http_client=http_client,
74
84
  )
75
85
  self._hostname = hostname
76
86
  self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes)
@@ -146,3 +156,85 @@ class ExternalAuthProvider(AuthProvider):
146
156
  headers = self._header_factory()
147
157
  for k, v in headers.items():
148
158
  request_headers[k] = v
159
+
160
+
161
+ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
162
+ """
163
+ A credential provider for Azure Service Principal authentication with Databricks.
164
+
165
+ This class implements the CredentialsProvider protocol to authenticate requests
166
+ to Databricks REST APIs using Azure Active Directory (AAD) service principal
167
+ credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
168
+ from Azure AD and automatically refreshes them when they expire.
169
+
170
+ Attributes:
171
+ hostname (str): The Databricks workspace hostname.
172
+ azure_client_id (str): The Azure service principal's client ID.
173
+ azure_client_secret (str): The Azure service principal's client secret.
174
+ azure_tenant_id (str): The Azure AD tenant ID.
175
+ azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
176
+ """
177
+
178
+ AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
179
+ AZURE_TOKEN_ENDPOINT = "oauth2/token"
180
+
181
+ AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
182
+
183
+ DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
184
+ DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
185
+ "X-Databricks-Azure-Workspace-Resource-Id"
186
+ )
187
+
188
+ def __init__(
189
+ self,
190
+ hostname,
191
+ azure_client_id,
192
+ azure_client_secret,
193
+ http_client,
194
+ azure_tenant_id=None,
195
+ azure_workspace_resource_id=None,
196
+ ):
197
+ self.hostname = hostname
198
+ self.azure_client_id = azure_client_id
199
+ self.azure_client_secret = azure_client_secret
200
+ self.azure_workspace_resource_id = azure_workspace_resource_id
201
+ self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
202
+ hostname, http_client
203
+ )
204
+ self._http_client = http_client
205
+
206
+ def auth_type(self) -> str:
207
+ return AuthType.AZURE_SP_M2M.value
208
+
209
+ def get_token_source(self, resource: str) -> RefreshableTokenSource:
210
+ return ClientCredentialsTokenSource(
211
+ token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
212
+ client_id=self.azure_client_id,
213
+ client_secret=self.azure_client_secret,
214
+ http_client=self._http_client,
215
+ extra_params={"resource": resource},
216
+ )
217
+
218
+ def __call__(self, *args, **kwargs) -> HeaderFactory:
219
+ inner = self.get_token_source(
220
+ resource=get_effective_azure_login_app_id(self.hostname)
221
+ )
222
+ cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
223
+
224
+ def header_factory() -> Dict[str, str]:
225
+ inner_token = inner.get_token()
226
+ cloud_token = cloud.get_token()
227
+
228
+ headers = {
229
+ HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
230
+ self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
231
+ }
232
+
233
+ if self.azure_workspace_resource_id:
234
+ headers[
235
+ self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
236
+ ] = self.azure_workspace_resource_id
237
+
238
+ return headers
239
+
240
+ return header_factory
@@ -0,0 +1,127 @@
1
+ from enum import Enum
2
+ import logging
3
+ from typing import Optional, List
4
+ from urllib.parse import urlparse
5
+ from databricks.sql.auth.retry import DatabricksRetryPolicy
6
+ from databricks.sql.common.http import HttpMethod
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class AuthType(Enum):
12
+ DATABRICKS_OAUTH = "databricks-oauth"
13
+ AZURE_OAUTH = "azure-oauth"
14
+ AZURE_SP_M2M = "azure-sp-m2m"
15
+
16
+
17
+ class AzureAppId(Enum):
18
+ DEV = (".dev.azuredatabricks.net", "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc")
19
+ STAGING = (".staging.azuredatabricks.net", "4a67d088-db5c-48f1-9ff2-0aace800ae68")
20
+ PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")
21
+
22
+
23
+ class ClientContext:
24
+ def __init__(
25
+ self,
26
+ hostname: str,
27
+ access_token: Optional[str] = None,
28
+ auth_type: Optional[str] = None,
29
+ oauth_scopes: Optional[List[str]] = None,
30
+ oauth_client_id: Optional[str] = None,
31
+ azure_client_id: Optional[str] = None,
32
+ azure_client_secret: Optional[str] = None,
33
+ azure_tenant_id: Optional[str] = None,
34
+ azure_workspace_resource_id: Optional[str] = None,
35
+ oauth_redirect_port_range: Optional[List[int]] = None,
36
+ use_cert_as_auth: Optional[str] = None,
37
+ tls_client_cert_file: Optional[str] = None,
38
+ oauth_persistence=None,
39
+ credentials_provider=None,
40
+ # HTTP client configuration parameters
41
+ ssl_options=None, # SSLOptions type
42
+ socket_timeout: Optional[float] = None,
43
+ retry_stop_after_attempts_count: Optional[int] = None,
44
+ retry_delay_min: Optional[float] = None,
45
+ retry_delay_max: Optional[float] = None,
46
+ retry_stop_after_attempts_duration: Optional[float] = None,
47
+ retry_delay_default: Optional[float] = None,
48
+ retry_dangerous_codes: Optional[List[int]] = None,
49
+ proxy_auth_method: Optional[str] = None,
50
+ pool_connections: Optional[int] = None,
51
+ pool_maxsize: Optional[int] = None,
52
+ user_agent: Optional[str] = None,
53
+ ):
54
+ self.hostname = hostname
55
+ self.access_token = access_token
56
+ self.auth_type = auth_type
57
+ self.oauth_scopes = oauth_scopes
58
+ self.oauth_client_id = oauth_client_id
59
+ self.azure_client_id = azure_client_id
60
+ self.azure_client_secret = azure_client_secret
61
+ self.azure_tenant_id = azure_tenant_id
62
+ self.azure_workspace_resource_id = azure_workspace_resource_id
63
+ self.oauth_redirect_port_range = oauth_redirect_port_range
64
+ self.use_cert_as_auth = use_cert_as_auth
65
+ self.tls_client_cert_file = tls_client_cert_file
66
+ self.oauth_persistence = oauth_persistence
67
+ self.credentials_provider = credentials_provider
68
+
69
+ # HTTP client configuration
70
+ self.ssl_options = ssl_options
71
+ self.socket_timeout = socket_timeout
72
+ self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
73
+ self.retry_delay_min = retry_delay_min or 1.0
74
+ self.retry_delay_max = retry_delay_max or 10.0
75
+ self.retry_stop_after_attempts_duration = (
76
+ retry_stop_after_attempts_duration or 300.0
77
+ )
78
+ self.retry_delay_default = retry_delay_default or 5.0
79
+ self.retry_dangerous_codes = retry_dangerous_codes or []
80
+ self.proxy_auth_method = proxy_auth_method
81
+ self.pool_connections = pool_connections or 10
82
+ self.pool_maxsize = pool_maxsize or 20
83
+ self.user_agent = user_agent
84
+
85
+
86
+ def get_effective_azure_login_app_id(hostname) -> str:
87
+ """
88
+ Get the effective Azure login app ID for a given hostname.
89
+ This function determines the appropriate Azure login app ID based on the hostname.
90
+ If the hostname does not match any of these domains, it returns the default Databricks resource ID.
91
+
92
+ """
93
+ for azure_app_id in AzureAppId:
94
+ domain, app_id = azure_app_id.value
95
+ if domain in hostname:
96
+ return app_id
97
+
98
+ # default databricks resource id
99
+ return AzureAppId.PROD.value[1]
100
+
101
+
102
+ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
103
+ """
104
+ Load the Azure tenant ID from the Azure Databricks login page.
105
+
106
+ This function retrieves the Azure tenant ID by making a request to the Databricks
107
+ Azure Active Directory (AAD) authentication endpoint. The endpoint redirects to
108
+ the Azure login page, and the tenant ID is extracted from the redirect URL.
109
+ """
110
+
111
+ login_url = f"{host}/aad/auth"
112
+ logger.debug("Loading tenant ID from %s", login_url)
113
+
114
+ with http_client.request_context(HttpMethod.GET, login_url) as resp:
115
+ entra_id_endpoint = resp.retries.history[-1].redirect_location
116
+ if entra_id_endpoint is None:
117
+ raise ValueError(
118
+ f"No Location header in response from {login_url}: {entra_id_endpoint}"
119
+ )
120
+
121
+ # The final redirect URL has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
122
+ # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
123
+ url = urlparse(entra_id_endpoint)
124
+ path_segments = url.path.split("/")
125
+ if len(path_segments) < 2:
126
+ raise ValueError(f"Invalid path in Location header: {url.path}")
127
+ return path_segments[1]