databricks-sdk 0.44.1__py3-none-any.whl → 0.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (63) hide show
  1. databricks/sdk/__init__.py +123 -115
  2. databricks/sdk/_base_client.py +112 -88
  3. databricks/sdk/_property.py +12 -7
  4. databricks/sdk/_widgets/__init__.py +13 -2
  5. databricks/sdk/_widgets/default_widgets_utils.py +21 -15
  6. databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
  7. databricks/sdk/azure.py +8 -6
  8. databricks/sdk/casing.py +5 -5
  9. databricks/sdk/config.py +152 -99
  10. databricks/sdk/core.py +57 -47
  11. databricks/sdk/credentials_provider.py +300 -205
  12. databricks/sdk/data_plane.py +86 -3
  13. databricks/sdk/dbutils.py +123 -87
  14. databricks/sdk/environments.py +52 -35
  15. databricks/sdk/errors/base.py +61 -35
  16. databricks/sdk/errors/customizer.py +3 -3
  17. databricks/sdk/errors/deserializer.py +38 -25
  18. databricks/sdk/errors/details.py +417 -0
  19. databricks/sdk/errors/mapper.py +1 -1
  20. databricks/sdk/errors/overrides.py +27 -24
  21. databricks/sdk/errors/parser.py +26 -14
  22. databricks/sdk/errors/platform.py +10 -10
  23. databricks/sdk/errors/private_link.py +24 -24
  24. databricks/sdk/logger/round_trip_logger.py +28 -20
  25. databricks/sdk/mixins/compute.py +90 -60
  26. databricks/sdk/mixins/files.py +815 -145
  27. databricks/sdk/mixins/jobs.py +191 -16
  28. databricks/sdk/mixins/open_ai_client.py +26 -20
  29. databricks/sdk/mixins/workspace.py +45 -34
  30. databricks/sdk/oauth.py +372 -196
  31. databricks/sdk/retries.py +14 -12
  32. databricks/sdk/runtime/__init__.py +34 -17
  33. databricks/sdk/runtime/dbutils_stub.py +52 -39
  34. databricks/sdk/service/_internal.py +12 -7
  35. databricks/sdk/service/apps.py +618 -418
  36. databricks/sdk/service/billing.py +827 -604
  37. databricks/sdk/service/catalog.py +6552 -4474
  38. databricks/sdk/service/cleanrooms.py +550 -388
  39. databricks/sdk/service/compute.py +5241 -3531
  40. databricks/sdk/service/dashboards.py +1313 -923
  41. databricks/sdk/service/files.py +442 -309
  42. databricks/sdk/service/iam.py +2115 -1483
  43. databricks/sdk/service/jobs.py +4151 -2588
  44. databricks/sdk/service/marketplace.py +2210 -1517
  45. databricks/sdk/service/ml.py +3364 -2255
  46. databricks/sdk/service/oauth2.py +922 -584
  47. databricks/sdk/service/pipelines.py +1865 -1203
  48. databricks/sdk/service/provisioning.py +1435 -1029
  49. databricks/sdk/service/serving.py +2040 -1278
  50. databricks/sdk/service/settings.py +2846 -1929
  51. databricks/sdk/service/sharing.py +2201 -877
  52. databricks/sdk/service/sql.py +4650 -3103
  53. databricks/sdk/service/vectorsearch.py +816 -550
  54. databricks/sdk/service/workspace.py +1330 -906
  55. databricks/sdk/useragent.py +36 -22
  56. databricks/sdk/version.py +1 -1
  57. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
  58. databricks_sdk-0.45.0.dist-info/RECORD +70 -0
  59. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
  60. databricks_sdk-0.44.1.dist-info/RECORD +0 -69
  61. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
  62. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
  63. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
@@ -26,13 +26,17 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
26
26
 
27
27
  CredentialsProvider = Callable[[], Dict[str, str]]
28
28
 
29
- logger = logging.getLogger('databricks.sdk')
29
+ logger = logging.getLogger("databricks.sdk")
30
30
 
31
31
 
32
32
  class OAuthCredentialsProvider:
33
- """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
33
+ """OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens."""
34
34
 
35
- def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
35
+ def __init__(
36
+ self,
37
+ credentials_provider: CredentialsProvider,
38
+ token_provider: Callable[[], Token],
39
+ ):
36
40
  self._credentials_provider = credentials_provider
37
41
  self._token_provider = token_provider
38
42
 
@@ -44,45 +48,49 @@ class OAuthCredentialsProvider:
44
48
 
45
49
 
46
50
  class CredentialsStrategy(abc.ABC):
47
- """ CredentialsProvider is the protocol (call-side interface)
48
- for authenticating requests to Databricks REST APIs"""
51
+ """CredentialsProvider is the protocol (call-side interface)
52
+ for authenticating requests to Databricks REST APIs"""
49
53
 
50
54
  @abc.abstractmethod
51
- def auth_type(self) -> str:
52
- ...
55
+ def auth_type(self) -> str: ...
53
56
 
54
57
  @abc.abstractmethod
55
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
56
- ...
58
+ def __call__(self, cfg: "Config") -> CredentialsProvider: ...
57
59
 
58
60
 
59
61
  class OauthCredentialsStrategy(CredentialsStrategy):
60
- """ OauthCredentialsProvider is a CredentialsProvider which
62
+ """OauthCredentialsProvider is a CredentialsProvider which
61
63
  supports Oauth tokens"""
62
64
 
63
- def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
65
+ def __init__(
66
+ self,
67
+ auth_type: str,
68
+ headers_provider: Callable[["Config"], OAuthCredentialsProvider],
69
+ ):
64
70
  self._headers_provider = headers_provider
65
71
  self._auth_type = auth_type
66
72
 
67
73
  def auth_type(self) -> str:
68
74
  return self._auth_type
69
75
 
70
- def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
76
+ def __call__(self, cfg: "Config") -> OAuthCredentialsProvider:
71
77
  return self._headers_provider(cfg)
72
78
 
73
- def oauth_token(self, cfg: 'Config') -> Token:
79
+ def oauth_token(self, cfg: "Config") -> Token:
74
80
  return self._headers_provider(cfg).oauth_token()
75
81
 
76
82
 
77
83
  def credentials_strategy(name: str, require: List[str]):
78
- """ Given the function that receives a Config and returns RequestVisitor,
84
+ """Given the function that receives a Config and returns RequestVisitor,
79
85
  create CredentialsProvider with a given name and required configuration
80
- attribute names to be present for this function to be called. """
86
+ attribute names to be present for this function to be called."""
81
87
 
82
- def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
88
+ def inner(
89
+ func: Callable[["Config"], CredentialsProvider],
90
+ ) -> CredentialsStrategy:
83
91
 
84
92
  @functools.wraps(func)
85
- def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
93
+ def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
86
94
  for attr in require:
87
95
  getattr(cfg, attr)
88
96
  if not getattr(cfg, attr):
@@ -96,14 +104,16 @@ def credentials_strategy(name: str, require: List[str]):
96
104
 
97
105
 
98
106
  def oauth_credentials_strategy(name: str, require: List[str]):
99
- """ Given the function that receives a Config and returns an OauthHeaderFactory,
107
+ """Given the function that receives a Config and returns an OauthHeaderFactory,
100
108
  create an OauthCredentialsProvider with a given name and required configuration
101
- attribute names to be present for this function to be called. """
109
+ attribute names to be present for this function to be called."""
102
110
 
103
- def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
111
+ def inner(
112
+ func: Callable[["Config"], OAuthCredentialsProvider],
113
+ ) -> OauthCredentialsStrategy:
104
114
 
105
115
  @functools.wraps(func)
106
- def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
116
+ def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
107
117
  for attr in require:
108
118
  if not getattr(cfg, attr):
109
119
  return None
@@ -114,11 +124,11 @@ def oauth_credentials_strategy(name: str, require: List[str]):
114
124
  return inner
115
125
 
116
126
 
117
- @credentials_strategy('basic', ['host', 'username', 'password'])
118
- def basic_auth(cfg: 'Config') -> CredentialsProvider:
119
- """ Given username and password, add base64-encoded Basic credentials """
120
- encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
121
- static_credentials = {'Authorization': f'Basic {encoded}'}
127
+ @credentials_strategy("basic", ["host", "username", "password"])
128
+ def basic_auth(cfg: "Config") -> CredentialsProvider:
129
+ """Given username and password, add base64-encoded Basic credentials"""
130
+ encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode()
131
+ static_credentials = {"Authorization": f"Basic {encoded}"}
122
132
 
123
133
  def inner() -> Dict[str, str]:
124
134
  return static_credentials
@@ -126,10 +136,10 @@ def basic_auth(cfg: 'Config') -> CredentialsProvider:
126
136
  return inner
127
137
 
128
138
 
129
- @credentials_strategy('pat', ['host', 'token'])
130
- def pat_auth(cfg: 'Config') -> CredentialsProvider:
131
- """ Adds Databricks Personal Access Token to every request """
132
- static_credentials = {'Authorization': f'Bearer {cfg.token}'}
139
+ @credentials_strategy("pat", ["host", "token"])
140
+ def pat_auth(cfg: "Config") -> CredentialsProvider:
141
+ """Adds Databricks Personal Access Token to every request"""
142
+ static_credentials = {"Authorization": f"Bearer {cfg.token}"}
133
143
 
134
144
  def inner() -> Dict[str, str]:
135
145
  return static_credentials
@@ -137,9 +147,9 @@ def pat_auth(cfg: 'Config') -> CredentialsProvider:
137
147
  return inner
138
148
 
139
149
 
140
- @credentials_strategy('runtime', [])
141
- def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
142
- if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
150
+ @credentials_strategy("runtime", [])
151
+ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
152
+ if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
143
153
  return None
144
154
 
145
155
  # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
@@ -148,36 +158,44 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
148
158
  from databricks.sdk.runtime import (init_runtime_legacy_auth,
149
159
  init_runtime_native_auth,
150
160
  init_runtime_repl_auth)
151
- for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]:
161
+
162
+ for init in [
163
+ init_runtime_native_auth,
164
+ init_runtime_repl_auth,
165
+ init_runtime_legacy_auth,
166
+ ]:
152
167
  if init is None:
153
168
  continue
154
169
  host, inner = init()
155
170
  if host is None:
156
- logger.debug(f'[{init.__name__}] no host detected')
171
+ logger.debug(f"[{init.__name__}] no host detected")
157
172
  continue
158
173
  cfg.host = host
159
- logger.debug(f'[{init.__name__}] runtime native auth configured')
174
+ logger.debug(f"[{init.__name__}] runtime native auth configured")
160
175
  return inner
161
176
  return None
162
177
 
163
178
 
164
- @oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
165
- def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
166
- """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
167
- if /oidc/.well-known/oauth-authorization-server is available on the given host. """
179
+ @oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
180
+ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
181
+ """Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
182
+ if /oidc/.well-known/oauth-authorization-server is available on the given host.
183
+ """
168
184
  oidc = cfg.oidc_endpoints
169
185
  if oidc is None:
170
186
  return None
171
187
 
172
- token_source = ClientCredentials(client_id=cfg.client_id,
173
- client_secret=cfg.client_secret,
174
- token_url=oidc.token_endpoint,
175
- scopes=["all-apis"],
176
- use_header=True)
188
+ token_source = ClientCredentials(
189
+ client_id=cfg.client_id,
190
+ client_secret=cfg.client_secret,
191
+ token_url=oidc.token_endpoint,
192
+ scopes=["all-apis"],
193
+ use_header=True,
194
+ )
177
195
 
178
196
  def inner() -> Dict[str, str]:
179
197
  token = token_source.token()
180
- return {'Authorization': f'{token.token_type} {token.access_token}'}
198
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
181
199
 
182
200
  def token() -> Token:
183
201
  return token_source.token()
@@ -185,9 +203,9 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
185
203
  return OAuthCredentialsProvider(inner, token)
186
204
 
187
205
 
188
- @credentials_strategy('external-browser', ['host', 'auth_type'])
189
- def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
190
- if cfg.auth_type != 'external-browser':
206
+ @credentials_strategy("external-browser", ["host", "auth_type"])
207
+ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
208
+ if cfg.auth_type != "external-browser":
191
209
  return None
192
210
 
193
211
  client_id, client_secret = None, None
@@ -198,17 +216,19 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
198
216
  client_id = cfg.azure_client
199
217
  client_secret = cfg.azure_client_secret
200
218
  if not client_id:
201
- client_id = 'databricks-cli'
219
+ client_id = "databricks-cli"
202
220
 
203
221
  # Load cached credentials from disk if they exist. Note that these are
204
222
  # local to the Python SDK and not reused by other SDKs.
205
223
  oidc_endpoints = cfg.oidc_endpoints
206
- redirect_url = 'http://localhost:8020'
207
- token_cache = TokenCache(host=cfg.host,
208
- oidc_endpoints=oidc_endpoints,
209
- client_id=client_id,
210
- client_secret=client_secret,
211
- redirect_url=redirect_url)
224
+ redirect_url = "http://localhost:8020"
225
+ token_cache = TokenCache(
226
+ host=cfg.host,
227
+ oidc_endpoints=oidc_endpoints,
228
+ client_id=client_id,
229
+ client_secret=client_secret,
230
+ redirect_url=redirect_url,
231
+ )
212
232
  credentials = token_cache.load()
213
233
  if credentials:
214
234
  try:
@@ -219,12 +239,14 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
219
239
  return credentials(cfg)
220
240
  # TODO: We should ideally use more specific exceptions.
221
241
  except Exception as e:
222
- logger.warning(f'Failed to refresh cached token: {e}. Initiating new OAuth login flow')
223
-
224
- oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
225
- client_id=client_id,
226
- redirect_url=redirect_url,
227
- client_secret=client_secret)
242
+ logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow")
243
+
244
+ oauth_client = OAuthClient(
245
+ oidc_endpoints=oidc_endpoints,
246
+ client_id=client_id,
247
+ redirect_url=redirect_url,
248
+ client_secret=client_secret,
249
+ )
228
250
  consent = oauth_client.initiate_consent()
229
251
  if not consent:
230
252
  return None
@@ -234,33 +256,41 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
234
256
  return credentials(cfg)
235
257
 
236
258
 
237
- def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]):
238
- """ Resolves Azure Databricks workspace URL from ARM Resource ID """
259
+ def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]):
260
+ """Resolves Azure Databricks workspace URL from ARM Resource ID"""
239
261
  if cfg.host:
240
262
  return
241
263
  if not cfg.azure_workspace_resource_id:
242
264
  return
243
265
  arm = cfg.arm_environment.resource_manager_endpoint
244
266
  token = token_source_for(arm).token()
245
- resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
246
- headers={"Authorization": f"Bearer {token.access_token}"})
267
+ resp = requests.get(
268
+ f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
269
+ headers={"Authorization": f"Bearer {token.access_token}"},
270
+ )
247
271
  if not resp.ok:
248
272
  raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
249
273
  cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
250
274
 
251
275
 
252
- @oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
253
- def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
254
- """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
255
- to every request, while automatically resolving different Azure environment endpoints. """
276
+ @oauth_credentials_strategy(
277
+ "azure-client-secret",
278
+ ["is_azure", "azure_client_id", "azure_client_secret"],
279
+ )
280
+ def azure_service_principal(cfg: "Config") -> CredentialsProvider:
281
+ """Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
282
+ to every request, while automatically resolving different Azure environment endpoints.
283
+ """
256
284
 
257
285
  def token_source_for(resource: str) -> TokenSource:
258
286
  aad_endpoint = cfg.arm_environment.active_directory_endpoint
259
- return ClientCredentials(client_id=cfg.azure_client_id,
260
- client_secret=cfg.azure_client_secret,
261
- token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
262
- endpoint_params={"resource": resource},
263
- use_params=True)
287
+ return ClientCredentials(
288
+ client_id=cfg.azure_client_id,
289
+ client_secret=cfg.azure_client_secret,
290
+ token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
291
+ endpoint_params={"resource": resource},
292
+ use_params=True,
293
+ )
264
294
 
265
295
  _ensure_host_present(cfg, token_source_for)
266
296
  cfg.load_azure_tenant_id()
@@ -269,7 +299,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
269
299
  cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
270
300
 
271
301
  def refreshed_headers() -> Dict[str, str]:
272
- headers = {'Authorization': f"Bearer {inner.token().access_token}", }
302
+ headers = {
303
+ "Authorization": f"Bearer {inner.token().access_token}",
304
+ }
273
305
  add_workspace_id_header(cfg, headers)
274
306
  add_sp_management_token(cloud, headers)
275
307
  return headers
@@ -280,9 +312,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
280
312
  return OAuthCredentialsProvider(refreshed_headers, token)
281
313
 
282
314
 
283
- @oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
284
- def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
285
- if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
315
+ @oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
316
+ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
317
+ if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
286
318
  # not in GitHub actions
287
319
  return None
288
320
 
@@ -292,7 +324,7 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
292
324
  return None
293
325
 
294
326
  # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
295
- headers = {'Authorization': f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
327
+ headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
296
328
  endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
297
329
  response = requests.get(endpoint, headers=headers)
298
330
  if not response.ok:
@@ -300,30 +332,34 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
300
332
 
301
333
  # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
302
334
  response_json = response.json()
303
- if 'value' not in response_json:
335
+ if "value" not in response_json:
304
336
  return None
305
337
 
306
- logger.info("Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id)
338
+ logger.info(
339
+ "Configured AAD token for GitHub Actions OIDC (%s)",
340
+ cfg.azure_client_id,
341
+ )
307
342
  params = {
308
- 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
309
- 'resource': cfg.effective_azure_login_app_id,
310
- 'client_assertion': response_json['value'],
343
+ "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
344
+ "resource": cfg.effective_azure_login_app_id,
345
+ "client_assertion": response_json["value"],
311
346
  }
312
347
  aad_endpoint = cfg.arm_environment.active_directory_endpoint
313
348
  if not cfg.azure_tenant_id:
314
349
  # detect Azure AD Tenant ID if it's not specified directly
315
350
  token_endpoint = cfg.oidc_endpoints.token_endpoint
316
- cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
351
+ cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
317
352
  inner = ClientCredentials(
318
353
  client_id=cfg.azure_client_id,
319
- client_secret="", # we have no (rotatable) secrets in OIDC flow
354
+ client_secret="", # we have no (rotatable) secrets in OIDC flow
320
355
  token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
321
356
  endpoint_params=params,
322
- use_params=True)
357
+ use_params=True,
358
+ )
323
359
 
324
360
  def refreshed_headers() -> Dict[str, str]:
325
361
  token = inner.token()
326
- return {'Authorization': f'{token.token_type} {token.access_token}'}
362
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
327
363
 
328
364
  def token() -> Token:
329
365
  return inner.token()
@@ -331,29 +367,32 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
331
367
  return OAuthCredentialsProvider(refreshed_headers, token)
332
368
 
333
369
 
334
- GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
370
+ GcpScopes = [
371
+ "https://www.googleapis.com/auth/cloud-platform",
372
+ "https://www.googleapis.com/auth/compute",
373
+ ]
335
374
 
336
375
 
337
- @oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
338
- def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
376
+ @oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
377
+ def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
339
378
  if not cfg.is_gcp:
340
379
  return None
341
380
  # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
342
381
  # Obtain the id token by providing the json file path and target audience.
343
- if (os.path.isfile(cfg.google_credentials)):
382
+ if os.path.isfile(cfg.google_credentials):
344
383
  with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
345
384
  account_info = json.load(json_file)
346
385
  else:
347
386
  # If the file doesn't exist, assume that the config is the actual JSON content.
348
387
  account_info = json.loads(cfg.google_credentials)
349
388
 
350
- credentials = service_account.IDTokenCredentials.from_service_account_info(info=account_info,
351
- target_audience=cfg.host)
389
+ credentials = service_account.IDTokenCredentials.from_service_account_info(
390
+ info=account_info, target_audience=cfg.host
391
+ )
352
392
 
353
393
  request = Request()
354
394
 
355
- gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
356
- scopes=GcpScopes)
395
+ gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes)
357
396
 
358
397
  def token() -> Token:
359
398
  credentials.refresh(request)
@@ -361,7 +400,7 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
361
400
 
362
401
  def refreshed_headers() -> Dict[str, str]:
363
402
  credentials.refresh(request)
364
- headers = {'Authorization': f'Bearer {credentials.token}'}
403
+ headers = {"Authorization": f"Bearer {credentials.token}"}
365
404
  if cfg.is_account_client:
366
405
  gcp_credentials.refresh(request)
367
406
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
@@ -370,24 +409,29 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
370
409
  return OAuthCredentialsProvider(refreshed_headers, token)
371
410
 
372
411
 
373
- @oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
374
- def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
412
+ @oauth_credentials_strategy("google-id", ["host", "google_service_account"])
413
+ def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
375
414
  if not cfg.is_gcp:
376
415
  return None
377
416
  credentials, _project_id = google.auth.default()
378
417
 
379
418
  # Create the impersonated credential.
380
- target_credentials = impersonated_credentials.Credentials(source_credentials=credentials,
381
- target_principal=cfg.google_service_account,
382
- target_scopes=[])
419
+ target_credentials = impersonated_credentials.Credentials(
420
+ source_credentials=credentials,
421
+ target_principal=cfg.google_service_account,
422
+ target_scopes=[],
423
+ )
383
424
 
384
425
  # Set the impersonated credential, target audience and token options.
385
- id_creds = impersonated_credentials.IDTokenCredentials(target_credentials,
386
- target_audience=cfg.host,
387
- include_email=True)
426
+ id_creds = impersonated_credentials.IDTokenCredentials(
427
+ target_credentials, target_audience=cfg.host, include_email=True
428
+ )
388
429
 
389
430
  gcp_impersonated_credentials = impersonated_credentials.Credentials(
390
- source_credentials=credentials, target_principal=cfg.google_service_account, target_scopes=GcpScopes)
431
+ source_credentials=credentials,
432
+ target_principal=cfg.google_service_account,
433
+ target_scopes=GcpScopes,
434
+ )
391
435
 
392
436
  request = Request()
393
437
 
@@ -397,7 +441,7 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
397
441
 
398
442
  def refreshed_headers() -> Dict[str, str]:
399
443
  id_creds.refresh(request)
400
- headers = {'Authorization': f'Bearer {id_creds.token}'}
444
+ headers = {"Authorization": f"Bearer {id_creds.token}"}
401
445
  if cfg.is_account_client:
402
446
  gcp_impersonated_credentials.refresh(request)
403
447
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
@@ -408,7 +452,13 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
408
452
 
409
453
  class CliTokenSource(Refreshable):
410
454
 
411
- def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str):
455
+ def __init__(
456
+ self,
457
+ cmd: List[str],
458
+ token_type_field: str,
459
+ access_token_field: str,
460
+ expiry_field: str,
461
+ ):
412
462
  super().__init__()
413
463
  self._cmd = cmd
414
464
  self._token_type_field = token_type_field
@@ -431,52 +481,74 @@ class CliTokenSource(Refreshable):
431
481
  out = _run_subprocess(self._cmd, capture_output=True, check=True)
432
482
  it = json.loads(out.stdout.decode())
433
483
  expires_on = self._parse_expiry(it[self._expiry_field])
434
- return Token(access_token=it[self._access_token_field],
435
- token_type=it[self._token_type_field],
436
- expiry=expires_on)
484
+ return Token(
485
+ access_token=it[self._access_token_field],
486
+ token_type=it[self._token_type_field],
487
+ expiry=expires_on,
488
+ )
437
489
  except ValueError as e:
438
490
  raise ValueError(f"cannot unmarshal CLI result: {e}")
439
491
  except subprocess.CalledProcessError as e:
440
492
  stdout = e.stdout.decode().strip()
441
493
  stderr = e.stderr.decode().strip()
442
494
  message = stdout or stderr
443
- raise IOError(f'cannot get access token: {message}') from e
495
+ raise IOError(f"cannot get access token: {message}") from e
444
496
 
445
497
 
446
- def _run_subprocess(popenargs,
447
- input=None,
448
- capture_output=True,
449
- timeout=None,
450
- check=False,
451
- **kwargs) -> subprocess.CompletedProcess:
498
+ def _run_subprocess(
499
+ popenargs,
500
+ input=None,
501
+ capture_output=True,
502
+ timeout=None,
503
+ check=False,
504
+ **kwargs,
505
+ ) -> subprocess.CompletedProcess:
452
506
  """Runs subprocess with given arguments.
453
- This handles OS-specific modifications that need to be made to the invocation of subprocess.run."""
454
- kwargs['shell'] = sys.platform.startswith('win')
507
+ This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
508
+ """
509
+ kwargs["shell"] = sys.platform.startswith("win")
455
510
  # windows requires shell=True to be able to execute 'az login' or other commands
456
511
  # cannot use shell=True all the time, as it breaks macOS
457
512
  logging.debug(f'Running command: {" ".join(popenargs)}')
458
- return subprocess.run(popenargs,
459
- input=input,
460
- capture_output=capture_output,
461
- timeout=timeout,
462
- check=check,
463
- **kwargs)
513
+ return subprocess.run(
514
+ popenargs,
515
+ input=input,
516
+ capture_output=capture_output,
517
+ timeout=timeout,
518
+ check=check,
519
+ **kwargs,
520
+ )
464
521
 
465
522
 
466
523
  class AzureCliTokenSource(CliTokenSource):
467
- """ Obtain the token granted by `az login` CLI command """
468
-
469
- def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
470
- cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
524
+ """Obtain the token granted by `az login` CLI command"""
525
+
526
+ def __init__(
527
+ self,
528
+ resource: str,
529
+ subscription: Optional[str] = None,
530
+ tenant: Optional[str] = None,
531
+ ):
532
+ cmd = [
533
+ "az",
534
+ "account",
535
+ "get-access-token",
536
+ "--resource",
537
+ resource,
538
+ "--output",
539
+ "json",
540
+ ]
471
541
  if subscription is not None:
472
542
  cmd.append("--subscription")
473
543
  cmd.append(subscription)
474
544
  if tenant and not self.__is_cli_using_managed_identity():
475
545
  cmd.extend(["--tenant", tenant])
476
- super().__init__(cmd=cmd,
477
- token_type_field='tokenType',
478
- access_token_field='accessToken',
479
- expiry_field='expiresOn')
546
+ super().__init__(
547
+ cmd=cmd,
548
+ token_type_field="tokenType",
549
+ access_token_field="accessToken",
550
+ expiry_field="expiresOn",
551
+ )
480
552
 
481
553
  @staticmethod
482
554
  def __is_cli_using_managed_identity() -> bool:
@@ -489,7 +561,8 @@ class AzureCliTokenSource(CliTokenSource):
489
561
  if user is None:
490
562
  return False
491
563
  return user.get("type") == "servicePrincipal" and user.get("name") in [
492
- 'systemAssignedIdentity', 'userAssignedIdentity'
564
+ "systemAssignedIdentity",
565
+ "userAssignedIdentity",
493
566
  ]
494
567
  except subprocess.CalledProcessError as e:
495
568
  logger.debug("Failed to get account information from Azure CLI", exc_info=e)
@@ -512,15 +585,13 @@ class AzureCliTokenSource(CliTokenSource):
512
585
  guaranteed to be unique within a tenant and should be used only for display purposes.
513
586
  - 'upn' - The username of the user.
514
587
  """
515
- return 'upn' in self.token().jwt_claims()
588
+ return "upn" in self.token().jwt_claims()
516
589
 
517
590
  @staticmethod
518
- def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
591
+ def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource":
519
592
  subscription = AzureCliTokenSource.get_subscription(cfg)
520
593
  if subscription is not None:
521
- token_source = AzureCliTokenSource(resource,
522
- subscription=subscription,
523
- tenant=cfg.azure_tenant_id)
594
+ token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
524
595
  try:
525
596
  # This will fail if the user has access to the workspace, but not to the subscription
526
597
  # itself.
@@ -535,32 +606,32 @@ class AzureCliTokenSource(CliTokenSource):
535
606
  return token_source
536
607
 
537
608
  @staticmethod
538
- def get_subscription(cfg: 'Config') -> Optional[str]:
609
+ def get_subscription(cfg: "Config") -> Optional[str]:
539
610
  resource = cfg.azure_workspace_resource_id
540
611
  if resource is None or resource == "":
541
612
  return None
542
- components = resource.split('/')
613
+ components = resource.split("/")
543
614
  if len(components) < 3:
544
615
  logger.warning("Invalid azure workspace resource ID")
545
616
  return None
546
617
  return components[2]
547
618
 
548
619
 
549
- @credentials_strategy('azure-cli', ['is_azure'])
550
- def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
551
- """ Adds refreshed OAuth token granted by `az login` command to every request. """
620
+ @credentials_strategy("azure-cli", ["is_azure"])
621
+ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
622
+ """Adds refreshed OAuth token granted by `az login` command to every request."""
552
623
  cfg.load_azure_tenant_id()
553
624
  token_source = None
554
625
  mgmt_token_source = None
555
626
  try:
556
627
  token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
557
628
  except FileNotFoundError:
558
- doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
559
- logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
629
+ doc = "https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest"
630
+ logger.debug(f"Most likely Azure CLI is not installed. See {doc} for details")
560
631
  return None
561
632
  except OSError as e:
562
- logger.debug('skipping Azure CLI auth', exc_info=e)
563
- logger.debug('This may happen if you are attempting to login to a dev or staging workspace')
633
+ logger.debug("skipping Azure CLI auth", exc_info=e)
634
+ logger.debug("This may happen if you are attempting to login to a dev or staging workspace")
564
635
  return None
565
636
 
566
637
  if not token_source.is_human_user():
@@ -568,7 +639,10 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
568
639
  management_endpoint = cfg.arm_environment.service_management_endpoint
569
640
  mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
570
641
  except Exception as e:
571
- logger.debug(f'Not including service management token in headers', exc_info=e)
642
+ logger.debug(
643
+ f"Not including service management token in headers",
644
+ exc_info=e,
645
+ )
572
646
  mgmt_token_source = None
573
647
 
574
648
  _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
@@ -576,7 +650,7 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
576
650
 
577
651
  def inner() -> Dict[str, str]:
578
652
  token = token_source.token()
579
- headers = {'Authorization': f'{token.token_type} {token.access_token}'}
653
+ headers = {"Authorization": f"{token.token_type} {token.access_token}"}
580
654
  add_workspace_id_header(cfg, headers)
581
655
  if mgmt_token_source:
582
656
  add_sp_management_token(mgmt_token_source, headers)
@@ -586,12 +660,12 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
586
660
 
587
661
 
588
662
  class DatabricksCliTokenSource(CliTokenSource):
589
- """ Obtain the token granted by `databricks auth login` CLI command """
663
+ """Obtain the token granted by `databricks auth login` CLI command"""
590
664
 
591
- def __init__(self, cfg: 'Config'):
592
- args = ['auth', 'token', '--host', cfg.host]
665
+ def __init__(self, cfg: "Config"):
666
+ args = ["auth", "token", "--host", cfg.host]
593
667
  if cfg.is_account_client:
594
- args += ['--account-id', cfg.account_id]
668
+ args += ["--account-id", cfg.account_id]
595
669
 
596
670
  cli_path = cfg.databricks_cli_path
597
671
 
@@ -611,10 +685,12 @@ class DatabricksCliTokenSource(CliTokenSource):
611
685
  elif cli_path.count("/") == 0:
612
686
  cli_path = self.__class__._find_executable(cli_path)
613
687
 
614
- super().__init__(cmd=[cli_path, *args],
615
- token_type_field='token_type',
616
- access_token_field='access_token',
617
- expiry_field='expiry')
688
+ super().__init__(
689
+ cmd=[cli_path, *args],
690
+ token_type_field="token_type",
691
+ access_token_field="access_token",
692
+ expiry_field="expiry",
693
+ )
618
694
 
619
695
  @staticmethod
620
696
  def _find_executable(name) -> str:
@@ -636,8 +712,8 @@ class DatabricksCliTokenSource(CliTokenSource):
636
712
  raise err
637
713
 
638
714
 
639
- @oauth_credentials_strategy('databricks-cli', ['host'])
640
- def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
715
+ @oauth_credentials_strategy("databricks-cli", ["host"])
716
+ def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]:
641
717
  try:
642
718
  token_source = DatabricksCliTokenSource(cfg)
643
719
  except FileNotFoundError as e:
@@ -647,8 +723,8 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
647
723
  try:
648
724
  token_source.token()
649
725
  except IOError as e:
650
- if 'databricks OAuth is not' in str(e):
651
- logger.debug(f'OAuth not configured or not available: {e}')
726
+ if "databricks OAuth is not" in str(e):
727
+ logger.debug(f"OAuth not configured or not available: {e}")
652
728
  return None
653
729
  raise e
654
730
 
@@ -656,7 +732,7 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
656
732
 
657
733
  def inner() -> Dict[str, str]:
658
734
  token = token_source.token()
659
- return {'Authorization': f'{token.token_type} {token.access_token}'}
735
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
660
736
 
661
737
  def token() -> Token:
662
738
  return token_source.token()
@@ -665,13 +741,14 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
665
741
 
666
742
 
667
743
  class MetadataServiceTokenSource(Refreshable):
668
- """ Obtain the token granted by Databricks Metadata Service """
744
+ """Obtain the token granted by Databricks Metadata Service"""
745
+
669
746
  METADATA_SERVICE_VERSION = "1"
670
747
  METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
671
748
  METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
672
- _metadata_service_timeout = 10 # seconds
749
+ _metadata_service_timeout = 10 # seconds
673
750
 
674
- def __init__(self, cfg: 'Config'):
751
+ def __init__(self, cfg: "Config"):
675
752
  super().__init__()
676
753
  self.url = cfg.metadata_service_url
677
754
  self.host = cfg.host
@@ -682,13 +759,14 @@ class MetadataServiceTokenSource(Refreshable):
682
759
  timeout=self._metadata_service_timeout,
683
760
  headers={
684
761
  self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
685
- self.METADATA_SERVICE_HOST_HEADER: self.host
762
+ self.METADATA_SERVICE_HOST_HEADER: self.host,
686
763
  },
687
764
  proxies={
688
765
  # Explicitly exclude localhost from being proxied. This is necessary
689
766
  # for Metadata URLs which typically point to localhost.
690
767
  "no_proxy": "localhost,127.0.0.1"
691
- })
768
+ },
769
+ )
692
770
  json_resp: dict[str, Union[str, float]] = resp.json()
693
771
  access_token = json_resp.get("access_token", None)
694
772
  if access_token is None:
@@ -706,9 +784,9 @@ class MetadataServiceTokenSource(Refreshable):
706
784
  return Token(access_token=access_token, token_type=token_type, expiry=expiry)
707
785
 
708
786
 
709
- @credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
710
- def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
711
- """ Adds refreshed token granted by Databricks Metadata Service to every request. """
787
+ @credentials_strategy("metadata-service", ["host", "metadata_service_url"])
788
+ def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]:
789
+ """Adds refreshed token granted by Databricks Metadata Service to every request."""
712
790
 
713
791
  token_source = MetadataServiceTokenSource(cfg)
714
792
  token_source.token()
@@ -716,14 +794,14 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
716
794
 
717
795
  def inner() -> Dict[str, str]:
718
796
  token = token_source.token()
719
- return {'Authorization': f'{token.token_type} {token.access_token}'}
797
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
720
798
 
721
799
  return inner
722
800
 
723
801
 
724
802
  # This Code is derived from Mlflow DatabricksModelServingConfigProvider
725
803
  # https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
726
- class ModelServingAuthProvider():
804
+ class ModelServingAuthProvider:
727
805
  USER_CREDENTIALS = "user_credentials"
728
806
 
729
807
  _MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
@@ -731,7 +809,7 @@ class ModelServingAuthProvider():
731
809
  def __init__(self, credential_type: Optional[str]):
732
810
  self.expiry_time = -1
733
811
  self.current_token = None
734
- self.refresh_duration = 300 # 300 Seconds
812
+ self.refresh_duration = 300 # 300 Seconds
735
813
  self.credential_type = credential_type
736
814
 
737
815
  def should_fetch_model_serving_environment_oauth() -> bool:
@@ -740,10 +818,14 @@ class ModelServingAuthProvider():
740
818
  Additionally check if the oauth token file path exists
741
819
  """
742
820
 
743
- is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
744
- or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
745
- return (is_in_model_serving_env == "true"
746
- and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
821
+ is_in_model_serving_env = (
822
+ os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
823
+ or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
824
+ or "false"
825
+ )
826
+ return is_in_model_serving_env == "true" and os.path.isfile(
827
+ ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH
828
+ )
747
829
 
748
830
  def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
749
831
  # Use Cached value if it is valid
@@ -758,8 +840,10 @@ class ModelServingAuthProvider():
758
840
  except Exception as e:
759
841
  # sleep and retry in case of any race conditions with OAuth refreshing
760
842
  if should_retry:
761
- logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
762
- exc_info=e)
843
+ logger.warning(
844
+ "Unable to read oauth token on first attmept in Model Serving Environment",
845
+ exc_info=e,
846
+ )
763
847
  time.sleep(0.5)
764
848
  return self._get_model_dependency_oauth_token(should_retry=False)
765
849
  else:
@@ -769,8 +853,8 @@ class ModelServingAuthProvider():
769
853
  return self.current_token
770
854
 
771
855
  def _get_invokers_token(self):
772
- current_thread = threading.current_thread()
773
- thread_data = current_thread.__dict__
856
+ main_thread = threading.main_thread()
857
+ thread_data = main_thread.__dict__
774
858
  invokers_token = None
775
859
  if "invokers_token" in thread_data:
776
860
  invokers_token = thread_data["invokers_token"]
@@ -785,8 +869,7 @@ class ModelServingAuthProvider():
785
869
  return None
786
870
 
787
871
  # read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
788
- host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
789
- "DB_MODEL_SERVING_HOST_URL")
872
+ host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL")
790
873
 
791
874
  if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
792
875
  return (host, self._get_invokers_token())
@@ -794,8 +877,7 @@ class ModelServingAuthProvider():
794
877
  return (host, self._get_model_dependency_oauth_token())
795
878
 
796
879
 
797
- def model_serving_auth_visitor(cfg: 'Config',
798
- credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
880
+ def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
799
881
  try:
800
882
  model_serving_auth_provider = ModelServingAuthProvider(credential_type)
801
883
  host, token = model_serving_auth_provider.get_databricks_host_token()
@@ -806,7 +888,10 @@ def model_serving_auth_visitor(cfg: 'Config',
806
888
  if cfg.host is None:
807
889
  cfg.host = host
808
890
  except Exception as e:
809
- logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
891
+ logger.warning(
892
+ "Unable to get auth from Databricks Model Serving Environment",
893
+ exc_info=e,
894
+ )
810
895
  return None
811
896
  logger.info("Using Databricks Model Serving Authentication")
812
897
 
@@ -818,8 +903,8 @@ def model_serving_auth_visitor(cfg: 'Config',
818
903
  return inner
819
904
 
820
905
 
821
- @credentials_strategy('model-serving', [])
822
- def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
906
+ @credentials_strategy("model-serving", [])
907
+ def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
823
908
  if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
824
909
  logger.debug("model-serving: Not in Databricks Model Serving, skipping")
825
910
  return None
@@ -828,20 +913,30 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
828
913
 
829
914
 
830
915
  class DefaultCredentials:
831
- """ Select the first applicable credential provider from the chain """
916
+ """Select the first applicable credential provider from the chain"""
832
917
 
833
918
  def __init__(self) -> None:
834
- self._auth_type = 'default'
919
+ self._auth_type = "default"
835
920
  self._auth_providers = [
836
- pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
837
- github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
838
- google_credentials, google_id, model_serving_auth
921
+ pat_auth,
922
+ basic_auth,
923
+ metadata_service,
924
+ oauth_service_principal,
925
+ azure_service_principal,
926
+ github_oidc_azure,
927
+ azure_cli,
928
+ external_browser,
929
+ databricks_cli,
930
+ runtime_native_auth,
931
+ google_credentials,
932
+ google_id,
933
+ model_serving_auth,
839
934
  ]
840
935
 
841
936
  def auth_type(self) -> str:
842
937
  return self._auth_type
843
938
 
844
- def oauth_token(self, cfg: 'Config') -> Token:
939
+ def oauth_token(self, cfg: "Config") -> Token:
845
940
  for provider in self._auth_providers:
846
941
  auth_type = provider.auth_type()
847
942
  if auth_type != self._auth_type:
@@ -849,14 +944,14 @@ class DefaultCredentials:
849
944
  continue
850
945
  return provider.oauth_token(cfg)
851
946
 
852
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
947
+ def __call__(self, cfg: "Config") -> CredentialsProvider:
853
948
  for provider in self._auth_providers:
854
949
  auth_type = provider.auth_type()
855
950
  if cfg.auth_type and auth_type != cfg.auth_type:
856
951
  # ignore other auth types if one is explicitly enforced
857
952
  logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
858
953
  continue
859
- logger.debug(f'Attempting to configure auth: {auth_type}')
954
+ logger.debug(f"Attempting to configure auth: {auth_type}")
860
955
  try:
861
956
  header_factory = provider(cfg)
862
957
  if not header_factory:
@@ -864,18 +959,18 @@ class DefaultCredentials:
864
959
  self._auth_type = auth_type
865
960
  return header_factory
866
961
  except Exception as e:
867
- raise ValueError(f'{auth_type}: {e}') from e
962
+ raise ValueError(f"{auth_type}: {e}") from e
868
963
  auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
869
964
  raise ValueError(
870
- f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
965
+ f"cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method."
871
966
  )
872
967
 
873
968
 
874
969
  class ModelServingUserCredentials(CredentialsStrategy):
875
970
  """
876
- This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
877
- In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
878
- In any other environments, the class defaults to the DefaultCredentialStrategy.
971
+ This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
972
+ In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
973
+ In any other environments, the class defaults to the DefaultCredentialStrategy.
879
974
  To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
880
975
 
881
976
  invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
@@ -891,7 +986,7 @@ class ModelServingUserCredentials(CredentialsStrategy):
891
986
  else:
892
987
  return self.default_credentials.auth_type()
893
988
 
894
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
989
+ def __call__(self, cfg: "Config") -> CredentialsProvider:
895
990
  if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
896
991
  header_factory = model_serving_auth_visitor(cfg, self.credential_type)
897
992
  if not header_factory: