databricks-sdk 0.44.1__py3-none-any.whl → 0.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (63) hide show
  1. databricks/sdk/__init__.py +123 -115
  2. databricks/sdk/_base_client.py +112 -88
  3. databricks/sdk/_property.py +12 -7
  4. databricks/sdk/_widgets/__init__.py +13 -2
  5. databricks/sdk/_widgets/default_widgets_utils.py +21 -15
  6. databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
  7. databricks/sdk/azure.py +8 -6
  8. databricks/sdk/casing.py +5 -5
  9. databricks/sdk/config.py +152 -99
  10. databricks/sdk/core.py +57 -47
  11. databricks/sdk/credentials_provider.py +300 -205
  12. databricks/sdk/data_plane.py +86 -3
  13. databricks/sdk/dbutils.py +123 -87
  14. databricks/sdk/environments.py +52 -35
  15. databricks/sdk/errors/base.py +61 -35
  16. databricks/sdk/errors/customizer.py +3 -3
  17. databricks/sdk/errors/deserializer.py +38 -25
  18. databricks/sdk/errors/details.py +417 -0
  19. databricks/sdk/errors/mapper.py +1 -1
  20. databricks/sdk/errors/overrides.py +27 -24
  21. databricks/sdk/errors/parser.py +26 -14
  22. databricks/sdk/errors/platform.py +10 -10
  23. databricks/sdk/errors/private_link.py +24 -24
  24. databricks/sdk/logger/round_trip_logger.py +28 -20
  25. databricks/sdk/mixins/compute.py +90 -60
  26. databricks/sdk/mixins/files.py +815 -145
  27. databricks/sdk/mixins/jobs.py +191 -16
  28. databricks/sdk/mixins/open_ai_client.py +26 -20
  29. databricks/sdk/mixins/workspace.py +45 -34
  30. databricks/sdk/oauth.py +372 -196
  31. databricks/sdk/retries.py +14 -12
  32. databricks/sdk/runtime/__init__.py +34 -17
  33. databricks/sdk/runtime/dbutils_stub.py +52 -39
  34. databricks/sdk/service/_internal.py +12 -7
  35. databricks/sdk/service/apps.py +618 -418
  36. databricks/sdk/service/billing.py +827 -604
  37. databricks/sdk/service/catalog.py +6552 -4474
  38. databricks/sdk/service/cleanrooms.py +550 -388
  39. databricks/sdk/service/compute.py +5241 -3531
  40. databricks/sdk/service/dashboards.py +1313 -923
  41. databricks/sdk/service/files.py +442 -309
  42. databricks/sdk/service/iam.py +2115 -1483
  43. databricks/sdk/service/jobs.py +4151 -2588
  44. databricks/sdk/service/marketplace.py +2210 -1517
  45. databricks/sdk/service/ml.py +3364 -2255
  46. databricks/sdk/service/oauth2.py +922 -584
  47. databricks/sdk/service/pipelines.py +1865 -1203
  48. databricks/sdk/service/provisioning.py +1435 -1029
  49. databricks/sdk/service/serving.py +2040 -1278
  50. databricks/sdk/service/settings.py +2846 -1929
  51. databricks/sdk/service/sharing.py +2201 -877
  52. databricks/sdk/service/sql.py +4650 -3103
  53. databricks/sdk/service/vectorsearch.py +816 -550
  54. databricks/sdk/service/workspace.py +1330 -906
  55. databricks/sdk/useragent.py +36 -22
  56. databricks/sdk/version.py +1 -1
  57. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
  58. databricks_sdk-0.45.0.dist-info/RECORD +70 -0
  59. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
  60. databricks_sdk-0.44.1.dist-info/RECORD +0 -69
  61. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
  62. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
  63. {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
databricks/sdk/oauth.py CHANGED
@@ -9,8 +9,10 @@ import threading
9
9
  import urllib.parse
10
10
  import webbrowser
11
11
  from abc import abstractmethod
12
+ from concurrent.futures import ThreadPoolExecutor
12
13
  from dataclasses import dataclass
13
14
  from datetime import datetime, timedelta
15
+ from enum import Enum
14
16
  from http.server import BaseHTTPRequestHandler, HTTPServer
15
17
  from typing import Any, Dict, List, Optional
16
18
 
@@ -21,7 +23,7 @@ from ._base_client import _BaseClient, _fix_host_if_needed
21
23
 
22
24
  # Error code for PKCE flow in Azure Active Directory, that gets additional retry.
23
25
  # See https://stackoverflow.com/a/75466778/277035 for more info
24
- NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327'
26
+ NO_ORIGIN_FOR_SPA_CLIENT_ERROR = "AADSTS9002327"
25
27
 
26
28
  URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
27
29
  JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
@@ -52,28 +54,33 @@ class OidcEndpoints:
52
54
  The endpoints used for OAuth-based authentication in Databricks.
53
55
  """
54
56
 
55
- authorization_endpoint: str # ../v1/authorize
57
+ authorization_endpoint: str # ../v1/authorize
56
58
  """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for
57
59
  the user to login and authorize the client for user-to-machine (U2M) flows."""
58
60
 
59
- token_endpoint: str # ../v1/token
61
+ token_endpoint: str # ../v1/token
60
62
  """The token endpoint for the OAuth flow."""
61
63
 
62
64
  @staticmethod
63
- def from_dict(d: dict) -> 'OidcEndpoints':
64
- return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'),
65
- token_endpoint=d.get('token_endpoint'))
65
+ def from_dict(d: dict) -> "OidcEndpoints":
66
+ return OidcEndpoints(
67
+ authorization_endpoint=d.get("authorization_endpoint"),
68
+ token_endpoint=d.get("token_endpoint"),
69
+ )
66
70
 
67
71
  def as_dict(self) -> dict:
68
- return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint}
72
+ return {
73
+ "authorization_endpoint": self.authorization_endpoint,
74
+ "token_endpoint": self.token_endpoint,
75
+ }
69
76
 
70
77
 
71
78
  @dataclass
72
79
  class Token:
73
80
  access_token: str
74
- token_type: str = None
75
- refresh_token: str = None
76
- expiry: datetime = None
81
+ token_type: Optional[str] = None
82
+ refresh_token: Optional[str] = None
83
+ expiry: Optional[datetime] = None
77
84
 
78
85
  @property
79
86
  def expired(self):
@@ -91,19 +98,24 @@ class Token:
91
98
  return self.access_token and not self.expired
92
99
 
93
100
  def as_dict(self) -> dict:
94
- raw = {'access_token': self.access_token, 'token_type': self.token_type}
101
+ raw = {
102
+ "access_token": self.access_token,
103
+ "token_type": self.token_type,
104
+ }
95
105
  if self.expiry:
96
- raw['expiry'] = self.expiry.isoformat()
106
+ raw["expiry"] = self.expiry.isoformat()
97
107
  if self.refresh_token:
98
- raw['refresh_token'] = self.refresh_token
108
+ raw["refresh_token"] = self.refresh_token
99
109
  return raw
100
110
 
101
111
  @staticmethod
102
- def from_dict(raw: dict) -> 'Token':
103
- return Token(access_token=raw['access_token'],
104
- token_type=raw['token_type'],
105
- expiry=datetime.fromisoformat(raw['expiry']),
106
- refresh_token=raw.get('refresh_token'))
112
+ def from_dict(raw: dict) -> "Token":
113
+ return Token(
114
+ access_token=raw["access_token"],
115
+ token_type=raw["token_type"],
116
+ expiry=datetime.fromisoformat(raw["expiry"]),
117
+ refresh_token=raw.get("refresh_token"),
118
+ )
107
119
 
108
120
  def jwt_claims(self) -> Dict[str, str]:
109
121
  """Get claims from the access token or return an empty dictionary if it is not a JWT token.
@@ -131,7 +143,7 @@ class Token:
131
143
  try:
132
144
  jwt_split = self.access_token.split(".")
133
145
  if len(jwt_split) != 3:
134
- logger.debug(f'Tried to decode access token as JWT, but failed: {len(jwt_split)} components')
146
+ logger.debug(f"Tried to decode access token as JWT, but failed: {len(jwt_split)} components")
135
147
  return {}
136
148
  payload_with_padding = jwt_split[1] + "=="
137
149
  payload_bytes = base64.standard_b64decode(payload_with_padding)
@@ -139,7 +151,7 @@ class Token:
139
151
  claims = json.loads(payload_json)
140
152
  return claims
141
153
  except ValueError as err:
142
- logger.debug(f'Tried to decode access token as JWT, but failed: {err}')
154
+ logger.debug(f"Tried to decode access token as JWT, but failed: {err}")
143
155
  return {}
144
156
 
145
157
 
@@ -150,17 +162,21 @@ class TokenSource:
150
162
  pass
151
163
 
152
164
 
153
- def retrieve_token(client_id,
154
- client_secret,
155
- token_url,
156
- params,
157
- use_params=False,
158
- use_header=False,
159
- headers=None) -> Token:
160
- logger.debug(f'Retrieving token for {client_id}')
165
+ def retrieve_token(
166
+ client_id,
167
+ client_secret,
168
+ token_url,
169
+ params,
170
+ use_params=False,
171
+ use_header=False,
172
+ headers=None,
173
+ ) -> Token:
174
+ logger.debug(f"Retrieving token for {client_id}")
161
175
  if use_params:
162
- if client_id: params["client_id"] = client_id
163
- if client_secret: params["client_secret"] = client_secret
176
+ if client_id:
177
+ params["client_id"] = client_id
178
+ if client_secret:
179
+ params["client_secret"] = client_secret
164
180
  auth = None
165
181
  if use_header:
166
182
  auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
@@ -168,40 +184,156 @@ def retrieve_token(client_id,
168
184
  auth = IgnoreNetrcAuth()
169
185
  resp = requests.post(token_url, params, auth=auth, headers=headers)
170
186
  if not resp.ok:
171
- if resp.headers['Content-Type'].startswith('application/json'):
187
+ if resp.headers["Content-Type"].startswith("application/json"):
172
188
  err = resp.json()
173
- code = err.get('errorCode', err.get('error', 'unknown'))
174
- summary = err.get('errorSummary', err.get('error_description', 'unknown'))
175
- summary = summary.replace("\r\n", ' ')
176
- raise ValueError(f'{code}: {summary}')
189
+ code = err.get("errorCode", err.get("error", "unknown"))
190
+ summary = err.get("errorSummary", err.get("error_description", "unknown"))
191
+ summary = summary.replace("\r\n", " ")
192
+ raise ValueError(f"{code}: {summary}")
177
193
  raise ValueError(resp.content)
178
194
  try:
179
195
  j = resp.json()
180
196
  expires_in = int(j["expires_in"])
181
197
  expiry = datetime.now() + timedelta(seconds=expires_in)
182
- return Token(access_token=j["access_token"],
183
- refresh_token=j.get('refresh_token'),
184
- token_type=j["token_type"],
185
- expiry=expiry)
198
+ return Token(
199
+ access_token=j["access_token"],
200
+ refresh_token=j.get("refresh_token"),
201
+ token_type=j["token_type"],
202
+ expiry=expiry,
203
+ )
186
204
  except Exception as e:
187
205
  raise NotImplementedError(f"Not supported yet: {e}")
188
206
 
189
207
 
190
- class Refreshable(TokenSource):
208
+ class _TokenState(Enum):
209
+ """
210
+ Represents the state of a token. Each token can be in one of
211
+ the following three states:
212
+ - FRESH: The token is valid.
213
+ - STALE: The token is valid but will expire soon.
214
+ - EXPIRED: The token has expired and cannot be used.
215
+ """
191
216
 
192
- def __init__(self, token=None):
193
- self._lock = threading.Lock() # to guard _token
194
- self._token = token
217
+ FRESH = 1 # The token is valid.
218
+ STALE = 2 # The token is valid but will expire soon.
219
+ EXPIRED = 3 # The token has expired and cannot be used.
195
220
 
221
+
222
+ class Refreshable(TokenSource):
223
+ """A token source that supports refreshing expired tokens."""
224
+
225
+ _EXECUTOR = None
226
+ _EXECUTOR_LOCK = threading.Lock()
227
+ _DEFAULT_STALE_DURATION = timedelta(minutes=3)
228
+
229
+ @classmethod
230
+ def _get_executor(cls):
231
+ """Lazy initialization of the ThreadPoolExecutor."""
232
+ if cls._EXECUTOR is None:
233
+ with cls._EXECUTOR_LOCK:
234
+ if cls._EXECUTOR is None:
235
+ # This thread pool has multiple workers because it is shared by all instances of Refreshable.
236
+ cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
237
+ return cls._EXECUTOR
238
+
239
+ def __init__(
240
+ self,
241
+ token: Optional[Token] = None,
242
+ disable_async: bool = True,
243
+ stale_duration: timedelta = _DEFAULT_STALE_DURATION,
244
+ ):
245
+ # Config properties
246
+ self._stale_duration = stale_duration
247
+ self._disable_async = disable_async
248
+ # Lock
249
+ self._lock = threading.Lock()
250
+ # Non Thread safe properties. They should be accessed only when protected by the lock above.
251
+ self._token = token or Token("")
252
+ self._is_refreshing = False
253
+ self._refresh_err = False
254
+
255
+ # This is the main entry point for the Token. Do not access the token
256
+ # using any of the internal functions.
196
257
  def token(self) -> Token:
197
- self._lock.acquire()
198
- try:
199
- if self._token and self._token.valid:
200
- return self._token
201
- self._token = self.refresh()
258
+ """Returns a valid token, blocking if async refresh is disabled."""
259
+ with self._lock:
260
+ if self._disable_async:
261
+ return self._blocking_token()
262
+ return self._async_token()
263
+
264
+ def _async_token(self) -> Token:
265
+ """
266
+ Returns a token.
267
+ If the token is stale, triggers an asynchronous refresh.
268
+ If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
269
+ """
270
+ state = self._token_state()
271
+ token = self._token
272
+
273
+ if state == _TokenState.FRESH:
274
+ return token
275
+ if state == _TokenState.STALE:
276
+ self._trigger_async_refresh()
277
+ return token
278
+ return self._blocking_token()
279
+
280
+ def _token_state(self) -> _TokenState:
281
+ """Returns the current state of the token."""
282
+ if not self._token or not self._token.valid:
283
+ return _TokenState.EXPIRED
284
+ if not self._token.expiry:
285
+ return _TokenState.FRESH
286
+
287
+ lifespan = self._token.expiry - datetime.now()
288
+ if lifespan < timedelta(seconds=0):
289
+ return _TokenState.EXPIRED
290
+ if lifespan < self._stale_duration:
291
+ return _TokenState.STALE
292
+ return _TokenState.FRESH
293
+
294
+ def _blocking_token(self) -> Token:
295
+ """Returns a token, blocking if necessary to refresh it."""
296
+ state = self._token_state()
297
+ # This is important to recover from potential previous failed attempts
298
+ # to refresh the token asynchronously.
299
+ self._refresh_err = False
300
+ self._is_refreshing = False
301
+
302
+ # It's possible that the token got refreshed (either by a _blocking_refresh or
303
+ # an _async_refresh call) while this particular call was waiting to acquire
304
+ # the lock. This check avoids refreshing the token again in such cases.
305
+ if state != _TokenState.EXPIRED:
202
306
  return self._token
203
- finally:
204
- self._lock.release()
307
+
308
+ self._token = self.refresh()
309
+ return self._token
310
+
311
+ def _trigger_async_refresh(self):
312
+ """Starts an asynchronous refresh if none is in progress."""
313
+
314
+ def _refresh_internal():
315
+ new_token = None
316
+ try:
317
+ new_token = self.refresh()
318
+ except Exception as e:
319
+ # This happens on a thread, so we don't want to propagate the error.
320
+ # Instead, if there is no new_token for any reason, we will disable async refresh below
321
+ # But we will do it inside the lock.
322
+ logger.warning(f"Tried to refresh token asynchronously, but failed: {e}")
323
+
324
+ with self._lock:
325
+ if new_token is not None:
326
+ self._token = new_token
327
+ else:
328
+ self._refresh_err = True
329
+ self._is_refreshing = False
330
+
331
+ # The token may have been refreshed by another thread.
332
+ if self._token_state() == _TokenState.FRESH:
333
+ return
334
+ if not self._is_refreshing and not self._refresh_err:
335
+ self._is_refreshing = True
336
+ Refreshable._get_executor().submit(_refresh_internal)
205
337
 
206
338
  @abstractmethod
207
339
  def refresh(self) -> Token:
@@ -219,23 +351,24 @@ class _OAuthCallback(BaseHTTPRequestHandler):
219
351
 
220
352
  def do_GET(self):
221
353
  from urllib.parse import parse_qsl
222
- parts = self.path.split('?')
354
+
355
+ parts = self.path.split("?")
223
356
  if len(parts) != 2:
224
- self.send_error(400, 'Missing Query')
357
+ self.send_error(400, "Missing Query")
225
358
  return
226
359
 
227
360
  query = dict(parse_qsl(parts[1]))
228
361
  self._feedback.append(query)
229
362
 
230
- if 'error' in query:
231
- self.send_error(400, query['error'], query.get('error_description'))
363
+ if "error" in query:
364
+ self.send_error(400, query["error"], query.get("error_description"))
232
365
  return
233
366
 
234
367
  self.send_response(200)
235
- self.send_header('Content-type', 'text/html')
368
+ self.send_header("Content-type", "text/html")
236
369
  self.end_headers()
237
370
  # TODO: show better message
238
- self.wfile.write(b'You can close this tab.')
371
+ self.wfile.write(b"You can close this tab.")
239
372
 
240
373
 
241
374
  def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
@@ -246,8 +379,8 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas
246
379
  :return: The account's OIDC endpoints.
247
380
  """
248
381
  host = _fix_host_if_needed(host)
249
- oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server'
250
- resp = client.do('GET', oidc)
382
+ oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server"
383
+ resp = client.do("GET", oidc)
251
384
  return OidcEndpoints.from_dict(resp)
252
385
 
253
386
 
@@ -258,12 +391,14 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
258
391
  :return: The workspace's OIDC endpoints.
259
392
  """
260
393
  host = _fix_host_if_needed(host)
261
- oidc = f'{host}/oidc/.well-known/oauth-authorization-server'
262
- resp = client.do('GET', oidc)
394
+ oidc = f"{host}/oidc/.well-known/oauth-authorization-server"
395
+ resp = client.do("GET", oidc)
263
396
  return OidcEndpoints.from_dict(resp)
264
397
 
265
398
 
266
- def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]:
399
+ def get_azure_entra_id_workspace_endpoints(
400
+ host: str,
401
+ ) -> Optional[OidcEndpoints]:
267
402
  """
268
403
  Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
269
404
  using an application registered in Azure Entra ID.
@@ -272,22 +407,26 @@ def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]
272
407
  """
273
408
  # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
274
409
  host = _fix_host_if_needed(host)
275
- res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
276
- real_auth_url = res.headers.get('location')
410
+ res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
411
+ real_auth_url = res.headers.get("location")
277
412
  if not real_auth_url:
278
413
  return None
279
- return OidcEndpoints(authorization_endpoint=real_auth_url,
280
- token_endpoint=real_auth_url.replace('/authorize', '/token'))
414
+ return OidcEndpoints(
415
+ authorization_endpoint=real_auth_url,
416
+ token_endpoint=real_auth_url.replace("/authorize", "/token"),
417
+ )
281
418
 
282
419
 
283
420
  class SessionCredentials(Refreshable):
284
421
 
285
- def __init__(self,
286
- token: Token,
287
- token_endpoint: str,
288
- client_id: str,
289
- client_secret: str = None,
290
- redirect_url: str = None):
422
+ def __init__(
423
+ self,
424
+ token: Token,
425
+ token_endpoint: str,
426
+ client_id: str,
427
+ client_secret: str = None,
428
+ redirect_url: str = None,
429
+ ):
291
430
  self._token_endpoint = token_endpoint
292
431
  self._client_id = client_id
293
432
  self._client_secret = client_secret
@@ -295,61 +434,72 @@ class SessionCredentials(Refreshable):
295
434
  super().__init__(token)
296
435
 
297
436
  def as_dict(self) -> dict:
298
- return {'token': self._token.as_dict()}
437
+ return {"token": self.token().as_dict()}
299
438
 
300
439
  @staticmethod
301
- def from_dict(raw: dict,
302
- token_endpoint: str,
303
- client_id: str,
304
- client_secret: str = None,
305
- redirect_url: str = None) -> 'SessionCredentials':
306
- return SessionCredentials(token=Token.from_dict(raw['token']),
307
- token_endpoint=token_endpoint,
308
- client_id=client_id,
309
- client_secret=client_secret,
310
- redirect_url=redirect_url)
440
+ def from_dict(
441
+ raw: dict,
442
+ token_endpoint: str,
443
+ client_id: str,
444
+ client_secret: str = None,
445
+ redirect_url: str = None,
446
+ ) -> "SessionCredentials":
447
+ return SessionCredentials(
448
+ token=Token.from_dict(raw["token"]),
449
+ token_endpoint=token_endpoint,
450
+ client_id=client_id,
451
+ client_secret=client_secret,
452
+ redirect_url=redirect_url,
453
+ )
311
454
 
312
455
  def auth_type(self):
313
456
  """Implementing CredentialsProvider protocol"""
314
457
  # TODO: distinguish between Databricks IDP and Azure AD
315
- return 'oauth'
458
+ return "oauth"
316
459
 
317
460
  def __call__(self, *args, **kwargs):
318
461
  """Implementing CredentialsProvider protocol"""
319
462
 
320
463
  def inner() -> Dict[str, str]:
321
- return {'Authorization': f"Bearer {self.token().access_token}"}
464
+ return {"Authorization": f"Bearer {self.token().access_token}"}
322
465
 
323
466
  return inner
324
467
 
325
468
  def refresh(self) -> Token:
326
469
  refresh_token = self._token.refresh_token
327
470
  if not refresh_token:
328
- raise ValueError('oauth2: token expired and refresh token is not set')
329
- params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token}
471
+ raise ValueError("oauth2: token expired and refresh token is not set")
472
+ params = {
473
+ "grant_type": "refresh_token",
474
+ "refresh_token": refresh_token,
475
+ }
330
476
  headers = {}
331
- if 'microsoft' in self._token_endpoint:
477
+ if "microsoft" in self._token_endpoint:
332
478
  # Tokens issued for the 'Single-Page Application' client-type may
333
479
  # only be redeemed via cross-origin requests
334
- headers = {'Origin': self._redirect_url}
335
- return retrieve_token(client_id=self._client_id,
336
- client_secret=self._client_secret,
337
- token_url=self._token_endpoint,
338
- params=params,
339
- use_params=True,
340
- headers=headers)
480
+ headers = {"Origin": self._redirect_url}
481
+ return retrieve_token(
482
+ client_id=self._client_id,
483
+ client_secret=self._client_secret,
484
+ token_url=self._token_endpoint,
485
+ params=params,
486
+ use_params=True,
487
+ headers=headers,
488
+ )
341
489
 
342
490
 
343
491
  class Consent:
344
492
 
345
- def __init__(self,
346
- state: str,
347
- verifier: str,
348
- authorization_url: str,
349
- redirect_url: str,
350
- token_endpoint: str,
351
- client_id: str,
352
- client_secret: str = None) -> None:
493
+ def __init__(
494
+ self,
495
+ state: str,
496
+ verifier: str,
497
+ authorization_url: str,
498
+ redirect_url: str,
499
+ token_endpoint: str,
500
+ client_id: str,
501
+ client_secret: str = None,
502
+ ) -> None:
353
503
  self._verifier = verifier
354
504
  self._state = state
355
505
  self._authorization_url = authorization_url
@@ -360,12 +510,12 @@ class Consent:
360
510
 
361
511
  def as_dict(self) -> dict:
362
512
  return {
363
- 'state': self._state,
364
- 'verifier': self._verifier,
365
- 'authorization_url': self._authorization_url,
366
- 'redirect_url': self._redirect_url,
367
- 'token_endpoint': self._token_endpoint,
368
- 'client_id': self._client_id,
513
+ "state": self._state,
514
+ "verifier": self._verifier,
515
+ "authorization_url": self._authorization_url,
516
+ "redirect_url": self._redirect_url,
517
+ "token_endpoint": self._token_endpoint,
518
+ "client_id": self._client_id,
369
519
  }
370
520
 
371
521
  @property
@@ -373,65 +523,74 @@ class Consent:
373
523
  return self._authorization_url
374
524
 
375
525
  @staticmethod
376
- def from_dict(raw: dict, client_secret: str = None) -> 'Consent':
377
- return Consent(raw['state'],
378
- raw['verifier'],
379
- authorization_url=raw['authorization_url'],
380
- redirect_url=raw['redirect_url'],
381
- token_endpoint=raw['token_endpoint'],
382
- client_id=raw['client_id'],
383
- client_secret=client_secret)
526
+ def from_dict(raw: dict, client_secret: str = None) -> "Consent":
527
+ return Consent(
528
+ raw["state"],
529
+ raw["verifier"],
530
+ authorization_url=raw["authorization_url"],
531
+ redirect_url=raw["redirect_url"],
532
+ token_endpoint=raw["token_endpoint"],
533
+ client_id=raw["client_id"],
534
+ client_secret=client_secret,
535
+ )
384
536
 
385
537
  def launch_external_browser(self) -> SessionCredentials:
386
538
  redirect_url = urllib.parse.urlparse(self._redirect_url)
387
- if redirect_url.hostname not in ('localhost', '127.0.0.1'):
388
- raise ValueError(f'cannot listen on {redirect_url.hostname}')
539
+ if redirect_url.hostname not in ("localhost", "127.0.0.1"):
540
+ raise ValueError(f"cannot listen on {redirect_url.hostname}")
389
541
  feedback = []
390
- logger.info(f'Opening {self._authorization_url} in a browser')
542
+ logger.info(f"Opening {self._authorization_url} in a browser")
391
543
  webbrowser.open_new(self._authorization_url)
392
544
  port = redirect_url.port
393
545
  handler_factory = functools.partial(_OAuthCallback, feedback)
394
546
  with HTTPServer(("localhost", port), handler_factory) as httpd:
395
- logger.info(f'Waiting for redirect to http://localhost:{port}')
547
+ logger.info(f"Waiting for redirect to http://localhost:{port}")
396
548
  httpd.handle_request()
397
549
  if not feedback:
398
- raise ValueError('No data received in callback')
550
+ raise ValueError("No data received in callback")
399
551
  query = feedback.pop()
400
552
  return self.exchange_callback_parameters(query)
401
553
 
402
554
  def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials:
403
- if 'error' in query:
404
- raise ValueError('{error}: {error_description}'.format(**query))
405
- if 'code' not in query or 'state' not in query:
406
- raise ValueError('No code returned in callback')
407
- return self.exchange(query['code'], query['state'])
555
+ if "error" in query:
556
+ raise ValueError("{error}: {error_description}".format(**query))
557
+ if "code" not in query or "state" not in query:
558
+ raise ValueError("No code returned in callback")
559
+ return self.exchange(query["code"], query["state"])
408
560
 
409
561
  def exchange(self, code: str, state: str) -> SessionCredentials:
410
562
  if self._state != state:
411
- raise ValueError('state mismatch')
563
+ raise ValueError("state mismatch")
412
564
  params = {
413
- 'redirect_uri': self._redirect_url,
414
- 'grant_type': 'authorization_code',
415
- 'code_verifier': self._verifier,
416
- 'code': code
565
+ "redirect_uri": self._redirect_url,
566
+ "grant_type": "authorization_code",
567
+ "code_verifier": self._verifier,
568
+ "code": code,
417
569
  }
418
570
  headers = {}
419
571
  while True:
420
572
  try:
421
- token = retrieve_token(client_id=self._client_id,
422
- client_secret=self._client_secret,
423
- token_url=self._token_endpoint,
424
- params=params,
425
- headers=headers,
426
- use_params=True)
427
- return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret,
428
- self._redirect_url)
573
+ token = retrieve_token(
574
+ client_id=self._client_id,
575
+ client_secret=self._client_secret,
576
+ token_url=self._token_endpoint,
577
+ params=params,
578
+ headers=headers,
579
+ use_params=True,
580
+ )
581
+ return SessionCredentials(
582
+ token,
583
+ self._token_endpoint,
584
+ self._client_id,
585
+ self._client_secret,
586
+ self._redirect_url,
587
+ )
429
588
  except ValueError as e:
430
589
  if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
431
590
  # Retry in cases of 'Single-Page Application' client-type with
432
591
  # 'Origin' header equal to client's redirect URL.
433
- headers['Origin'] = self._redirect_url
434
- msg = f'Retrying OAuth token exchange with {self._redirect_url} origin'
592
+ headers["Origin"] = self._redirect_url
593
+ msg = f"Retrying OAuth token exchange with {self._redirect_url} origin"
435
594
  logger.debug(msg)
436
595
  continue
437
596
  raise e
@@ -456,15 +615,17 @@ class OAuthClient:
456
615
  exchange it for a token without possessing the Code Verifier.
457
616
  """
458
617
 
459
- def __init__(self,
460
- oidc_endpoints: OidcEndpoints,
461
- redirect_url: str,
462
- client_id: str,
463
- scopes: List[str] = None,
464
- client_secret: str = None):
618
+ def __init__(
619
+ self,
620
+ oidc_endpoints: OidcEndpoints,
621
+ redirect_url: str,
622
+ client_id: str,
623
+ scopes: List[str] = None,
624
+ client_secret: str = None,
625
+ ):
465
626
 
466
627
  if not scopes:
467
- scopes = ['all-apis']
628
+ scopes = ["all-apis"]
468
629
 
469
630
  self.redirect_url = redirect_url
470
631
  self._client_id = client_id
@@ -473,25 +634,27 @@ class OAuthClient:
473
634
  self._scopes = scopes
474
635
 
475
636
  @staticmethod
476
- def from_host(host: str,
477
- client_id: str,
478
- redirect_url: str,
479
- *,
480
- scopes: List[str] = None,
481
- client_secret: str = None) -> 'OAuthClient':
637
+ def from_host(
638
+ host: str,
639
+ client_id: str,
640
+ redirect_url: str,
641
+ *,
642
+ scopes: List[str] = None,
643
+ client_secret: str = None,
644
+ ) -> "OAuthClient":
482
645
  from .core import Config
483
646
  from .credentials_provider import credentials_strategy
484
647
 
485
- @credentials_strategy('noop', [])
648
+ @credentials_strategy("noop", [])
486
649
  def noop_credentials(_: any):
487
650
  return lambda: {}
488
651
 
489
652
  config = Config(host=host, credentials_strategy=noop_credentials)
490
653
  if not scopes:
491
- scopes = ['all-apis']
654
+ scopes = ["all-apis"]
492
655
  oidc = config.oidc_endpoints
493
656
  if not oidc:
494
- raise ValueError(f'{host} does not support OAuth')
657
+ raise ValueError(f"{host} does not support OAuth")
495
658
  return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret)
496
659
 
497
660
  def initiate_consent(self) -> Consent:
@@ -500,28 +663,30 @@ class OAuthClient:
500
663
  # token_urlsafe() already returns base64-encoded string
501
664
  verifier = secrets.token_urlsafe(32)
502
665
  digest = hashlib.sha256(verifier.encode("UTF-8")).digest()
503
- challenge = (base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", ""))
666
+ challenge = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "")
504
667
 
505
668
  params = {
506
- 'response_type': 'code',
507
- 'client_id': self._client_id,
508
- 'redirect_uri': self.redirect_url,
509
- 'scope': ' '.join(self._scopes),
510
- 'state': state,
511
- 'code_challenge': challenge,
512
- 'code_challenge_method': 'S256'
669
+ "response_type": "code",
670
+ "client_id": self._client_id,
671
+ "redirect_uri": self.redirect_url,
672
+ "scope": " ".join(self._scopes),
673
+ "state": state,
674
+ "code_challenge": challenge,
675
+ "code_challenge_method": "S256",
513
676
  }
514
- auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}'
515
- return Consent(state,
516
- verifier,
517
- authorization_url=auth_url,
518
- redirect_url=self.redirect_url,
519
- token_endpoint=self._oidc_endpoints.token_endpoint,
520
- client_id=self._client_id,
521
- client_secret=self._client_secret)
677
+ auth_url = f"{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}"
678
+ return Consent(
679
+ state,
680
+ verifier,
681
+ authorization_url=auth_url,
682
+ redirect_url=self.redirect_url,
683
+ token_endpoint=self._oidc_endpoints.token_endpoint,
684
+ client_id=self._client_id,
685
+ client_secret=self._client_secret,
686
+ )
522
687
 
523
688
  def __repr__(self) -> str:
524
- return f'<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>'
689
+ return f"<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>"
525
690
 
526
691
 
527
692
  @dataclass
@@ -535,6 +700,7 @@ class ClientCredentials(Refreshable):
535
700
  the background job uses the Client ID and Client Secret to obtain
536
701
  an Access Token from the Authorization Server.
537
702
  """
703
+
538
704
  client_id: str
539
705
  client_secret: str
540
706
  token_url: str
@@ -553,24 +719,28 @@ class ClientCredentials(Refreshable):
553
719
  if self.endpoint_params:
554
720
  for k, v in self.endpoint_params.items():
555
721
  params[k] = v
556
- return retrieve_token(self.client_id,
557
- self.client_secret,
558
- self.token_url,
559
- params,
560
- use_params=self.use_params,
561
- use_header=self.use_header)
722
+ return retrieve_token(
723
+ self.client_id,
724
+ self.client_secret,
725
+ self.token_url,
726
+ params,
727
+ use_params=self.use_params,
728
+ use_header=self.use_header,
729
+ )
562
730
 
563
731
 
564
732
  class TokenCache:
565
733
  BASE_PATH = "~/.config/databricks-sdk-py/oauth"
566
734
 
567
- def __init__(self,
568
- host: str,
569
- oidc_endpoints: OidcEndpoints,
570
- client_id: str,
571
- redirect_url: str = None,
572
- client_secret: str = None,
573
- scopes: List[str] = None) -> None:
735
+ def __init__(
736
+ self,
737
+ host: str,
738
+ oidc_endpoints: OidcEndpoints,
739
+ client_id: str,
740
+ redirect_url: Optional[str] = None,
741
+ client_secret: Optional[str] = None,
742
+ scopes: Optional[List[str]] = None,
743
+ ) -> None:
574
744
  self._host = host
575
745
  self._client_id = client_id
576
746
  self._oidc_endpoints = oidc_endpoints
@@ -582,8 +752,12 @@ class TokenCache:
582
752
  def filename(self) -> str:
583
753
  # Include host, client_id, and scopes in the cache filename to make it unique.
584
754
  hash = hashlib.sha256()
585
- for chunk in [self._host, self._client_id, ",".join(self._scopes), ]:
586
- hash.update(chunk.encode('utf-8'))
755
+ for chunk in [
756
+ self._host,
757
+ self._client_id,
758
+ ",".join(self._scopes),
759
+ ]:
760
+ hash.update(chunk.encode("utf-8"))
587
761
  return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))
588
762
 
589
763
  def load(self) -> Optional[SessionCredentials]:
@@ -594,13 +768,15 @@ class TokenCache:
594
768
  return None
595
769
 
596
770
  try:
597
- with open(self.filename, 'r') as f:
771
+ with open(self.filename, "r") as f:
598
772
  raw = json.load(f)
599
- return SessionCredentials.from_dict(raw,
600
- token_endpoint=self._oidc_endpoints.token_endpoint,
601
- client_id=self._client_id,
602
- client_secret=self._client_secret,
603
- redirect_url=self._redirect_url)
773
+ return SessionCredentials.from_dict(
774
+ raw,
775
+ token_endpoint=self._oidc_endpoints.token_endpoint,
776
+ client_id=self._client_id,
777
+ client_secret=self._client_secret,
778
+ redirect_url=self._redirect_url,
779
+ )
604
780
  except Exception:
605
781
  return None
606
782
 
@@ -609,6 +785,6 @@ class TokenCache:
609
785
  Save credentials to cache file.
610
786
  """
611
787
  os.makedirs(os.path.dirname(self.filename), exist_ok=True)
612
- with open(self.filename, 'w') as f:
788
+ with open(self.filename, "w") as f:
613
789
  json.dump(credentials.as_dict(), f)
614
790
  os.chmod(self.filename, 0o600)