databricks-sdk 0.44.1__py3-none-any.whl → 0.46.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (63) hide show
  1. databricks/sdk/__init__.py +135 -116
  2. databricks/sdk/_base_client.py +112 -88
  3. databricks/sdk/_property.py +12 -7
  4. databricks/sdk/_widgets/__init__.py +13 -2
  5. databricks/sdk/_widgets/default_widgets_utils.py +21 -15
  6. databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
  7. databricks/sdk/azure.py +8 -6
  8. databricks/sdk/casing.py +5 -5
  9. databricks/sdk/config.py +156 -99
  10. databricks/sdk/core.py +57 -47
  11. databricks/sdk/credentials_provider.py +306 -206
  12. databricks/sdk/data_plane.py +75 -50
  13. databricks/sdk/dbutils.py +123 -87
  14. databricks/sdk/environments.py +52 -35
  15. databricks/sdk/errors/base.py +61 -35
  16. databricks/sdk/errors/customizer.py +3 -3
  17. databricks/sdk/errors/deserializer.py +38 -25
  18. databricks/sdk/errors/details.py +417 -0
  19. databricks/sdk/errors/mapper.py +1 -1
  20. databricks/sdk/errors/overrides.py +27 -24
  21. databricks/sdk/errors/parser.py +26 -14
  22. databricks/sdk/errors/platform.py +10 -10
  23. databricks/sdk/errors/private_link.py +24 -24
  24. databricks/sdk/logger/round_trip_logger.py +28 -20
  25. databricks/sdk/mixins/compute.py +90 -60
  26. databricks/sdk/mixins/files.py +815 -145
  27. databricks/sdk/mixins/jobs.py +191 -16
  28. databricks/sdk/mixins/open_ai_client.py +26 -20
  29. databricks/sdk/mixins/workspace.py +45 -34
  30. databricks/sdk/oauth.py +379 -198
  31. databricks/sdk/retries.py +14 -12
  32. databricks/sdk/runtime/__init__.py +34 -17
  33. databricks/sdk/runtime/dbutils_stub.py +52 -39
  34. databricks/sdk/service/_internal.py +12 -7
  35. databricks/sdk/service/apps.py +618 -418
  36. databricks/sdk/service/billing.py +827 -604
  37. databricks/sdk/service/catalog.py +6552 -4474
  38. databricks/sdk/service/cleanrooms.py +550 -388
  39. databricks/sdk/service/compute.py +5263 -3536
  40. databricks/sdk/service/dashboards.py +1331 -924
  41. databricks/sdk/service/files.py +446 -309
  42. databricks/sdk/service/iam.py +2115 -1483
  43. databricks/sdk/service/jobs.py +4151 -2588
  44. databricks/sdk/service/marketplace.py +2210 -1517
  45. databricks/sdk/service/ml.py +3839 -2256
  46. databricks/sdk/service/oauth2.py +910 -584
  47. databricks/sdk/service/pipelines.py +1865 -1203
  48. databricks/sdk/service/provisioning.py +1435 -1029
  49. databricks/sdk/service/serving.py +2060 -1290
  50. databricks/sdk/service/settings.py +2846 -1929
  51. databricks/sdk/service/sharing.py +2201 -877
  52. databricks/sdk/service/sql.py +4650 -3103
  53. databricks/sdk/service/vectorsearch.py +816 -550
  54. databricks/sdk/service/workspace.py +1330 -906
  55. databricks/sdk/useragent.py +36 -22
  56. databricks/sdk/version.py +1 -1
  57. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/METADATA +31 -31
  58. databricks_sdk-0.46.0.dist-info/RECORD +70 -0
  59. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/WHEEL +1 -1
  60. databricks_sdk-0.44.1.dist-info/RECORD +0 -69
  61. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/LICENSE +0 -0
  62. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/NOTICE +0 -0
  63. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/top_level.txt +0 -0
@@ -26,13 +26,17 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
26
26
 
27
27
  CredentialsProvider = Callable[[], Dict[str, str]]
28
28
 
29
- logger = logging.getLogger('databricks.sdk')
29
+ logger = logging.getLogger("databricks.sdk")
30
30
 
31
31
 
32
32
  class OAuthCredentialsProvider:
33
- """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
33
+ """OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens."""
34
34
 
35
- def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
35
+ def __init__(
36
+ self,
37
+ credentials_provider: CredentialsProvider,
38
+ token_provider: Callable[[], Token],
39
+ ):
36
40
  self._credentials_provider = credentials_provider
37
41
  self._token_provider = token_provider
38
42
 
@@ -44,45 +48,49 @@ class OAuthCredentialsProvider:
44
48
 
45
49
 
46
50
  class CredentialsStrategy(abc.ABC):
47
- """ CredentialsProvider is the protocol (call-side interface)
48
- for authenticating requests to Databricks REST APIs"""
51
+ """CredentialsProvider is the protocol (call-side interface)
52
+ for authenticating requests to Databricks REST APIs"""
49
53
 
50
54
  @abc.abstractmethod
51
- def auth_type(self) -> str:
52
- ...
55
+ def auth_type(self) -> str: ...
53
56
 
54
57
  @abc.abstractmethod
55
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
56
- ...
58
+ def __call__(self, cfg: "Config") -> CredentialsProvider: ...
57
59
 
58
60
 
59
61
  class OauthCredentialsStrategy(CredentialsStrategy):
60
- """ OauthCredentialsProvider is a CredentialsProvider which
62
+ """OauthCredentialsProvider is a CredentialsProvider which
61
63
  supports Oauth tokens"""
62
64
 
63
- def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
65
+ def __init__(
66
+ self,
67
+ auth_type: str,
68
+ headers_provider: Callable[["Config"], OAuthCredentialsProvider],
69
+ ):
64
70
  self._headers_provider = headers_provider
65
71
  self._auth_type = auth_type
66
72
 
67
73
  def auth_type(self) -> str:
68
74
  return self._auth_type
69
75
 
70
- def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
76
+ def __call__(self, cfg: "Config") -> OAuthCredentialsProvider:
71
77
  return self._headers_provider(cfg)
72
78
 
73
- def oauth_token(self, cfg: 'Config') -> Token:
79
+ def oauth_token(self, cfg: "Config") -> Token:
74
80
  return self._headers_provider(cfg).oauth_token()
75
81
 
76
82
 
77
83
  def credentials_strategy(name: str, require: List[str]):
78
- """ Given the function that receives a Config and returns RequestVisitor,
84
+ """Given the function that receives a Config and returns RequestVisitor,
79
85
  create CredentialsProvider with a given name and required configuration
80
- attribute names to be present for this function to be called. """
86
+ attribute names to be present for this function to be called."""
81
87
 
82
- def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
88
+ def inner(
89
+ func: Callable[["Config"], CredentialsProvider],
90
+ ) -> CredentialsStrategy:
83
91
 
84
92
  @functools.wraps(func)
85
- def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
93
+ def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
86
94
  for attr in require:
87
95
  getattr(cfg, attr)
88
96
  if not getattr(cfg, attr):
@@ -96,14 +104,16 @@ def credentials_strategy(name: str, require: List[str]):
96
104
 
97
105
 
98
106
  def oauth_credentials_strategy(name: str, require: List[str]):
99
- """ Given the function that receives a Config and returns an OauthHeaderFactory,
107
+ """Given the function that receives a Config and returns an OauthHeaderFactory,
100
108
  create an OauthCredentialsProvider with a given name and required configuration
101
- attribute names to be present for this function to be called. """
109
+ attribute names to be present for this function to be called."""
102
110
 
103
- def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
111
+ def inner(
112
+ func: Callable[["Config"], OAuthCredentialsProvider],
113
+ ) -> OauthCredentialsStrategy:
104
114
 
105
115
  @functools.wraps(func)
106
- def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
116
+ def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
107
117
  for attr in require:
108
118
  if not getattr(cfg, attr):
109
119
  return None
@@ -114,11 +124,11 @@ def oauth_credentials_strategy(name: str, require: List[str]):
114
124
  return inner
115
125
 
116
126
 
117
- @credentials_strategy('basic', ['host', 'username', 'password'])
118
- def basic_auth(cfg: 'Config') -> CredentialsProvider:
119
- """ Given username and password, add base64-encoded Basic credentials """
120
- encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
121
- static_credentials = {'Authorization': f'Basic {encoded}'}
127
+ @credentials_strategy("basic", ["host", "username", "password"])
128
+ def basic_auth(cfg: "Config") -> CredentialsProvider:
129
+ """Given username and password, add base64-encoded Basic credentials"""
130
+ encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode()
131
+ static_credentials = {"Authorization": f"Basic {encoded}"}
122
132
 
123
133
  def inner() -> Dict[str, str]:
124
134
  return static_credentials
@@ -126,10 +136,10 @@ def basic_auth(cfg: 'Config') -> CredentialsProvider:
126
136
  return inner
127
137
 
128
138
 
129
- @credentials_strategy('pat', ['host', 'token'])
130
- def pat_auth(cfg: 'Config') -> CredentialsProvider:
131
- """ Adds Databricks Personal Access Token to every request """
132
- static_credentials = {'Authorization': f'Bearer {cfg.token}'}
139
+ @credentials_strategy("pat", ["host", "token"])
140
+ def pat_auth(cfg: "Config") -> CredentialsProvider:
141
+ """Adds Databricks Personal Access Token to every request"""
142
+ static_credentials = {"Authorization": f"Bearer {cfg.token}"}
133
143
 
134
144
  def inner() -> Dict[str, str]:
135
145
  return static_credentials
@@ -137,9 +147,9 @@ def pat_auth(cfg: 'Config') -> CredentialsProvider:
137
147
  return inner
138
148
 
139
149
 
140
- @credentials_strategy('runtime', [])
141
- def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
142
- if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
150
+ @credentials_strategy("runtime", [])
151
+ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
152
+ if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
143
153
  return None
144
154
 
145
155
  # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
@@ -148,36 +158,45 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
148
158
  from databricks.sdk.runtime import (init_runtime_legacy_auth,
149
159
  init_runtime_native_auth,
150
160
  init_runtime_repl_auth)
151
- for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]:
161
+
162
+ for init in [
163
+ init_runtime_native_auth,
164
+ init_runtime_repl_auth,
165
+ init_runtime_legacy_auth,
166
+ ]:
152
167
  if init is None:
153
168
  continue
154
169
  host, inner = init()
155
170
  if host is None:
156
- logger.debug(f'[{init.__name__}] no host detected')
171
+ logger.debug(f"[{init.__name__}] no host detected")
157
172
  continue
158
173
  cfg.host = host
159
- logger.debug(f'[{init.__name__}] runtime native auth configured')
174
+ logger.debug(f"[{init.__name__}] runtime native auth configured")
160
175
  return inner
161
176
  return None
162
177
 
163
178
 
164
- @oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
165
- def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
166
- """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
167
- if /oidc/.well-known/oauth-authorization-server is available on the given host. """
179
+ @oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
180
+ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
181
+ """Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
182
+ if /oidc/.well-known/oauth-authorization-server is available on the given host.
183
+ """
168
184
  oidc = cfg.oidc_endpoints
169
185
  if oidc is None:
170
186
  return None
171
187
 
172
- token_source = ClientCredentials(client_id=cfg.client_id,
173
- client_secret=cfg.client_secret,
174
- token_url=oidc.token_endpoint,
175
- scopes=["all-apis"],
176
- use_header=True)
188
+ token_source = ClientCredentials(
189
+ client_id=cfg.client_id,
190
+ client_secret=cfg.client_secret,
191
+ token_url=oidc.token_endpoint,
192
+ scopes=["all-apis"],
193
+ use_header=True,
194
+ disable_async=not cfg.enable_experimental_async_token_refresh,
195
+ )
177
196
 
178
197
  def inner() -> Dict[str, str]:
179
198
  token = token_source.token()
180
- return {'Authorization': f'{token.token_type} {token.access_token}'}
199
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
181
200
 
182
201
  def token() -> Token:
183
202
  return token_source.token()
@@ -185,9 +204,9 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
185
204
  return OAuthCredentialsProvider(inner, token)
186
205
 
187
206
 
188
- @credentials_strategy('external-browser', ['host', 'auth_type'])
189
- def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
190
- if cfg.auth_type != 'external-browser':
207
+ @credentials_strategy("external-browser", ["host", "auth_type"])
208
+ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
209
+ if cfg.auth_type != "external-browser":
191
210
  return None
192
211
 
193
212
  client_id, client_secret = None, None
@@ -198,17 +217,19 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
198
217
  client_id = cfg.azure_client
199
218
  client_secret = cfg.azure_client_secret
200
219
  if not client_id:
201
- client_id = 'databricks-cli'
220
+ client_id = "databricks-cli"
202
221
 
203
222
  # Load cached credentials from disk if they exist. Note that these are
204
223
  # local to the Python SDK and not reused by other SDKs.
205
224
  oidc_endpoints = cfg.oidc_endpoints
206
- redirect_url = 'http://localhost:8020'
207
- token_cache = TokenCache(host=cfg.host,
208
- oidc_endpoints=oidc_endpoints,
209
- client_id=client_id,
210
- client_secret=client_secret,
211
- redirect_url=redirect_url)
225
+ redirect_url = "http://localhost:8020"
226
+ token_cache = TokenCache(
227
+ host=cfg.host,
228
+ oidc_endpoints=oidc_endpoints,
229
+ client_id=client_id,
230
+ client_secret=client_secret,
231
+ redirect_url=redirect_url,
232
+ )
212
233
  credentials = token_cache.load()
213
234
  if credentials:
214
235
  try:
@@ -219,12 +240,14 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
219
240
  return credentials(cfg)
220
241
  # TODO: We should ideally use more specific exceptions.
221
242
  except Exception as e:
222
- logger.warning(f'Failed to refresh cached token: {e}. Initiating new OAuth login flow')
223
-
224
- oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
225
- client_id=client_id,
226
- redirect_url=redirect_url,
227
- client_secret=client_secret)
243
+ logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow")
244
+
245
+ oauth_client = OAuthClient(
246
+ oidc_endpoints=oidc_endpoints,
247
+ client_id=client_id,
248
+ redirect_url=redirect_url,
249
+ client_secret=client_secret,
250
+ )
228
251
  consent = oauth_client.initiate_consent()
229
252
  if not consent:
230
253
  return None
@@ -234,33 +257,42 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
234
257
  return credentials(cfg)
235
258
 
236
259
 
237
- def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]):
238
- """ Resolves Azure Databricks workspace URL from ARM Resource ID """
260
+ def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]):
261
+ """Resolves Azure Databricks workspace URL from ARM Resource ID"""
239
262
  if cfg.host:
240
263
  return
241
264
  if not cfg.azure_workspace_resource_id:
242
265
  return
243
266
  arm = cfg.arm_environment.resource_manager_endpoint
244
267
  token = token_source_for(arm).token()
245
- resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
246
- headers={"Authorization": f"Bearer {token.access_token}"})
268
+ resp = requests.get(
269
+ f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
270
+ headers={"Authorization": f"Bearer {token.access_token}"},
271
+ )
247
272
  if not resp.ok:
248
273
  raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
249
274
  cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
250
275
 
251
276
 
252
- @oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
253
- def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
254
- """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
255
- to every request, while automatically resolving different Azure environment endpoints. """
277
+ @oauth_credentials_strategy(
278
+ "azure-client-secret",
279
+ ["is_azure", "azure_client_id", "azure_client_secret"],
280
+ )
281
+ def azure_service_principal(cfg: "Config") -> CredentialsProvider:
282
+ """Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
283
+ to every request, while automatically resolving different Azure environment endpoints.
284
+ """
256
285
 
257
286
  def token_source_for(resource: str) -> TokenSource:
258
287
  aad_endpoint = cfg.arm_environment.active_directory_endpoint
259
- return ClientCredentials(client_id=cfg.azure_client_id,
260
- client_secret=cfg.azure_client_secret,
261
- token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
262
- endpoint_params={"resource": resource},
263
- use_params=True)
288
+ return ClientCredentials(
289
+ client_id=cfg.azure_client_id,
290
+ client_secret=cfg.azure_client_secret,
291
+ token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
292
+ endpoint_params={"resource": resource},
293
+ use_params=True,
294
+ disable_async=not cfg.enable_experimental_async_token_refresh,
295
+ )
264
296
 
265
297
  _ensure_host_present(cfg, token_source_for)
266
298
  cfg.load_azure_tenant_id()
@@ -269,7 +301,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
269
301
  cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
270
302
 
271
303
  def refreshed_headers() -> Dict[str, str]:
272
- headers = {'Authorization': f"Bearer {inner.token().access_token}", }
304
+ headers = {
305
+ "Authorization": f"Bearer {inner.token().access_token}",
306
+ }
273
307
  add_workspace_id_header(cfg, headers)
274
308
  add_sp_management_token(cloud, headers)
275
309
  return headers
@@ -280,9 +314,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
280
314
  return OAuthCredentialsProvider(refreshed_headers, token)
281
315
 
282
316
 
283
- @oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
284
- def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
285
- if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
317
+ @oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
318
+ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
319
+ if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
286
320
  # not in GitHub actions
287
321
  return None
288
322
 
@@ -292,7 +326,7 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
292
326
  return None
293
327
 
294
328
  # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
295
- headers = {'Authorization': f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
329
+ headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
296
330
  endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
297
331
  response = requests.get(endpoint, headers=headers)
298
332
  if not response.ok:
@@ -300,30 +334,35 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
300
334
 
301
335
  # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
302
336
  response_json = response.json()
303
- if 'value' not in response_json:
337
+ if "value" not in response_json:
304
338
  return None
305
339
 
306
- logger.info("Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id)
340
+ logger.info(
341
+ "Configured AAD token for GitHub Actions OIDC (%s)",
342
+ cfg.azure_client_id,
343
+ )
307
344
  params = {
308
- 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
309
- 'resource': cfg.effective_azure_login_app_id,
310
- 'client_assertion': response_json['value'],
345
+ "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
346
+ "resource": cfg.effective_azure_login_app_id,
347
+ "client_assertion": response_json["value"],
311
348
  }
312
349
  aad_endpoint = cfg.arm_environment.active_directory_endpoint
313
350
  if not cfg.azure_tenant_id:
314
351
  # detect Azure AD Tenant ID if it's not specified directly
315
352
  token_endpoint = cfg.oidc_endpoints.token_endpoint
316
- cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
353
+ cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
317
354
  inner = ClientCredentials(
318
355
  client_id=cfg.azure_client_id,
319
- client_secret="", # we have no (rotatable) secrets in OIDC flow
356
+ client_secret="", # we have no (rotatable) secrets in OIDC flow
320
357
  token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
321
358
  endpoint_params=params,
322
- use_params=True)
359
+ use_params=True,
360
+ disable_async=not cfg.enable_experimental_async_token_refresh,
361
+ )
323
362
 
324
363
  def refreshed_headers() -> Dict[str, str]:
325
364
  token = inner.token()
326
- return {'Authorization': f'{token.token_type} {token.access_token}'}
365
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
327
366
 
328
367
  def token() -> Token:
329
368
  return inner.token()
@@ -331,29 +370,32 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
331
370
  return OAuthCredentialsProvider(refreshed_headers, token)
332
371
 
333
372
 
334
- GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
373
+ GcpScopes = [
374
+ "https://www.googleapis.com/auth/cloud-platform",
375
+ "https://www.googleapis.com/auth/compute",
376
+ ]
335
377
 
336
378
 
337
- @oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
338
- def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
379
+ @oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
380
+ def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
339
381
  if not cfg.is_gcp:
340
382
  return None
341
383
  # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
342
384
  # Obtain the id token by providing the json file path and target audience.
343
- if (os.path.isfile(cfg.google_credentials)):
385
+ if os.path.isfile(cfg.google_credentials):
344
386
  with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
345
387
  account_info = json.load(json_file)
346
388
  else:
347
389
  # If the file doesn't exist, assume that the config is the actual JSON content.
348
390
  account_info = json.loads(cfg.google_credentials)
349
391
 
350
- credentials = service_account.IDTokenCredentials.from_service_account_info(info=account_info,
351
- target_audience=cfg.host)
392
+ credentials = service_account.IDTokenCredentials.from_service_account_info(
393
+ info=account_info, target_audience=cfg.host
394
+ )
352
395
 
353
396
  request = Request()
354
397
 
355
- gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
356
- scopes=GcpScopes)
398
+ gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes)
357
399
 
358
400
  def token() -> Token:
359
401
  credentials.refresh(request)
@@ -361,7 +403,7 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
361
403
 
362
404
  def refreshed_headers() -> Dict[str, str]:
363
405
  credentials.refresh(request)
364
- headers = {'Authorization': f'Bearer {credentials.token}'}
406
+ headers = {"Authorization": f"Bearer {credentials.token}"}
365
407
  if cfg.is_account_client:
366
408
  gcp_credentials.refresh(request)
367
409
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
@@ -370,24 +412,29 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
370
412
  return OAuthCredentialsProvider(refreshed_headers, token)
371
413
 
372
414
 
373
- @oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
374
- def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
415
+ @oauth_credentials_strategy("google-id", ["host", "google_service_account"])
416
+ def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
375
417
  if not cfg.is_gcp:
376
418
  return None
377
419
  credentials, _project_id = google.auth.default()
378
420
 
379
421
  # Create the impersonated credential.
380
- target_credentials = impersonated_credentials.Credentials(source_credentials=credentials,
381
- target_principal=cfg.google_service_account,
382
- target_scopes=[])
422
+ target_credentials = impersonated_credentials.Credentials(
423
+ source_credentials=credentials,
424
+ target_principal=cfg.google_service_account,
425
+ target_scopes=[],
426
+ )
383
427
 
384
428
  # Set the impersonated credential, target audience and token options.
385
- id_creds = impersonated_credentials.IDTokenCredentials(target_credentials,
386
- target_audience=cfg.host,
387
- include_email=True)
429
+ id_creds = impersonated_credentials.IDTokenCredentials(
430
+ target_credentials, target_audience=cfg.host, include_email=True
431
+ )
388
432
 
389
433
  gcp_impersonated_credentials = impersonated_credentials.Credentials(
390
- source_credentials=credentials, target_principal=cfg.google_service_account, target_scopes=GcpScopes)
434
+ source_credentials=credentials,
435
+ target_principal=cfg.google_service_account,
436
+ target_scopes=GcpScopes,
437
+ )
391
438
 
392
439
  request = Request()
393
440
 
@@ -397,7 +444,7 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
397
444
 
398
445
  def refreshed_headers() -> Dict[str, str]:
399
446
  id_creds.refresh(request)
400
- headers = {'Authorization': f'Bearer {id_creds.token}'}
447
+ headers = {"Authorization": f"Bearer {id_creds.token}"}
401
448
  if cfg.is_account_client:
402
449
  gcp_impersonated_credentials.refresh(request)
403
450
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
@@ -408,8 +455,15 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
408
455
 
409
456
  class CliTokenSource(Refreshable):
410
457
 
411
- def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str):
412
- super().__init__()
458
+ def __init__(
459
+ self,
460
+ cmd: List[str],
461
+ token_type_field: str,
462
+ access_token_field: str,
463
+ expiry_field: str,
464
+ disable_async: bool = True,
465
+ ):
466
+ super().__init__(disable_async=disable_async)
413
467
  self._cmd = cmd
414
468
  self._token_type_field = token_type_field
415
469
  self._access_token_field = access_token_field
@@ -431,52 +485,74 @@ class CliTokenSource(Refreshable):
431
485
  out = _run_subprocess(self._cmd, capture_output=True, check=True)
432
486
  it = json.loads(out.stdout.decode())
433
487
  expires_on = self._parse_expiry(it[self._expiry_field])
434
- return Token(access_token=it[self._access_token_field],
435
- token_type=it[self._token_type_field],
436
- expiry=expires_on)
488
+ return Token(
489
+ access_token=it[self._access_token_field],
490
+ token_type=it[self._token_type_field],
491
+ expiry=expires_on,
492
+ )
437
493
  except ValueError as e:
438
494
  raise ValueError(f"cannot unmarshal CLI result: {e}")
439
495
  except subprocess.CalledProcessError as e:
440
496
  stdout = e.stdout.decode().strip()
441
497
  stderr = e.stderr.decode().strip()
442
498
  message = stdout or stderr
443
- raise IOError(f'cannot get access token: {message}') from e
499
+ raise IOError(f"cannot get access token: {message}") from e
444
500
 
445
501
 
446
- def _run_subprocess(popenargs,
447
- input=None,
448
- capture_output=True,
449
- timeout=None,
450
- check=False,
451
- **kwargs) -> subprocess.CompletedProcess:
502
+ def _run_subprocess(
503
+ popenargs,
504
+ input=None,
505
+ capture_output=True,
506
+ timeout=None,
507
+ check=False,
508
+ **kwargs,
509
+ ) -> subprocess.CompletedProcess:
452
510
  """Runs subprocess with given arguments.
453
- This handles OS-specific modifications that need to be made to the invocation of subprocess.run."""
454
- kwargs['shell'] = sys.platform.startswith('win')
511
+ This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
512
+ """
513
+ kwargs["shell"] = sys.platform.startswith("win")
455
514
  # windows requires shell=True to be able to execute 'az login' or other commands
456
515
  # cannot use shell=True all the time, as it breaks macOS
457
516
  logging.debug(f'Running command: {" ".join(popenargs)}')
458
- return subprocess.run(popenargs,
459
- input=input,
460
- capture_output=capture_output,
461
- timeout=timeout,
462
- check=check,
463
- **kwargs)
517
+ return subprocess.run(
518
+ popenargs,
519
+ input=input,
520
+ capture_output=capture_output,
521
+ timeout=timeout,
522
+ check=check,
523
+ **kwargs,
524
+ )
464
525
 
465
526
 
466
527
  class AzureCliTokenSource(CliTokenSource):
467
- """ Obtain the token granted by `az login` CLI command """
468
-
469
- def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
470
- cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
528
+ """Obtain the token granted by `az login` CLI command"""
529
+
530
+ def __init__(
531
+ self,
532
+ resource: str,
533
+ subscription: Optional[str] = None,
534
+ tenant: Optional[str] = None,
535
+ ):
536
+ cmd = [
537
+ "az",
538
+ "account",
539
+ "get-access-token",
540
+ "--resource",
541
+ resource,
542
+ "--output",
543
+ "json",
544
+ ]
471
545
  if subscription is not None:
472
546
  cmd.append("--subscription")
473
547
  cmd.append(subscription)
474
548
  if tenant and not self.__is_cli_using_managed_identity():
475
549
  cmd.extend(["--tenant", tenant])
476
- super().__init__(cmd=cmd,
477
- token_type_field='tokenType',
478
- access_token_field='accessToken',
479
- expiry_field='expiresOn')
550
+ super().__init__(
551
+ cmd=cmd,
552
+ token_type_field="tokenType",
553
+ access_token_field="accessToken",
554
+ expiry_field="expiresOn",
555
+ )
480
556
 
481
557
  @staticmethod
482
558
  def __is_cli_using_managed_identity() -> bool:
@@ -489,7 +565,8 @@ class AzureCliTokenSource(CliTokenSource):
489
565
  if user is None:
490
566
  return False
491
567
  return user.get("type") == "servicePrincipal" and user.get("name") in [
492
- 'systemAssignedIdentity', 'userAssignedIdentity'
568
+ "systemAssignedIdentity",
569
+ "userAssignedIdentity",
493
570
  ]
494
571
  except subprocess.CalledProcessError as e:
495
572
  logger.debug("Failed to get account information from Azure CLI", exc_info=e)
@@ -512,15 +589,13 @@ class AzureCliTokenSource(CliTokenSource):
512
589
  guaranteed to be unique within a tenant and should be used only for display purposes.
513
590
  - 'upn' - The username of the user.
514
591
  """
515
- return 'upn' in self.token().jwt_claims()
592
+ return "upn" in self.token().jwt_claims()
516
593
 
517
594
  @staticmethod
518
- def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
595
+ def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource":
519
596
  subscription = AzureCliTokenSource.get_subscription(cfg)
520
597
  if subscription is not None:
521
- token_source = AzureCliTokenSource(resource,
522
- subscription=subscription,
523
- tenant=cfg.azure_tenant_id)
598
+ token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
524
599
  try:
525
600
  # This will fail if the user has access to the workspace, but not to the subscription
526
601
  # itself.
@@ -535,32 +610,32 @@ class AzureCliTokenSource(CliTokenSource):
535
610
  return token_source
536
611
 
537
612
  @staticmethod
538
- def get_subscription(cfg: 'Config') -> Optional[str]:
613
+ def get_subscription(cfg: "Config") -> Optional[str]:
539
614
  resource = cfg.azure_workspace_resource_id
540
615
  if resource is None or resource == "":
541
616
  return None
542
- components = resource.split('/')
617
+ components = resource.split("/")
543
618
  if len(components) < 3:
544
619
  logger.warning("Invalid azure workspace resource ID")
545
620
  return None
546
621
  return components[2]
547
622
 
548
623
 
549
- @credentials_strategy('azure-cli', ['is_azure'])
550
- def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
551
- """ Adds refreshed OAuth token granted by `az login` command to every request. """
624
+ @credentials_strategy("azure-cli", ["is_azure"])
625
+ def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
626
+ """Adds refreshed OAuth token granted by `az login` command to every request."""
552
627
  cfg.load_azure_tenant_id()
553
628
  token_source = None
554
629
  mgmt_token_source = None
555
630
  try:
556
631
  token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
557
632
  except FileNotFoundError:
558
- doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
559
- logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
633
+ doc = "https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest"
634
+ logger.debug(f"Most likely Azure CLI is not installed. See {doc} for details")
560
635
  return None
561
636
  except OSError as e:
562
- logger.debug('skipping Azure CLI auth', exc_info=e)
563
- logger.debug('This may happen if you are attempting to login to a dev or staging workspace')
637
+ logger.debug("skipping Azure CLI auth", exc_info=e)
638
+ logger.debug("This may happen if you are attempting to login to a dev or staging workspace")
564
639
  return None
565
640
 
566
641
  if not token_source.is_human_user():
@@ -568,7 +643,10 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
568
643
  management_endpoint = cfg.arm_environment.service_management_endpoint
569
644
  mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
570
645
  except Exception as e:
571
- logger.debug(f'Not including service management token in headers', exc_info=e)
646
+ logger.debug(
647
+ f"Not including service management token in headers",
648
+ exc_info=e,
649
+ )
572
650
  mgmt_token_source = None
573
651
 
574
652
  _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
@@ -576,7 +654,7 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
576
654
 
577
655
  def inner() -> Dict[str, str]:
578
656
  token = token_source.token()
579
- headers = {'Authorization': f'{token.token_type} {token.access_token}'}
657
+ headers = {"Authorization": f"{token.token_type} {token.access_token}"}
580
658
  add_workspace_id_header(cfg, headers)
581
659
  if mgmt_token_source:
582
660
  add_sp_management_token(mgmt_token_source, headers)
@@ -586,12 +664,12 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
586
664
 
587
665
 
588
666
  class DatabricksCliTokenSource(CliTokenSource):
589
- """ Obtain the token granted by `databricks auth login` CLI command """
667
+ """Obtain the token granted by `databricks auth login` CLI command"""
590
668
 
591
- def __init__(self, cfg: 'Config'):
592
- args = ['auth', 'token', '--host', cfg.host]
669
+ def __init__(self, cfg: "Config"):
670
+ args = ["auth", "token", "--host", cfg.host]
593
671
  if cfg.is_account_client:
594
- args += ['--account-id', cfg.account_id]
672
+ args += ["--account-id", cfg.account_id]
595
673
 
596
674
  cli_path = cfg.databricks_cli_path
597
675
 
@@ -611,10 +689,13 @@ class DatabricksCliTokenSource(CliTokenSource):
611
689
  elif cli_path.count("/") == 0:
612
690
  cli_path = self.__class__._find_executable(cli_path)
613
691
 
614
- super().__init__(cmd=[cli_path, *args],
615
- token_type_field='token_type',
616
- access_token_field='access_token',
617
- expiry_field='expiry')
692
+ super().__init__(
693
+ cmd=[cli_path, *args],
694
+ token_type_field="token_type",
695
+ access_token_field="access_token",
696
+ expiry_field="expiry",
697
+ disable_async=not cfg.enable_experimental_async_token_refresh,
698
+ )
618
699
 
619
700
  @staticmethod
620
701
  def _find_executable(name) -> str:
@@ -636,8 +717,8 @@ class DatabricksCliTokenSource(CliTokenSource):
636
717
  raise err
637
718
 
638
719
 
639
- @oauth_credentials_strategy('databricks-cli', ['host'])
640
- def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
720
+ @oauth_credentials_strategy("databricks-cli", ["host"])
721
+ def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]:
641
722
  try:
642
723
  token_source = DatabricksCliTokenSource(cfg)
643
724
  except FileNotFoundError as e:
@@ -647,8 +728,8 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
647
728
  try:
648
729
  token_source.token()
649
730
  except IOError as e:
650
- if 'databricks OAuth is not' in str(e):
651
- logger.debug(f'OAuth not configured or not available: {e}')
731
+ if "databricks OAuth is not" in str(e):
732
+ logger.debug(f"OAuth not configured or not available: {e}")
652
733
  return None
653
734
  raise e
654
735
 
@@ -656,7 +737,7 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
656
737
 
657
738
  def inner() -> Dict[str, str]:
658
739
  token = token_source.token()
659
- return {'Authorization': f'{token.token_type} {token.access_token}'}
740
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
660
741
 
661
742
  def token() -> Token:
662
743
  return token_source.token()
@@ -665,13 +746,14 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
665
746
 
666
747
 
667
748
  class MetadataServiceTokenSource(Refreshable):
668
- """ Obtain the token granted by Databricks Metadata Service """
749
+ """Obtain the token granted by Databricks Metadata Service"""
750
+
669
751
  METADATA_SERVICE_VERSION = "1"
670
752
  METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
671
753
  METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
672
- _metadata_service_timeout = 10 # seconds
754
+ _metadata_service_timeout = 10 # seconds
673
755
 
674
- def __init__(self, cfg: 'Config'):
756
+ def __init__(self, cfg: "Config"):
675
757
  super().__init__()
676
758
  self.url = cfg.metadata_service_url
677
759
  self.host = cfg.host
@@ -682,13 +764,14 @@ class MetadataServiceTokenSource(Refreshable):
682
764
  timeout=self._metadata_service_timeout,
683
765
  headers={
684
766
  self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
685
- self.METADATA_SERVICE_HOST_HEADER: self.host
767
+ self.METADATA_SERVICE_HOST_HEADER: self.host,
686
768
  },
687
769
  proxies={
688
770
  # Explicitly exclude localhost from being proxied. This is necessary
689
771
  # for Metadata URLs which typically point to localhost.
690
772
  "no_proxy": "localhost,127.0.0.1"
691
- })
773
+ },
774
+ )
692
775
  json_resp: dict[str, Union[str, float]] = resp.json()
693
776
  access_token = json_resp.get("access_token", None)
694
777
  if access_token is None:
@@ -706,9 +789,9 @@ class MetadataServiceTokenSource(Refreshable):
706
789
  return Token(access_token=access_token, token_type=token_type, expiry=expiry)
707
790
 
708
791
 
709
- @credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
710
- def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
711
- """ Adds refreshed token granted by Databricks Metadata Service to every request. """
792
+ @credentials_strategy("metadata-service", ["host", "metadata_service_url"])
793
+ def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]:
794
+ """Adds refreshed token granted by Databricks Metadata Service to every request."""
712
795
 
713
796
  token_source = MetadataServiceTokenSource(cfg)
714
797
  token_source.token()
@@ -716,14 +799,14 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
716
799
 
717
800
  def inner() -> Dict[str, str]:
718
801
  token = token_source.token()
719
- return {'Authorization': f'{token.token_type} {token.access_token}'}
802
+ return {"Authorization": f"{token.token_type} {token.access_token}"}
720
803
 
721
804
  return inner
722
805
 
723
806
 
724
807
  # This Code is derived from Mlflow DatabricksModelServingConfigProvider
725
808
  # https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
726
- class ModelServingAuthProvider():
809
+ class ModelServingAuthProvider:
727
810
  USER_CREDENTIALS = "user_credentials"
728
811
 
729
812
  _MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
@@ -731,7 +814,7 @@ class ModelServingAuthProvider():
731
814
  def __init__(self, credential_type: Optional[str]):
732
815
  self.expiry_time = -1
733
816
  self.current_token = None
734
- self.refresh_duration = 300 # 300 Seconds
817
+ self.refresh_duration = 300 # 300 Seconds
735
818
  self.credential_type = credential_type
736
819
 
737
820
  def should_fetch_model_serving_environment_oauth() -> bool:
@@ -740,10 +823,14 @@ class ModelServingAuthProvider():
740
823
  Additionally check if the oauth token file path exists
741
824
  """
742
825
 
743
- is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
744
- or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
745
- return (is_in_model_serving_env == "true"
746
- and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
826
+ is_in_model_serving_env = (
827
+ os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
828
+ or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
829
+ or "false"
830
+ )
831
+ return is_in_model_serving_env == "true" and os.path.isfile(
832
+ ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH
833
+ )
747
834
 
748
835
  def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
749
836
  # Use Cached value if it is valid
@@ -758,8 +845,10 @@ class ModelServingAuthProvider():
758
845
  except Exception as e:
759
846
  # sleep and retry in case of any race conditions with OAuth refreshing
760
847
  if should_retry:
761
- logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
762
- exc_info=e)
848
+ logger.warning(
849
+ "Unable to read oauth token on first attmept in Model Serving Environment",
850
+ exc_info=e,
851
+ )
763
852
  time.sleep(0.5)
764
853
  return self._get_model_dependency_oauth_token(should_retry=False)
765
854
  else:
@@ -769,8 +858,8 @@ class ModelServingAuthProvider():
769
858
  return self.current_token
770
859
 
771
860
  def _get_invokers_token(self):
772
- current_thread = threading.current_thread()
773
- thread_data = current_thread.__dict__
861
+ main_thread = threading.main_thread()
862
+ thread_data = main_thread.__dict__
774
863
  invokers_token = None
775
864
  if "invokers_token" in thread_data:
776
865
  invokers_token = thread_data["invokers_token"]
@@ -785,8 +874,7 @@ class ModelServingAuthProvider():
785
874
  return None
786
875
 
787
876
  # read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
788
- host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
789
- "DB_MODEL_SERVING_HOST_URL")
877
+ host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL")
790
878
 
791
879
  if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
792
880
  return (host, self._get_invokers_token())
@@ -794,8 +882,7 @@ class ModelServingAuthProvider():
794
882
  return (host, self._get_model_dependency_oauth_token())
795
883
 
796
884
 
797
- def model_serving_auth_visitor(cfg: 'Config',
798
- credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
885
+ def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
799
886
  try:
800
887
  model_serving_auth_provider = ModelServingAuthProvider(credential_type)
801
888
  host, token = model_serving_auth_provider.get_databricks_host_token()
@@ -806,7 +893,10 @@ def model_serving_auth_visitor(cfg: 'Config',
806
893
  if cfg.host is None:
807
894
  cfg.host = host
808
895
  except Exception as e:
809
- logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
896
+ logger.warning(
897
+ "Unable to get auth from Databricks Model Serving Environment",
898
+ exc_info=e,
899
+ )
810
900
  return None
811
901
  logger.info("Using Databricks Model Serving Authentication")
812
902
 
@@ -818,8 +908,8 @@ def model_serving_auth_visitor(cfg: 'Config',
818
908
  return inner
819
909
 
820
910
 
821
- @credentials_strategy('model-serving', [])
822
- def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
911
+ @credentials_strategy("model-serving", [])
912
+ def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
823
913
  if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
824
914
  logger.debug("model-serving: Not in Databricks Model Serving, skipping")
825
915
  return None
@@ -828,20 +918,30 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
828
918
 
829
919
 
830
920
  class DefaultCredentials:
831
- """ Select the first applicable credential provider from the chain """
921
+ """Select the first applicable credential provider from the chain"""
832
922
 
833
923
  def __init__(self) -> None:
834
- self._auth_type = 'default'
924
+ self._auth_type = "default"
835
925
  self._auth_providers = [
836
- pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
837
- github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
838
- google_credentials, google_id, model_serving_auth
926
+ pat_auth,
927
+ basic_auth,
928
+ metadata_service,
929
+ oauth_service_principal,
930
+ azure_service_principal,
931
+ github_oidc_azure,
932
+ azure_cli,
933
+ external_browser,
934
+ databricks_cli,
935
+ runtime_native_auth,
936
+ google_credentials,
937
+ google_id,
938
+ model_serving_auth,
839
939
  ]
840
940
 
841
941
  def auth_type(self) -> str:
842
942
  return self._auth_type
843
943
 
844
- def oauth_token(self, cfg: 'Config') -> Token:
944
+ def oauth_token(self, cfg: "Config") -> Token:
845
945
  for provider in self._auth_providers:
846
946
  auth_type = provider.auth_type()
847
947
  if auth_type != self._auth_type:
@@ -849,14 +949,14 @@ class DefaultCredentials:
849
949
  continue
850
950
  return provider.oauth_token(cfg)
851
951
 
852
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
952
+ def __call__(self, cfg: "Config") -> CredentialsProvider:
853
953
  for provider in self._auth_providers:
854
954
  auth_type = provider.auth_type()
855
955
  if cfg.auth_type and auth_type != cfg.auth_type:
856
956
  # ignore other auth types if one is explicitly enforced
857
957
  logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
858
958
  continue
859
- logger.debug(f'Attempting to configure auth: {auth_type}')
959
+ logger.debug(f"Attempting to configure auth: {auth_type}")
860
960
  try:
861
961
  header_factory = provider(cfg)
862
962
  if not header_factory:
@@ -864,18 +964,18 @@ class DefaultCredentials:
864
964
  self._auth_type = auth_type
865
965
  return header_factory
866
966
  except Exception as e:
867
- raise ValueError(f'{auth_type}: {e}') from e
967
+ raise ValueError(f"{auth_type}: {e}") from e
868
968
  auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
869
969
  raise ValueError(
870
- f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
970
+ f"cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method."
871
971
  )
872
972
 
873
973
 
874
974
  class ModelServingUserCredentials(CredentialsStrategy):
875
975
  """
876
- This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
877
- In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
878
- In any other environments, the class defaults to the DefaultCredentialStrategy.
976
+ This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
977
+ In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
978
+ In any other environments, the class defaults to the DefaultCredentialStrategy.
879
979
  To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
880
980
 
881
981
  invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
@@ -891,7 +991,7 @@ class ModelServingUserCredentials(CredentialsStrategy):
891
991
  else:
892
992
  return self.default_credentials.auth_type()
893
993
 
894
- def __call__(self, cfg: 'Config') -> CredentialsProvider:
994
+ def __call__(self, cfg: "Config") -> CredentialsProvider:
895
995
  if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
896
996
  header_factory = model_serving_auth_visitor(cfg, self.credential_type)
897
997
  if not header_factory: