databricks-sdk 0.44.0__py3-none-any.whl → 0.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (63) hide show
  1. databricks/sdk/__init__.py +123 -115
  2. databricks/sdk/_base_client.py +112 -88
  3. databricks/sdk/_property.py +12 -7
  4. databricks/sdk/_widgets/__init__.py +13 -2
  5. databricks/sdk/_widgets/default_widgets_utils.py +21 -15
  6. databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
  7. databricks/sdk/azure.py +8 -6
  8. databricks/sdk/casing.py +5 -5
  9. databricks/sdk/config.py +152 -99
  10. databricks/sdk/core.py +57 -47
  11. databricks/sdk/credentials_provider.py +360 -210
  12. databricks/sdk/data_plane.py +86 -3
  13. databricks/sdk/dbutils.py +123 -87
  14. databricks/sdk/environments.py +52 -35
  15. databricks/sdk/errors/base.py +61 -35
  16. databricks/sdk/errors/customizer.py +3 -3
  17. databricks/sdk/errors/deserializer.py +38 -25
  18. databricks/sdk/errors/details.py +417 -0
  19. databricks/sdk/errors/mapper.py +1 -1
  20. databricks/sdk/errors/overrides.py +27 -24
  21. databricks/sdk/errors/parser.py +26 -14
  22. databricks/sdk/errors/platform.py +10 -10
  23. databricks/sdk/errors/private_link.py +24 -24
  24. databricks/sdk/logger/round_trip_logger.py +28 -20
  25. databricks/sdk/mixins/compute.py +90 -60
  26. databricks/sdk/mixins/files.py +815 -145
  27. databricks/sdk/mixins/jobs.py +201 -20
  28. databricks/sdk/mixins/open_ai_client.py +26 -20
  29. databricks/sdk/mixins/workspace.py +45 -34
  30. databricks/sdk/oauth.py +372 -196
  31. databricks/sdk/retries.py +14 -12
  32. databricks/sdk/runtime/__init__.py +34 -17
  33. databricks/sdk/runtime/dbutils_stub.py +52 -39
  34. databricks/sdk/service/_internal.py +12 -7
  35. databricks/sdk/service/apps.py +618 -418
  36. databricks/sdk/service/billing.py +827 -604
  37. databricks/sdk/service/catalog.py +6552 -4474
  38. databricks/sdk/service/cleanrooms.py +550 -388
  39. databricks/sdk/service/compute.py +5241 -3531
  40. databricks/sdk/service/dashboards.py +1313 -923
  41. databricks/sdk/service/files.py +442 -309
  42. databricks/sdk/service/iam.py +2115 -1483
  43. databricks/sdk/service/jobs.py +4151 -2588
  44. databricks/sdk/service/marketplace.py +2210 -1517
  45. databricks/sdk/service/ml.py +3364 -2255
  46. databricks/sdk/service/oauth2.py +922 -584
  47. databricks/sdk/service/pipelines.py +1865 -1203
  48. databricks/sdk/service/provisioning.py +1435 -1029
  49. databricks/sdk/service/serving.py +2040 -1278
  50. databricks/sdk/service/settings.py +2846 -1929
  51. databricks/sdk/service/sharing.py +2201 -877
  52. databricks/sdk/service/sql.py +4650 -3103
  53. databricks/sdk/service/vectorsearch.py +816 -550
  54. databricks/sdk/service/workspace.py +1330 -906
  55. databricks/sdk/useragent.py +36 -22
  56. databricks/sdk/version.py +1 -1
  57. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
  58. databricks_sdk-0.45.0.dist-info/RECORD +70 -0
  59. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
  60. databricks_sdk-0.44.0.dist-info/RECORD +0 -69
  61. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
  62. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
  63. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import pathlib
9
9
  import platform
10
10
  import subprocess
11
11
  import sys
12
+ import threading
12
13
  import time
13
14
  from datetime import datetime
14
15
  from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -25,13 +26,17 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
25
26
 
26
27
  CredentialsProvider = Callable[[], Dict[str, str]]
27
28
 
28
- logger = logging.getLogger('databricks.sdk')
29
+ logger = logging.getLogger("databricks.sdk")
29
30
 
30
31
 
31
32
  class OAuthCredentialsProvider:
32
- """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
33
+ """OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens."""
33
34
 
34
- def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
35
+ def __init__(
36
+ self,
37
+ credentials_provider: CredentialsProvider,
38
+ token_provider: Callable[[], Token],
39
+ ):
35
40
  self._credentials_provider = credentials_provider
36
41
  self._token_provider = token_provider
37
42
 
@@ -43,45 +48,49 @@ class OAuthCredentialsProvider:
43
48
 
44
49
 
45
50
  class CredentialsStrategy(abc.ABC):
46
- """ CredentialsProvider is the protocol (call-side interface)
47
- for authenticating requests to Databricks REST APIs"""
51
+ """CredentialsProvider is the protocol (call-side interface)
52
+ for authenticating requests to Databricks REST APIs"""
48
53
 
49
54
  @abc.abstractmethod
50
- def auth_type(self) -> str:
51
- ...
55
+ def auth_type(self) -> str: ...
52
56
 
53
57
  @abc.abstractmethod
54
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
55
- ...
58
+ def __call__(self, cfg: "Config") -> CredentialsProvider: ...
56
59
 
57
60
 
58
61
  class OauthCredentialsStrategy(CredentialsStrategy):
59
- """ OauthCredentialsProvider is a CredentialsProvider which
62
+ """OauthCredentialsProvider is a CredentialsProvider which
60
63
  supports Oauth tokens"""
61
64
 
62
- def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
65
+ def __init__(
66
+ self,
67
+ auth_type: str,
68
+ headers_provider: Callable[["Config"], OAuthCredentialsProvider],
69
+ ):
63
70
  self._headers_provider = headers_provider
64
71
  self._auth_type = auth_type
65
72
 
66
73
  def auth_type(self) -> str:
67
74
  return self._auth_type
68
75
 
69
- def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
76
+ def __call__(self, cfg: "Config") -> OAuthCredentialsProvider:
70
77
  return self._headers_provider(cfg)
71
78
 
72
- def oauth_token(self, cfg: 'Config') -> Token:
79
+ def oauth_token(self, cfg: "Config") -> Token:
73
80
  return self._headers_provider(cfg).oauth_token()
74
81
 
75
82
 
76
83
  def credentials_strategy(name: str, require: List[str]):
77
- """ Given the function that receives a Config and returns RequestVisitor,
84
+ """Given the function that receives a Config and returns RequestVisitor,
78
85
  create CredentialsProvider with a given name and required configuration
79
- attribute names to be present for this function to be called. """
86
+ attribute names to be present for this function to be called."""
80
87
 
81
- def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
88
+ def inner(
89
+ func: Callable[["Config"], CredentialsProvider],
90
+ ) -> CredentialsStrategy:
82
91
 
83
92
  @functools.wraps(func)
84
- def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
93
+ def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
85
94
  for attr in require:
86
95
  getattr(cfg, attr)
87
96
  if not getattr(cfg, attr):
@@ -95,14 +104,16 @@ def credentials_strategy(name: str, require: List[str]):
95
104
 
96
105
 
97
106
  def oauth_credentials_strategy(name: str, require: List[str]):
98
- """ Given the function that receives a Config and returns an OauthHeaderFactory,
107
+ """Given the function that receives a Config and returns an OauthHeaderFactory,
99
108
  create an OauthCredentialsProvider with a given name and required configuration
100
- attribute names to be present for this function to be called. """
109
+ attribute names to be present for this function to be called."""
101
110
 
102
- def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
111
+ def inner(
112
+ func: Callable[["Config"], OAuthCredentialsProvider],
113
+ ) -> OauthCredentialsStrategy:
103
114
 
104
115
  @functools.wraps(func)
105
- def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
116
+ def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
106
117
  for attr in require:
107
118
  if not getattr(cfg, attr):
108
119
  return None
@@ -113,11 +124,11 @@ def oauth_credentials_strategy(name: str, require: List[str]):
113
124
  return inner
114
125
 
115
126
 
116
- @credentials_strategy('basic', ['host', 'username', 'password'])
117
- def basic_auth(cfg: 'Config') -> CredentialsProvider:
118
- """ Given username and password, add base64-encoded Basic credentials """
119
- encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
120
- static_credentials = {'Authorization': f'Basic {encoded}'}
127
+ @credentials_strategy("basic", ["host", "username", "password"])
128
+ def basic_auth(cfg: "Config") -> CredentialsProvider:
129
+ """Given username and password, add base64-encoded Basic credentials"""
130
+ encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode()
131
+ static_credentials = {"Authorization": f"Basic {encoded}"}
121
132
 
122
133
  def inner() -> Dict[str, str]:
123
134
  return static_credentials
@@ -125,10 +136,10 @@ def basic_auth(cfg: 'Config') -> CredentialsProvider:
125
136
  return inner
126
137
 
127
138
 
128
- @credentials_strategy('pat', ['host', 'token'])
129
- def pat_auth(cfg: 'Config') -> CredentialsProvider:
130
- """ Adds Databricks Personal Access Token to every request """
131
- static_credentials = {'Authorization': f'Bearer {cfg.token}'}
139
+ @credentials_strategy("pat", ["host", "token"])
140
+ def pat_auth(cfg: "Config") -> CredentialsProvider:
141
+ """Adds Databricks Personal Access Token to every request"""
142
+ static_credentials = {"Authorization": f"Bearer {cfg.token}"}
132
143
 
133
144
  def inner() -> Dict[str, str]:
134
145
  return static_credentials
@@ -136,9 +147,9 @@ def pat_auth(cfg: 'Config') -> CredentialsProvider:
136
147
  return inner
137
148
 
138
149
 
139
- @credentials_strategy('runtime', [])
140
- def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
141
- if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
150
+ @credentials_strategy("runtime", [])
151
+ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
152
+ if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
142
153
  return None
143
154
 
144
155
  # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
@@ -147,36 +158,44 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
147
158
  from databricks.sdk.runtime import (init_runtime_legacy_auth,
148
159
  init_runtime_native_auth,
149
160
  init_runtime_repl_auth)
150
- for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]:
161
+
162
+ for init in [
163
+ init_runtime_native_auth,
164
+ init_runtime_repl_auth,
165
+ init_runtime_legacy_auth,
166
+ ]:
151
167
  if init is None:
152
168
  continue
153
169
  host, inner = init()
154
170
  if host is None:
155
- logger.debug(f'[{init.__name__}] no host detected')
171
+ logger.debug(f"[{init.__name__}] no host detected")
156
172
  continue
157
173
  cfg.host = host
158
- logger.debug(f'[{init.__name__}] runtime native auth configured')
174
+ logger.debug(f"[{init.__name__}] runtime native auth configured")
159
175
  return inner
160
176
  return None
161
177
 
162
178
 
163
- @oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
164
- def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
165
- """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
166
- if /oidc/.well-known/oauth-authorization-server is available on the given host. """
179
+ @oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
180
+ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
181
+ """Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
182
+ if /oidc/.well-known/oauth-authorization-server is available on the given host.
183
+ """
167
184
  oidc = cfg.oidc_endpoints
168
185
  if oidc is None:
169
186
  return None
170
187
 
171
- token_source = ClientCredentials(client_id=cfg.client_id,
172
- client_secret=cfg.client_secret,
173
- token_url=oidc.token_endpoint,
174
- scopes=["all-apis"],
175
- use_header=True)
188
+ token_source = ClientCredentials(
189
+ client_id=cfg.client_id,
190
+ client_secret=cfg.client_secret,
191
+ token_url=oidc.token_endpoint,
192
+ scopes=["all-apis"],
193
+ use_header=True,
194
+ )
176
195
 
177
196
  def inner() -> Dict[str, str]:
178
197
  token = token_source.token()
179
- return {'Authorization': f'{token.token_type} {token.access_token}'}
198
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
180
199
 
181
200
  def token() -> Token:
182
201
  return token_source.token()
@@ -184,9 +203,9 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
184
203
  return OAuthCredentialsProvider(inner, token)
185
204
 
186
205
 
187
- @credentials_strategy('external-browser', ['host', 'auth_type'])
188
- def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
189
- if cfg.auth_type != 'external-browser':
206
+ @credentials_strategy("external-browser", ["host", "auth_type"])
207
+ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
208
+ if cfg.auth_type != "external-browser":
190
209
  return None
191
210
 
192
211
  client_id, client_secret = None, None
@@ -197,17 +216,19 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
197
216
  client_id = cfg.azure_client
198
217
  client_secret = cfg.azure_client_secret
199
218
  if not client_id:
200
- client_id = 'databricks-cli'
219
+ client_id = "databricks-cli"
201
220
 
202
221
  # Load cached credentials from disk if they exist. Note that these are
203
222
  # local to the Python SDK and not reused by other SDKs.
204
223
  oidc_endpoints = cfg.oidc_endpoints
205
- redirect_url = 'http://localhost:8020'
206
- token_cache = TokenCache(host=cfg.host,
207
- oidc_endpoints=oidc_endpoints,
208
- client_id=client_id,
209
- client_secret=client_secret,
210
- redirect_url=redirect_url)
224
+ redirect_url = "http://localhost:8020"
225
+ token_cache = TokenCache(
226
+ host=cfg.host,
227
+ oidc_endpoints=oidc_endpoints,
228
+ client_id=client_id,
229
+ client_secret=client_secret,
230
+ redirect_url=redirect_url,
231
+ )
211
232
  credentials = token_cache.load()
212
233
  if credentials:
213
234
  try:
@@ -218,12 +239,14 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
218
239
  return credentials(cfg)
219
240
  # TODO: We should ideally use more specific exceptions.
220
241
  except Exception as e:
221
- logger.warning(f'Failed to refresh cached token: {e}. Initiating new OAuth login flow')
222
-
223
- oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
224
- client_id=client_id,
225
- redirect_url=redirect_url,
226
- client_secret=client_secret)
242
+ logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow")
243
+
244
+ oauth_client = OAuthClient(
245
+ oidc_endpoints=oidc_endpoints,
246
+ client_id=client_id,
247
+ redirect_url=redirect_url,
248
+ client_secret=client_secret,
249
+ )
227
250
  consent = oauth_client.initiate_consent()
228
251
  if not consent:
229
252
  return None
@@ -233,33 +256,41 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
233
256
  return credentials(cfg)
234
257
 
235
258
 
236
- def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]):
237
- """ Resolves Azure Databricks workspace URL from ARM Resource ID """
259
+ def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]):
260
+ """Resolves Azure Databricks workspace URL from ARM Resource ID"""
238
261
  if cfg.host:
239
262
  return
240
263
  if not cfg.azure_workspace_resource_id:
241
264
  return
242
265
  arm = cfg.arm_environment.resource_manager_endpoint
243
266
  token = token_source_for(arm).token()
244
- resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
245
- headers={"Authorization": f"Bearer {token.access_token}"})
267
+ resp = requests.get(
268
+ f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
269
+ headers={"Authorization": f"Bearer {token.access_token}"},
270
+ )
246
271
  if not resp.ok:
247
272
  raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
248
273
  cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
249
274
 
250
275
 
251
- @oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
252
- def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
253
- """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
254
- to every request, while automatically resolving different Azure environment endpoints. """
276
+ @oauth_credentials_strategy(
277
+ "azure-client-secret",
278
+ ["is_azure", "azure_client_id", "azure_client_secret"],
279
+ )
280
+ def azure_service_principal(cfg: "Config") -> CredentialsProvider:
281
+ """Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
282
+ to every request, while automatically resolving different Azure environment endpoints.
283
+ """
255
284
 
256
285
  def token_source_for(resource: str) -> TokenSource:
257
286
  aad_endpoint = cfg.arm_environment.active_directory_endpoint
258
- return ClientCredentials(client_id=cfg.azure_client_id,
259
- client_secret=cfg.azure_client_secret,
260
- token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
261
- endpoint_params={"resource": resource},
262
- use_params=True)
287
+ return ClientCredentials(
288
+ client_id=cfg.azure_client_id,
289
+ client_secret=cfg.azure_client_secret,
290
+ token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
291
+ endpoint_params={"resource": resource},
292
+ use_params=True,
293
+ )
263
294
 
264
295
  _ensure_host_present(cfg, token_source_for)
265
296
  cfg.load_azure_tenant_id()
@@ -268,7 +299,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
268
299
  cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
269
300
 
270
301
  def refreshed_headers() -> Dict[str, str]:
271
- headers = {'Authorization': f"Bearer {inner.token().access_token}", }
302
+ headers = {
303
+ "Authorization": f"Bearer {inner.token().access_token}",
304
+ }
272
305
  add_workspace_id_header(cfg, headers)
273
306
  add_sp_management_token(cloud, headers)
274
307
  return headers
@@ -279,9 +312,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
279
312
  return OAuthCredentialsProvider(refreshed_headers, token)
280
313
 
281
314
 
282
- @oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
283
- def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
284
- if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
315
+ @oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
316
+ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
317
+ if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
285
318
  # not in GitHub actions
286
319
  return None
287
320
 
@@ -291,7 +324,7 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
291
324
  return None
292
325
 
293
326
  # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
294
- headers = {'Authorization': f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
327
+ headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
295
328
  endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
296
329
  response = requests.get(endpoint, headers=headers)
297
330
  if not response.ok:
@@ -299,30 +332,34 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
299
332
 
300
333
  # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
301
334
  response_json = response.json()
302
- if 'value' not in response_json:
335
+ if "value" not in response_json:
303
336
  return None
304
337
 
305
- logger.info("Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id)
338
+ logger.info(
339
+ "Configured AAD token for GitHub Actions OIDC (%s)",
340
+ cfg.azure_client_id,
341
+ )
306
342
  params = {
307
- 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
308
- 'resource': cfg.effective_azure_login_app_id,
309
- 'client_assertion': response_json['value'],
343
+ "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
344
+ "resource": cfg.effective_azure_login_app_id,
345
+ "client_assertion": response_json["value"],
310
346
  }
311
347
  aad_endpoint = cfg.arm_environment.active_directory_endpoint
312
348
  if not cfg.azure_tenant_id:
313
349
  # detect Azure AD Tenant ID if it's not specified directly
314
350
  token_endpoint = cfg.oidc_endpoints.token_endpoint
315
- cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
351
+ cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
316
352
  inner = ClientCredentials(
317
353
  client_id=cfg.azure_client_id,
318
- client_secret="", # we have no (rotatable) secrets in OIDC flow
354
+ client_secret="", # we have no (rotatable) secrets in OIDC flow
319
355
  token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
320
356
  endpoint_params=params,
321
- use_params=True)
357
+ use_params=True,
358
+ )
322
359
 
323
360
  def refreshed_headers() -> Dict[str, str]:
324
361
  token = inner.token()
325
- return {'Authorization': f'{token.token_type} {token.access_token}'}
362
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
326
363
 
327
364
  def token() -> Token:
328
365
  return inner.token()
@@ -330,29 +367,32 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
330
367
  return OAuthCredentialsProvider(refreshed_headers, token)
331
368
 
332
369
 
333
- GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
370
+ GcpScopes = [
371
+ "https://www.googleapis.com/auth/cloud-platform",
372
+ "https://www.googleapis.com/auth/compute",
373
+ ]
334
374
 
335
375
 
336
- @oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
337
- def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
376
+ @oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
377
+ def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
338
378
  if not cfg.is_gcp:
339
379
  return None
340
380
  # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
341
381
  # Obtain the id token by providing the json file path and target audience.
342
- if (os.path.isfile(cfg.google_credentials)):
382
+ if os.path.isfile(cfg.google_credentials):
343
383
  with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
344
384
  account_info = json.load(json_file)
345
385
  else:
346
386
  # If the file doesn't exist, assume that the config is the actual JSON content.
347
387
  account_info = json.loads(cfg.google_credentials)
348
388
 
349
- credentials = service_account.IDTokenCredentials.from_service_account_info(info=account_info,
350
- target_audience=cfg.host)
389
+ credentials = service_account.IDTokenCredentials.from_service_account_info(
390
+ info=account_info, target_audience=cfg.host
391
+ )
351
392
 
352
393
  request = Request()
353
394
 
354
- gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
355
- scopes=GcpScopes)
395
+ gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes)
356
396
 
357
397
  def token() -> Token:
358
398
  credentials.refresh(request)
@@ -360,7 +400,7 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
360
400
 
361
401
  def refreshed_headers() -> Dict[str, str]:
362
402
  credentials.refresh(request)
363
- headers = {'Authorization': f'Bearer {credentials.token}'}
403
+ headers = {"Authorization": f"Bearer {credentials.token}"}
364
404
  if cfg.is_account_client:
365
405
  gcp_credentials.refresh(request)
366
406
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
@@ -369,24 +409,29 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
369
409
  return OAuthCredentialsProvider(refreshed_headers, token)
370
410
 
371
411
 
372
- @oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
373
- def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
412
+ @oauth_credentials_strategy("google-id", ["host", "google_service_account"])
413
+ def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
374
414
  if not cfg.is_gcp:
375
415
  return None
376
416
  credentials, _project_id = google.auth.default()
377
417
 
378
418
  # Create the impersonated credential.
379
- target_credentials = impersonated_credentials.Credentials(source_credentials=credentials,
380
- target_principal=cfg.google_service_account,
381
- target_scopes=[])
419
+ target_credentials = impersonated_credentials.Credentials(
420
+ source_credentials=credentials,
421
+ target_principal=cfg.google_service_account,
422
+ target_scopes=[],
423
+ )
382
424
 
383
425
  # Set the impersonated credential, target audience and token options.
384
- id_creds = impersonated_credentials.IDTokenCredentials(target_credentials,
385
- target_audience=cfg.host,
386
- include_email=True)
426
+ id_creds = impersonated_credentials.IDTokenCredentials(
427
+ target_credentials, target_audience=cfg.host, include_email=True
428
+ )
387
429
 
388
430
  gcp_impersonated_credentials = impersonated_credentials.Credentials(
389
- source_credentials=credentials, target_principal=cfg.google_service_account, target_scopes=GcpScopes)
431
+ source_credentials=credentials,
432
+ target_principal=cfg.google_service_account,
433
+ target_scopes=GcpScopes,
434
+ )
390
435
 
391
436
  request = Request()
392
437
 
@@ -396,7 +441,7 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
396
441
 
397
442
  def refreshed_headers() -> Dict[str, str]:
398
443
  id_creds.refresh(request)
399
- headers = {'Authorization': f'Bearer {id_creds.token}'}
444
+ headers = {"Authorization": f"Bearer {id_creds.token}"}
400
445
  if cfg.is_account_client:
401
446
  gcp_impersonated_credentials.refresh(request)
402
447
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
@@ -407,7 +452,13 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
407
452
 
408
453
  class CliTokenSource(Refreshable):
409
454
 
410
- def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str):
455
+ def __init__(
456
+ self,
457
+ cmd: List[str],
458
+ token_type_field: str,
459
+ access_token_field: str,
460
+ expiry_field: str,
461
+ ):
411
462
  super().__init__()
412
463
  self._cmd = cmd
413
464
  self._token_type_field = token_type_field
@@ -430,52 +481,74 @@ class CliTokenSource(Refreshable):
430
481
  out = _run_subprocess(self._cmd, capture_output=True, check=True)
431
482
  it = json.loads(out.stdout.decode())
432
483
  expires_on = self._parse_expiry(it[self._expiry_field])
433
- return Token(access_token=it[self._access_token_field],
434
- token_type=it[self._token_type_field],
435
- expiry=expires_on)
484
+ return Token(
485
+ access_token=it[self._access_token_field],
486
+ token_type=it[self._token_type_field],
487
+ expiry=expires_on,
488
+ )
436
489
  except ValueError as e:
437
490
  raise ValueError(f"cannot unmarshal CLI result: {e}")
438
491
  except subprocess.CalledProcessError as e:
439
492
  stdout = e.stdout.decode().strip()
440
493
  stderr = e.stderr.decode().strip()
441
494
  message = stdout or stderr
442
- raise IOError(f'cannot get access token: {message}') from e
495
+ raise IOError(f"cannot get access token: {message}") from e
443
496
 
444
497
 
445
- def _run_subprocess(popenargs,
446
- input=None,
447
- capture_output=True,
448
- timeout=None,
449
- check=False,
450
- **kwargs) -> subprocess.CompletedProcess:
498
+ def _run_subprocess(
499
+ popenargs,
500
+ input=None,
501
+ capture_output=True,
502
+ timeout=None,
503
+ check=False,
504
+ **kwargs,
505
+ ) -> subprocess.CompletedProcess:
451
506
  """Runs subprocess with given arguments.
452
- This handles OS-specific modifications that need to be made to the invocation of subprocess.run."""
453
- kwargs['shell'] = sys.platform.startswith('win')
507
+ This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
508
+ """
509
+ kwargs["shell"] = sys.platform.startswith("win")
454
510
  # windows requires shell=True to be able to execute 'az login' or other commands
455
511
  # cannot use shell=True all the time, as it breaks macOS
456
512
  logging.debug(f'Running command: {" ".join(popenargs)}')
457
- return subprocess.run(popenargs,
458
- input=input,
459
- capture_output=capture_output,
460
- timeout=timeout,
461
- check=check,
462
- **kwargs)
513
+ return subprocess.run(
514
+ popenargs,
515
+ input=input,
516
+ capture_output=capture_output,
517
+ timeout=timeout,
518
+ check=check,
519
+ **kwargs,
520
+ )
463
521
 
464
522
 
465
523
  class AzureCliTokenSource(CliTokenSource):
466
- """ Obtain the token granted by `az login` CLI command """
467
-
468
- def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
469
- cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
524
+ """Obtain the token granted by `az login` CLI command"""
525
+
526
+ def __init__(
527
+ self,
528
+ resource: str,
529
+ subscription: Optional[str] = None,
530
+ tenant: Optional[str] = None,
531
+ ):
532
+ cmd = [
533
+ "az",
534
+ "account",
535
+ "get-access-token",
536
+ "--resource",
537
+ resource,
538
+ "--output",
539
+ "json",
540
+ ]
470
541
  if subscription is not None:
471
542
  cmd.append("--subscription")
472
543
  cmd.append(subscription)
473
544
  if tenant and not self.__is_cli_using_managed_identity():
474
545
  cmd.extend(["--tenant", tenant])
475
- super().__init__(cmd=cmd,
476
- token_type_field='tokenType',
477
- access_token_field='accessToken',
478
- expiry_field='expiresOn')
546
+ super().__init__(
547
+ cmd=cmd,
548
+ token_type_field="tokenType",
549
+ access_token_field="accessToken",
550
+ expiry_field="expiresOn",
551
+ )
479
552
 
480
553
  @staticmethod
481
554
  def __is_cli_using_managed_identity() -> bool:
@@ -488,7 +561,8 @@ class AzureCliTokenSource(CliTokenSource):
488
561
  if user is None:
489
562
  return False
490
563
  return user.get("type") == "servicePrincipal" and user.get("name") in [
491
- 'systemAssignedIdentity', 'userAssignedIdentity'
564
+ "systemAssignedIdentity",
565
+ "userAssignedIdentity",
492
566
  ]
493
567
  except subprocess.CalledProcessError as e:
494
568
  logger.debug("Failed to get account information from Azure CLI", exc_info=e)
@@ -511,15 +585,13 @@ class AzureCliTokenSource(CliTokenSource):
511
585
  guaranteed to be unique within a tenant and should be used only for display purposes.
512
586
  - 'upn' - The username of the user.
513
587
  """
514
- return 'upn' in self.token().jwt_claims()
588
+ return "upn" in self.token().jwt_claims()
515
589
 
516
590
  @staticmethod
517
- def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
591
+ def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource":
518
592
  subscription = AzureCliTokenSource.get_subscription(cfg)
519
593
  if subscription is not None:
520
- token_source = AzureCliTokenSource(resource,
521
- subscription=subscription,
522
- tenant=cfg.azure_tenant_id)
594
+ token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
523
595
  try:
524
596
  # This will fail if the user has access to the workspace, but not to the subscription
525
597
  # itself.
@@ -534,32 +606,32 @@ class AzureCliTokenSource(CliTokenSource):
534
606
  return token_source
535
607
 
536
608
  @staticmethod
537
- def get_subscription(cfg: 'Config') -> Optional[str]:
609
+ def get_subscription(cfg: "Config") -> Optional[str]:
538
610
  resource = cfg.azure_workspace_resource_id
539
611
  if resource is None or resource == "":
540
612
  return None
541
- components = resource.split('/')
613
+ components = resource.split("/")
542
614
  if len(components) < 3:
543
615
  logger.warning("Invalid azure workspace resource ID")
544
616
  return None
545
617
  return components[2]
546
618
 
547
619
 
548
- @credentials_strategy('azure-cli', ['is_azure'])
549
- def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
550
- """ Adds refreshed OAuth token granted by `az login` command to every request. """
620
+ @credentials_strategy("azure-cli", ["is_azure"])
621
+ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
622
+ """Adds refreshed OAuth token granted by `az login` command to every request."""
551
623
  cfg.load_azure_tenant_id()
552
624
  token_source = None
553
625
  mgmt_token_source = None
554
626
  try:
555
627
  token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
556
628
  except FileNotFoundError:
557
- doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
558
- logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
629
+ doc = "https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest"
630
+ logger.debug(f"Most likely Azure CLI is not installed. See {doc} for details")
559
631
  return None
560
632
  except OSError as e:
561
- logger.debug('skipping Azure CLI auth', exc_info=e)
562
- logger.debug('This may happen if you are attempting to login to a dev or staging workspace')
633
+ logger.debug("skipping Azure CLI auth", exc_info=e)
634
+ logger.debug("This may happen if you are attempting to login to a dev or staging workspace")
563
635
  return None
564
636
 
565
637
  if not token_source.is_human_user():
@@ -567,7 +639,10 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
567
639
  management_endpoint = cfg.arm_environment.service_management_endpoint
568
640
  mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
569
641
  except Exception as e:
570
- logger.debug(f'Not including service management token in headers', exc_info=e)
642
+ logger.debug(
643
+ f"Not including service management token in headers",
644
+ exc_info=e,
645
+ )
571
646
  mgmt_token_source = None
572
647
 
573
648
  _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
@@ -575,7 +650,7 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
575
650
 
576
651
  def inner() -> Dict[str, str]:
577
652
  token = token_source.token()
578
- headers = {'Authorization': f'{token.token_type} {token.access_token}'}
653
+ headers = {"Authorization": f"{token.token_type} {token.access_token}"}
579
654
  add_workspace_id_header(cfg, headers)
580
655
  if mgmt_token_source:
581
656
  add_sp_management_token(mgmt_token_source, headers)
@@ -585,12 +660,12 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
585
660
 
586
661
 
587
662
  class DatabricksCliTokenSource(CliTokenSource):
588
- """ Obtain the token granted by `databricks auth login` CLI command """
663
+ """Obtain the token granted by `databricks auth login` CLI command"""
589
664
 
590
- def __init__(self, cfg: 'Config'):
591
- args = ['auth', 'token', '--host', cfg.host]
665
+ def __init__(self, cfg: "Config"):
666
+ args = ["auth", "token", "--host", cfg.host]
592
667
  if cfg.is_account_client:
593
- args += ['--account-id', cfg.account_id]
668
+ args += ["--account-id", cfg.account_id]
594
669
 
595
670
  cli_path = cfg.databricks_cli_path
596
671
 
@@ -610,10 +685,12 @@ class DatabricksCliTokenSource(CliTokenSource):
610
685
  elif cli_path.count("/") == 0:
611
686
  cli_path = self.__class__._find_executable(cli_path)
612
687
 
613
- super().__init__(cmd=[cli_path, *args],
614
- token_type_field='token_type',
615
- access_token_field='access_token',
616
- expiry_field='expiry')
688
+ super().__init__(
689
+ cmd=[cli_path, *args],
690
+ token_type_field="token_type",
691
+ access_token_field="access_token",
692
+ expiry_field="expiry",
693
+ )
617
694
 
618
695
  @staticmethod
619
696
  def _find_executable(name) -> str:
@@ -635,8 +712,8 @@ class DatabricksCliTokenSource(CliTokenSource):
635
712
  raise err
636
713
 
637
714
 
638
- @oauth_credentials_strategy('databricks-cli', ['host'])
639
- def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
715
+ @oauth_credentials_strategy("databricks-cli", ["host"])
716
+ def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]:
640
717
  try:
641
718
  token_source = DatabricksCliTokenSource(cfg)
642
719
  except FileNotFoundError as e:
@@ -646,8 +723,8 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
646
723
  try:
647
724
  token_source.token()
648
725
  except IOError as e:
649
- if 'databricks OAuth is not' in str(e):
650
- logger.debug(f'OAuth not configured or not available: {e}')
726
+ if "databricks OAuth is not" in str(e):
727
+ logger.debug(f"OAuth not configured or not available: {e}")
651
728
  return None
652
729
  raise e
653
730
 
@@ -655,7 +732,7 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
655
732
 
656
733
  def inner() -> Dict[str, str]:
657
734
  token = token_source.token()
658
- return {'Authorization': f'{token.token_type} {token.access_token}'}
735
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
659
736
 
660
737
  def token() -> Token:
661
738
  return token_source.token()
@@ -664,13 +741,14 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
664
741
 
665
742
 
666
743
  class MetadataServiceTokenSource(Refreshable):
667
- """ Obtain the token granted by Databricks Metadata Service """
744
+ """Obtain the token granted by Databricks Metadata Service"""
745
+
668
746
  METADATA_SERVICE_VERSION = "1"
669
747
  METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
670
748
  METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
671
- _metadata_service_timeout = 10 # seconds
749
+ _metadata_service_timeout = 10 # seconds
672
750
 
673
- def __init__(self, cfg: 'Config'):
751
+ def __init__(self, cfg: "Config"):
674
752
  super().__init__()
675
753
  self.url = cfg.metadata_service_url
676
754
  self.host = cfg.host
@@ -681,13 +759,14 @@ class MetadataServiceTokenSource(Refreshable):
681
759
  timeout=self._metadata_service_timeout,
682
760
  headers={
683
761
  self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
684
- self.METADATA_SERVICE_HOST_HEADER: self.host
762
+ self.METADATA_SERVICE_HOST_HEADER: self.host,
685
763
  },
686
764
  proxies={
687
765
  # Explicitly exclude localhost from being proxied. This is necessary
688
766
  # for Metadata URLs which typically point to localhost.
689
767
  "no_proxy": "localhost,127.0.0.1"
690
- })
768
+ },
769
+ )
691
770
  json_resp: dict[str, Union[str, float]] = resp.json()
692
771
  access_token = json_resp.get("access_token", None)
693
772
  if access_token is None:
@@ -705,9 +784,9 @@ class MetadataServiceTokenSource(Refreshable):
705
784
  return Token(access_token=access_token, token_type=token_type, expiry=expiry)
706
785
 
707
786
 
708
- @credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
709
- def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
710
- """ Adds refreshed token granted by Databricks Metadata Service to every request. """
787
+ @credentials_strategy("metadata-service", ["host", "metadata_service_url"])
788
+ def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]:
789
+ """Adds refreshed token granted by Databricks Metadata Service to every request."""
711
790
 
712
791
  token_source = MetadataServiceTokenSource(cfg)
713
792
  token_source.token()
@@ -715,74 +794,92 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
715
794
 
716
795
  def inner() -> Dict[str, str]:
717
796
  token = token_source.token()
718
- return {'Authorization': f'{token.token_type} {token.access_token}'}
797
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
719
798
 
720
799
  return inner
721
800
 
722
801
 
723
802
  # This Code is derived from Mlflow DatabricksModelServingConfigProvider
724
803
  # https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
725
- class ModelServingAuthProvider():
804
+ class ModelServingAuthProvider:
805
+ USER_CREDENTIALS = "user_credentials"
806
+
726
807
  _MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
727
808
 
728
- def __init__(self):
809
+ def __init__(self, credential_type: Optional[str]):
729
810
  self.expiry_time = -1
730
811
  self.current_token = None
731
- self.refresh_duration = 300 # 300 Seconds
812
+ self.refresh_duration = 300 # 300 Seconds
813
+ self.credential_type = credential_type
732
814
 
733
- def should_fetch_model_serving_environment_oauth(self) -> bool:
815
+ def should_fetch_model_serving_environment_oauth() -> bool:
734
816
  """
735
817
  Check whether this is the model serving environment
736
818
  Additionally check if the oauth token file path exists
737
819
  """
738
820
 
739
- is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
740
- or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
741
- return (is_in_model_serving_env == "true"
742
- and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
821
+ is_in_model_serving_env = (
822
+ os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
823
+ or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
824
+ or "false"
825
+ )
826
+ return is_in_model_serving_env == "true" and os.path.isfile(
827
+ ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH
828
+ )
743
829
 
744
- def get_model_dependency_oauth_token(self, should_retry=True) -> str:
830
+ def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
745
831
  # Use Cached value if it is valid
746
832
  if self.current_token is not None and self.expiry_time > time.time():
747
833
  return self.current_token
748
834
 
749
835
  try:
750
- with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
836
+ with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
751
837
  oauth_dict = json.load(f)
752
838
  self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
753
839
  self.expiry_time = time.time() + self.refresh_duration
754
840
  except Exception as e:
755
841
  # sleep and retry in case of any race conditions with OAuth refreshing
756
842
  if should_retry:
757
- logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
758
- exc_info=e)
843
+ logger.warning(
844
+ "Unable to read oauth token on first attmept in Model Serving Environment",
845
+ exc_info=e,
846
+ )
759
847
  time.sleep(0.5)
760
- return self.get_model_dependency_oauth_token(should_retry=False)
848
+ return self._get_model_dependency_oauth_token(should_retry=False)
761
849
  else:
762
850
  raise RuntimeError(
763
851
  "Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
764
852
  ) from e
765
853
  return self.current_token
766
854
 
855
+ def _get_invokers_token(self):
856
+ main_thread = threading.main_thread()
857
+ thread_data = main_thread.__dict__
858
+ invokers_token = None
859
+ if "invokers_token" in thread_data:
860
+ invokers_token = thread_data["invokers_token"]
861
+
862
+ if invokers_token is None:
863
+ raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")
864
+
865
+ return invokers_token
866
+
767
867
  def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
768
- if not self.should_fetch_model_serving_environment_oauth():
868
+ if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
769
869
  return None
770
870
 
771
871
  # read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
772
- host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
773
- "DB_MODEL_SERVING_HOST_URL")
774
- token = self.get_model_dependency_oauth_token()
872
+ host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL")
775
873
 
776
- return (host, token)
874
+ if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
875
+ return (host, self._get_invokers_token())
876
+ else:
877
+ return (host, self._get_model_dependency_oauth_token())
777
878
 
778
879
 
779
- @credentials_strategy('model-serving', [])
780
- def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
880
+ def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
781
881
  try:
782
- model_serving_auth_provider = ModelServingAuthProvider()
783
- if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
784
- logger.debug("model-serving: Not in Databricks Model Serving, skipping")
785
- return None
882
+ model_serving_auth_provider = ModelServingAuthProvider(credential_type)
786
883
  host, token = model_serving_auth_provider.get_databricks_host_token()
787
884
  if token is None:
788
885
  raise ValueError(
@@ -791,9 +888,11 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
791
888
  if cfg.host is None:
792
889
  cfg.host = host
793
890
  except Exception as e:
794
- logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
891
+ logger.warning(
892
+ "Unable to get auth from Databricks Model Serving Environment",
893
+ exc_info=e,
894
+ )
795
895
  return None
796
-
797
896
  logger.info("Using Databricks Model Serving Authentication")
798
897
 
799
898
  def inner() -> Dict[str, str]:
@@ -804,21 +903,40 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
804
903
  return inner
805
904
 
806
905
 
906
+ @credentials_strategy("model-serving", [])
907
+ def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
908
+ if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
909
+ logger.debug("model-serving: Not in Databricks Model Serving, skipping")
910
+ return None
911
+
912
+ return model_serving_auth_visitor(cfg)
913
+
914
+
807
915
  class DefaultCredentials:
808
- """ Select the first applicable credential provider from the chain """
916
+ """Select the first applicable credential provider from the chain"""
809
917
 
810
918
  def __init__(self) -> None:
811
- self._auth_type = 'default'
919
+ self._auth_type = "default"
812
920
  self._auth_providers = [
813
- pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
814
- github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
815
- google_credentials, google_id, model_serving_auth
921
+ pat_auth,
922
+ basic_auth,
923
+ metadata_service,
924
+ oauth_service_principal,
925
+ azure_service_principal,
926
+ github_oidc_azure,
927
+ azure_cli,
928
+ external_browser,
929
+ databricks_cli,
930
+ runtime_native_auth,
931
+ google_credentials,
932
+ google_id,
933
+ model_serving_auth,
816
934
  ]
817
935
 
818
936
  def auth_type(self) -> str:
819
937
  return self._auth_type
820
938
 
821
- def oauth_token(self, cfg: 'Config') -> Token:
939
+ def oauth_token(self, cfg: "Config") -> Token:
822
940
  for provider in self._auth_providers:
823
941
  auth_type = provider.auth_type()
824
942
  if auth_type != self._auth_type:
@@ -826,14 +944,14 @@ class DefaultCredentials:
826
944
  continue
827
945
  return provider.oauth_token(cfg)
828
946
 
829
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
947
+ def __call__(self, cfg: "Config") -> CredentialsProvider:
830
948
  for provider in self._auth_providers:
831
949
  auth_type = provider.auth_type()
832
950
  if cfg.auth_type and auth_type != cfg.auth_type:
833
951
  # ignore other auth types if one is explicitly enforced
834
952
  logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
835
953
  continue
836
- logger.debug(f'Attempting to configure auth: {auth_type}')
954
+ logger.debug(f"Attempting to configure auth: {auth_type}")
837
955
  try:
838
956
  header_factory = provider(cfg)
839
957
  if not header_factory:
@@ -841,8 +959,40 @@ class DefaultCredentials:
841
959
  self._auth_type = auth_type
842
960
  return header_factory
843
961
  except Exception as e:
844
- raise ValueError(f'{auth_type}: {e}') from e
962
+ raise ValueError(f"{auth_type}: {e}") from e
845
963
  auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
846
964
  raise ValueError(
847
- f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
965
+ f"cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method."
848
966
  )
967
+
968
+
969
+ class ModelServingUserCredentials(CredentialsStrategy):
970
+ """
971
+ This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
972
+ In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
973
+ In any other environments, the class defaults to the DefaultCredentialStrategy.
974
+ To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
975
+
976
+ invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
977
+ """
978
+
979
+ def __init__(self):
980
+ self.credential_type = ModelServingAuthProvider.USER_CREDENTIALS
981
+ self.default_credentials = DefaultCredentials()
982
+
983
+ def auth_type(self):
984
+ if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
985
+ return "model_serving_" + self.credential_type
986
+ else:
987
+ return self.default_credentials.auth_type()
988
+
989
+ def __call__(self, cfg: "Config") -> CredentialsProvider:
990
+ if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
991
+ header_factory = model_serving_auth_visitor(cfg, self.credential_type)
992
+ if not header_factory:
993
+ raise ValueError(
994
+ f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
995
+ )
996
+ return header_factory
997
+ else:
998
+ return self.default_credentials(cfg)