databricks-sdk 0.44.0__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +123 -115
- databricks/sdk/_base_client.py +112 -88
- databricks/sdk/_property.py +12 -7
- databricks/sdk/_widgets/__init__.py +13 -2
- databricks/sdk/_widgets/default_widgets_utils.py +21 -15
- databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
- databricks/sdk/azure.py +8 -6
- databricks/sdk/casing.py +5 -5
- databricks/sdk/config.py +152 -99
- databricks/sdk/core.py +57 -47
- databricks/sdk/credentials_provider.py +360 -210
- databricks/sdk/data_plane.py +86 -3
- databricks/sdk/dbutils.py +123 -87
- databricks/sdk/environments.py +52 -35
- databricks/sdk/errors/base.py +61 -35
- databricks/sdk/errors/customizer.py +3 -3
- databricks/sdk/errors/deserializer.py +38 -25
- databricks/sdk/errors/details.py +417 -0
- databricks/sdk/errors/mapper.py +1 -1
- databricks/sdk/errors/overrides.py +27 -24
- databricks/sdk/errors/parser.py +26 -14
- databricks/sdk/errors/platform.py +10 -10
- databricks/sdk/errors/private_link.py +24 -24
- databricks/sdk/logger/round_trip_logger.py +28 -20
- databricks/sdk/mixins/compute.py +90 -60
- databricks/sdk/mixins/files.py +815 -145
- databricks/sdk/mixins/jobs.py +201 -20
- databricks/sdk/mixins/open_ai_client.py +26 -20
- databricks/sdk/mixins/workspace.py +45 -34
- databricks/sdk/oauth.py +372 -196
- databricks/sdk/retries.py +14 -12
- databricks/sdk/runtime/__init__.py +34 -17
- databricks/sdk/runtime/dbutils_stub.py +52 -39
- databricks/sdk/service/_internal.py +12 -7
- databricks/sdk/service/apps.py +618 -418
- databricks/sdk/service/billing.py +827 -604
- databricks/sdk/service/catalog.py +6552 -4474
- databricks/sdk/service/cleanrooms.py +550 -388
- databricks/sdk/service/compute.py +5241 -3531
- databricks/sdk/service/dashboards.py +1313 -923
- databricks/sdk/service/files.py +442 -309
- databricks/sdk/service/iam.py +2115 -1483
- databricks/sdk/service/jobs.py +4151 -2588
- databricks/sdk/service/marketplace.py +2210 -1517
- databricks/sdk/service/ml.py +3364 -2255
- databricks/sdk/service/oauth2.py +922 -584
- databricks/sdk/service/pipelines.py +1865 -1203
- databricks/sdk/service/provisioning.py +1435 -1029
- databricks/sdk/service/serving.py +2040 -1278
- databricks/sdk/service/settings.py +2846 -1929
- databricks/sdk/service/sharing.py +2201 -877
- databricks/sdk/service/sql.py +4650 -3103
- databricks/sdk/service/vectorsearch.py +816 -550
- databricks/sdk/service/workspace.py +1330 -906
- databricks/sdk/useragent.py +36 -22
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
- databricks_sdk-0.45.0.dist-info/RECORD +70 -0
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
- databricks_sdk-0.44.0.dist-info/RECORD +0 -69
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
databricks/sdk/oauth.py
CHANGED
|
@@ -9,8 +9,10 @@ import threading
|
|
|
9
9
|
import urllib.parse
|
|
10
10
|
import webbrowser
|
|
11
11
|
from abc import abstractmethod
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
13
|
from dataclasses import dataclass
|
|
13
14
|
from datetime import datetime, timedelta
|
|
15
|
+
from enum import Enum
|
|
14
16
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
15
17
|
from typing import Any, Dict, List, Optional
|
|
16
18
|
|
|
@@ -21,7 +23,7 @@ from ._base_client import _BaseClient, _fix_host_if_needed
|
|
|
21
23
|
|
|
22
24
|
# Error code for PKCE flow in Azure Active Directory, that gets additional retry.
|
|
23
25
|
# See https://stackoverflow.com/a/75466778/277035 for more info
|
|
24
|
-
NO_ORIGIN_FOR_SPA_CLIENT_ERROR =
|
|
26
|
+
NO_ORIGIN_FOR_SPA_CLIENT_ERROR = "AADSTS9002327"
|
|
25
27
|
|
|
26
28
|
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
|
|
27
29
|
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
@@ -52,28 +54,33 @@ class OidcEndpoints:
|
|
|
52
54
|
The endpoints used for OAuth-based authentication in Databricks.
|
|
53
55
|
"""
|
|
54
56
|
|
|
55
|
-
authorization_endpoint: str
|
|
57
|
+
authorization_endpoint: str # ../v1/authorize
|
|
56
58
|
"""The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for
|
|
57
59
|
the user to login and authorize the client for user-to-machine (U2M) flows."""
|
|
58
60
|
|
|
59
|
-
token_endpoint: str
|
|
61
|
+
token_endpoint: str # ../v1/token
|
|
60
62
|
"""The token endpoint for the OAuth flow."""
|
|
61
63
|
|
|
62
64
|
@staticmethod
|
|
63
|
-
def from_dict(d: dict) ->
|
|
64
|
-
return OidcEndpoints(
|
|
65
|
-
|
|
65
|
+
def from_dict(d: dict) -> "OidcEndpoints":
|
|
66
|
+
return OidcEndpoints(
|
|
67
|
+
authorization_endpoint=d.get("authorization_endpoint"),
|
|
68
|
+
token_endpoint=d.get("token_endpoint"),
|
|
69
|
+
)
|
|
66
70
|
|
|
67
71
|
def as_dict(self) -> dict:
|
|
68
|
-
return {
|
|
72
|
+
return {
|
|
73
|
+
"authorization_endpoint": self.authorization_endpoint,
|
|
74
|
+
"token_endpoint": self.token_endpoint,
|
|
75
|
+
}
|
|
69
76
|
|
|
70
77
|
|
|
71
78
|
@dataclass
|
|
72
79
|
class Token:
|
|
73
80
|
access_token: str
|
|
74
|
-
token_type: str = None
|
|
75
|
-
refresh_token: str = None
|
|
76
|
-
expiry: datetime = None
|
|
81
|
+
token_type: Optional[str] = None
|
|
82
|
+
refresh_token: Optional[str] = None
|
|
83
|
+
expiry: Optional[datetime] = None
|
|
77
84
|
|
|
78
85
|
@property
|
|
79
86
|
def expired(self):
|
|
@@ -91,19 +98,24 @@ class Token:
|
|
|
91
98
|
return self.access_token and not self.expired
|
|
92
99
|
|
|
93
100
|
def as_dict(self) -> dict:
|
|
94
|
-
raw = {
|
|
101
|
+
raw = {
|
|
102
|
+
"access_token": self.access_token,
|
|
103
|
+
"token_type": self.token_type,
|
|
104
|
+
}
|
|
95
105
|
if self.expiry:
|
|
96
|
-
raw[
|
|
106
|
+
raw["expiry"] = self.expiry.isoformat()
|
|
97
107
|
if self.refresh_token:
|
|
98
|
-
raw[
|
|
108
|
+
raw["refresh_token"] = self.refresh_token
|
|
99
109
|
return raw
|
|
100
110
|
|
|
101
111
|
@staticmethod
|
|
102
|
-
def from_dict(raw: dict) ->
|
|
103
|
-
return Token(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
112
|
+
def from_dict(raw: dict) -> "Token":
|
|
113
|
+
return Token(
|
|
114
|
+
access_token=raw["access_token"],
|
|
115
|
+
token_type=raw["token_type"],
|
|
116
|
+
expiry=datetime.fromisoformat(raw["expiry"]),
|
|
117
|
+
refresh_token=raw.get("refresh_token"),
|
|
118
|
+
)
|
|
107
119
|
|
|
108
120
|
def jwt_claims(self) -> Dict[str, str]:
|
|
109
121
|
"""Get claims from the access token or return an empty dictionary if it is not a JWT token.
|
|
@@ -131,7 +143,7 @@ class Token:
|
|
|
131
143
|
try:
|
|
132
144
|
jwt_split = self.access_token.split(".")
|
|
133
145
|
if len(jwt_split) != 3:
|
|
134
|
-
logger.debug(f
|
|
146
|
+
logger.debug(f"Tried to decode access token as JWT, but failed: {len(jwt_split)} components")
|
|
135
147
|
return {}
|
|
136
148
|
payload_with_padding = jwt_split[1] + "=="
|
|
137
149
|
payload_bytes = base64.standard_b64decode(payload_with_padding)
|
|
@@ -139,7 +151,7 @@ class Token:
|
|
|
139
151
|
claims = json.loads(payload_json)
|
|
140
152
|
return claims
|
|
141
153
|
except ValueError as err:
|
|
142
|
-
logger.debug(f
|
|
154
|
+
logger.debug(f"Tried to decode access token as JWT, but failed: {err}")
|
|
143
155
|
return {}
|
|
144
156
|
|
|
145
157
|
|
|
@@ -150,17 +162,21 @@ class TokenSource:
|
|
|
150
162
|
pass
|
|
151
163
|
|
|
152
164
|
|
|
153
|
-
def retrieve_token(
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
165
|
+
def retrieve_token(
|
|
166
|
+
client_id,
|
|
167
|
+
client_secret,
|
|
168
|
+
token_url,
|
|
169
|
+
params,
|
|
170
|
+
use_params=False,
|
|
171
|
+
use_header=False,
|
|
172
|
+
headers=None,
|
|
173
|
+
) -> Token:
|
|
174
|
+
logger.debug(f"Retrieving token for {client_id}")
|
|
161
175
|
if use_params:
|
|
162
|
-
if client_id:
|
|
163
|
-
|
|
176
|
+
if client_id:
|
|
177
|
+
params["client_id"] = client_id
|
|
178
|
+
if client_secret:
|
|
179
|
+
params["client_secret"] = client_secret
|
|
164
180
|
auth = None
|
|
165
181
|
if use_header:
|
|
166
182
|
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
|
|
@@ -168,40 +184,156 @@ def retrieve_token(client_id,
|
|
|
168
184
|
auth = IgnoreNetrcAuth()
|
|
169
185
|
resp = requests.post(token_url, params, auth=auth, headers=headers)
|
|
170
186
|
if not resp.ok:
|
|
171
|
-
if resp.headers[
|
|
187
|
+
if resp.headers["Content-Type"].startswith("application/json"):
|
|
172
188
|
err = resp.json()
|
|
173
|
-
code = err.get(
|
|
174
|
-
summary = err.get(
|
|
175
|
-
summary = summary.replace("\r\n",
|
|
176
|
-
raise ValueError(f
|
|
189
|
+
code = err.get("errorCode", err.get("error", "unknown"))
|
|
190
|
+
summary = err.get("errorSummary", err.get("error_description", "unknown"))
|
|
191
|
+
summary = summary.replace("\r\n", " ")
|
|
192
|
+
raise ValueError(f"{code}: {summary}")
|
|
177
193
|
raise ValueError(resp.content)
|
|
178
194
|
try:
|
|
179
195
|
j = resp.json()
|
|
180
196
|
expires_in = int(j["expires_in"])
|
|
181
197
|
expiry = datetime.now() + timedelta(seconds=expires_in)
|
|
182
|
-
return Token(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
198
|
+
return Token(
|
|
199
|
+
access_token=j["access_token"],
|
|
200
|
+
refresh_token=j.get("refresh_token"),
|
|
201
|
+
token_type=j["token_type"],
|
|
202
|
+
expiry=expiry,
|
|
203
|
+
)
|
|
186
204
|
except Exception as e:
|
|
187
205
|
raise NotImplementedError(f"Not supported yet: {e}")
|
|
188
206
|
|
|
189
207
|
|
|
190
|
-
class
|
|
208
|
+
class _TokenState(Enum):
|
|
209
|
+
"""
|
|
210
|
+
Represents the state of a token. Each token can be in one of
|
|
211
|
+
the following three states:
|
|
212
|
+
- FRESH: The token is valid.
|
|
213
|
+
- STALE: The token is valid but will expire soon.
|
|
214
|
+
- EXPIRED: The token has expired and cannot be used.
|
|
215
|
+
"""
|
|
191
216
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
217
|
+
FRESH = 1 # The token is valid.
|
|
218
|
+
STALE = 2 # The token is valid but will expire soon.
|
|
219
|
+
EXPIRED = 3 # The token has expired and cannot be used.
|
|
195
220
|
|
|
221
|
+
|
|
222
|
+
class Refreshable(TokenSource):
|
|
223
|
+
"""A token source that supports refreshing expired tokens."""
|
|
224
|
+
|
|
225
|
+
_EXECUTOR = None
|
|
226
|
+
_EXECUTOR_LOCK = threading.Lock()
|
|
227
|
+
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def _get_executor(cls):
|
|
231
|
+
"""Lazy initialization of the ThreadPoolExecutor."""
|
|
232
|
+
if cls._EXECUTOR is None:
|
|
233
|
+
with cls._EXECUTOR_LOCK:
|
|
234
|
+
if cls._EXECUTOR is None:
|
|
235
|
+
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
|
|
236
|
+
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
|
|
237
|
+
return cls._EXECUTOR
|
|
238
|
+
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
token: Optional[Token] = None,
|
|
242
|
+
disable_async: bool = True,
|
|
243
|
+
stale_duration: timedelta = _DEFAULT_STALE_DURATION,
|
|
244
|
+
):
|
|
245
|
+
# Config properties
|
|
246
|
+
self._stale_duration = stale_duration
|
|
247
|
+
self._disable_async = disable_async
|
|
248
|
+
# Lock
|
|
249
|
+
self._lock = threading.Lock()
|
|
250
|
+
# Non Thread safe properties. They should be accessed only when protected by the lock above.
|
|
251
|
+
self._token = token or Token("")
|
|
252
|
+
self._is_refreshing = False
|
|
253
|
+
self._refresh_err = False
|
|
254
|
+
|
|
255
|
+
# This is the main entry point for the Token. Do not access the token
|
|
256
|
+
# using any of the internal functions.
|
|
196
257
|
def token(self) -> Token:
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if self.
|
|
200
|
-
return self.
|
|
201
|
-
|
|
258
|
+
"""Returns a valid token, blocking if async refresh is disabled."""
|
|
259
|
+
with self._lock:
|
|
260
|
+
if self._disable_async:
|
|
261
|
+
return self._blocking_token()
|
|
262
|
+
return self._async_token()
|
|
263
|
+
|
|
264
|
+
def _async_token(self) -> Token:
|
|
265
|
+
"""
|
|
266
|
+
Returns a token.
|
|
267
|
+
If the token is stale, triggers an asynchronous refresh.
|
|
268
|
+
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
|
|
269
|
+
"""
|
|
270
|
+
state = self._token_state()
|
|
271
|
+
token = self._token
|
|
272
|
+
|
|
273
|
+
if state == _TokenState.FRESH:
|
|
274
|
+
return token
|
|
275
|
+
if state == _TokenState.STALE:
|
|
276
|
+
self._trigger_async_refresh()
|
|
277
|
+
return token
|
|
278
|
+
return self._blocking_token()
|
|
279
|
+
|
|
280
|
+
def _token_state(self) -> _TokenState:
|
|
281
|
+
"""Returns the current state of the token."""
|
|
282
|
+
if not self._token or not self._token.valid:
|
|
283
|
+
return _TokenState.EXPIRED
|
|
284
|
+
if not self._token.expiry:
|
|
285
|
+
return _TokenState.FRESH
|
|
286
|
+
|
|
287
|
+
lifespan = self._token.expiry - datetime.now()
|
|
288
|
+
if lifespan < timedelta(seconds=0):
|
|
289
|
+
return _TokenState.EXPIRED
|
|
290
|
+
if lifespan < self._stale_duration:
|
|
291
|
+
return _TokenState.STALE
|
|
292
|
+
return _TokenState.FRESH
|
|
293
|
+
|
|
294
|
+
def _blocking_token(self) -> Token:
|
|
295
|
+
"""Returns a token, blocking if necessary to refresh it."""
|
|
296
|
+
state = self._token_state()
|
|
297
|
+
# This is important to recover from potential previous failed attempts
|
|
298
|
+
# to refresh the token asynchronously.
|
|
299
|
+
self._refresh_err = False
|
|
300
|
+
self._is_refreshing = False
|
|
301
|
+
|
|
302
|
+
# It's possible that the token got refreshed (either by a _blocking_refresh or
|
|
303
|
+
# an _async_refresh call) while this particular call was waiting to acquire
|
|
304
|
+
# the lock. This check avoids refreshing the token again in such cases.
|
|
305
|
+
if state != _TokenState.EXPIRED:
|
|
202
306
|
return self._token
|
|
203
|
-
|
|
204
|
-
|
|
307
|
+
|
|
308
|
+
self._token = self.refresh()
|
|
309
|
+
return self._token
|
|
310
|
+
|
|
311
|
+
def _trigger_async_refresh(self):
|
|
312
|
+
"""Starts an asynchronous refresh if none is in progress."""
|
|
313
|
+
|
|
314
|
+
def _refresh_internal():
|
|
315
|
+
new_token = None
|
|
316
|
+
try:
|
|
317
|
+
new_token = self.refresh()
|
|
318
|
+
except Exception as e:
|
|
319
|
+
# This happens on a thread, so we don't want to propagate the error.
|
|
320
|
+
# Instead, if there is no new_token for any reason, we will disable async refresh below
|
|
321
|
+
# But we will do it inside the lock.
|
|
322
|
+
logger.warning(f"Tried to refresh token asynchronously, but failed: {e}")
|
|
323
|
+
|
|
324
|
+
with self._lock:
|
|
325
|
+
if new_token is not None:
|
|
326
|
+
self._token = new_token
|
|
327
|
+
else:
|
|
328
|
+
self._refresh_err = True
|
|
329
|
+
self._is_refreshing = False
|
|
330
|
+
|
|
331
|
+
# The token may have been refreshed by another thread.
|
|
332
|
+
if self._token_state() == _TokenState.FRESH:
|
|
333
|
+
return
|
|
334
|
+
if not self._is_refreshing and not self._refresh_err:
|
|
335
|
+
self._is_refreshing = True
|
|
336
|
+
Refreshable._get_executor().submit(_refresh_internal)
|
|
205
337
|
|
|
206
338
|
@abstractmethod
|
|
207
339
|
def refresh(self) -> Token:
|
|
@@ -219,23 +351,24 @@ class _OAuthCallback(BaseHTTPRequestHandler):
|
|
|
219
351
|
|
|
220
352
|
def do_GET(self):
|
|
221
353
|
from urllib.parse import parse_qsl
|
|
222
|
-
|
|
354
|
+
|
|
355
|
+
parts = self.path.split("?")
|
|
223
356
|
if len(parts) != 2:
|
|
224
|
-
self.send_error(400,
|
|
357
|
+
self.send_error(400, "Missing Query")
|
|
225
358
|
return
|
|
226
359
|
|
|
227
360
|
query = dict(parse_qsl(parts[1]))
|
|
228
361
|
self._feedback.append(query)
|
|
229
362
|
|
|
230
|
-
if
|
|
231
|
-
self.send_error(400, query[
|
|
363
|
+
if "error" in query:
|
|
364
|
+
self.send_error(400, query["error"], query.get("error_description"))
|
|
232
365
|
return
|
|
233
366
|
|
|
234
367
|
self.send_response(200)
|
|
235
|
-
self.send_header(
|
|
368
|
+
self.send_header("Content-type", "text/html")
|
|
236
369
|
self.end_headers()
|
|
237
370
|
# TODO: show better message
|
|
238
|
-
self.wfile.write(b
|
|
371
|
+
self.wfile.write(b"You can close this tab.")
|
|
239
372
|
|
|
240
373
|
|
|
241
374
|
def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
|
|
@@ -246,8 +379,8 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas
|
|
|
246
379
|
:return: The account's OIDC endpoints.
|
|
247
380
|
"""
|
|
248
381
|
host = _fix_host_if_needed(host)
|
|
249
|
-
oidc = f
|
|
250
|
-
resp = client.do(
|
|
382
|
+
oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server"
|
|
383
|
+
resp = client.do("GET", oidc)
|
|
251
384
|
return OidcEndpoints.from_dict(resp)
|
|
252
385
|
|
|
253
386
|
|
|
@@ -258,12 +391,14 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
|
|
|
258
391
|
:return: The workspace's OIDC endpoints.
|
|
259
392
|
"""
|
|
260
393
|
host = _fix_host_if_needed(host)
|
|
261
|
-
oidc = f
|
|
262
|
-
resp = client.do(
|
|
394
|
+
oidc = f"{host}/oidc/.well-known/oauth-authorization-server"
|
|
395
|
+
resp = client.do("GET", oidc)
|
|
263
396
|
return OidcEndpoints.from_dict(resp)
|
|
264
397
|
|
|
265
398
|
|
|
266
|
-
def get_azure_entra_id_workspace_endpoints(
|
|
399
|
+
def get_azure_entra_id_workspace_endpoints(
|
|
400
|
+
host: str,
|
|
401
|
+
) -> Optional[OidcEndpoints]:
|
|
267
402
|
"""
|
|
268
403
|
Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
|
|
269
404
|
using an application registered in Azure Entra ID.
|
|
@@ -272,22 +407,26 @@ def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]
|
|
|
272
407
|
"""
|
|
273
408
|
# In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
|
|
274
409
|
host = _fix_host_if_needed(host)
|
|
275
|
-
res = requests.get(f
|
|
276
|
-
real_auth_url = res.headers.get(
|
|
410
|
+
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
|
|
411
|
+
real_auth_url = res.headers.get("location")
|
|
277
412
|
if not real_auth_url:
|
|
278
413
|
return None
|
|
279
|
-
return OidcEndpoints(
|
|
280
|
-
|
|
414
|
+
return OidcEndpoints(
|
|
415
|
+
authorization_endpoint=real_auth_url,
|
|
416
|
+
token_endpoint=real_auth_url.replace("/authorize", "/token"),
|
|
417
|
+
)
|
|
281
418
|
|
|
282
419
|
|
|
283
420
|
class SessionCredentials(Refreshable):
|
|
284
421
|
|
|
285
|
-
def __init__(
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
422
|
+
def __init__(
|
|
423
|
+
self,
|
|
424
|
+
token: Token,
|
|
425
|
+
token_endpoint: str,
|
|
426
|
+
client_id: str,
|
|
427
|
+
client_secret: str = None,
|
|
428
|
+
redirect_url: str = None,
|
|
429
|
+
):
|
|
291
430
|
self._token_endpoint = token_endpoint
|
|
292
431
|
self._client_id = client_id
|
|
293
432
|
self._client_secret = client_secret
|
|
@@ -295,61 +434,72 @@ class SessionCredentials(Refreshable):
|
|
|
295
434
|
super().__init__(token)
|
|
296
435
|
|
|
297
436
|
def as_dict(self) -> dict:
|
|
298
|
-
return {
|
|
437
|
+
return {"token": self.token().as_dict()}
|
|
299
438
|
|
|
300
439
|
@staticmethod
|
|
301
|
-
def from_dict(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
440
|
+
def from_dict(
|
|
441
|
+
raw: dict,
|
|
442
|
+
token_endpoint: str,
|
|
443
|
+
client_id: str,
|
|
444
|
+
client_secret: str = None,
|
|
445
|
+
redirect_url: str = None,
|
|
446
|
+
) -> "SessionCredentials":
|
|
447
|
+
return SessionCredentials(
|
|
448
|
+
token=Token.from_dict(raw["token"]),
|
|
449
|
+
token_endpoint=token_endpoint,
|
|
450
|
+
client_id=client_id,
|
|
451
|
+
client_secret=client_secret,
|
|
452
|
+
redirect_url=redirect_url,
|
|
453
|
+
)
|
|
311
454
|
|
|
312
455
|
def auth_type(self):
|
|
313
456
|
"""Implementing CredentialsProvider protocol"""
|
|
314
457
|
# TODO: distinguish between Databricks IDP and Azure AD
|
|
315
|
-
return
|
|
458
|
+
return "oauth"
|
|
316
459
|
|
|
317
460
|
def __call__(self, *args, **kwargs):
|
|
318
461
|
"""Implementing CredentialsProvider protocol"""
|
|
319
462
|
|
|
320
463
|
def inner() -> Dict[str, str]:
|
|
321
|
-
return {
|
|
464
|
+
return {"Authorization": f"Bearer {self.token().access_token}"}
|
|
322
465
|
|
|
323
466
|
return inner
|
|
324
467
|
|
|
325
468
|
def refresh(self) -> Token:
|
|
326
469
|
refresh_token = self._token.refresh_token
|
|
327
470
|
if not refresh_token:
|
|
328
|
-
raise ValueError(
|
|
329
|
-
params = {
|
|
471
|
+
raise ValueError("oauth2: token expired and refresh token is not set")
|
|
472
|
+
params = {
|
|
473
|
+
"grant_type": "refresh_token",
|
|
474
|
+
"refresh_token": refresh_token,
|
|
475
|
+
}
|
|
330
476
|
headers = {}
|
|
331
|
-
if
|
|
477
|
+
if "microsoft" in self._token_endpoint:
|
|
332
478
|
# Tokens issued for the 'Single-Page Application' client-type may
|
|
333
479
|
# only be redeemed via cross-origin requests
|
|
334
|
-
headers = {
|
|
335
|
-
return retrieve_token(
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
480
|
+
headers = {"Origin": self._redirect_url}
|
|
481
|
+
return retrieve_token(
|
|
482
|
+
client_id=self._client_id,
|
|
483
|
+
client_secret=self._client_secret,
|
|
484
|
+
token_url=self._token_endpoint,
|
|
485
|
+
params=params,
|
|
486
|
+
use_params=True,
|
|
487
|
+
headers=headers,
|
|
488
|
+
)
|
|
341
489
|
|
|
342
490
|
|
|
343
491
|
class Consent:
|
|
344
492
|
|
|
345
|
-
def __init__(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
493
|
+
def __init__(
|
|
494
|
+
self,
|
|
495
|
+
state: str,
|
|
496
|
+
verifier: str,
|
|
497
|
+
authorization_url: str,
|
|
498
|
+
redirect_url: str,
|
|
499
|
+
token_endpoint: str,
|
|
500
|
+
client_id: str,
|
|
501
|
+
client_secret: str = None,
|
|
502
|
+
) -> None:
|
|
353
503
|
self._verifier = verifier
|
|
354
504
|
self._state = state
|
|
355
505
|
self._authorization_url = authorization_url
|
|
@@ -360,12 +510,12 @@ class Consent:
|
|
|
360
510
|
|
|
361
511
|
def as_dict(self) -> dict:
|
|
362
512
|
return {
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
513
|
+
"state": self._state,
|
|
514
|
+
"verifier": self._verifier,
|
|
515
|
+
"authorization_url": self._authorization_url,
|
|
516
|
+
"redirect_url": self._redirect_url,
|
|
517
|
+
"token_endpoint": self._token_endpoint,
|
|
518
|
+
"client_id": self._client_id,
|
|
369
519
|
}
|
|
370
520
|
|
|
371
521
|
@property
|
|
@@ -373,65 +523,74 @@ class Consent:
|
|
|
373
523
|
return self._authorization_url
|
|
374
524
|
|
|
375
525
|
@staticmethod
|
|
376
|
-
def from_dict(raw: dict, client_secret: str = None) ->
|
|
377
|
-
return Consent(
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
526
|
+
def from_dict(raw: dict, client_secret: str = None) -> "Consent":
|
|
527
|
+
return Consent(
|
|
528
|
+
raw["state"],
|
|
529
|
+
raw["verifier"],
|
|
530
|
+
authorization_url=raw["authorization_url"],
|
|
531
|
+
redirect_url=raw["redirect_url"],
|
|
532
|
+
token_endpoint=raw["token_endpoint"],
|
|
533
|
+
client_id=raw["client_id"],
|
|
534
|
+
client_secret=client_secret,
|
|
535
|
+
)
|
|
384
536
|
|
|
385
537
|
def launch_external_browser(self) -> SessionCredentials:
|
|
386
538
|
redirect_url = urllib.parse.urlparse(self._redirect_url)
|
|
387
|
-
if redirect_url.hostname not in (
|
|
388
|
-
raise ValueError(f
|
|
539
|
+
if redirect_url.hostname not in ("localhost", "127.0.0.1"):
|
|
540
|
+
raise ValueError(f"cannot listen on {redirect_url.hostname}")
|
|
389
541
|
feedback = []
|
|
390
|
-
logger.info(f
|
|
542
|
+
logger.info(f"Opening {self._authorization_url} in a browser")
|
|
391
543
|
webbrowser.open_new(self._authorization_url)
|
|
392
544
|
port = redirect_url.port
|
|
393
545
|
handler_factory = functools.partial(_OAuthCallback, feedback)
|
|
394
546
|
with HTTPServer(("localhost", port), handler_factory) as httpd:
|
|
395
|
-
logger.info(f
|
|
547
|
+
logger.info(f"Waiting for redirect to http://localhost:{port}")
|
|
396
548
|
httpd.handle_request()
|
|
397
549
|
if not feedback:
|
|
398
|
-
raise ValueError(
|
|
550
|
+
raise ValueError("No data received in callback")
|
|
399
551
|
query = feedback.pop()
|
|
400
552
|
return self.exchange_callback_parameters(query)
|
|
401
553
|
|
|
402
554
|
def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials:
|
|
403
|
-
if
|
|
404
|
-
raise ValueError(
|
|
405
|
-
if
|
|
406
|
-
raise ValueError(
|
|
407
|
-
return self.exchange(query[
|
|
555
|
+
if "error" in query:
|
|
556
|
+
raise ValueError("{error}: {error_description}".format(**query))
|
|
557
|
+
if "code" not in query or "state" not in query:
|
|
558
|
+
raise ValueError("No code returned in callback")
|
|
559
|
+
return self.exchange(query["code"], query["state"])
|
|
408
560
|
|
|
409
561
|
def exchange(self, code: str, state: str) -> SessionCredentials:
|
|
410
562
|
if self._state != state:
|
|
411
|
-
raise ValueError(
|
|
563
|
+
raise ValueError("state mismatch")
|
|
412
564
|
params = {
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
565
|
+
"redirect_uri": self._redirect_url,
|
|
566
|
+
"grant_type": "authorization_code",
|
|
567
|
+
"code_verifier": self._verifier,
|
|
568
|
+
"code": code,
|
|
417
569
|
}
|
|
418
570
|
headers = {}
|
|
419
571
|
while True:
|
|
420
572
|
try:
|
|
421
|
-
token = retrieve_token(
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
573
|
+
token = retrieve_token(
|
|
574
|
+
client_id=self._client_id,
|
|
575
|
+
client_secret=self._client_secret,
|
|
576
|
+
token_url=self._token_endpoint,
|
|
577
|
+
params=params,
|
|
578
|
+
headers=headers,
|
|
579
|
+
use_params=True,
|
|
580
|
+
)
|
|
581
|
+
return SessionCredentials(
|
|
582
|
+
token,
|
|
583
|
+
self._token_endpoint,
|
|
584
|
+
self._client_id,
|
|
585
|
+
self._client_secret,
|
|
586
|
+
self._redirect_url,
|
|
587
|
+
)
|
|
429
588
|
except ValueError as e:
|
|
430
589
|
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
|
|
431
590
|
# Retry in cases of 'Single-Page Application' client-type with
|
|
432
591
|
# 'Origin' header equal to client's redirect URL.
|
|
433
|
-
headers[
|
|
434
|
-
msg = f
|
|
592
|
+
headers["Origin"] = self._redirect_url
|
|
593
|
+
msg = f"Retrying OAuth token exchange with {self._redirect_url} origin"
|
|
435
594
|
logger.debug(msg)
|
|
436
595
|
continue
|
|
437
596
|
raise e
|
|
@@ -456,15 +615,17 @@ class OAuthClient:
|
|
|
456
615
|
exchange it for a token without possessing the Code Verifier.
|
|
457
616
|
"""
|
|
458
617
|
|
|
459
|
-
def __init__(
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
618
|
+
def __init__(
|
|
619
|
+
self,
|
|
620
|
+
oidc_endpoints: OidcEndpoints,
|
|
621
|
+
redirect_url: str,
|
|
622
|
+
client_id: str,
|
|
623
|
+
scopes: List[str] = None,
|
|
624
|
+
client_secret: str = None,
|
|
625
|
+
):
|
|
465
626
|
|
|
466
627
|
if not scopes:
|
|
467
|
-
scopes = [
|
|
628
|
+
scopes = ["all-apis"]
|
|
468
629
|
|
|
469
630
|
self.redirect_url = redirect_url
|
|
470
631
|
self._client_id = client_id
|
|
@@ -473,25 +634,27 @@ class OAuthClient:
|
|
|
473
634
|
self._scopes = scopes
|
|
474
635
|
|
|
475
636
|
@staticmethod
|
|
476
|
-
def from_host(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
637
|
+
def from_host(
|
|
638
|
+
host: str,
|
|
639
|
+
client_id: str,
|
|
640
|
+
redirect_url: str,
|
|
641
|
+
*,
|
|
642
|
+
scopes: List[str] = None,
|
|
643
|
+
client_secret: str = None,
|
|
644
|
+
) -> "OAuthClient":
|
|
482
645
|
from .core import Config
|
|
483
646
|
from .credentials_provider import credentials_strategy
|
|
484
647
|
|
|
485
|
-
@credentials_strategy(
|
|
648
|
+
@credentials_strategy("noop", [])
|
|
486
649
|
def noop_credentials(_: any):
|
|
487
650
|
return lambda: {}
|
|
488
651
|
|
|
489
652
|
config = Config(host=host, credentials_strategy=noop_credentials)
|
|
490
653
|
if not scopes:
|
|
491
|
-
scopes = [
|
|
654
|
+
scopes = ["all-apis"]
|
|
492
655
|
oidc = config.oidc_endpoints
|
|
493
656
|
if not oidc:
|
|
494
|
-
raise ValueError(f
|
|
657
|
+
raise ValueError(f"{host} does not support OAuth")
|
|
495
658
|
return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret)
|
|
496
659
|
|
|
497
660
|
def initiate_consent(self) -> Consent:
|
|
@@ -500,28 +663,30 @@ class OAuthClient:
|
|
|
500
663
|
# token_urlsafe() already returns base64-encoded string
|
|
501
664
|
verifier = secrets.token_urlsafe(32)
|
|
502
665
|
digest = hashlib.sha256(verifier.encode("UTF-8")).digest()
|
|
503
|
-
challenge =
|
|
666
|
+
challenge = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "")
|
|
504
667
|
|
|
505
668
|
params = {
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
669
|
+
"response_type": "code",
|
|
670
|
+
"client_id": self._client_id,
|
|
671
|
+
"redirect_uri": self.redirect_url,
|
|
672
|
+
"scope": " ".join(self._scopes),
|
|
673
|
+
"state": state,
|
|
674
|
+
"code_challenge": challenge,
|
|
675
|
+
"code_challenge_method": "S256",
|
|
513
676
|
}
|
|
514
|
-
auth_url = f
|
|
515
|
-
return Consent(
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
677
|
+
auth_url = f"{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}"
|
|
678
|
+
return Consent(
|
|
679
|
+
state,
|
|
680
|
+
verifier,
|
|
681
|
+
authorization_url=auth_url,
|
|
682
|
+
redirect_url=self.redirect_url,
|
|
683
|
+
token_endpoint=self._oidc_endpoints.token_endpoint,
|
|
684
|
+
client_id=self._client_id,
|
|
685
|
+
client_secret=self._client_secret,
|
|
686
|
+
)
|
|
522
687
|
|
|
523
688
|
def __repr__(self) -> str:
|
|
524
|
-
return f
|
|
689
|
+
return f"<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>"
|
|
525
690
|
|
|
526
691
|
|
|
527
692
|
@dataclass
|
|
@@ -535,6 +700,7 @@ class ClientCredentials(Refreshable):
|
|
|
535
700
|
the background job uses the Client ID and Client Secret to obtain
|
|
536
701
|
an Access Token from the Authorization Server.
|
|
537
702
|
"""
|
|
703
|
+
|
|
538
704
|
client_id: str
|
|
539
705
|
client_secret: str
|
|
540
706
|
token_url: str
|
|
@@ -553,24 +719,28 @@ class ClientCredentials(Refreshable):
|
|
|
553
719
|
if self.endpoint_params:
|
|
554
720
|
for k, v in self.endpoint_params.items():
|
|
555
721
|
params[k] = v
|
|
556
|
-
return retrieve_token(
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
722
|
+
return retrieve_token(
|
|
723
|
+
self.client_id,
|
|
724
|
+
self.client_secret,
|
|
725
|
+
self.token_url,
|
|
726
|
+
params,
|
|
727
|
+
use_params=self.use_params,
|
|
728
|
+
use_header=self.use_header,
|
|
729
|
+
)
|
|
562
730
|
|
|
563
731
|
|
|
564
732
|
class TokenCache:
|
|
565
733
|
BASE_PATH = "~/.config/databricks-sdk-py/oauth"
|
|
566
734
|
|
|
567
|
-
def __init__(
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
735
|
+
def __init__(
|
|
736
|
+
self,
|
|
737
|
+
host: str,
|
|
738
|
+
oidc_endpoints: OidcEndpoints,
|
|
739
|
+
client_id: str,
|
|
740
|
+
redirect_url: Optional[str] = None,
|
|
741
|
+
client_secret: Optional[str] = None,
|
|
742
|
+
scopes: Optional[List[str]] = None,
|
|
743
|
+
) -> None:
|
|
574
744
|
self._host = host
|
|
575
745
|
self._client_id = client_id
|
|
576
746
|
self._oidc_endpoints = oidc_endpoints
|
|
@@ -582,8 +752,12 @@ class TokenCache:
|
|
|
582
752
|
def filename(self) -> str:
|
|
583
753
|
# Include host, client_id, and scopes in the cache filename to make it unique.
|
|
584
754
|
hash = hashlib.sha256()
|
|
585
|
-
for chunk in [
|
|
586
|
-
|
|
755
|
+
for chunk in [
|
|
756
|
+
self._host,
|
|
757
|
+
self._client_id,
|
|
758
|
+
",".join(self._scopes),
|
|
759
|
+
]:
|
|
760
|
+
hash.update(chunk.encode("utf-8"))
|
|
587
761
|
return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))
|
|
588
762
|
|
|
589
763
|
def load(self) -> Optional[SessionCredentials]:
|
|
@@ -594,13 +768,15 @@ class TokenCache:
|
|
|
594
768
|
return None
|
|
595
769
|
|
|
596
770
|
try:
|
|
597
|
-
with open(self.filename,
|
|
771
|
+
with open(self.filename, "r") as f:
|
|
598
772
|
raw = json.load(f)
|
|
599
|
-
return SessionCredentials.from_dict(
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
773
|
+
return SessionCredentials.from_dict(
|
|
774
|
+
raw,
|
|
775
|
+
token_endpoint=self._oidc_endpoints.token_endpoint,
|
|
776
|
+
client_id=self._client_id,
|
|
777
|
+
client_secret=self._client_secret,
|
|
778
|
+
redirect_url=self._redirect_url,
|
|
779
|
+
)
|
|
604
780
|
except Exception:
|
|
605
781
|
return None
|
|
606
782
|
|
|
@@ -609,6 +785,6 @@ class TokenCache:
|
|
|
609
785
|
Save credentials to cache file.
|
|
610
786
|
"""
|
|
611
787
|
os.makedirs(os.path.dirname(self.filename), exist_ok=True)
|
|
612
|
-
with open(self.filename,
|
|
788
|
+
with open(self.filename, "w") as f:
|
|
613
789
|
json.dump(credentials.as_dict(), f)
|
|
614
790
|
os.chmod(self.filename, 0o600)
|