databricks-sdk 0.44.1__py3-none-any.whl → 0.46.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +135 -116
- databricks/sdk/_base_client.py +112 -88
- databricks/sdk/_property.py +12 -7
- databricks/sdk/_widgets/__init__.py +13 -2
- databricks/sdk/_widgets/default_widgets_utils.py +21 -15
- databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
- databricks/sdk/azure.py +8 -6
- databricks/sdk/casing.py +5 -5
- databricks/sdk/config.py +156 -99
- databricks/sdk/core.py +57 -47
- databricks/sdk/credentials_provider.py +306 -206
- databricks/sdk/data_plane.py +75 -50
- databricks/sdk/dbutils.py +123 -87
- databricks/sdk/environments.py +52 -35
- databricks/sdk/errors/base.py +61 -35
- databricks/sdk/errors/customizer.py +3 -3
- databricks/sdk/errors/deserializer.py +38 -25
- databricks/sdk/errors/details.py +417 -0
- databricks/sdk/errors/mapper.py +1 -1
- databricks/sdk/errors/overrides.py +27 -24
- databricks/sdk/errors/parser.py +26 -14
- databricks/sdk/errors/platform.py +10 -10
- databricks/sdk/errors/private_link.py +24 -24
- databricks/sdk/logger/round_trip_logger.py +28 -20
- databricks/sdk/mixins/compute.py +90 -60
- databricks/sdk/mixins/files.py +815 -145
- databricks/sdk/mixins/jobs.py +191 -16
- databricks/sdk/mixins/open_ai_client.py +26 -20
- databricks/sdk/mixins/workspace.py +45 -34
- databricks/sdk/oauth.py +379 -198
- databricks/sdk/retries.py +14 -12
- databricks/sdk/runtime/__init__.py +34 -17
- databricks/sdk/runtime/dbutils_stub.py +52 -39
- databricks/sdk/service/_internal.py +12 -7
- databricks/sdk/service/apps.py +618 -418
- databricks/sdk/service/billing.py +827 -604
- databricks/sdk/service/catalog.py +6552 -4474
- databricks/sdk/service/cleanrooms.py +550 -388
- databricks/sdk/service/compute.py +5263 -3536
- databricks/sdk/service/dashboards.py +1331 -924
- databricks/sdk/service/files.py +446 -309
- databricks/sdk/service/iam.py +2115 -1483
- databricks/sdk/service/jobs.py +4151 -2588
- databricks/sdk/service/marketplace.py +2210 -1517
- databricks/sdk/service/ml.py +3839 -2256
- databricks/sdk/service/oauth2.py +910 -584
- databricks/sdk/service/pipelines.py +1865 -1203
- databricks/sdk/service/provisioning.py +1435 -1029
- databricks/sdk/service/serving.py +2060 -1290
- databricks/sdk/service/settings.py +2846 -1929
- databricks/sdk/service/sharing.py +2201 -877
- databricks/sdk/service/sql.py +4650 -3103
- databricks/sdk/service/vectorsearch.py +816 -550
- databricks/sdk/service/workspace.py +1330 -906
- databricks/sdk/useragent.py +36 -22
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/METADATA +31 -31
- databricks_sdk-0.46.0.dist-info/RECORD +70 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/WHEEL +1 -1
- databricks_sdk-0.44.1.dist-info/RECORD +0 -69
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/top_level.txt +0 -0
|
@@ -26,13 +26,17 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
|
|
|
26
26
|
|
|
27
27
|
CredentialsProvider = Callable[[], Dict[str, str]]
|
|
28
28
|
|
|
29
|
-
logger = logging.getLogger(
|
|
29
|
+
logger = logging.getLogger("databricks.sdk")
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class OAuthCredentialsProvider:
|
|
33
|
-
"""
|
|
33
|
+
"""OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens."""
|
|
34
34
|
|
|
35
|
-
def __init__(
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
credentials_provider: CredentialsProvider,
|
|
38
|
+
token_provider: Callable[[], Token],
|
|
39
|
+
):
|
|
36
40
|
self._credentials_provider = credentials_provider
|
|
37
41
|
self._token_provider = token_provider
|
|
38
42
|
|
|
@@ -44,45 +48,49 @@ class OAuthCredentialsProvider:
|
|
|
44
48
|
|
|
45
49
|
|
|
46
50
|
class CredentialsStrategy(abc.ABC):
|
|
47
|
-
"""
|
|
48
|
-
|
|
51
|
+
"""CredentialsProvider is the protocol (call-side interface)
|
|
52
|
+
for authenticating requests to Databricks REST APIs"""
|
|
49
53
|
|
|
50
54
|
@abc.abstractmethod
|
|
51
|
-
def auth_type(self) -> str:
|
|
52
|
-
...
|
|
55
|
+
def auth_type(self) -> str: ...
|
|
53
56
|
|
|
54
57
|
@abc.abstractmethod
|
|
55
|
-
def __call__(self, cfg:
|
|
56
|
-
...
|
|
58
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider: ...
|
|
57
59
|
|
|
58
60
|
|
|
59
61
|
class OauthCredentialsStrategy(CredentialsStrategy):
|
|
60
|
-
"""
|
|
62
|
+
"""OauthCredentialsProvider is a CredentialsProvider which
|
|
61
63
|
supports Oauth tokens"""
|
|
62
64
|
|
|
63
|
-
def __init__(
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
auth_type: str,
|
|
68
|
+
headers_provider: Callable[["Config"], OAuthCredentialsProvider],
|
|
69
|
+
):
|
|
64
70
|
self._headers_provider = headers_provider
|
|
65
71
|
self._auth_type = auth_type
|
|
66
72
|
|
|
67
73
|
def auth_type(self) -> str:
|
|
68
74
|
return self._auth_type
|
|
69
75
|
|
|
70
|
-
def __call__(self, cfg:
|
|
76
|
+
def __call__(self, cfg: "Config") -> OAuthCredentialsProvider:
|
|
71
77
|
return self._headers_provider(cfg)
|
|
72
78
|
|
|
73
|
-
def oauth_token(self, cfg:
|
|
79
|
+
def oauth_token(self, cfg: "Config") -> Token:
|
|
74
80
|
return self._headers_provider(cfg).oauth_token()
|
|
75
81
|
|
|
76
82
|
|
|
77
83
|
def credentials_strategy(name: str, require: List[str]):
|
|
78
|
-
"""
|
|
84
|
+
"""Given the function that receives a Config and returns RequestVisitor,
|
|
79
85
|
create CredentialsProvider with a given name and required configuration
|
|
80
|
-
attribute names to be present for this function to be called.
|
|
86
|
+
attribute names to be present for this function to be called."""
|
|
81
87
|
|
|
82
|
-
def inner(
|
|
88
|
+
def inner(
|
|
89
|
+
func: Callable[["Config"], CredentialsProvider],
|
|
90
|
+
) -> CredentialsStrategy:
|
|
83
91
|
|
|
84
92
|
@functools.wraps(func)
|
|
85
|
-
def wrapper(cfg:
|
|
93
|
+
def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
86
94
|
for attr in require:
|
|
87
95
|
getattr(cfg, attr)
|
|
88
96
|
if not getattr(cfg, attr):
|
|
@@ -96,14 +104,16 @@ def credentials_strategy(name: str, require: List[str]):
|
|
|
96
104
|
|
|
97
105
|
|
|
98
106
|
def oauth_credentials_strategy(name: str, require: List[str]):
|
|
99
|
-
"""
|
|
107
|
+
"""Given the function that receives a Config and returns an OauthHeaderFactory,
|
|
100
108
|
create an OauthCredentialsProvider with a given name and required configuration
|
|
101
|
-
attribute names to be present for this function to be called.
|
|
109
|
+
attribute names to be present for this function to be called."""
|
|
102
110
|
|
|
103
|
-
def inner(
|
|
111
|
+
def inner(
|
|
112
|
+
func: Callable[["Config"], OAuthCredentialsProvider],
|
|
113
|
+
) -> OauthCredentialsStrategy:
|
|
104
114
|
|
|
105
115
|
@functools.wraps(func)
|
|
106
|
-
def wrapper(cfg:
|
|
116
|
+
def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
|
|
107
117
|
for attr in require:
|
|
108
118
|
if not getattr(cfg, attr):
|
|
109
119
|
return None
|
|
@@ -114,11 +124,11 @@ def oauth_credentials_strategy(name: str, require: List[str]):
|
|
|
114
124
|
return inner
|
|
115
125
|
|
|
116
126
|
|
|
117
|
-
@credentials_strategy(
|
|
118
|
-
def basic_auth(cfg:
|
|
119
|
-
"""
|
|
120
|
-
encoded = base64.b64encode(f
|
|
121
|
-
static_credentials = {
|
|
127
|
+
@credentials_strategy("basic", ["host", "username", "password"])
|
|
128
|
+
def basic_auth(cfg: "Config") -> CredentialsProvider:
|
|
129
|
+
"""Given username and password, add base64-encoded Basic credentials"""
|
|
130
|
+
encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode()
|
|
131
|
+
static_credentials = {"Authorization": f"Basic {encoded}"}
|
|
122
132
|
|
|
123
133
|
def inner() -> Dict[str, str]:
|
|
124
134
|
return static_credentials
|
|
@@ -126,10 +136,10 @@ def basic_auth(cfg: 'Config') -> CredentialsProvider:
|
|
|
126
136
|
return inner
|
|
127
137
|
|
|
128
138
|
|
|
129
|
-
@credentials_strategy(
|
|
130
|
-
def pat_auth(cfg:
|
|
131
|
-
"""
|
|
132
|
-
static_credentials = {
|
|
139
|
+
@credentials_strategy("pat", ["host", "token"])
|
|
140
|
+
def pat_auth(cfg: "Config") -> CredentialsProvider:
|
|
141
|
+
"""Adds Databricks Personal Access Token to every request"""
|
|
142
|
+
static_credentials = {"Authorization": f"Bearer {cfg.token}"}
|
|
133
143
|
|
|
134
144
|
def inner() -> Dict[str, str]:
|
|
135
145
|
return static_credentials
|
|
@@ -137,9 +147,9 @@ def pat_auth(cfg: 'Config') -> CredentialsProvider:
|
|
|
137
147
|
return inner
|
|
138
148
|
|
|
139
149
|
|
|
140
|
-
@credentials_strategy(
|
|
141
|
-
def runtime_native_auth(cfg:
|
|
142
|
-
if
|
|
150
|
+
@credentials_strategy("runtime", [])
|
|
151
|
+
def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
152
|
+
if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
|
|
143
153
|
return None
|
|
144
154
|
|
|
145
155
|
# This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
|
|
@@ -148,36 +158,45 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
148
158
|
from databricks.sdk.runtime import (init_runtime_legacy_auth,
|
|
149
159
|
init_runtime_native_auth,
|
|
150
160
|
init_runtime_repl_auth)
|
|
151
|
-
|
|
161
|
+
|
|
162
|
+
for init in [
|
|
163
|
+
init_runtime_native_auth,
|
|
164
|
+
init_runtime_repl_auth,
|
|
165
|
+
init_runtime_legacy_auth,
|
|
166
|
+
]:
|
|
152
167
|
if init is None:
|
|
153
168
|
continue
|
|
154
169
|
host, inner = init()
|
|
155
170
|
if host is None:
|
|
156
|
-
logger.debug(f
|
|
171
|
+
logger.debug(f"[{init.__name__}] no host detected")
|
|
157
172
|
continue
|
|
158
173
|
cfg.host = host
|
|
159
|
-
logger.debug(f
|
|
174
|
+
logger.debug(f"[{init.__name__}] runtime native auth configured")
|
|
160
175
|
return inner
|
|
161
176
|
return None
|
|
162
177
|
|
|
163
178
|
|
|
164
|
-
@oauth_credentials_strategy(
|
|
165
|
-
def oauth_service_principal(cfg:
|
|
166
|
-
"""
|
|
167
|
-
if /oidc/.well-known/oauth-authorization-server is available on the given host.
|
|
179
|
+
@oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
|
|
180
|
+
def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
181
|
+
"""Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
|
|
182
|
+
if /oidc/.well-known/oauth-authorization-server is available on the given host.
|
|
183
|
+
"""
|
|
168
184
|
oidc = cfg.oidc_endpoints
|
|
169
185
|
if oidc is None:
|
|
170
186
|
return None
|
|
171
187
|
|
|
172
|
-
token_source = ClientCredentials(
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
188
|
+
token_source = ClientCredentials(
|
|
189
|
+
client_id=cfg.client_id,
|
|
190
|
+
client_secret=cfg.client_secret,
|
|
191
|
+
token_url=oidc.token_endpoint,
|
|
192
|
+
scopes=["all-apis"],
|
|
193
|
+
use_header=True,
|
|
194
|
+
disable_async=not cfg.enable_experimental_async_token_refresh,
|
|
195
|
+
)
|
|
177
196
|
|
|
178
197
|
def inner() -> Dict[str, str]:
|
|
179
198
|
token = token_source.token()
|
|
180
|
-
return {
|
|
199
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
181
200
|
|
|
182
201
|
def token() -> Token:
|
|
183
202
|
return token_source.token()
|
|
@@ -185,9 +204,9 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
185
204
|
return OAuthCredentialsProvider(inner, token)
|
|
186
205
|
|
|
187
206
|
|
|
188
|
-
@credentials_strategy(
|
|
189
|
-
def external_browser(cfg:
|
|
190
|
-
if cfg.auth_type !=
|
|
207
|
+
@credentials_strategy("external-browser", ["host", "auth_type"])
|
|
208
|
+
def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
209
|
+
if cfg.auth_type != "external-browser":
|
|
191
210
|
return None
|
|
192
211
|
|
|
193
212
|
client_id, client_secret = None, None
|
|
@@ -198,17 +217,19 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
198
217
|
client_id = cfg.azure_client
|
|
199
218
|
client_secret = cfg.azure_client_secret
|
|
200
219
|
if not client_id:
|
|
201
|
-
client_id =
|
|
220
|
+
client_id = "databricks-cli"
|
|
202
221
|
|
|
203
222
|
# Load cached credentials from disk if they exist. Note that these are
|
|
204
223
|
# local to the Python SDK and not reused by other SDKs.
|
|
205
224
|
oidc_endpoints = cfg.oidc_endpoints
|
|
206
|
-
redirect_url =
|
|
207
|
-
token_cache = TokenCache(
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
225
|
+
redirect_url = "http://localhost:8020"
|
|
226
|
+
token_cache = TokenCache(
|
|
227
|
+
host=cfg.host,
|
|
228
|
+
oidc_endpoints=oidc_endpoints,
|
|
229
|
+
client_id=client_id,
|
|
230
|
+
client_secret=client_secret,
|
|
231
|
+
redirect_url=redirect_url,
|
|
232
|
+
)
|
|
212
233
|
credentials = token_cache.load()
|
|
213
234
|
if credentials:
|
|
214
235
|
try:
|
|
@@ -219,12 +240,14 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
219
240
|
return credentials(cfg)
|
|
220
241
|
# TODO: We should ideally use more specific exceptions.
|
|
221
242
|
except Exception as e:
|
|
222
|
-
logger.warning(f
|
|
223
|
-
|
|
224
|
-
oauth_client = OAuthClient(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
243
|
+
logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow")
|
|
244
|
+
|
|
245
|
+
oauth_client = OAuthClient(
|
|
246
|
+
oidc_endpoints=oidc_endpoints,
|
|
247
|
+
client_id=client_id,
|
|
248
|
+
redirect_url=redirect_url,
|
|
249
|
+
client_secret=client_secret,
|
|
250
|
+
)
|
|
228
251
|
consent = oauth_client.initiate_consent()
|
|
229
252
|
if not consent:
|
|
230
253
|
return None
|
|
@@ -234,33 +257,42 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
234
257
|
return credentials(cfg)
|
|
235
258
|
|
|
236
259
|
|
|
237
|
-
def _ensure_host_present(cfg:
|
|
238
|
-
"""
|
|
260
|
+
def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]):
|
|
261
|
+
"""Resolves Azure Databricks workspace URL from ARM Resource ID"""
|
|
239
262
|
if cfg.host:
|
|
240
263
|
return
|
|
241
264
|
if not cfg.azure_workspace_resource_id:
|
|
242
265
|
return
|
|
243
266
|
arm = cfg.arm_environment.resource_manager_endpoint
|
|
244
267
|
token = token_source_for(arm).token()
|
|
245
|
-
resp = requests.get(
|
|
246
|
-
|
|
268
|
+
resp = requests.get(
|
|
269
|
+
f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
|
|
270
|
+
headers={"Authorization": f"Bearer {token.access_token}"},
|
|
271
|
+
)
|
|
247
272
|
if not resp.ok:
|
|
248
273
|
raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
|
|
249
274
|
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
|
|
250
275
|
|
|
251
276
|
|
|
252
|
-
@oauth_credentials_strategy(
|
|
253
|
-
|
|
254
|
-
"""
|
|
255
|
-
|
|
277
|
+
@oauth_credentials_strategy(
|
|
278
|
+
"azure-client-secret",
|
|
279
|
+
["is_azure", "azure_client_id", "azure_client_secret"],
|
|
280
|
+
)
|
|
281
|
+
def azure_service_principal(cfg: "Config") -> CredentialsProvider:
|
|
282
|
+
"""Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
|
|
283
|
+
to every request, while automatically resolving different Azure environment endpoints.
|
|
284
|
+
"""
|
|
256
285
|
|
|
257
286
|
def token_source_for(resource: str) -> TokenSource:
|
|
258
287
|
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
259
|
-
return ClientCredentials(
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
288
|
+
return ClientCredentials(
|
|
289
|
+
client_id=cfg.azure_client_id,
|
|
290
|
+
client_secret=cfg.azure_client_secret,
|
|
291
|
+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
292
|
+
endpoint_params={"resource": resource},
|
|
293
|
+
use_params=True,
|
|
294
|
+
disable_async=not cfg.enable_experimental_async_token_refresh,
|
|
295
|
+
)
|
|
264
296
|
|
|
265
297
|
_ensure_host_present(cfg, token_source_for)
|
|
266
298
|
cfg.load_azure_tenant_id()
|
|
@@ -269,7 +301,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
|
269
301
|
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
|
|
270
302
|
|
|
271
303
|
def refreshed_headers() -> Dict[str, str]:
|
|
272
|
-
headers = {
|
|
304
|
+
headers = {
|
|
305
|
+
"Authorization": f"Bearer {inner.token().access_token}",
|
|
306
|
+
}
|
|
273
307
|
add_workspace_id_header(cfg, headers)
|
|
274
308
|
add_sp_management_token(cloud, headers)
|
|
275
309
|
return headers
|
|
@@ -280,9 +314,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
|
280
314
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
281
315
|
|
|
282
316
|
|
|
283
|
-
@oauth_credentials_strategy(
|
|
284
|
-
def github_oidc_azure(cfg:
|
|
285
|
-
if
|
|
317
|
+
@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
|
|
318
|
+
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
319
|
+
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
|
|
286
320
|
# not in GitHub actions
|
|
287
321
|
return None
|
|
288
322
|
|
|
@@ -292,7 +326,7 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
292
326
|
return None
|
|
293
327
|
|
|
294
328
|
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
|
|
295
|
-
headers = {
|
|
329
|
+
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
|
|
296
330
|
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
|
|
297
331
|
response = requests.get(endpoint, headers=headers)
|
|
298
332
|
if not response.ok:
|
|
@@ -300,30 +334,35 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
300
334
|
|
|
301
335
|
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
|
|
302
336
|
response_json = response.json()
|
|
303
|
-
if
|
|
337
|
+
if "value" not in response_json:
|
|
304
338
|
return None
|
|
305
339
|
|
|
306
|
-
logger.info(
|
|
340
|
+
logger.info(
|
|
341
|
+
"Configured AAD token for GitHub Actions OIDC (%s)",
|
|
342
|
+
cfg.azure_client_id,
|
|
343
|
+
)
|
|
307
344
|
params = {
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
345
|
+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
|
346
|
+
"resource": cfg.effective_azure_login_app_id,
|
|
347
|
+
"client_assertion": response_json["value"],
|
|
311
348
|
}
|
|
312
349
|
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
313
350
|
if not cfg.azure_tenant_id:
|
|
314
351
|
# detect Azure AD Tenant ID if it's not specified directly
|
|
315
352
|
token_endpoint = cfg.oidc_endpoints.token_endpoint
|
|
316
|
-
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint,
|
|
353
|
+
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
|
|
317
354
|
inner = ClientCredentials(
|
|
318
355
|
client_id=cfg.azure_client_id,
|
|
319
|
-
client_secret="",
|
|
356
|
+
client_secret="", # we have no (rotatable) secrets in OIDC flow
|
|
320
357
|
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
321
358
|
endpoint_params=params,
|
|
322
|
-
use_params=True
|
|
359
|
+
use_params=True,
|
|
360
|
+
disable_async=not cfg.enable_experimental_async_token_refresh,
|
|
361
|
+
)
|
|
323
362
|
|
|
324
363
|
def refreshed_headers() -> Dict[str, str]:
|
|
325
364
|
token = inner.token()
|
|
326
|
-
return {
|
|
365
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
327
366
|
|
|
328
367
|
def token() -> Token:
|
|
329
368
|
return inner.token()
|
|
@@ -331,29 +370,32 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
331
370
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
332
371
|
|
|
333
372
|
|
|
334
|
-
GcpScopes = [
|
|
373
|
+
GcpScopes = [
|
|
374
|
+
"https://www.googleapis.com/auth/cloud-platform",
|
|
375
|
+
"https://www.googleapis.com/auth/compute",
|
|
376
|
+
]
|
|
335
377
|
|
|
336
378
|
|
|
337
|
-
@oauth_credentials_strategy(
|
|
338
|
-
def google_credentials(cfg:
|
|
379
|
+
@oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
|
|
380
|
+
def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
339
381
|
if not cfg.is_gcp:
|
|
340
382
|
return None
|
|
341
383
|
# Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
|
|
342
384
|
# Obtain the id token by providing the json file path and target audience.
|
|
343
|
-
if
|
|
385
|
+
if os.path.isfile(cfg.google_credentials):
|
|
344
386
|
with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
|
|
345
387
|
account_info = json.load(json_file)
|
|
346
388
|
else:
|
|
347
389
|
# If the file doesn't exist, assume that the config is the actual JSON content.
|
|
348
390
|
account_info = json.loads(cfg.google_credentials)
|
|
349
391
|
|
|
350
|
-
credentials = service_account.IDTokenCredentials.from_service_account_info(
|
|
351
|
-
|
|
392
|
+
credentials = service_account.IDTokenCredentials.from_service_account_info(
|
|
393
|
+
info=account_info, target_audience=cfg.host
|
|
394
|
+
)
|
|
352
395
|
|
|
353
396
|
request = Request()
|
|
354
397
|
|
|
355
|
-
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
|
|
356
|
-
scopes=GcpScopes)
|
|
398
|
+
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes)
|
|
357
399
|
|
|
358
400
|
def token() -> Token:
|
|
359
401
|
credentials.refresh(request)
|
|
@@ -361,7 +403,7 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
361
403
|
|
|
362
404
|
def refreshed_headers() -> Dict[str, str]:
|
|
363
405
|
credentials.refresh(request)
|
|
364
|
-
headers = {
|
|
406
|
+
headers = {"Authorization": f"Bearer {credentials.token}"}
|
|
365
407
|
if cfg.is_account_client:
|
|
366
408
|
gcp_credentials.refresh(request)
|
|
367
409
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
|
|
@@ -370,24 +412,29 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
370
412
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
371
413
|
|
|
372
414
|
|
|
373
|
-
@oauth_credentials_strategy(
|
|
374
|
-
def google_id(cfg:
|
|
415
|
+
@oauth_credentials_strategy("google-id", ["host", "google_service_account"])
|
|
416
|
+
def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
375
417
|
if not cfg.is_gcp:
|
|
376
418
|
return None
|
|
377
419
|
credentials, _project_id = google.auth.default()
|
|
378
420
|
|
|
379
421
|
# Create the impersonated credential.
|
|
380
|
-
target_credentials = impersonated_credentials.Credentials(
|
|
381
|
-
|
|
382
|
-
|
|
422
|
+
target_credentials = impersonated_credentials.Credentials(
|
|
423
|
+
source_credentials=credentials,
|
|
424
|
+
target_principal=cfg.google_service_account,
|
|
425
|
+
target_scopes=[],
|
|
426
|
+
)
|
|
383
427
|
|
|
384
428
|
# Set the impersonated credential, target audience and token options.
|
|
385
|
-
id_creds = impersonated_credentials.IDTokenCredentials(
|
|
386
|
-
|
|
387
|
-
|
|
429
|
+
id_creds = impersonated_credentials.IDTokenCredentials(
|
|
430
|
+
target_credentials, target_audience=cfg.host, include_email=True
|
|
431
|
+
)
|
|
388
432
|
|
|
389
433
|
gcp_impersonated_credentials = impersonated_credentials.Credentials(
|
|
390
|
-
source_credentials=credentials,
|
|
434
|
+
source_credentials=credentials,
|
|
435
|
+
target_principal=cfg.google_service_account,
|
|
436
|
+
target_scopes=GcpScopes,
|
|
437
|
+
)
|
|
391
438
|
|
|
392
439
|
request = Request()
|
|
393
440
|
|
|
@@ -397,7 +444,7 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
397
444
|
|
|
398
445
|
def refreshed_headers() -> Dict[str, str]:
|
|
399
446
|
id_creds.refresh(request)
|
|
400
|
-
headers = {
|
|
447
|
+
headers = {"Authorization": f"Bearer {id_creds.token}"}
|
|
401
448
|
if cfg.is_account_client:
|
|
402
449
|
gcp_impersonated_credentials.refresh(request)
|
|
403
450
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
|
|
@@ -408,8 +455,15 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
408
455
|
|
|
409
456
|
class CliTokenSource(Refreshable):
|
|
410
457
|
|
|
411
|
-
def __init__(
|
|
412
|
-
|
|
458
|
+
def __init__(
|
|
459
|
+
self,
|
|
460
|
+
cmd: List[str],
|
|
461
|
+
token_type_field: str,
|
|
462
|
+
access_token_field: str,
|
|
463
|
+
expiry_field: str,
|
|
464
|
+
disable_async: bool = True,
|
|
465
|
+
):
|
|
466
|
+
super().__init__(disable_async=disable_async)
|
|
413
467
|
self._cmd = cmd
|
|
414
468
|
self._token_type_field = token_type_field
|
|
415
469
|
self._access_token_field = access_token_field
|
|
@@ -431,52 +485,74 @@ class CliTokenSource(Refreshable):
|
|
|
431
485
|
out = _run_subprocess(self._cmd, capture_output=True, check=True)
|
|
432
486
|
it = json.loads(out.stdout.decode())
|
|
433
487
|
expires_on = self._parse_expiry(it[self._expiry_field])
|
|
434
|
-
return Token(
|
|
435
|
-
|
|
436
|
-
|
|
488
|
+
return Token(
|
|
489
|
+
access_token=it[self._access_token_field],
|
|
490
|
+
token_type=it[self._token_type_field],
|
|
491
|
+
expiry=expires_on,
|
|
492
|
+
)
|
|
437
493
|
except ValueError as e:
|
|
438
494
|
raise ValueError(f"cannot unmarshal CLI result: {e}")
|
|
439
495
|
except subprocess.CalledProcessError as e:
|
|
440
496
|
stdout = e.stdout.decode().strip()
|
|
441
497
|
stderr = e.stderr.decode().strip()
|
|
442
498
|
message = stdout or stderr
|
|
443
|
-
raise IOError(f
|
|
499
|
+
raise IOError(f"cannot get access token: {message}") from e
|
|
444
500
|
|
|
445
501
|
|
|
446
|
-
def _run_subprocess(
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
502
|
+
def _run_subprocess(
|
|
503
|
+
popenargs,
|
|
504
|
+
input=None,
|
|
505
|
+
capture_output=True,
|
|
506
|
+
timeout=None,
|
|
507
|
+
check=False,
|
|
508
|
+
**kwargs,
|
|
509
|
+
) -> subprocess.CompletedProcess:
|
|
452
510
|
"""Runs subprocess with given arguments.
|
|
453
|
-
This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
|
|
454
|
-
|
|
511
|
+
This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
|
|
512
|
+
"""
|
|
513
|
+
kwargs["shell"] = sys.platform.startswith("win")
|
|
455
514
|
# windows requires shell=True to be able to execute 'az login' or other commands
|
|
456
515
|
# cannot use shell=True all the time, as it breaks macOS
|
|
457
516
|
logging.debug(f'Running command: {" ".join(popenargs)}')
|
|
458
|
-
return subprocess.run(
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
517
|
+
return subprocess.run(
|
|
518
|
+
popenargs,
|
|
519
|
+
input=input,
|
|
520
|
+
capture_output=capture_output,
|
|
521
|
+
timeout=timeout,
|
|
522
|
+
check=check,
|
|
523
|
+
**kwargs,
|
|
524
|
+
)
|
|
464
525
|
|
|
465
526
|
|
|
466
527
|
class AzureCliTokenSource(CliTokenSource):
|
|
467
|
-
"""
|
|
468
|
-
|
|
469
|
-
def __init__(
|
|
470
|
-
|
|
528
|
+
"""Obtain the token granted by `az login` CLI command"""
|
|
529
|
+
|
|
530
|
+
def __init__(
|
|
531
|
+
self,
|
|
532
|
+
resource: str,
|
|
533
|
+
subscription: Optional[str] = None,
|
|
534
|
+
tenant: Optional[str] = None,
|
|
535
|
+
):
|
|
536
|
+
cmd = [
|
|
537
|
+
"az",
|
|
538
|
+
"account",
|
|
539
|
+
"get-access-token",
|
|
540
|
+
"--resource",
|
|
541
|
+
resource,
|
|
542
|
+
"--output",
|
|
543
|
+
"json",
|
|
544
|
+
]
|
|
471
545
|
if subscription is not None:
|
|
472
546
|
cmd.append("--subscription")
|
|
473
547
|
cmd.append(subscription)
|
|
474
548
|
if tenant and not self.__is_cli_using_managed_identity():
|
|
475
549
|
cmd.extend(["--tenant", tenant])
|
|
476
|
-
super().__init__(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
550
|
+
super().__init__(
|
|
551
|
+
cmd=cmd,
|
|
552
|
+
token_type_field="tokenType",
|
|
553
|
+
access_token_field="accessToken",
|
|
554
|
+
expiry_field="expiresOn",
|
|
555
|
+
)
|
|
480
556
|
|
|
481
557
|
@staticmethod
|
|
482
558
|
def __is_cli_using_managed_identity() -> bool:
|
|
@@ -489,7 +565,8 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
489
565
|
if user is None:
|
|
490
566
|
return False
|
|
491
567
|
return user.get("type") == "servicePrincipal" and user.get("name") in [
|
|
492
|
-
|
|
568
|
+
"systemAssignedIdentity",
|
|
569
|
+
"userAssignedIdentity",
|
|
493
570
|
]
|
|
494
571
|
except subprocess.CalledProcessError as e:
|
|
495
572
|
logger.debug("Failed to get account information from Azure CLI", exc_info=e)
|
|
@@ -512,15 +589,13 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
512
589
|
guaranteed to be unique within a tenant and should be used only for display purposes.
|
|
513
590
|
- 'upn' - The username of the user.
|
|
514
591
|
"""
|
|
515
|
-
return
|
|
592
|
+
return "upn" in self.token().jwt_claims()
|
|
516
593
|
|
|
517
594
|
@staticmethod
|
|
518
|
-
def for_resource(cfg:
|
|
595
|
+
def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource":
|
|
519
596
|
subscription = AzureCliTokenSource.get_subscription(cfg)
|
|
520
597
|
if subscription is not None:
|
|
521
|
-
token_source = AzureCliTokenSource(resource,
|
|
522
|
-
subscription=subscription,
|
|
523
|
-
tenant=cfg.azure_tenant_id)
|
|
598
|
+
token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
|
|
524
599
|
try:
|
|
525
600
|
# This will fail if the user has access to the workspace, but not to the subscription
|
|
526
601
|
# itself.
|
|
@@ -535,32 +610,32 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
535
610
|
return token_source
|
|
536
611
|
|
|
537
612
|
@staticmethod
|
|
538
|
-
def get_subscription(cfg:
|
|
613
|
+
def get_subscription(cfg: "Config") -> Optional[str]:
|
|
539
614
|
resource = cfg.azure_workspace_resource_id
|
|
540
615
|
if resource is None or resource == "":
|
|
541
616
|
return None
|
|
542
|
-
components = resource.split(
|
|
617
|
+
components = resource.split("/")
|
|
543
618
|
if len(components) < 3:
|
|
544
619
|
logger.warning("Invalid azure workspace resource ID")
|
|
545
620
|
return None
|
|
546
621
|
return components[2]
|
|
547
622
|
|
|
548
623
|
|
|
549
|
-
@credentials_strategy(
|
|
550
|
-
def azure_cli(cfg:
|
|
551
|
-
"""
|
|
624
|
+
@credentials_strategy("azure-cli", ["is_azure"])
|
|
625
|
+
def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
626
|
+
"""Adds refreshed OAuth token granted by `az login` command to every request."""
|
|
552
627
|
cfg.load_azure_tenant_id()
|
|
553
628
|
token_source = None
|
|
554
629
|
mgmt_token_source = None
|
|
555
630
|
try:
|
|
556
631
|
token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
|
|
557
632
|
except FileNotFoundError:
|
|
558
|
-
doc =
|
|
559
|
-
logger.debug(f
|
|
633
|
+
doc = "https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest"
|
|
634
|
+
logger.debug(f"Most likely Azure CLI is not installed. See {doc} for details")
|
|
560
635
|
return None
|
|
561
636
|
except OSError as e:
|
|
562
|
-
logger.debug(
|
|
563
|
-
logger.debug(
|
|
637
|
+
logger.debug("skipping Azure CLI auth", exc_info=e)
|
|
638
|
+
logger.debug("This may happen if you are attempting to login to a dev or staging workspace")
|
|
564
639
|
return None
|
|
565
640
|
|
|
566
641
|
if not token_source.is_human_user():
|
|
@@ -568,7 +643,10 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
568
643
|
management_endpoint = cfg.arm_environment.service_management_endpoint
|
|
569
644
|
mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
|
|
570
645
|
except Exception as e:
|
|
571
|
-
logger.debug(
|
|
646
|
+
logger.debug(
|
|
647
|
+
f"Not including service management token in headers",
|
|
648
|
+
exc_info=e,
|
|
649
|
+
)
|
|
572
650
|
mgmt_token_source = None
|
|
573
651
|
|
|
574
652
|
_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
|
|
@@ -576,7 +654,7 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
576
654
|
|
|
577
655
|
def inner() -> Dict[str, str]:
|
|
578
656
|
token = token_source.token()
|
|
579
|
-
headers = {
|
|
657
|
+
headers = {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
580
658
|
add_workspace_id_header(cfg, headers)
|
|
581
659
|
if mgmt_token_source:
|
|
582
660
|
add_sp_management_token(mgmt_token_source, headers)
|
|
@@ -586,12 +664,12 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
586
664
|
|
|
587
665
|
|
|
588
666
|
class DatabricksCliTokenSource(CliTokenSource):
|
|
589
|
-
"""
|
|
667
|
+
"""Obtain the token granted by `databricks auth login` CLI command"""
|
|
590
668
|
|
|
591
|
-
def __init__(self, cfg:
|
|
592
|
-
args = [
|
|
669
|
+
def __init__(self, cfg: "Config"):
|
|
670
|
+
args = ["auth", "token", "--host", cfg.host]
|
|
593
671
|
if cfg.is_account_client:
|
|
594
|
-
args += [
|
|
672
|
+
args += ["--account-id", cfg.account_id]
|
|
595
673
|
|
|
596
674
|
cli_path = cfg.databricks_cli_path
|
|
597
675
|
|
|
@@ -611,10 +689,13 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
611
689
|
elif cli_path.count("/") == 0:
|
|
612
690
|
cli_path = self.__class__._find_executable(cli_path)
|
|
613
691
|
|
|
614
|
-
super().__init__(
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
692
|
+
super().__init__(
|
|
693
|
+
cmd=[cli_path, *args],
|
|
694
|
+
token_type_field="token_type",
|
|
695
|
+
access_token_field="access_token",
|
|
696
|
+
expiry_field="expiry",
|
|
697
|
+
disable_async=not cfg.enable_experimental_async_token_refresh,
|
|
698
|
+
)
|
|
618
699
|
|
|
619
700
|
@staticmethod
|
|
620
701
|
def _find_executable(name) -> str:
|
|
@@ -636,8 +717,8 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
636
717
|
raise err
|
|
637
718
|
|
|
638
719
|
|
|
639
|
-
@oauth_credentials_strategy(
|
|
640
|
-
def databricks_cli(cfg:
|
|
720
|
+
@oauth_credentials_strategy("databricks-cli", ["host"])
|
|
721
|
+
def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
641
722
|
try:
|
|
642
723
|
token_source = DatabricksCliTokenSource(cfg)
|
|
643
724
|
except FileNotFoundError as e:
|
|
@@ -647,8 +728,8 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
647
728
|
try:
|
|
648
729
|
token_source.token()
|
|
649
730
|
except IOError as e:
|
|
650
|
-
if
|
|
651
|
-
logger.debug(f
|
|
731
|
+
if "databricks OAuth is not" in str(e):
|
|
732
|
+
logger.debug(f"OAuth not configured or not available: {e}")
|
|
652
733
|
return None
|
|
653
734
|
raise e
|
|
654
735
|
|
|
@@ -656,7 +737,7 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
656
737
|
|
|
657
738
|
def inner() -> Dict[str, str]:
|
|
658
739
|
token = token_source.token()
|
|
659
|
-
return {
|
|
740
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
660
741
|
|
|
661
742
|
def token() -> Token:
|
|
662
743
|
return token_source.token()
|
|
@@ -665,13 +746,14 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
665
746
|
|
|
666
747
|
|
|
667
748
|
class MetadataServiceTokenSource(Refreshable):
|
|
668
|
-
"""
|
|
749
|
+
"""Obtain the token granted by Databricks Metadata Service"""
|
|
750
|
+
|
|
669
751
|
METADATA_SERVICE_VERSION = "1"
|
|
670
752
|
METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
|
|
671
753
|
METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
|
|
672
|
-
_metadata_service_timeout = 10
|
|
754
|
+
_metadata_service_timeout = 10 # seconds
|
|
673
755
|
|
|
674
|
-
def __init__(self, cfg:
|
|
756
|
+
def __init__(self, cfg: "Config"):
|
|
675
757
|
super().__init__()
|
|
676
758
|
self.url = cfg.metadata_service_url
|
|
677
759
|
self.host = cfg.host
|
|
@@ -682,13 +764,14 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
682
764
|
timeout=self._metadata_service_timeout,
|
|
683
765
|
headers={
|
|
684
766
|
self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
|
|
685
|
-
self.METADATA_SERVICE_HOST_HEADER: self.host
|
|
767
|
+
self.METADATA_SERVICE_HOST_HEADER: self.host,
|
|
686
768
|
},
|
|
687
769
|
proxies={
|
|
688
770
|
# Explicitly exclude localhost from being proxied. This is necessary
|
|
689
771
|
# for Metadata URLs which typically point to localhost.
|
|
690
772
|
"no_proxy": "localhost,127.0.0.1"
|
|
691
|
-
}
|
|
773
|
+
},
|
|
774
|
+
)
|
|
692
775
|
json_resp: dict[str, Union[str, float]] = resp.json()
|
|
693
776
|
access_token = json_resp.get("access_token", None)
|
|
694
777
|
if access_token is None:
|
|
@@ -706,9 +789,9 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
706
789
|
return Token(access_token=access_token, token_type=token_type, expiry=expiry)
|
|
707
790
|
|
|
708
791
|
|
|
709
|
-
@credentials_strategy(
|
|
710
|
-
def metadata_service(cfg:
|
|
711
|
-
"""
|
|
792
|
+
@credentials_strategy("metadata-service", ["host", "metadata_service_url"])
|
|
793
|
+
def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
794
|
+
"""Adds refreshed token granted by Databricks Metadata Service to every request."""
|
|
712
795
|
|
|
713
796
|
token_source = MetadataServiceTokenSource(cfg)
|
|
714
797
|
token_source.token()
|
|
@@ -716,14 +799,14 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
716
799
|
|
|
717
800
|
def inner() -> Dict[str, str]:
|
|
718
801
|
token = token_source.token()
|
|
719
|
-
return {
|
|
802
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
720
803
|
|
|
721
804
|
return inner
|
|
722
805
|
|
|
723
806
|
|
|
724
807
|
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
|
|
725
808
|
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
|
|
726
|
-
class ModelServingAuthProvider
|
|
809
|
+
class ModelServingAuthProvider:
|
|
727
810
|
USER_CREDENTIALS = "user_credentials"
|
|
728
811
|
|
|
729
812
|
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
|
|
@@ -731,7 +814,7 @@ class ModelServingAuthProvider():
|
|
|
731
814
|
def __init__(self, credential_type: Optional[str]):
|
|
732
815
|
self.expiry_time = -1
|
|
733
816
|
self.current_token = None
|
|
734
|
-
self.refresh_duration = 300
|
|
817
|
+
self.refresh_duration = 300 # 300 Seconds
|
|
735
818
|
self.credential_type = credential_type
|
|
736
819
|
|
|
737
820
|
def should_fetch_model_serving_environment_oauth() -> bool:
|
|
@@ -740,10 +823,14 @@ class ModelServingAuthProvider():
|
|
|
740
823
|
Additionally check if the oauth token file path exists
|
|
741
824
|
"""
|
|
742
825
|
|
|
743
|
-
is_in_model_serving_env = (
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
826
|
+
is_in_model_serving_env = (
|
|
827
|
+
os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
|
|
828
|
+
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
|
|
829
|
+
or "false"
|
|
830
|
+
)
|
|
831
|
+
return is_in_model_serving_env == "true" and os.path.isfile(
|
|
832
|
+
ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH
|
|
833
|
+
)
|
|
747
834
|
|
|
748
835
|
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
|
|
749
836
|
# Use Cached value if it is valid
|
|
@@ -758,8 +845,10 @@ class ModelServingAuthProvider():
|
|
|
758
845
|
except Exception as e:
|
|
759
846
|
# sleep and retry in case of any race conditions with OAuth refreshing
|
|
760
847
|
if should_retry:
|
|
761
|
-
logger.warning(
|
|
762
|
-
|
|
848
|
+
logger.warning(
|
|
849
|
+
"Unable to read oauth token on first attmept in Model Serving Environment",
|
|
850
|
+
exc_info=e,
|
|
851
|
+
)
|
|
763
852
|
time.sleep(0.5)
|
|
764
853
|
return self._get_model_dependency_oauth_token(should_retry=False)
|
|
765
854
|
else:
|
|
@@ -769,8 +858,8 @@ class ModelServingAuthProvider():
|
|
|
769
858
|
return self.current_token
|
|
770
859
|
|
|
771
860
|
def _get_invokers_token(self):
|
|
772
|
-
|
|
773
|
-
thread_data =
|
|
861
|
+
main_thread = threading.main_thread()
|
|
862
|
+
thread_data = main_thread.__dict__
|
|
774
863
|
invokers_token = None
|
|
775
864
|
if "invokers_token" in thread_data:
|
|
776
865
|
invokers_token = thread_data["invokers_token"]
|
|
@@ -785,8 +874,7 @@ class ModelServingAuthProvider():
|
|
|
785
874
|
return None
|
|
786
875
|
|
|
787
876
|
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
|
|
788
|
-
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
|
|
789
|
-
"DB_MODEL_SERVING_HOST_URL")
|
|
877
|
+
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL")
|
|
790
878
|
|
|
791
879
|
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
|
|
792
880
|
return (host, self._get_invokers_token())
|
|
@@ -794,8 +882,7 @@ class ModelServingAuthProvider():
|
|
|
794
882
|
return (host, self._get_model_dependency_oauth_token())
|
|
795
883
|
|
|
796
884
|
|
|
797
|
-
def model_serving_auth_visitor(cfg:
|
|
798
|
-
credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
|
|
885
|
+
def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
|
|
799
886
|
try:
|
|
800
887
|
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
|
|
801
888
|
host, token = model_serving_auth_provider.get_databricks_host_token()
|
|
@@ -806,7 +893,10 @@ def model_serving_auth_visitor(cfg: 'Config',
|
|
|
806
893
|
if cfg.host is None:
|
|
807
894
|
cfg.host = host
|
|
808
895
|
except Exception as e:
|
|
809
|
-
logger.warning(
|
|
896
|
+
logger.warning(
|
|
897
|
+
"Unable to get auth from Databricks Model Serving Environment",
|
|
898
|
+
exc_info=e,
|
|
899
|
+
)
|
|
810
900
|
return None
|
|
811
901
|
logger.info("Using Databricks Model Serving Authentication")
|
|
812
902
|
|
|
@@ -818,8 +908,8 @@ def model_serving_auth_visitor(cfg: 'Config',
|
|
|
818
908
|
return inner
|
|
819
909
|
|
|
820
910
|
|
|
821
|
-
@credentials_strategy(
|
|
822
|
-
def model_serving_auth(cfg:
|
|
911
|
+
@credentials_strategy("model-serving", [])
|
|
912
|
+
def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
823
913
|
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
824
914
|
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
|
|
825
915
|
return None
|
|
@@ -828,20 +918,30 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
828
918
|
|
|
829
919
|
|
|
830
920
|
class DefaultCredentials:
|
|
831
|
-
"""
|
|
921
|
+
"""Select the first applicable credential provider from the chain"""
|
|
832
922
|
|
|
833
923
|
def __init__(self) -> None:
|
|
834
|
-
self._auth_type =
|
|
924
|
+
self._auth_type = "default"
|
|
835
925
|
self._auth_providers = [
|
|
836
|
-
pat_auth,
|
|
837
|
-
|
|
838
|
-
|
|
926
|
+
pat_auth,
|
|
927
|
+
basic_auth,
|
|
928
|
+
metadata_service,
|
|
929
|
+
oauth_service_principal,
|
|
930
|
+
azure_service_principal,
|
|
931
|
+
github_oidc_azure,
|
|
932
|
+
azure_cli,
|
|
933
|
+
external_browser,
|
|
934
|
+
databricks_cli,
|
|
935
|
+
runtime_native_auth,
|
|
936
|
+
google_credentials,
|
|
937
|
+
google_id,
|
|
938
|
+
model_serving_auth,
|
|
839
939
|
]
|
|
840
940
|
|
|
841
941
|
def auth_type(self) -> str:
|
|
842
942
|
return self._auth_type
|
|
843
943
|
|
|
844
|
-
def oauth_token(self, cfg:
|
|
944
|
+
def oauth_token(self, cfg: "Config") -> Token:
|
|
845
945
|
for provider in self._auth_providers:
|
|
846
946
|
auth_type = provider.auth_type()
|
|
847
947
|
if auth_type != self._auth_type:
|
|
@@ -849,14 +949,14 @@ class DefaultCredentials:
|
|
|
849
949
|
continue
|
|
850
950
|
return provider.oauth_token(cfg)
|
|
851
951
|
|
|
852
|
-
def __call__(self, cfg:
|
|
952
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider:
|
|
853
953
|
for provider in self._auth_providers:
|
|
854
954
|
auth_type = provider.auth_type()
|
|
855
955
|
if cfg.auth_type and auth_type != cfg.auth_type:
|
|
856
956
|
# ignore other auth types if one is explicitly enforced
|
|
857
957
|
logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
|
|
858
958
|
continue
|
|
859
|
-
logger.debug(f
|
|
959
|
+
logger.debug(f"Attempting to configure auth: {auth_type}")
|
|
860
960
|
try:
|
|
861
961
|
header_factory = provider(cfg)
|
|
862
962
|
if not header_factory:
|
|
@@ -864,18 +964,18 @@ class DefaultCredentials:
|
|
|
864
964
|
self._auth_type = auth_type
|
|
865
965
|
return header_factory
|
|
866
966
|
except Exception as e:
|
|
867
|
-
raise ValueError(f
|
|
967
|
+
raise ValueError(f"{auth_type}: {e}") from e
|
|
868
968
|
auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
|
|
869
969
|
raise ValueError(
|
|
870
|
-
f
|
|
970
|
+
f"cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method."
|
|
871
971
|
)
|
|
872
972
|
|
|
873
973
|
|
|
874
974
|
class ModelServingUserCredentials(CredentialsStrategy):
|
|
875
975
|
"""
|
|
876
|
-
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
|
|
877
|
-
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
|
|
878
|
-
In any other environments, the class defaults to the DefaultCredentialStrategy.
|
|
976
|
+
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
|
|
977
|
+
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
|
|
978
|
+
In any other environments, the class defaults to the DefaultCredentialStrategy.
|
|
879
979
|
To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
|
|
880
980
|
|
|
881
981
|
invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
|
|
@@ -891,7 +991,7 @@ class ModelServingUserCredentials(CredentialsStrategy):
|
|
|
891
991
|
else:
|
|
892
992
|
return self.default_credentials.auth_type()
|
|
893
993
|
|
|
894
|
-
def __call__(self, cfg:
|
|
994
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider:
|
|
895
995
|
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
896
996
|
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
|
|
897
997
|
if not header_factory:
|