databricks-sdk 0.44.1__py3-none-any.whl → 0.46.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +135 -116
- databricks/sdk/_base_client.py +112 -88
- databricks/sdk/_property.py +12 -7
- databricks/sdk/_widgets/__init__.py +13 -2
- databricks/sdk/_widgets/default_widgets_utils.py +21 -15
- databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
- databricks/sdk/azure.py +8 -6
- databricks/sdk/casing.py +5 -5
- databricks/sdk/config.py +156 -99
- databricks/sdk/core.py +57 -47
- databricks/sdk/credentials_provider.py +306 -206
- databricks/sdk/data_plane.py +75 -50
- databricks/sdk/dbutils.py +123 -87
- databricks/sdk/environments.py +52 -35
- databricks/sdk/errors/base.py +61 -35
- databricks/sdk/errors/customizer.py +3 -3
- databricks/sdk/errors/deserializer.py +38 -25
- databricks/sdk/errors/details.py +417 -0
- databricks/sdk/errors/mapper.py +1 -1
- databricks/sdk/errors/overrides.py +27 -24
- databricks/sdk/errors/parser.py +26 -14
- databricks/sdk/errors/platform.py +10 -10
- databricks/sdk/errors/private_link.py +24 -24
- databricks/sdk/logger/round_trip_logger.py +28 -20
- databricks/sdk/mixins/compute.py +90 -60
- databricks/sdk/mixins/files.py +815 -145
- databricks/sdk/mixins/jobs.py +191 -16
- databricks/sdk/mixins/open_ai_client.py +26 -20
- databricks/sdk/mixins/workspace.py +45 -34
- databricks/sdk/oauth.py +379 -198
- databricks/sdk/retries.py +14 -12
- databricks/sdk/runtime/__init__.py +34 -17
- databricks/sdk/runtime/dbutils_stub.py +52 -39
- databricks/sdk/service/_internal.py +12 -7
- databricks/sdk/service/apps.py +618 -418
- databricks/sdk/service/billing.py +827 -604
- databricks/sdk/service/catalog.py +6552 -4474
- databricks/sdk/service/cleanrooms.py +550 -388
- databricks/sdk/service/compute.py +5263 -3536
- databricks/sdk/service/dashboards.py +1331 -924
- databricks/sdk/service/files.py +446 -309
- databricks/sdk/service/iam.py +2115 -1483
- databricks/sdk/service/jobs.py +4151 -2588
- databricks/sdk/service/marketplace.py +2210 -1517
- databricks/sdk/service/ml.py +3839 -2256
- databricks/sdk/service/oauth2.py +910 -584
- databricks/sdk/service/pipelines.py +1865 -1203
- databricks/sdk/service/provisioning.py +1435 -1029
- databricks/sdk/service/serving.py +2060 -1290
- databricks/sdk/service/settings.py +2846 -1929
- databricks/sdk/service/sharing.py +2201 -877
- databricks/sdk/service/sql.py +4650 -3103
- databricks/sdk/service/vectorsearch.py +816 -550
- databricks/sdk/service/workspace.py +1330 -906
- databricks/sdk/useragent.py +36 -22
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/METADATA +31 -31
- databricks_sdk-0.46.0.dist-info/RECORD +70 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/WHEEL +1 -1
- databricks_sdk-0.44.1.dist-info/RECORD +0 -69
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/top_level.txt +0 -0
databricks/sdk/oauth.py
CHANGED
|
@@ -9,8 +9,10 @@ import threading
|
|
|
9
9
|
import urllib.parse
|
|
10
10
|
import webbrowser
|
|
11
11
|
from abc import abstractmethod
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
13
|
from dataclasses import dataclass
|
|
13
14
|
from datetime import datetime, timedelta
|
|
15
|
+
from enum import Enum
|
|
14
16
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
15
17
|
from typing import Any, Dict, List, Optional
|
|
16
18
|
|
|
@@ -21,7 +23,7 @@ from ._base_client import _BaseClient, _fix_host_if_needed
|
|
|
21
23
|
|
|
22
24
|
# Error code for PKCE flow in Azure Active Directory, that gets additional retry.
|
|
23
25
|
# See https://stackoverflow.com/a/75466778/277035 for more info
|
|
24
|
-
NO_ORIGIN_FOR_SPA_CLIENT_ERROR =
|
|
26
|
+
NO_ORIGIN_FOR_SPA_CLIENT_ERROR = "AADSTS9002327"
|
|
25
27
|
|
|
26
28
|
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
|
|
27
29
|
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
@@ -52,28 +54,33 @@ class OidcEndpoints:
|
|
|
52
54
|
The endpoints used for OAuth-based authentication in Databricks.
|
|
53
55
|
"""
|
|
54
56
|
|
|
55
|
-
authorization_endpoint: str
|
|
57
|
+
authorization_endpoint: str # ../v1/authorize
|
|
56
58
|
"""The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for
|
|
57
59
|
the user to login and authorize the client for user-to-machine (U2M) flows."""
|
|
58
60
|
|
|
59
|
-
token_endpoint: str
|
|
61
|
+
token_endpoint: str # ../v1/token
|
|
60
62
|
"""The token endpoint for the OAuth flow."""
|
|
61
63
|
|
|
62
64
|
@staticmethod
|
|
63
|
-
def from_dict(d: dict) ->
|
|
64
|
-
return OidcEndpoints(
|
|
65
|
-
|
|
65
|
+
def from_dict(d: dict) -> "OidcEndpoints":
|
|
66
|
+
return OidcEndpoints(
|
|
67
|
+
authorization_endpoint=d.get("authorization_endpoint"),
|
|
68
|
+
token_endpoint=d.get("token_endpoint"),
|
|
69
|
+
)
|
|
66
70
|
|
|
67
71
|
def as_dict(self) -> dict:
|
|
68
|
-
return {
|
|
72
|
+
return {
|
|
73
|
+
"authorization_endpoint": self.authorization_endpoint,
|
|
74
|
+
"token_endpoint": self.token_endpoint,
|
|
75
|
+
}
|
|
69
76
|
|
|
70
77
|
|
|
71
78
|
@dataclass
|
|
72
79
|
class Token:
|
|
73
80
|
access_token: str
|
|
74
|
-
token_type: str = None
|
|
75
|
-
refresh_token: str = None
|
|
76
|
-
expiry: datetime = None
|
|
81
|
+
token_type: Optional[str] = None
|
|
82
|
+
refresh_token: Optional[str] = None
|
|
83
|
+
expiry: Optional[datetime] = None
|
|
77
84
|
|
|
78
85
|
@property
|
|
79
86
|
def expired(self):
|
|
@@ -91,19 +98,24 @@ class Token:
|
|
|
91
98
|
return self.access_token and not self.expired
|
|
92
99
|
|
|
93
100
|
def as_dict(self) -> dict:
|
|
94
|
-
raw = {
|
|
101
|
+
raw = {
|
|
102
|
+
"access_token": self.access_token,
|
|
103
|
+
"token_type": self.token_type,
|
|
104
|
+
}
|
|
95
105
|
if self.expiry:
|
|
96
|
-
raw[
|
|
106
|
+
raw["expiry"] = self.expiry.isoformat()
|
|
97
107
|
if self.refresh_token:
|
|
98
|
-
raw[
|
|
108
|
+
raw["refresh_token"] = self.refresh_token
|
|
99
109
|
return raw
|
|
100
110
|
|
|
101
111
|
@staticmethod
|
|
102
|
-
def from_dict(raw: dict) ->
|
|
103
|
-
return Token(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
112
|
+
def from_dict(raw: dict) -> "Token":
|
|
113
|
+
return Token(
|
|
114
|
+
access_token=raw["access_token"],
|
|
115
|
+
token_type=raw["token_type"],
|
|
116
|
+
expiry=datetime.fromisoformat(raw["expiry"]),
|
|
117
|
+
refresh_token=raw.get("refresh_token"),
|
|
118
|
+
)
|
|
107
119
|
|
|
108
120
|
def jwt_claims(self) -> Dict[str, str]:
|
|
109
121
|
"""Get claims from the access token or return an empty dictionary if it is not a JWT token.
|
|
@@ -131,7 +143,7 @@ class Token:
|
|
|
131
143
|
try:
|
|
132
144
|
jwt_split = self.access_token.split(".")
|
|
133
145
|
if len(jwt_split) != 3:
|
|
134
|
-
logger.debug(f
|
|
146
|
+
logger.debug(f"Tried to decode access token as JWT, but failed: {len(jwt_split)} components")
|
|
135
147
|
return {}
|
|
136
148
|
payload_with_padding = jwt_split[1] + "=="
|
|
137
149
|
payload_bytes = base64.standard_b64decode(payload_with_padding)
|
|
@@ -139,7 +151,7 @@ class Token:
|
|
|
139
151
|
claims = json.loads(payload_json)
|
|
140
152
|
return claims
|
|
141
153
|
except ValueError as err:
|
|
142
|
-
logger.debug(f
|
|
154
|
+
logger.debug(f"Tried to decode access token as JWT, but failed: {err}")
|
|
143
155
|
return {}
|
|
144
156
|
|
|
145
157
|
|
|
@@ -150,17 +162,21 @@ class TokenSource:
|
|
|
150
162
|
pass
|
|
151
163
|
|
|
152
164
|
|
|
153
|
-
def retrieve_token(
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
165
|
+
def retrieve_token(
|
|
166
|
+
client_id,
|
|
167
|
+
client_secret,
|
|
168
|
+
token_url,
|
|
169
|
+
params,
|
|
170
|
+
use_params=False,
|
|
171
|
+
use_header=False,
|
|
172
|
+
headers=None,
|
|
173
|
+
) -> Token:
|
|
174
|
+
logger.debug(f"Retrieving token for {client_id}")
|
|
161
175
|
if use_params:
|
|
162
|
-
if client_id:
|
|
163
|
-
|
|
176
|
+
if client_id:
|
|
177
|
+
params["client_id"] = client_id
|
|
178
|
+
if client_secret:
|
|
179
|
+
params["client_secret"] = client_secret
|
|
164
180
|
auth = None
|
|
165
181
|
if use_header:
|
|
166
182
|
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
|
|
@@ -168,40 +184,156 @@ def retrieve_token(client_id,
|
|
|
168
184
|
auth = IgnoreNetrcAuth()
|
|
169
185
|
resp = requests.post(token_url, params, auth=auth, headers=headers)
|
|
170
186
|
if not resp.ok:
|
|
171
|
-
if resp.headers[
|
|
187
|
+
if resp.headers["Content-Type"].startswith("application/json"):
|
|
172
188
|
err = resp.json()
|
|
173
|
-
code = err.get(
|
|
174
|
-
summary = err.get(
|
|
175
|
-
summary = summary.replace("\r\n",
|
|
176
|
-
raise ValueError(f
|
|
189
|
+
code = err.get("errorCode", err.get("error", "unknown"))
|
|
190
|
+
summary = err.get("errorSummary", err.get("error_description", "unknown"))
|
|
191
|
+
summary = summary.replace("\r\n", " ")
|
|
192
|
+
raise ValueError(f"{code}: {summary}")
|
|
177
193
|
raise ValueError(resp.content)
|
|
178
194
|
try:
|
|
179
195
|
j = resp.json()
|
|
180
196
|
expires_in = int(j["expires_in"])
|
|
181
197
|
expiry = datetime.now() + timedelta(seconds=expires_in)
|
|
182
|
-
return Token(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
198
|
+
return Token(
|
|
199
|
+
access_token=j["access_token"],
|
|
200
|
+
refresh_token=j.get("refresh_token"),
|
|
201
|
+
token_type=j["token_type"],
|
|
202
|
+
expiry=expiry,
|
|
203
|
+
)
|
|
186
204
|
except Exception as e:
|
|
187
205
|
raise NotImplementedError(f"Not supported yet: {e}")
|
|
188
206
|
|
|
189
207
|
|
|
190
|
-
class
|
|
208
|
+
class _TokenState(Enum):
|
|
209
|
+
"""
|
|
210
|
+
Represents the state of a token. Each token can be in one of
|
|
211
|
+
the following three states:
|
|
212
|
+
- FRESH: The token is valid.
|
|
213
|
+
- STALE: The token is valid but will expire soon.
|
|
214
|
+
- EXPIRED: The token has expired and cannot be used.
|
|
215
|
+
"""
|
|
191
216
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
217
|
+
FRESH = 1 # The token is valid.
|
|
218
|
+
STALE = 2 # The token is valid but will expire soon.
|
|
219
|
+
EXPIRED = 3 # The token has expired and cannot be used.
|
|
195
220
|
|
|
221
|
+
|
|
222
|
+
class Refreshable(TokenSource):
|
|
223
|
+
"""A token source that supports refreshing expired tokens."""
|
|
224
|
+
|
|
225
|
+
_EXECUTOR = None
|
|
226
|
+
_EXECUTOR_LOCK = threading.Lock()
|
|
227
|
+
_DEFAULT_STALE_DURATION = timedelta(minutes=3)
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def _get_executor(cls):
|
|
231
|
+
"""Lazy initialization of the ThreadPoolExecutor."""
|
|
232
|
+
if cls._EXECUTOR is None:
|
|
233
|
+
with cls._EXECUTOR_LOCK:
|
|
234
|
+
if cls._EXECUTOR is None:
|
|
235
|
+
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
|
|
236
|
+
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
|
|
237
|
+
return cls._EXECUTOR
|
|
238
|
+
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
token: Optional[Token] = None,
|
|
242
|
+
disable_async: bool = True,
|
|
243
|
+
stale_duration: timedelta = _DEFAULT_STALE_DURATION,
|
|
244
|
+
):
|
|
245
|
+
# Config properties
|
|
246
|
+
self._stale_duration = stale_duration
|
|
247
|
+
self._disable_async = disable_async
|
|
248
|
+
# Lock
|
|
249
|
+
self._lock = threading.Lock()
|
|
250
|
+
# Non Thread safe properties. They should be accessed only when protected by the lock above.
|
|
251
|
+
self._token = token or Token("")
|
|
252
|
+
self._is_refreshing = False
|
|
253
|
+
self._refresh_err = False
|
|
254
|
+
|
|
255
|
+
# This is the main entry point for the Token. Do not access the token
|
|
256
|
+
# using any of the internal functions.
|
|
196
257
|
def token(self) -> Token:
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if self.
|
|
200
|
-
return self.
|
|
201
|
-
|
|
258
|
+
"""Returns a valid token, blocking if async refresh is disabled."""
|
|
259
|
+
with self._lock:
|
|
260
|
+
if self._disable_async:
|
|
261
|
+
return self._blocking_token()
|
|
262
|
+
return self._async_token()
|
|
263
|
+
|
|
264
|
+
def _async_token(self) -> Token:
|
|
265
|
+
"""
|
|
266
|
+
Returns a token.
|
|
267
|
+
If the token is stale, triggers an asynchronous refresh.
|
|
268
|
+
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
|
|
269
|
+
"""
|
|
270
|
+
state = self._token_state()
|
|
271
|
+
token = self._token
|
|
272
|
+
|
|
273
|
+
if state == _TokenState.FRESH:
|
|
274
|
+
return token
|
|
275
|
+
if state == _TokenState.STALE:
|
|
276
|
+
self._trigger_async_refresh()
|
|
277
|
+
return token
|
|
278
|
+
return self._blocking_token()
|
|
279
|
+
|
|
280
|
+
def _token_state(self) -> _TokenState:
|
|
281
|
+
"""Returns the current state of the token."""
|
|
282
|
+
if not self._token or not self._token.valid:
|
|
283
|
+
return _TokenState.EXPIRED
|
|
284
|
+
if not self._token.expiry:
|
|
285
|
+
return _TokenState.FRESH
|
|
286
|
+
|
|
287
|
+
lifespan = self._token.expiry - datetime.now()
|
|
288
|
+
if lifespan < timedelta(seconds=0):
|
|
289
|
+
return _TokenState.EXPIRED
|
|
290
|
+
if lifespan < self._stale_duration:
|
|
291
|
+
return _TokenState.STALE
|
|
292
|
+
return _TokenState.FRESH
|
|
293
|
+
|
|
294
|
+
def _blocking_token(self) -> Token:
|
|
295
|
+
"""Returns a token, blocking if necessary to refresh it."""
|
|
296
|
+
state = self._token_state()
|
|
297
|
+
# This is important to recover from potential previous failed attempts
|
|
298
|
+
# to refresh the token asynchronously.
|
|
299
|
+
self._refresh_err = False
|
|
300
|
+
self._is_refreshing = False
|
|
301
|
+
|
|
302
|
+
# It's possible that the token got refreshed (either by a _blocking_refresh or
|
|
303
|
+
# an _async_refresh call) while this particular call was waiting to acquire
|
|
304
|
+
# the lock. This check avoids refreshing the token again in such cases.
|
|
305
|
+
if state != _TokenState.EXPIRED:
|
|
202
306
|
return self._token
|
|
203
|
-
|
|
204
|
-
|
|
307
|
+
|
|
308
|
+
self._token = self.refresh()
|
|
309
|
+
return self._token
|
|
310
|
+
|
|
311
|
+
def _trigger_async_refresh(self):
|
|
312
|
+
"""Starts an asynchronous refresh if none is in progress."""
|
|
313
|
+
|
|
314
|
+
def _refresh_internal():
|
|
315
|
+
new_token = None
|
|
316
|
+
try:
|
|
317
|
+
new_token = self.refresh()
|
|
318
|
+
except Exception as e:
|
|
319
|
+
# This happens on a thread, so we don't want to propagate the error.
|
|
320
|
+
# Instead, if there is no new_token for any reason, we will disable async refresh below
|
|
321
|
+
# But we will do it inside the lock.
|
|
322
|
+
logger.warning(f"Tried to refresh token asynchronously, but failed: {e}")
|
|
323
|
+
|
|
324
|
+
with self._lock:
|
|
325
|
+
if new_token is not None:
|
|
326
|
+
self._token = new_token
|
|
327
|
+
else:
|
|
328
|
+
self._refresh_err = True
|
|
329
|
+
self._is_refreshing = False
|
|
330
|
+
|
|
331
|
+
# The token may have been refreshed by another thread.
|
|
332
|
+
if self._token_state() == _TokenState.FRESH:
|
|
333
|
+
return
|
|
334
|
+
if not self._is_refreshing and not self._refresh_err:
|
|
335
|
+
self._is_refreshing = True
|
|
336
|
+
Refreshable._get_executor().submit(_refresh_internal)
|
|
205
337
|
|
|
206
338
|
@abstractmethod
|
|
207
339
|
def refresh(self) -> Token:
|
|
@@ -219,23 +351,24 @@ class _OAuthCallback(BaseHTTPRequestHandler):
|
|
|
219
351
|
|
|
220
352
|
def do_GET(self):
|
|
221
353
|
from urllib.parse import parse_qsl
|
|
222
|
-
|
|
354
|
+
|
|
355
|
+
parts = self.path.split("?")
|
|
223
356
|
if len(parts) != 2:
|
|
224
|
-
self.send_error(400,
|
|
357
|
+
self.send_error(400, "Missing Query")
|
|
225
358
|
return
|
|
226
359
|
|
|
227
360
|
query = dict(parse_qsl(parts[1]))
|
|
228
361
|
self._feedback.append(query)
|
|
229
362
|
|
|
230
|
-
if
|
|
231
|
-
self.send_error(400, query[
|
|
363
|
+
if "error" in query:
|
|
364
|
+
self.send_error(400, query["error"], query.get("error_description"))
|
|
232
365
|
return
|
|
233
366
|
|
|
234
367
|
self.send_response(200)
|
|
235
|
-
self.send_header(
|
|
368
|
+
self.send_header("Content-type", "text/html")
|
|
236
369
|
self.end_headers()
|
|
237
370
|
# TODO: show better message
|
|
238
|
-
self.wfile.write(b
|
|
371
|
+
self.wfile.write(b"You can close this tab.")
|
|
239
372
|
|
|
240
373
|
|
|
241
374
|
def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
|
|
@@ -246,8 +379,8 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas
|
|
|
246
379
|
:return: The account's OIDC endpoints.
|
|
247
380
|
"""
|
|
248
381
|
host = _fix_host_if_needed(host)
|
|
249
|
-
oidc = f
|
|
250
|
-
resp = client.do(
|
|
382
|
+
oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server"
|
|
383
|
+
resp = client.do("GET", oidc)
|
|
251
384
|
return OidcEndpoints.from_dict(resp)
|
|
252
385
|
|
|
253
386
|
|
|
@@ -258,12 +391,14 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
|
|
|
258
391
|
:return: The workspace's OIDC endpoints.
|
|
259
392
|
"""
|
|
260
393
|
host = _fix_host_if_needed(host)
|
|
261
|
-
oidc = f
|
|
262
|
-
resp = client.do(
|
|
394
|
+
oidc = f"{host}/oidc/.well-known/oauth-authorization-server"
|
|
395
|
+
resp = client.do("GET", oidc)
|
|
263
396
|
return OidcEndpoints.from_dict(resp)
|
|
264
397
|
|
|
265
398
|
|
|
266
|
-
def get_azure_entra_id_workspace_endpoints(
|
|
399
|
+
def get_azure_entra_id_workspace_endpoints(
|
|
400
|
+
host: str,
|
|
401
|
+
) -> Optional[OidcEndpoints]:
|
|
267
402
|
"""
|
|
268
403
|
Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
|
|
269
404
|
using an application registered in Azure Entra ID.
|
|
@@ -272,84 +407,103 @@ def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]
|
|
|
272
407
|
"""
|
|
273
408
|
# In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
|
|
274
409
|
host = _fix_host_if_needed(host)
|
|
275
|
-
res = requests.get(f
|
|
276
|
-
real_auth_url = res.headers.get(
|
|
410
|
+
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
|
|
411
|
+
real_auth_url = res.headers.get("location")
|
|
277
412
|
if not real_auth_url:
|
|
278
413
|
return None
|
|
279
|
-
return OidcEndpoints(
|
|
280
|
-
|
|
414
|
+
return OidcEndpoints(
|
|
415
|
+
authorization_endpoint=real_auth_url,
|
|
416
|
+
token_endpoint=real_auth_url.replace("/authorize", "/token"),
|
|
417
|
+
)
|
|
281
418
|
|
|
282
419
|
|
|
283
420
|
class SessionCredentials(Refreshable):
|
|
284
421
|
|
|
285
|
-
def __init__(
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
422
|
+
def __init__(
|
|
423
|
+
self,
|
|
424
|
+
token: Token,
|
|
425
|
+
token_endpoint: str,
|
|
426
|
+
client_id: str,
|
|
427
|
+
client_secret: str = None,
|
|
428
|
+
redirect_url: str = None,
|
|
429
|
+
disable_async: bool = True,
|
|
430
|
+
):
|
|
291
431
|
self._token_endpoint = token_endpoint
|
|
292
432
|
self._client_id = client_id
|
|
293
433
|
self._client_secret = client_secret
|
|
294
434
|
self._redirect_url = redirect_url
|
|
295
|
-
super().__init__(
|
|
435
|
+
super().__init__(
|
|
436
|
+
token=token,
|
|
437
|
+
disable_async=disable_async,
|
|
438
|
+
)
|
|
296
439
|
|
|
297
440
|
def as_dict(self) -> dict:
|
|
298
|
-
return {
|
|
441
|
+
return {"token": self.token().as_dict()}
|
|
299
442
|
|
|
300
443
|
@staticmethod
|
|
301
|
-
def from_dict(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
444
|
+
def from_dict(
|
|
445
|
+
raw: dict,
|
|
446
|
+
token_endpoint: str,
|
|
447
|
+
client_id: str,
|
|
448
|
+
client_secret: str = None,
|
|
449
|
+
redirect_url: str = None,
|
|
450
|
+
) -> "SessionCredentials":
|
|
451
|
+
return SessionCredentials(
|
|
452
|
+
token=Token.from_dict(raw["token"]),
|
|
453
|
+
token_endpoint=token_endpoint,
|
|
454
|
+
client_id=client_id,
|
|
455
|
+
client_secret=client_secret,
|
|
456
|
+
redirect_url=redirect_url,
|
|
457
|
+
)
|
|
311
458
|
|
|
312
459
|
def auth_type(self):
|
|
313
460
|
"""Implementing CredentialsProvider protocol"""
|
|
314
461
|
# TODO: distinguish between Databricks IDP and Azure AD
|
|
315
|
-
return
|
|
462
|
+
return "oauth"
|
|
316
463
|
|
|
317
464
|
def __call__(self, *args, **kwargs):
|
|
318
465
|
"""Implementing CredentialsProvider protocol"""
|
|
319
466
|
|
|
320
467
|
def inner() -> Dict[str, str]:
|
|
321
|
-
return {
|
|
468
|
+
return {"Authorization": f"Bearer {self.token().access_token}"}
|
|
322
469
|
|
|
323
470
|
return inner
|
|
324
471
|
|
|
325
472
|
def refresh(self) -> Token:
|
|
326
473
|
refresh_token = self._token.refresh_token
|
|
327
474
|
if not refresh_token:
|
|
328
|
-
raise ValueError(
|
|
329
|
-
params = {
|
|
475
|
+
raise ValueError("oauth2: token expired and refresh token is not set")
|
|
476
|
+
params = {
|
|
477
|
+
"grant_type": "refresh_token",
|
|
478
|
+
"refresh_token": refresh_token,
|
|
479
|
+
}
|
|
330
480
|
headers = {}
|
|
331
|
-
if
|
|
481
|
+
if "microsoft" in self._token_endpoint:
|
|
332
482
|
# Tokens issued for the 'Single-Page Application' client-type may
|
|
333
483
|
# only be redeemed via cross-origin requests
|
|
334
|
-
headers = {
|
|
335
|
-
return retrieve_token(
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
484
|
+
headers = {"Origin": self._redirect_url}
|
|
485
|
+
return retrieve_token(
|
|
486
|
+
client_id=self._client_id,
|
|
487
|
+
client_secret=self._client_secret,
|
|
488
|
+
token_url=self._token_endpoint,
|
|
489
|
+
params=params,
|
|
490
|
+
use_params=True,
|
|
491
|
+
headers=headers,
|
|
492
|
+
)
|
|
341
493
|
|
|
342
494
|
|
|
343
495
|
class Consent:
|
|
344
496
|
|
|
345
|
-
def __init__(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
497
|
+
def __init__(
|
|
498
|
+
self,
|
|
499
|
+
state: str,
|
|
500
|
+
verifier: str,
|
|
501
|
+
authorization_url: str,
|
|
502
|
+
redirect_url: str,
|
|
503
|
+
token_endpoint: str,
|
|
504
|
+
client_id: str,
|
|
505
|
+
client_secret: str = None,
|
|
506
|
+
) -> None:
|
|
353
507
|
self._verifier = verifier
|
|
354
508
|
self._state = state
|
|
355
509
|
self._authorization_url = authorization_url
|
|
@@ -360,12 +514,12 @@ class Consent:
|
|
|
360
514
|
|
|
361
515
|
def as_dict(self) -> dict:
|
|
362
516
|
return {
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
517
|
+
"state": self._state,
|
|
518
|
+
"verifier": self._verifier,
|
|
519
|
+
"authorization_url": self._authorization_url,
|
|
520
|
+
"redirect_url": self._redirect_url,
|
|
521
|
+
"token_endpoint": self._token_endpoint,
|
|
522
|
+
"client_id": self._client_id,
|
|
369
523
|
}
|
|
370
524
|
|
|
371
525
|
@property
|
|
@@ -373,65 +527,74 @@ class Consent:
|
|
|
373
527
|
return self._authorization_url
|
|
374
528
|
|
|
375
529
|
@staticmethod
|
|
376
|
-
def from_dict(raw: dict, client_secret: str = None) ->
|
|
377
|
-
return Consent(
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
530
|
+
def from_dict(raw: dict, client_secret: str = None) -> "Consent":
|
|
531
|
+
return Consent(
|
|
532
|
+
raw["state"],
|
|
533
|
+
raw["verifier"],
|
|
534
|
+
authorization_url=raw["authorization_url"],
|
|
535
|
+
redirect_url=raw["redirect_url"],
|
|
536
|
+
token_endpoint=raw["token_endpoint"],
|
|
537
|
+
client_id=raw["client_id"],
|
|
538
|
+
client_secret=client_secret,
|
|
539
|
+
)
|
|
384
540
|
|
|
385
541
|
def launch_external_browser(self) -> SessionCredentials:
|
|
386
542
|
redirect_url = urllib.parse.urlparse(self._redirect_url)
|
|
387
|
-
if redirect_url.hostname not in (
|
|
388
|
-
raise ValueError(f
|
|
543
|
+
if redirect_url.hostname not in ("localhost", "127.0.0.1"):
|
|
544
|
+
raise ValueError(f"cannot listen on {redirect_url.hostname}")
|
|
389
545
|
feedback = []
|
|
390
|
-
logger.info(f
|
|
546
|
+
logger.info(f"Opening {self._authorization_url} in a browser")
|
|
391
547
|
webbrowser.open_new(self._authorization_url)
|
|
392
548
|
port = redirect_url.port
|
|
393
549
|
handler_factory = functools.partial(_OAuthCallback, feedback)
|
|
394
550
|
with HTTPServer(("localhost", port), handler_factory) as httpd:
|
|
395
|
-
logger.info(f
|
|
551
|
+
logger.info(f"Waiting for redirect to http://localhost:{port}")
|
|
396
552
|
httpd.handle_request()
|
|
397
553
|
if not feedback:
|
|
398
|
-
raise ValueError(
|
|
554
|
+
raise ValueError("No data received in callback")
|
|
399
555
|
query = feedback.pop()
|
|
400
556
|
return self.exchange_callback_parameters(query)
|
|
401
557
|
|
|
402
558
|
def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials:
|
|
403
|
-
if
|
|
404
|
-
raise ValueError(
|
|
405
|
-
if
|
|
406
|
-
raise ValueError(
|
|
407
|
-
return self.exchange(query[
|
|
559
|
+
if "error" in query:
|
|
560
|
+
raise ValueError("{error}: {error_description}".format(**query))
|
|
561
|
+
if "code" not in query or "state" not in query:
|
|
562
|
+
raise ValueError("No code returned in callback")
|
|
563
|
+
return self.exchange(query["code"], query["state"])
|
|
408
564
|
|
|
409
565
|
def exchange(self, code: str, state: str) -> SessionCredentials:
|
|
410
566
|
if self._state != state:
|
|
411
|
-
raise ValueError(
|
|
567
|
+
raise ValueError("state mismatch")
|
|
412
568
|
params = {
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
569
|
+
"redirect_uri": self._redirect_url,
|
|
570
|
+
"grant_type": "authorization_code",
|
|
571
|
+
"code_verifier": self._verifier,
|
|
572
|
+
"code": code,
|
|
417
573
|
}
|
|
418
574
|
headers = {}
|
|
419
575
|
while True:
|
|
420
576
|
try:
|
|
421
|
-
token = retrieve_token(
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
577
|
+
token = retrieve_token(
|
|
578
|
+
client_id=self._client_id,
|
|
579
|
+
client_secret=self._client_secret,
|
|
580
|
+
token_url=self._token_endpoint,
|
|
581
|
+
params=params,
|
|
582
|
+
headers=headers,
|
|
583
|
+
use_params=True,
|
|
584
|
+
)
|
|
585
|
+
return SessionCredentials(
|
|
586
|
+
token,
|
|
587
|
+
self._token_endpoint,
|
|
588
|
+
self._client_id,
|
|
589
|
+
self._client_secret,
|
|
590
|
+
self._redirect_url,
|
|
591
|
+
)
|
|
429
592
|
except ValueError as e:
|
|
430
593
|
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
|
|
431
594
|
# Retry in cases of 'Single-Page Application' client-type with
|
|
432
595
|
# 'Origin' header equal to client's redirect URL.
|
|
433
|
-
headers[
|
|
434
|
-
msg = f
|
|
596
|
+
headers["Origin"] = self._redirect_url
|
|
597
|
+
msg = f"Retrying OAuth token exchange with {self._redirect_url} origin"
|
|
435
598
|
logger.debug(msg)
|
|
436
599
|
continue
|
|
437
600
|
raise e
|
|
@@ -456,15 +619,17 @@ class OAuthClient:
|
|
|
456
619
|
exchange it for a token without possessing the Code Verifier.
|
|
457
620
|
"""
|
|
458
621
|
|
|
459
|
-
def __init__(
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
622
|
+
def __init__(
|
|
623
|
+
self,
|
|
624
|
+
oidc_endpoints: OidcEndpoints,
|
|
625
|
+
redirect_url: str,
|
|
626
|
+
client_id: str,
|
|
627
|
+
scopes: List[str] = None,
|
|
628
|
+
client_secret: str = None,
|
|
629
|
+
):
|
|
465
630
|
|
|
466
631
|
if not scopes:
|
|
467
|
-
scopes = [
|
|
632
|
+
scopes = ["all-apis"]
|
|
468
633
|
|
|
469
634
|
self.redirect_url = redirect_url
|
|
470
635
|
self._client_id = client_id
|
|
@@ -473,25 +638,27 @@ class OAuthClient:
|
|
|
473
638
|
self._scopes = scopes
|
|
474
639
|
|
|
475
640
|
@staticmethod
|
|
476
|
-
def from_host(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
641
|
+
def from_host(
|
|
642
|
+
host: str,
|
|
643
|
+
client_id: str,
|
|
644
|
+
redirect_url: str,
|
|
645
|
+
*,
|
|
646
|
+
scopes: List[str] = None,
|
|
647
|
+
client_secret: str = None,
|
|
648
|
+
) -> "OAuthClient":
|
|
482
649
|
from .core import Config
|
|
483
650
|
from .credentials_provider import credentials_strategy
|
|
484
651
|
|
|
485
|
-
@credentials_strategy(
|
|
652
|
+
@credentials_strategy("noop", [])
|
|
486
653
|
def noop_credentials(_: any):
|
|
487
654
|
return lambda: {}
|
|
488
655
|
|
|
489
656
|
config = Config(host=host, credentials_strategy=noop_credentials)
|
|
490
657
|
if not scopes:
|
|
491
|
-
scopes = [
|
|
658
|
+
scopes = ["all-apis"]
|
|
492
659
|
oidc = config.oidc_endpoints
|
|
493
660
|
if not oidc:
|
|
494
|
-
raise ValueError(f
|
|
661
|
+
raise ValueError(f"{host} does not support OAuth")
|
|
495
662
|
return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret)
|
|
496
663
|
|
|
497
664
|
def initiate_consent(self) -> Consent:
|
|
@@ -500,28 +667,30 @@ class OAuthClient:
|
|
|
500
667
|
# token_urlsafe() already returns base64-encoded string
|
|
501
668
|
verifier = secrets.token_urlsafe(32)
|
|
502
669
|
digest = hashlib.sha256(verifier.encode("UTF-8")).digest()
|
|
503
|
-
challenge =
|
|
670
|
+
challenge = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "")
|
|
504
671
|
|
|
505
672
|
params = {
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
673
|
+
"response_type": "code",
|
|
674
|
+
"client_id": self._client_id,
|
|
675
|
+
"redirect_uri": self.redirect_url,
|
|
676
|
+
"scope": " ".join(self._scopes),
|
|
677
|
+
"state": state,
|
|
678
|
+
"code_challenge": challenge,
|
|
679
|
+
"code_challenge_method": "S256",
|
|
513
680
|
}
|
|
514
|
-
auth_url = f
|
|
515
|
-
return Consent(
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
681
|
+
auth_url = f"{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}"
|
|
682
|
+
return Consent(
|
|
683
|
+
state,
|
|
684
|
+
verifier,
|
|
685
|
+
authorization_url=auth_url,
|
|
686
|
+
redirect_url=self.redirect_url,
|
|
687
|
+
token_endpoint=self._oidc_endpoints.token_endpoint,
|
|
688
|
+
client_id=self._client_id,
|
|
689
|
+
client_secret=self._client_secret,
|
|
690
|
+
)
|
|
522
691
|
|
|
523
692
|
def __repr__(self) -> str:
|
|
524
|
-
return f
|
|
693
|
+
return f"<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>"
|
|
525
694
|
|
|
526
695
|
|
|
527
696
|
@dataclass
|
|
@@ -535,6 +704,7 @@ class ClientCredentials(Refreshable):
|
|
|
535
704
|
the background job uses the Client ID and Client Secret to obtain
|
|
536
705
|
an Access Token from the Authorization Server.
|
|
537
706
|
"""
|
|
707
|
+
|
|
538
708
|
client_id: str
|
|
539
709
|
client_secret: str
|
|
540
710
|
token_url: str
|
|
@@ -542,9 +712,10 @@ class ClientCredentials(Refreshable):
|
|
|
542
712
|
scopes: List[str] = None
|
|
543
713
|
use_params: bool = False
|
|
544
714
|
use_header: bool = False
|
|
715
|
+
disable_async: bool = True
|
|
545
716
|
|
|
546
717
|
def __post_init__(self):
|
|
547
|
-
super().__init__()
|
|
718
|
+
super().__init__(disable_async=self.disable_async)
|
|
548
719
|
|
|
549
720
|
def refresh(self) -> Token:
|
|
550
721
|
params = {"grant_type": "client_credentials"}
|
|
@@ -553,24 +724,28 @@ class ClientCredentials(Refreshable):
|
|
|
553
724
|
if self.endpoint_params:
|
|
554
725
|
for k, v in self.endpoint_params.items():
|
|
555
726
|
params[k] = v
|
|
556
|
-
return retrieve_token(
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
727
|
+
return retrieve_token(
|
|
728
|
+
self.client_id,
|
|
729
|
+
self.client_secret,
|
|
730
|
+
self.token_url,
|
|
731
|
+
params,
|
|
732
|
+
use_params=self.use_params,
|
|
733
|
+
use_header=self.use_header,
|
|
734
|
+
)
|
|
562
735
|
|
|
563
736
|
|
|
564
737
|
class TokenCache:
|
|
565
738
|
BASE_PATH = "~/.config/databricks-sdk-py/oauth"
|
|
566
739
|
|
|
567
|
-
def __init__(
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
740
|
+
def __init__(
|
|
741
|
+
self,
|
|
742
|
+
host: str,
|
|
743
|
+
oidc_endpoints: OidcEndpoints,
|
|
744
|
+
client_id: str,
|
|
745
|
+
redirect_url: Optional[str] = None,
|
|
746
|
+
client_secret: Optional[str] = None,
|
|
747
|
+
scopes: Optional[List[str]] = None,
|
|
748
|
+
) -> None:
|
|
574
749
|
self._host = host
|
|
575
750
|
self._client_id = client_id
|
|
576
751
|
self._oidc_endpoints = oidc_endpoints
|
|
@@ -582,8 +757,12 @@ class TokenCache:
|
|
|
582
757
|
def filename(self) -> str:
|
|
583
758
|
# Include host, client_id, and scopes in the cache filename to make it unique.
|
|
584
759
|
hash = hashlib.sha256()
|
|
585
|
-
for chunk in [
|
|
586
|
-
|
|
760
|
+
for chunk in [
|
|
761
|
+
self._host,
|
|
762
|
+
self._client_id,
|
|
763
|
+
",".join(self._scopes),
|
|
764
|
+
]:
|
|
765
|
+
hash.update(chunk.encode("utf-8"))
|
|
587
766
|
return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))
|
|
588
767
|
|
|
589
768
|
def load(self) -> Optional[SessionCredentials]:
|
|
@@ -594,13 +773,15 @@ class TokenCache:
|
|
|
594
773
|
return None
|
|
595
774
|
|
|
596
775
|
try:
|
|
597
|
-
with open(self.filename,
|
|
776
|
+
with open(self.filename, "r") as f:
|
|
598
777
|
raw = json.load(f)
|
|
599
|
-
return SessionCredentials.from_dict(
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
778
|
+
return SessionCredentials.from_dict(
|
|
779
|
+
raw,
|
|
780
|
+
token_endpoint=self._oidc_endpoints.token_endpoint,
|
|
781
|
+
client_id=self._client_id,
|
|
782
|
+
client_secret=self._client_secret,
|
|
783
|
+
redirect_url=self._redirect_url,
|
|
784
|
+
)
|
|
604
785
|
except Exception:
|
|
605
786
|
return None
|
|
606
787
|
|
|
@@ -609,6 +790,6 @@ class TokenCache:
|
|
|
609
790
|
Save credentials to cache file.
|
|
610
791
|
"""
|
|
611
792
|
os.makedirs(os.path.dirname(self.filename), exist_ok=True)
|
|
612
|
-
with open(self.filename,
|
|
793
|
+
with open(self.filename, "w") as f:
|
|
613
794
|
json.dump(credentials.as_dict(), f)
|
|
614
795
|
os.chmod(self.filename, 0o600)
|