databricks-sdk 0.44.1__py3-none-any.whl → 0.46.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (63) hide show
  1. databricks/sdk/__init__.py +135 -116
  2. databricks/sdk/_base_client.py +112 -88
  3. databricks/sdk/_property.py +12 -7
  4. databricks/sdk/_widgets/__init__.py +13 -2
  5. databricks/sdk/_widgets/default_widgets_utils.py +21 -15
  6. databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
  7. databricks/sdk/azure.py +8 -6
  8. databricks/sdk/casing.py +5 -5
  9. databricks/sdk/config.py +156 -99
  10. databricks/sdk/core.py +57 -47
  11. databricks/sdk/credentials_provider.py +306 -206
  12. databricks/sdk/data_plane.py +75 -50
  13. databricks/sdk/dbutils.py +123 -87
  14. databricks/sdk/environments.py +52 -35
  15. databricks/sdk/errors/base.py +61 -35
  16. databricks/sdk/errors/customizer.py +3 -3
  17. databricks/sdk/errors/deserializer.py +38 -25
  18. databricks/sdk/errors/details.py +417 -0
  19. databricks/sdk/errors/mapper.py +1 -1
  20. databricks/sdk/errors/overrides.py +27 -24
  21. databricks/sdk/errors/parser.py +26 -14
  22. databricks/sdk/errors/platform.py +10 -10
  23. databricks/sdk/errors/private_link.py +24 -24
  24. databricks/sdk/logger/round_trip_logger.py +28 -20
  25. databricks/sdk/mixins/compute.py +90 -60
  26. databricks/sdk/mixins/files.py +815 -145
  27. databricks/sdk/mixins/jobs.py +191 -16
  28. databricks/sdk/mixins/open_ai_client.py +26 -20
  29. databricks/sdk/mixins/workspace.py +45 -34
  30. databricks/sdk/oauth.py +379 -198
  31. databricks/sdk/retries.py +14 -12
  32. databricks/sdk/runtime/__init__.py +34 -17
  33. databricks/sdk/runtime/dbutils_stub.py +52 -39
  34. databricks/sdk/service/_internal.py +12 -7
  35. databricks/sdk/service/apps.py +618 -418
  36. databricks/sdk/service/billing.py +827 -604
  37. databricks/sdk/service/catalog.py +6552 -4474
  38. databricks/sdk/service/cleanrooms.py +550 -388
  39. databricks/sdk/service/compute.py +5263 -3536
  40. databricks/sdk/service/dashboards.py +1331 -924
  41. databricks/sdk/service/files.py +446 -309
  42. databricks/sdk/service/iam.py +2115 -1483
  43. databricks/sdk/service/jobs.py +4151 -2588
  44. databricks/sdk/service/marketplace.py +2210 -1517
  45. databricks/sdk/service/ml.py +3839 -2256
  46. databricks/sdk/service/oauth2.py +910 -584
  47. databricks/sdk/service/pipelines.py +1865 -1203
  48. databricks/sdk/service/provisioning.py +1435 -1029
  49. databricks/sdk/service/serving.py +2060 -1290
  50. databricks/sdk/service/settings.py +2846 -1929
  51. databricks/sdk/service/sharing.py +2201 -877
  52. databricks/sdk/service/sql.py +4650 -3103
  53. databricks/sdk/service/vectorsearch.py +816 -550
  54. databricks/sdk/service/workspace.py +1330 -906
  55. databricks/sdk/useragent.py +36 -22
  56. databricks/sdk/version.py +1 -1
  57. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/METADATA +31 -31
  58. databricks_sdk-0.46.0.dist-info/RECORD +70 -0
  59. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/WHEEL +1 -1
  60. databricks_sdk-0.44.1.dist-info/RECORD +0 -69
  61. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/LICENSE +0 -0
  62. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/NOTICE +0 -0
  63. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/top_level.txt +0 -0
databricks/sdk/oauth.py CHANGED
@@ -9,8 +9,10 @@ import threading
9
9
  import urllib.parse
10
10
  import webbrowser
11
11
  from abc import abstractmethod
12
+ from concurrent.futures import ThreadPoolExecutor
12
13
  from dataclasses import dataclass
13
14
  from datetime import datetime, timedelta
15
+ from enum import Enum
14
16
  from http.server import BaseHTTPRequestHandler, HTTPServer
15
17
  from typing import Any, Dict, List, Optional
16
18
 
@@ -21,7 +23,7 @@ from ._base_client import _BaseClient, _fix_host_if_needed
21
23
 
22
24
  # Error code for PKCE flow in Azure Active Directory, that gets additional retry.
23
25
  # See https://stackoverflow.com/a/75466778/277035 for more info
24
- NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327'
26
+ NO_ORIGIN_FOR_SPA_CLIENT_ERROR = "AADSTS9002327"
25
27
 
26
28
  URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
27
29
  JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
@@ -52,28 +54,33 @@ class OidcEndpoints:
52
54
  The endpoints used for OAuth-based authentication in Databricks.
53
55
  """
54
56
 
55
- authorization_endpoint: str # ../v1/authorize
57
+ authorization_endpoint: str # ../v1/authorize
56
58
  """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for
57
59
  the user to login and authorize the client for user-to-machine (U2M) flows."""
58
60
 
59
- token_endpoint: str # ../v1/token
61
+ token_endpoint: str # ../v1/token
60
62
  """The token endpoint for the OAuth flow."""
61
63
 
62
64
  @staticmethod
63
- def from_dict(d: dict) -> 'OidcEndpoints':
64
- return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'),
65
- token_endpoint=d.get('token_endpoint'))
65
+ def from_dict(d: dict) -> "OidcEndpoints":
66
+ return OidcEndpoints(
67
+ authorization_endpoint=d.get("authorization_endpoint"),
68
+ token_endpoint=d.get("token_endpoint"),
69
+ )
66
70
 
67
71
  def as_dict(self) -> dict:
68
- return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint}
72
+ return {
73
+ "authorization_endpoint": self.authorization_endpoint,
74
+ "token_endpoint": self.token_endpoint,
75
+ }
69
76
 
70
77
 
71
78
  @dataclass
72
79
  class Token:
73
80
  access_token: str
74
- token_type: str = None
75
- refresh_token: str = None
76
- expiry: datetime = None
81
+ token_type: Optional[str] = None
82
+ refresh_token: Optional[str] = None
83
+ expiry: Optional[datetime] = None
77
84
 
78
85
  @property
79
86
  def expired(self):
@@ -91,19 +98,24 @@ class Token:
91
98
  return self.access_token and not self.expired
92
99
 
93
100
  def as_dict(self) -> dict:
94
- raw = {'access_token': self.access_token, 'token_type': self.token_type}
101
+ raw = {
102
+ "access_token": self.access_token,
103
+ "token_type": self.token_type,
104
+ }
95
105
  if self.expiry:
96
- raw['expiry'] = self.expiry.isoformat()
106
+ raw["expiry"] = self.expiry.isoformat()
97
107
  if self.refresh_token:
98
- raw['refresh_token'] = self.refresh_token
108
+ raw["refresh_token"] = self.refresh_token
99
109
  return raw
100
110
 
101
111
  @staticmethod
102
- def from_dict(raw: dict) -> 'Token':
103
- return Token(access_token=raw['access_token'],
104
- token_type=raw['token_type'],
105
- expiry=datetime.fromisoformat(raw['expiry']),
106
- refresh_token=raw.get('refresh_token'))
112
+ def from_dict(raw: dict) -> "Token":
113
+ return Token(
114
+ access_token=raw["access_token"],
115
+ token_type=raw["token_type"],
116
+ expiry=datetime.fromisoformat(raw["expiry"]),
117
+ refresh_token=raw.get("refresh_token"),
118
+ )
107
119
 
108
120
  def jwt_claims(self) -> Dict[str, str]:
109
121
  """Get claims from the access token or return an empty dictionary if it is not a JWT token.
@@ -131,7 +143,7 @@ class Token:
131
143
  try:
132
144
  jwt_split = self.access_token.split(".")
133
145
  if len(jwt_split) != 3:
134
- logger.debug(f'Tried to decode access token as JWT, but failed: {len(jwt_split)} components')
146
+ logger.debug(f"Tried to decode access token as JWT, but failed: {len(jwt_split)} components")
135
147
  return {}
136
148
  payload_with_padding = jwt_split[1] + "=="
137
149
  payload_bytes = base64.standard_b64decode(payload_with_padding)
@@ -139,7 +151,7 @@ class Token:
139
151
  claims = json.loads(payload_json)
140
152
  return claims
141
153
  except ValueError as err:
142
- logger.debug(f'Tried to decode access token as JWT, but failed: {err}')
154
+ logger.debug(f"Tried to decode access token as JWT, but failed: {err}")
143
155
  return {}
144
156
 
145
157
 
@@ -150,17 +162,21 @@ class TokenSource:
150
162
  pass
151
163
 
152
164
 
153
- def retrieve_token(client_id,
154
- client_secret,
155
- token_url,
156
- params,
157
- use_params=False,
158
- use_header=False,
159
- headers=None) -> Token:
160
- logger.debug(f'Retrieving token for {client_id}')
165
+ def retrieve_token(
166
+ client_id,
167
+ client_secret,
168
+ token_url,
169
+ params,
170
+ use_params=False,
171
+ use_header=False,
172
+ headers=None,
173
+ ) -> Token:
174
+ logger.debug(f"Retrieving token for {client_id}")
161
175
  if use_params:
162
- if client_id: params["client_id"] = client_id
163
- if client_secret: params["client_secret"] = client_secret
176
+ if client_id:
177
+ params["client_id"] = client_id
178
+ if client_secret:
179
+ params["client_secret"] = client_secret
164
180
  auth = None
165
181
  if use_header:
166
182
  auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
@@ -168,40 +184,156 @@ def retrieve_token(client_id,
168
184
  auth = IgnoreNetrcAuth()
169
185
  resp = requests.post(token_url, params, auth=auth, headers=headers)
170
186
  if not resp.ok:
171
- if resp.headers['Content-Type'].startswith('application/json'):
187
+ if resp.headers["Content-Type"].startswith("application/json"):
172
188
  err = resp.json()
173
- code = err.get('errorCode', err.get('error', 'unknown'))
174
- summary = err.get('errorSummary', err.get('error_description', 'unknown'))
175
- summary = summary.replace("\r\n", ' ')
176
- raise ValueError(f'{code}: {summary}')
189
+ code = err.get("errorCode", err.get("error", "unknown"))
190
+ summary = err.get("errorSummary", err.get("error_description", "unknown"))
191
+ summary = summary.replace("\r\n", " ")
192
+ raise ValueError(f"{code}: {summary}")
177
193
  raise ValueError(resp.content)
178
194
  try:
179
195
  j = resp.json()
180
196
  expires_in = int(j["expires_in"])
181
197
  expiry = datetime.now() + timedelta(seconds=expires_in)
182
- return Token(access_token=j["access_token"],
183
- refresh_token=j.get('refresh_token'),
184
- token_type=j["token_type"],
185
- expiry=expiry)
198
+ return Token(
199
+ access_token=j["access_token"],
200
+ refresh_token=j.get("refresh_token"),
201
+ token_type=j["token_type"],
202
+ expiry=expiry,
203
+ )
186
204
  except Exception as e:
187
205
  raise NotImplementedError(f"Not supported yet: {e}")
188
206
 
189
207
 
190
- class Refreshable(TokenSource):
208
+ class _TokenState(Enum):
209
+ """
210
+ Represents the state of a token. Each token can be in one of
211
+ the following three states:
212
+ - FRESH: The token is valid.
213
+ - STALE: The token is valid but will expire soon.
214
+ - EXPIRED: The token has expired and cannot be used.
215
+ """
191
216
 
192
- def __init__(self, token=None):
193
- self._lock = threading.Lock() # to guard _token
194
- self._token = token
217
+ FRESH = 1 # The token is valid.
218
+ STALE = 2 # The token is valid but will expire soon.
219
+ EXPIRED = 3 # The token has expired and cannot be used.
195
220
 
221
+
222
+ class Refreshable(TokenSource):
223
+ """A token source that supports refreshing expired tokens."""
224
+
225
+ _EXECUTOR = None
226
+ _EXECUTOR_LOCK = threading.Lock()
227
+ _DEFAULT_STALE_DURATION = timedelta(minutes=3)
228
+
229
+ @classmethod
230
+ def _get_executor(cls):
231
+ """Lazy initialization of the ThreadPoolExecutor."""
232
+ if cls._EXECUTOR is None:
233
+ with cls._EXECUTOR_LOCK:
234
+ if cls._EXECUTOR is None:
235
+ # This thread pool has multiple workers because it is shared by all instances of Refreshable.
236
+ cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
237
+ return cls._EXECUTOR
238
+
239
+ def __init__(
240
+ self,
241
+ token: Optional[Token] = None,
242
+ disable_async: bool = True,
243
+ stale_duration: timedelta = _DEFAULT_STALE_DURATION,
244
+ ):
245
+ # Config properties
246
+ self._stale_duration = stale_duration
247
+ self._disable_async = disable_async
248
+ # Lock
249
+ self._lock = threading.Lock()
250
+ # Non Thread safe properties. They should be accessed only when protected by the lock above.
251
+ self._token = token or Token("")
252
+ self._is_refreshing = False
253
+ self._refresh_err = False
254
+
255
+ # This is the main entry point for the Token. Do not access the token
256
+ # using any of the internal functions.
196
257
  def token(self) -> Token:
197
- self._lock.acquire()
198
- try:
199
- if self._token and self._token.valid:
200
- return self._token
201
- self._token = self.refresh()
258
+ """Returns a valid token, blocking if async refresh is disabled."""
259
+ with self._lock:
260
+ if self._disable_async:
261
+ return self._blocking_token()
262
+ return self._async_token()
263
+
264
+ def _async_token(self) -> Token:
265
+ """
266
+ Returns a token.
267
+ If the token is stale, triggers an asynchronous refresh.
268
+ If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
269
+ """
270
+ state = self._token_state()
271
+ token = self._token
272
+
273
+ if state == _TokenState.FRESH:
274
+ return token
275
+ if state == _TokenState.STALE:
276
+ self._trigger_async_refresh()
277
+ return token
278
+ return self._blocking_token()
279
+
280
+ def _token_state(self) -> _TokenState:
281
+ """Returns the current state of the token."""
282
+ if not self._token or not self._token.valid:
283
+ return _TokenState.EXPIRED
284
+ if not self._token.expiry:
285
+ return _TokenState.FRESH
286
+
287
+ lifespan = self._token.expiry - datetime.now()
288
+ if lifespan < timedelta(seconds=0):
289
+ return _TokenState.EXPIRED
290
+ if lifespan < self._stale_duration:
291
+ return _TokenState.STALE
292
+ return _TokenState.FRESH
293
+
294
+ def _blocking_token(self) -> Token:
295
+ """Returns a token, blocking if necessary to refresh it."""
296
+ state = self._token_state()
297
+ # This is important to recover from potential previous failed attempts
298
+ # to refresh the token asynchronously.
299
+ self._refresh_err = False
300
+ self._is_refreshing = False
301
+
302
+ # It's possible that the token got refreshed (either by a _blocking_refresh or
303
+ # an _async_refresh call) while this particular call was waiting to acquire
304
+ # the lock. This check avoids refreshing the token again in such cases.
305
+ if state != _TokenState.EXPIRED:
202
306
  return self._token
203
- finally:
204
- self._lock.release()
307
+
308
+ self._token = self.refresh()
309
+ return self._token
310
+
311
+ def _trigger_async_refresh(self):
312
+ """Starts an asynchronous refresh if none is in progress."""
313
+
314
+ def _refresh_internal():
315
+ new_token = None
316
+ try:
317
+ new_token = self.refresh()
318
+ except Exception as e:
319
+ # This happens on a thread, so we don't want to propagate the error.
320
+ # Instead, if there is no new_token for any reason, we will disable async refresh below
321
+ # But we will do it inside the lock.
322
+ logger.warning(f"Tried to refresh token asynchronously, but failed: {e}")
323
+
324
+ with self._lock:
325
+ if new_token is not None:
326
+ self._token = new_token
327
+ else:
328
+ self._refresh_err = True
329
+ self._is_refreshing = False
330
+
331
+ # The token may have been refreshed by another thread.
332
+ if self._token_state() == _TokenState.FRESH:
333
+ return
334
+ if not self._is_refreshing and not self._refresh_err:
335
+ self._is_refreshing = True
336
+ Refreshable._get_executor().submit(_refresh_internal)
205
337
 
206
338
  @abstractmethod
207
339
  def refresh(self) -> Token:
@@ -219,23 +351,24 @@ class _OAuthCallback(BaseHTTPRequestHandler):
219
351
 
220
352
  def do_GET(self):
221
353
  from urllib.parse import parse_qsl
222
- parts = self.path.split('?')
354
+
355
+ parts = self.path.split("?")
223
356
  if len(parts) != 2:
224
- self.send_error(400, 'Missing Query')
357
+ self.send_error(400, "Missing Query")
225
358
  return
226
359
 
227
360
  query = dict(parse_qsl(parts[1]))
228
361
  self._feedback.append(query)
229
362
 
230
- if 'error' in query:
231
- self.send_error(400, query['error'], query.get('error_description'))
363
+ if "error" in query:
364
+ self.send_error(400, query["error"], query.get("error_description"))
232
365
  return
233
366
 
234
367
  self.send_response(200)
235
- self.send_header('Content-type', 'text/html')
368
+ self.send_header("Content-type", "text/html")
236
369
  self.end_headers()
237
370
  # TODO: show better message
238
- self.wfile.write(b'You can close this tab.')
371
+ self.wfile.write(b"You can close this tab.")
239
372
 
240
373
 
241
374
  def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
@@ -246,8 +379,8 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas
246
379
  :return: The account's OIDC endpoints.
247
380
  """
248
381
  host = _fix_host_if_needed(host)
249
- oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server'
250
- resp = client.do('GET', oidc)
382
+ oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server"
383
+ resp = client.do("GET", oidc)
251
384
  return OidcEndpoints.from_dict(resp)
252
385
 
253
386
 
@@ -258,12 +391,14 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
258
391
  :return: The workspace's OIDC endpoints.
259
392
  """
260
393
  host = _fix_host_if_needed(host)
261
- oidc = f'{host}/oidc/.well-known/oauth-authorization-server'
262
- resp = client.do('GET', oidc)
394
+ oidc = f"{host}/oidc/.well-known/oauth-authorization-server"
395
+ resp = client.do("GET", oidc)
263
396
  return OidcEndpoints.from_dict(resp)
264
397
 
265
398
 
266
- def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]:
399
+ def get_azure_entra_id_workspace_endpoints(
400
+ host: str,
401
+ ) -> Optional[OidcEndpoints]:
267
402
  """
268
403
  Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
269
404
  using an application registered in Azure Entra ID.
@@ -272,84 +407,103 @@ def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]
272
407
  """
273
408
  # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
274
409
  host = _fix_host_if_needed(host)
275
- res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
276
- real_auth_url = res.headers.get('location')
410
+ res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
411
+ real_auth_url = res.headers.get("location")
277
412
  if not real_auth_url:
278
413
  return None
279
- return OidcEndpoints(authorization_endpoint=real_auth_url,
280
- token_endpoint=real_auth_url.replace('/authorize', '/token'))
414
+ return OidcEndpoints(
415
+ authorization_endpoint=real_auth_url,
416
+ token_endpoint=real_auth_url.replace("/authorize", "/token"),
417
+ )
281
418
 
282
419
 
283
420
  class SessionCredentials(Refreshable):
284
421
 
285
- def __init__(self,
286
- token: Token,
287
- token_endpoint: str,
288
- client_id: str,
289
- client_secret: str = None,
290
- redirect_url: str = None):
422
+ def __init__(
423
+ self,
424
+ token: Token,
425
+ token_endpoint: str,
426
+ client_id: str,
427
+ client_secret: str = None,
428
+ redirect_url: str = None,
429
+ disable_async: bool = True,
430
+ ):
291
431
  self._token_endpoint = token_endpoint
292
432
  self._client_id = client_id
293
433
  self._client_secret = client_secret
294
434
  self._redirect_url = redirect_url
295
- super().__init__(token)
435
+ super().__init__(
436
+ token=token,
437
+ disable_async=disable_async,
438
+ )
296
439
 
297
440
  def as_dict(self) -> dict:
298
- return {'token': self._token.as_dict()}
441
+ return {"token": self.token().as_dict()}
299
442
 
300
443
  @staticmethod
301
- def from_dict(raw: dict,
302
- token_endpoint: str,
303
- client_id: str,
304
- client_secret: str = None,
305
- redirect_url: str = None) -> 'SessionCredentials':
306
- return SessionCredentials(token=Token.from_dict(raw['token']),
307
- token_endpoint=token_endpoint,
308
- client_id=client_id,
309
- client_secret=client_secret,
310
- redirect_url=redirect_url)
444
+ def from_dict(
445
+ raw: dict,
446
+ token_endpoint: str,
447
+ client_id: str,
448
+ client_secret: str = None,
449
+ redirect_url: str = None,
450
+ ) -> "SessionCredentials":
451
+ return SessionCredentials(
452
+ token=Token.from_dict(raw["token"]),
453
+ token_endpoint=token_endpoint,
454
+ client_id=client_id,
455
+ client_secret=client_secret,
456
+ redirect_url=redirect_url,
457
+ )
311
458
 
312
459
  def auth_type(self):
313
460
  """Implementing CredentialsProvider protocol"""
314
461
  # TODO: distinguish between Databricks IDP and Azure AD
315
- return 'oauth'
462
+ return "oauth"
316
463
 
317
464
  def __call__(self, *args, **kwargs):
318
465
  """Implementing CredentialsProvider protocol"""
319
466
 
320
467
  def inner() -> Dict[str, str]:
321
- return {'Authorization': f"Bearer {self.token().access_token}"}
468
+ return {"Authorization": f"Bearer {self.token().access_token}"}
322
469
 
323
470
  return inner
324
471
 
325
472
  def refresh(self) -> Token:
326
473
  refresh_token = self._token.refresh_token
327
474
  if not refresh_token:
328
- raise ValueError('oauth2: token expired and refresh token is not set')
329
- params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token}
475
+ raise ValueError("oauth2: token expired and refresh token is not set")
476
+ params = {
477
+ "grant_type": "refresh_token",
478
+ "refresh_token": refresh_token,
479
+ }
330
480
  headers = {}
331
- if 'microsoft' in self._token_endpoint:
481
+ if "microsoft" in self._token_endpoint:
332
482
  # Tokens issued for the 'Single-Page Application' client-type may
333
483
  # only be redeemed via cross-origin requests
334
- headers = {'Origin': self._redirect_url}
335
- return retrieve_token(client_id=self._client_id,
336
- client_secret=self._client_secret,
337
- token_url=self._token_endpoint,
338
- params=params,
339
- use_params=True,
340
- headers=headers)
484
+ headers = {"Origin": self._redirect_url}
485
+ return retrieve_token(
486
+ client_id=self._client_id,
487
+ client_secret=self._client_secret,
488
+ token_url=self._token_endpoint,
489
+ params=params,
490
+ use_params=True,
491
+ headers=headers,
492
+ )
341
493
 
342
494
 
343
495
  class Consent:
344
496
 
345
- def __init__(self,
346
- state: str,
347
- verifier: str,
348
- authorization_url: str,
349
- redirect_url: str,
350
- token_endpoint: str,
351
- client_id: str,
352
- client_secret: str = None) -> None:
497
+ def __init__(
498
+ self,
499
+ state: str,
500
+ verifier: str,
501
+ authorization_url: str,
502
+ redirect_url: str,
503
+ token_endpoint: str,
504
+ client_id: str,
505
+ client_secret: str = None,
506
+ ) -> None:
353
507
  self._verifier = verifier
354
508
  self._state = state
355
509
  self._authorization_url = authorization_url
@@ -360,12 +514,12 @@ class Consent:
360
514
 
361
515
  def as_dict(self) -> dict:
362
516
  return {
363
- 'state': self._state,
364
- 'verifier': self._verifier,
365
- 'authorization_url': self._authorization_url,
366
- 'redirect_url': self._redirect_url,
367
- 'token_endpoint': self._token_endpoint,
368
- 'client_id': self._client_id,
517
+ "state": self._state,
518
+ "verifier": self._verifier,
519
+ "authorization_url": self._authorization_url,
520
+ "redirect_url": self._redirect_url,
521
+ "token_endpoint": self._token_endpoint,
522
+ "client_id": self._client_id,
369
523
  }
370
524
 
371
525
  @property
@@ -373,65 +527,74 @@ class Consent:
373
527
  return self._authorization_url
374
528
 
375
529
  @staticmethod
376
- def from_dict(raw: dict, client_secret: str = None) -> 'Consent':
377
- return Consent(raw['state'],
378
- raw['verifier'],
379
- authorization_url=raw['authorization_url'],
380
- redirect_url=raw['redirect_url'],
381
- token_endpoint=raw['token_endpoint'],
382
- client_id=raw['client_id'],
383
- client_secret=client_secret)
530
+ def from_dict(raw: dict, client_secret: str = None) -> "Consent":
531
+ return Consent(
532
+ raw["state"],
533
+ raw["verifier"],
534
+ authorization_url=raw["authorization_url"],
535
+ redirect_url=raw["redirect_url"],
536
+ token_endpoint=raw["token_endpoint"],
537
+ client_id=raw["client_id"],
538
+ client_secret=client_secret,
539
+ )
384
540
 
385
541
  def launch_external_browser(self) -> SessionCredentials:
386
542
  redirect_url = urllib.parse.urlparse(self._redirect_url)
387
- if redirect_url.hostname not in ('localhost', '127.0.0.1'):
388
- raise ValueError(f'cannot listen on {redirect_url.hostname}')
543
+ if redirect_url.hostname not in ("localhost", "127.0.0.1"):
544
+ raise ValueError(f"cannot listen on {redirect_url.hostname}")
389
545
  feedback = []
390
- logger.info(f'Opening {self._authorization_url} in a browser')
546
+ logger.info(f"Opening {self._authorization_url} in a browser")
391
547
  webbrowser.open_new(self._authorization_url)
392
548
  port = redirect_url.port
393
549
  handler_factory = functools.partial(_OAuthCallback, feedback)
394
550
  with HTTPServer(("localhost", port), handler_factory) as httpd:
395
- logger.info(f'Waiting for redirect to http://localhost:{port}')
551
+ logger.info(f"Waiting for redirect to http://localhost:{port}")
396
552
  httpd.handle_request()
397
553
  if not feedback:
398
- raise ValueError('No data received in callback')
554
+ raise ValueError("No data received in callback")
399
555
  query = feedback.pop()
400
556
  return self.exchange_callback_parameters(query)
401
557
 
402
558
  def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials:
403
- if 'error' in query:
404
- raise ValueError('{error}: {error_description}'.format(**query))
405
- if 'code' not in query or 'state' not in query:
406
- raise ValueError('No code returned in callback')
407
- return self.exchange(query['code'], query['state'])
559
+ if "error" in query:
560
+ raise ValueError("{error}: {error_description}".format(**query))
561
+ if "code" not in query or "state" not in query:
562
+ raise ValueError("No code returned in callback")
563
+ return self.exchange(query["code"], query["state"])
408
564
 
409
565
  def exchange(self, code: str, state: str) -> SessionCredentials:
410
566
  if self._state != state:
411
- raise ValueError('state mismatch')
567
+ raise ValueError("state mismatch")
412
568
  params = {
413
- 'redirect_uri': self._redirect_url,
414
- 'grant_type': 'authorization_code',
415
- 'code_verifier': self._verifier,
416
- 'code': code
569
+ "redirect_uri": self._redirect_url,
570
+ "grant_type": "authorization_code",
571
+ "code_verifier": self._verifier,
572
+ "code": code,
417
573
  }
418
574
  headers = {}
419
575
  while True:
420
576
  try:
421
- token = retrieve_token(client_id=self._client_id,
422
- client_secret=self._client_secret,
423
- token_url=self._token_endpoint,
424
- params=params,
425
- headers=headers,
426
- use_params=True)
427
- return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret,
428
- self._redirect_url)
577
+ token = retrieve_token(
578
+ client_id=self._client_id,
579
+ client_secret=self._client_secret,
580
+ token_url=self._token_endpoint,
581
+ params=params,
582
+ headers=headers,
583
+ use_params=True,
584
+ )
585
+ return SessionCredentials(
586
+ token,
587
+ self._token_endpoint,
588
+ self._client_id,
589
+ self._client_secret,
590
+ self._redirect_url,
591
+ )
429
592
  except ValueError as e:
430
593
  if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
431
594
  # Retry in cases of 'Single-Page Application' client-type with
432
595
  # 'Origin' header equal to client's redirect URL.
433
- headers['Origin'] = self._redirect_url
434
- msg = f'Retrying OAuth token exchange with {self._redirect_url} origin'
596
+ headers["Origin"] = self._redirect_url
597
+ msg = f"Retrying OAuth token exchange with {self._redirect_url} origin"
435
598
  logger.debug(msg)
436
599
  continue
437
600
  raise e
@@ -456,15 +619,17 @@ class OAuthClient:
456
619
  exchange it for a token without possessing the Code Verifier.
457
620
  """
458
621
 
459
- def __init__(self,
460
- oidc_endpoints: OidcEndpoints,
461
- redirect_url: str,
462
- client_id: str,
463
- scopes: List[str] = None,
464
- client_secret: str = None):
622
+ def __init__(
623
+ self,
624
+ oidc_endpoints: OidcEndpoints,
625
+ redirect_url: str,
626
+ client_id: str,
627
+ scopes: List[str] = None,
628
+ client_secret: str = None,
629
+ ):
465
630
 
466
631
  if not scopes:
467
- scopes = ['all-apis']
632
+ scopes = ["all-apis"]
468
633
 
469
634
  self.redirect_url = redirect_url
470
635
  self._client_id = client_id
@@ -473,25 +638,27 @@ class OAuthClient:
473
638
  self._scopes = scopes
474
639
 
475
640
  @staticmethod
476
- def from_host(host: str,
477
- client_id: str,
478
- redirect_url: str,
479
- *,
480
- scopes: List[str] = None,
481
- client_secret: str = None) -> 'OAuthClient':
641
+ def from_host(
642
+ host: str,
643
+ client_id: str,
644
+ redirect_url: str,
645
+ *,
646
+ scopes: List[str] = None,
647
+ client_secret: str = None,
648
+ ) -> "OAuthClient":
482
649
  from .core import Config
483
650
  from .credentials_provider import credentials_strategy
484
651
 
485
- @credentials_strategy('noop', [])
652
+ @credentials_strategy("noop", [])
486
653
  def noop_credentials(_: any):
487
654
  return lambda: {}
488
655
 
489
656
  config = Config(host=host, credentials_strategy=noop_credentials)
490
657
  if not scopes:
491
- scopes = ['all-apis']
658
+ scopes = ["all-apis"]
492
659
  oidc = config.oidc_endpoints
493
660
  if not oidc:
494
- raise ValueError(f'{host} does not support OAuth')
661
+ raise ValueError(f"{host} does not support OAuth")
495
662
  return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret)
496
663
 
497
664
  def initiate_consent(self) -> Consent:
@@ -500,28 +667,30 @@ class OAuthClient:
500
667
  # token_urlsafe() already returns base64-encoded string
501
668
  verifier = secrets.token_urlsafe(32)
502
669
  digest = hashlib.sha256(verifier.encode("UTF-8")).digest()
503
- challenge = (base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", ""))
670
+ challenge = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "")
504
671
 
505
672
  params = {
506
- 'response_type': 'code',
507
- 'client_id': self._client_id,
508
- 'redirect_uri': self.redirect_url,
509
- 'scope': ' '.join(self._scopes),
510
- 'state': state,
511
- 'code_challenge': challenge,
512
- 'code_challenge_method': 'S256'
673
+ "response_type": "code",
674
+ "client_id": self._client_id,
675
+ "redirect_uri": self.redirect_url,
676
+ "scope": " ".join(self._scopes),
677
+ "state": state,
678
+ "code_challenge": challenge,
679
+ "code_challenge_method": "S256",
513
680
  }
514
- auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}'
515
- return Consent(state,
516
- verifier,
517
- authorization_url=auth_url,
518
- redirect_url=self.redirect_url,
519
- token_endpoint=self._oidc_endpoints.token_endpoint,
520
- client_id=self._client_id,
521
- client_secret=self._client_secret)
681
+ auth_url = f"{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}"
682
+ return Consent(
683
+ state,
684
+ verifier,
685
+ authorization_url=auth_url,
686
+ redirect_url=self.redirect_url,
687
+ token_endpoint=self._oidc_endpoints.token_endpoint,
688
+ client_id=self._client_id,
689
+ client_secret=self._client_secret,
690
+ )
522
691
 
523
692
  def __repr__(self) -> str:
524
- return f'<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>'
693
+ return f"<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>"
525
694
 
526
695
 
527
696
  @dataclass
@@ -535,6 +704,7 @@ class ClientCredentials(Refreshable):
535
704
  the background job uses the Client ID and Client Secret to obtain
536
705
  an Access Token from the Authorization Server.
537
706
  """
707
+
538
708
  client_id: str
539
709
  client_secret: str
540
710
  token_url: str
@@ -542,9 +712,10 @@ class ClientCredentials(Refreshable):
542
712
  scopes: List[str] = None
543
713
  use_params: bool = False
544
714
  use_header: bool = False
715
+ disable_async: bool = True
545
716
 
546
717
  def __post_init__(self):
547
- super().__init__()
718
+ super().__init__(disable_async=self.disable_async)
548
719
 
549
720
  def refresh(self) -> Token:
550
721
  params = {"grant_type": "client_credentials"}
@@ -553,24 +724,28 @@ class ClientCredentials(Refreshable):
553
724
  if self.endpoint_params:
554
725
  for k, v in self.endpoint_params.items():
555
726
  params[k] = v
556
- return retrieve_token(self.client_id,
557
- self.client_secret,
558
- self.token_url,
559
- params,
560
- use_params=self.use_params,
561
- use_header=self.use_header)
727
+ return retrieve_token(
728
+ self.client_id,
729
+ self.client_secret,
730
+ self.token_url,
731
+ params,
732
+ use_params=self.use_params,
733
+ use_header=self.use_header,
734
+ )
562
735
 
563
736
 
564
737
  class TokenCache:
565
738
  BASE_PATH = "~/.config/databricks-sdk-py/oauth"
566
739
 
567
- def __init__(self,
568
- host: str,
569
- oidc_endpoints: OidcEndpoints,
570
- client_id: str,
571
- redirect_url: str = None,
572
- client_secret: str = None,
573
- scopes: List[str] = None) -> None:
740
+ def __init__(
741
+ self,
742
+ host: str,
743
+ oidc_endpoints: OidcEndpoints,
744
+ client_id: str,
745
+ redirect_url: Optional[str] = None,
746
+ client_secret: Optional[str] = None,
747
+ scopes: Optional[List[str]] = None,
748
+ ) -> None:
574
749
  self._host = host
575
750
  self._client_id = client_id
576
751
  self._oidc_endpoints = oidc_endpoints
@@ -582,8 +757,12 @@ class TokenCache:
582
757
  def filename(self) -> str:
583
758
  # Include host, client_id, and scopes in the cache filename to make it unique.
584
759
  hash = hashlib.sha256()
585
- for chunk in [self._host, self._client_id, ",".join(self._scopes), ]:
586
- hash.update(chunk.encode('utf-8'))
760
+ for chunk in [
761
+ self._host,
762
+ self._client_id,
763
+ ",".join(self._scopes),
764
+ ]:
765
+ hash.update(chunk.encode("utf-8"))
587
766
  return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))
588
767
 
589
768
  def load(self) -> Optional[SessionCredentials]:
@@ -594,13 +773,15 @@ class TokenCache:
594
773
  return None
595
774
 
596
775
  try:
597
- with open(self.filename, 'r') as f:
776
+ with open(self.filename, "r") as f:
598
777
  raw = json.load(f)
599
- return SessionCredentials.from_dict(raw,
600
- token_endpoint=self._oidc_endpoints.token_endpoint,
601
- client_id=self._client_id,
602
- client_secret=self._client_secret,
603
- redirect_url=self._redirect_url)
778
+ return SessionCredentials.from_dict(
779
+ raw,
780
+ token_endpoint=self._oidc_endpoints.token_endpoint,
781
+ client_id=self._client_id,
782
+ client_secret=self._client_secret,
783
+ redirect_url=self._redirect_url,
784
+ )
604
785
  except Exception:
605
786
  return None
606
787
 
@@ -609,6 +790,6 @@ class TokenCache:
609
790
  Save credentials to cache file.
610
791
  """
611
792
  os.makedirs(os.path.dirname(self.filename), exist_ok=True)
612
- with open(self.filename, 'w') as f:
793
+ with open(self.filename, "w") as f:
613
794
  json.dump(credentials.as_dict(), f)
614
795
  os.chmod(self.filename, 0o600)