databricks-sdk 0.44.0__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +123 -115
- databricks/sdk/_base_client.py +112 -88
- databricks/sdk/_property.py +12 -7
- databricks/sdk/_widgets/__init__.py +13 -2
- databricks/sdk/_widgets/default_widgets_utils.py +21 -15
- databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
- databricks/sdk/azure.py +8 -6
- databricks/sdk/casing.py +5 -5
- databricks/sdk/config.py +152 -99
- databricks/sdk/core.py +57 -47
- databricks/sdk/credentials_provider.py +360 -210
- databricks/sdk/data_plane.py +86 -3
- databricks/sdk/dbutils.py +123 -87
- databricks/sdk/environments.py +52 -35
- databricks/sdk/errors/base.py +61 -35
- databricks/sdk/errors/customizer.py +3 -3
- databricks/sdk/errors/deserializer.py +38 -25
- databricks/sdk/errors/details.py +417 -0
- databricks/sdk/errors/mapper.py +1 -1
- databricks/sdk/errors/overrides.py +27 -24
- databricks/sdk/errors/parser.py +26 -14
- databricks/sdk/errors/platform.py +10 -10
- databricks/sdk/errors/private_link.py +24 -24
- databricks/sdk/logger/round_trip_logger.py +28 -20
- databricks/sdk/mixins/compute.py +90 -60
- databricks/sdk/mixins/files.py +815 -145
- databricks/sdk/mixins/jobs.py +201 -20
- databricks/sdk/mixins/open_ai_client.py +26 -20
- databricks/sdk/mixins/workspace.py +45 -34
- databricks/sdk/oauth.py +372 -196
- databricks/sdk/retries.py +14 -12
- databricks/sdk/runtime/__init__.py +34 -17
- databricks/sdk/runtime/dbutils_stub.py +52 -39
- databricks/sdk/service/_internal.py +12 -7
- databricks/sdk/service/apps.py +618 -418
- databricks/sdk/service/billing.py +827 -604
- databricks/sdk/service/catalog.py +6552 -4474
- databricks/sdk/service/cleanrooms.py +550 -388
- databricks/sdk/service/compute.py +5241 -3531
- databricks/sdk/service/dashboards.py +1313 -923
- databricks/sdk/service/files.py +442 -309
- databricks/sdk/service/iam.py +2115 -1483
- databricks/sdk/service/jobs.py +4151 -2588
- databricks/sdk/service/marketplace.py +2210 -1517
- databricks/sdk/service/ml.py +3364 -2255
- databricks/sdk/service/oauth2.py +922 -584
- databricks/sdk/service/pipelines.py +1865 -1203
- databricks/sdk/service/provisioning.py +1435 -1029
- databricks/sdk/service/serving.py +2040 -1278
- databricks/sdk/service/settings.py +2846 -1929
- databricks/sdk/service/sharing.py +2201 -877
- databricks/sdk/service/sql.py +4650 -3103
- databricks/sdk/service/vectorsearch.py +816 -550
- databricks/sdk/service/workspace.py +1330 -906
- databricks/sdk/useragent.py +36 -22
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
- databricks_sdk-0.45.0.dist-info/RECORD +70 -0
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
- databricks_sdk-0.44.0.dist-info/RECORD +0 -69
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
|
@@ -9,6 +9,7 @@ import pathlib
|
|
|
9
9
|
import platform
|
|
10
10
|
import subprocess
|
|
11
11
|
import sys
|
|
12
|
+
import threading
|
|
12
13
|
import time
|
|
13
14
|
from datetime import datetime
|
|
14
15
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
@@ -25,13 +26,17 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
|
|
|
25
26
|
|
|
26
27
|
CredentialsProvider = Callable[[], Dict[str, str]]
|
|
27
28
|
|
|
28
|
-
logger = logging.getLogger(
|
|
29
|
+
logger = logging.getLogger("databricks.sdk")
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class OAuthCredentialsProvider:
|
|
32
|
-
"""
|
|
33
|
+
"""OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens."""
|
|
33
34
|
|
|
34
|
-
def __init__(
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
credentials_provider: CredentialsProvider,
|
|
38
|
+
token_provider: Callable[[], Token],
|
|
39
|
+
):
|
|
35
40
|
self._credentials_provider = credentials_provider
|
|
36
41
|
self._token_provider = token_provider
|
|
37
42
|
|
|
@@ -43,45 +48,49 @@ class OAuthCredentialsProvider:
|
|
|
43
48
|
|
|
44
49
|
|
|
45
50
|
class CredentialsStrategy(abc.ABC):
|
|
46
|
-
"""
|
|
47
|
-
|
|
51
|
+
"""CredentialsProvider is the protocol (call-side interface)
|
|
52
|
+
for authenticating requests to Databricks REST APIs"""
|
|
48
53
|
|
|
49
54
|
@abc.abstractmethod
|
|
50
|
-
def auth_type(self) -> str:
|
|
51
|
-
...
|
|
55
|
+
def auth_type(self) -> str: ...
|
|
52
56
|
|
|
53
57
|
@abc.abstractmethod
|
|
54
|
-
def __call__(self, cfg:
|
|
55
|
-
...
|
|
58
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider: ...
|
|
56
59
|
|
|
57
60
|
|
|
58
61
|
class OauthCredentialsStrategy(CredentialsStrategy):
|
|
59
|
-
"""
|
|
62
|
+
"""OauthCredentialsProvider is a CredentialsProvider which
|
|
60
63
|
supports Oauth tokens"""
|
|
61
64
|
|
|
62
|
-
def __init__(
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
auth_type: str,
|
|
68
|
+
headers_provider: Callable[["Config"], OAuthCredentialsProvider],
|
|
69
|
+
):
|
|
63
70
|
self._headers_provider = headers_provider
|
|
64
71
|
self._auth_type = auth_type
|
|
65
72
|
|
|
66
73
|
def auth_type(self) -> str:
|
|
67
74
|
return self._auth_type
|
|
68
75
|
|
|
69
|
-
def __call__(self, cfg:
|
|
76
|
+
def __call__(self, cfg: "Config") -> OAuthCredentialsProvider:
|
|
70
77
|
return self._headers_provider(cfg)
|
|
71
78
|
|
|
72
|
-
def oauth_token(self, cfg:
|
|
79
|
+
def oauth_token(self, cfg: "Config") -> Token:
|
|
73
80
|
return self._headers_provider(cfg).oauth_token()
|
|
74
81
|
|
|
75
82
|
|
|
76
83
|
def credentials_strategy(name: str, require: List[str]):
|
|
77
|
-
"""
|
|
84
|
+
"""Given the function that receives a Config and returns RequestVisitor,
|
|
78
85
|
create CredentialsProvider with a given name and required configuration
|
|
79
|
-
attribute names to be present for this function to be called.
|
|
86
|
+
attribute names to be present for this function to be called."""
|
|
80
87
|
|
|
81
|
-
def inner(
|
|
88
|
+
def inner(
|
|
89
|
+
func: Callable[["Config"], CredentialsProvider],
|
|
90
|
+
) -> CredentialsStrategy:
|
|
82
91
|
|
|
83
92
|
@functools.wraps(func)
|
|
84
|
-
def wrapper(cfg:
|
|
93
|
+
def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
85
94
|
for attr in require:
|
|
86
95
|
getattr(cfg, attr)
|
|
87
96
|
if not getattr(cfg, attr):
|
|
@@ -95,14 +104,16 @@ def credentials_strategy(name: str, require: List[str]):
|
|
|
95
104
|
|
|
96
105
|
|
|
97
106
|
def oauth_credentials_strategy(name: str, require: List[str]):
|
|
98
|
-
"""
|
|
107
|
+
"""Given the function that receives a Config and returns an OauthHeaderFactory,
|
|
99
108
|
create an OauthCredentialsProvider with a given name and required configuration
|
|
100
|
-
attribute names to be present for this function to be called.
|
|
109
|
+
attribute names to be present for this function to be called."""
|
|
101
110
|
|
|
102
|
-
def inner(
|
|
111
|
+
def inner(
|
|
112
|
+
func: Callable[["Config"], OAuthCredentialsProvider],
|
|
113
|
+
) -> OauthCredentialsStrategy:
|
|
103
114
|
|
|
104
115
|
@functools.wraps(func)
|
|
105
|
-
def wrapper(cfg:
|
|
116
|
+
def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
|
|
106
117
|
for attr in require:
|
|
107
118
|
if not getattr(cfg, attr):
|
|
108
119
|
return None
|
|
@@ -113,11 +124,11 @@ def oauth_credentials_strategy(name: str, require: List[str]):
|
|
|
113
124
|
return inner
|
|
114
125
|
|
|
115
126
|
|
|
116
|
-
@credentials_strategy(
|
|
117
|
-
def basic_auth(cfg:
|
|
118
|
-
"""
|
|
119
|
-
encoded = base64.b64encode(f
|
|
120
|
-
static_credentials = {
|
|
127
|
+
@credentials_strategy("basic", ["host", "username", "password"])
|
|
128
|
+
def basic_auth(cfg: "Config") -> CredentialsProvider:
|
|
129
|
+
"""Given username and password, add base64-encoded Basic credentials"""
|
|
130
|
+
encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode()
|
|
131
|
+
static_credentials = {"Authorization": f"Basic {encoded}"}
|
|
121
132
|
|
|
122
133
|
def inner() -> Dict[str, str]:
|
|
123
134
|
return static_credentials
|
|
@@ -125,10 +136,10 @@ def basic_auth(cfg: 'Config') -> CredentialsProvider:
|
|
|
125
136
|
return inner
|
|
126
137
|
|
|
127
138
|
|
|
128
|
-
@credentials_strategy(
|
|
129
|
-
def pat_auth(cfg:
|
|
130
|
-
"""
|
|
131
|
-
static_credentials = {
|
|
139
|
+
@credentials_strategy("pat", ["host", "token"])
|
|
140
|
+
def pat_auth(cfg: "Config") -> CredentialsProvider:
|
|
141
|
+
"""Adds Databricks Personal Access Token to every request"""
|
|
142
|
+
static_credentials = {"Authorization": f"Bearer {cfg.token}"}
|
|
132
143
|
|
|
133
144
|
def inner() -> Dict[str, str]:
|
|
134
145
|
return static_credentials
|
|
@@ -136,9 +147,9 @@ def pat_auth(cfg: 'Config') -> CredentialsProvider:
|
|
|
136
147
|
return inner
|
|
137
148
|
|
|
138
149
|
|
|
139
|
-
@credentials_strategy(
|
|
140
|
-
def runtime_native_auth(cfg:
|
|
141
|
-
if
|
|
150
|
+
@credentials_strategy("runtime", [])
|
|
151
|
+
def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
152
|
+
if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
|
|
142
153
|
return None
|
|
143
154
|
|
|
144
155
|
# This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
|
|
@@ -147,36 +158,44 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
147
158
|
from databricks.sdk.runtime import (init_runtime_legacy_auth,
|
|
148
159
|
init_runtime_native_auth,
|
|
149
160
|
init_runtime_repl_auth)
|
|
150
|
-
|
|
161
|
+
|
|
162
|
+
for init in [
|
|
163
|
+
init_runtime_native_auth,
|
|
164
|
+
init_runtime_repl_auth,
|
|
165
|
+
init_runtime_legacy_auth,
|
|
166
|
+
]:
|
|
151
167
|
if init is None:
|
|
152
168
|
continue
|
|
153
169
|
host, inner = init()
|
|
154
170
|
if host is None:
|
|
155
|
-
logger.debug(f
|
|
171
|
+
logger.debug(f"[{init.__name__}] no host detected")
|
|
156
172
|
continue
|
|
157
173
|
cfg.host = host
|
|
158
|
-
logger.debug(f
|
|
174
|
+
logger.debug(f"[{init.__name__}] runtime native auth configured")
|
|
159
175
|
return inner
|
|
160
176
|
return None
|
|
161
177
|
|
|
162
178
|
|
|
163
|
-
@oauth_credentials_strategy(
|
|
164
|
-
def oauth_service_principal(cfg:
|
|
165
|
-
"""
|
|
166
|
-
if /oidc/.well-known/oauth-authorization-server is available on the given host.
|
|
179
|
+
@oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
|
|
180
|
+
def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
181
|
+
"""Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
|
|
182
|
+
if /oidc/.well-known/oauth-authorization-server is available on the given host.
|
|
183
|
+
"""
|
|
167
184
|
oidc = cfg.oidc_endpoints
|
|
168
185
|
if oidc is None:
|
|
169
186
|
return None
|
|
170
187
|
|
|
171
|
-
token_source = ClientCredentials(
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
188
|
+
token_source = ClientCredentials(
|
|
189
|
+
client_id=cfg.client_id,
|
|
190
|
+
client_secret=cfg.client_secret,
|
|
191
|
+
token_url=oidc.token_endpoint,
|
|
192
|
+
scopes=["all-apis"],
|
|
193
|
+
use_header=True,
|
|
194
|
+
)
|
|
176
195
|
|
|
177
196
|
def inner() -> Dict[str, str]:
|
|
178
197
|
token = token_source.token()
|
|
179
|
-
return {
|
|
198
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
180
199
|
|
|
181
200
|
def token() -> Token:
|
|
182
201
|
return token_source.token()
|
|
@@ -184,9 +203,9 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
184
203
|
return OAuthCredentialsProvider(inner, token)
|
|
185
204
|
|
|
186
205
|
|
|
187
|
-
@credentials_strategy(
|
|
188
|
-
def external_browser(cfg:
|
|
189
|
-
if cfg.auth_type !=
|
|
206
|
+
@credentials_strategy("external-browser", ["host", "auth_type"])
|
|
207
|
+
def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
208
|
+
if cfg.auth_type != "external-browser":
|
|
190
209
|
return None
|
|
191
210
|
|
|
192
211
|
client_id, client_secret = None, None
|
|
@@ -197,17 +216,19 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
197
216
|
client_id = cfg.azure_client
|
|
198
217
|
client_secret = cfg.azure_client_secret
|
|
199
218
|
if not client_id:
|
|
200
|
-
client_id =
|
|
219
|
+
client_id = "databricks-cli"
|
|
201
220
|
|
|
202
221
|
# Load cached credentials from disk if they exist. Note that these are
|
|
203
222
|
# local to the Python SDK and not reused by other SDKs.
|
|
204
223
|
oidc_endpoints = cfg.oidc_endpoints
|
|
205
|
-
redirect_url =
|
|
206
|
-
token_cache = TokenCache(
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
224
|
+
redirect_url = "http://localhost:8020"
|
|
225
|
+
token_cache = TokenCache(
|
|
226
|
+
host=cfg.host,
|
|
227
|
+
oidc_endpoints=oidc_endpoints,
|
|
228
|
+
client_id=client_id,
|
|
229
|
+
client_secret=client_secret,
|
|
230
|
+
redirect_url=redirect_url,
|
|
231
|
+
)
|
|
211
232
|
credentials = token_cache.load()
|
|
212
233
|
if credentials:
|
|
213
234
|
try:
|
|
@@ -218,12 +239,14 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
218
239
|
return credentials(cfg)
|
|
219
240
|
# TODO: We should ideally use more specific exceptions.
|
|
220
241
|
except Exception as e:
|
|
221
|
-
logger.warning(f
|
|
222
|
-
|
|
223
|
-
oauth_client = OAuthClient(
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
242
|
+
logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow")
|
|
243
|
+
|
|
244
|
+
oauth_client = OAuthClient(
|
|
245
|
+
oidc_endpoints=oidc_endpoints,
|
|
246
|
+
client_id=client_id,
|
|
247
|
+
redirect_url=redirect_url,
|
|
248
|
+
client_secret=client_secret,
|
|
249
|
+
)
|
|
227
250
|
consent = oauth_client.initiate_consent()
|
|
228
251
|
if not consent:
|
|
229
252
|
return None
|
|
@@ -233,33 +256,41 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
233
256
|
return credentials(cfg)
|
|
234
257
|
|
|
235
258
|
|
|
236
|
-
def _ensure_host_present(cfg:
|
|
237
|
-
"""
|
|
259
|
+
def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]):
|
|
260
|
+
"""Resolves Azure Databricks workspace URL from ARM Resource ID"""
|
|
238
261
|
if cfg.host:
|
|
239
262
|
return
|
|
240
263
|
if not cfg.azure_workspace_resource_id:
|
|
241
264
|
return
|
|
242
265
|
arm = cfg.arm_environment.resource_manager_endpoint
|
|
243
266
|
token = token_source_for(arm).token()
|
|
244
|
-
resp = requests.get(
|
|
245
|
-
|
|
267
|
+
resp = requests.get(
|
|
268
|
+
f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
|
|
269
|
+
headers={"Authorization": f"Bearer {token.access_token}"},
|
|
270
|
+
)
|
|
246
271
|
if not resp.ok:
|
|
247
272
|
raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
|
|
248
273
|
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
|
|
249
274
|
|
|
250
275
|
|
|
251
|
-
@oauth_credentials_strategy(
|
|
252
|
-
|
|
253
|
-
"""
|
|
254
|
-
|
|
276
|
+
@oauth_credentials_strategy(
|
|
277
|
+
"azure-client-secret",
|
|
278
|
+
["is_azure", "azure_client_id", "azure_client_secret"],
|
|
279
|
+
)
|
|
280
|
+
def azure_service_principal(cfg: "Config") -> CredentialsProvider:
|
|
281
|
+
"""Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
|
|
282
|
+
to every request, while automatically resolving different Azure environment endpoints.
|
|
283
|
+
"""
|
|
255
284
|
|
|
256
285
|
def token_source_for(resource: str) -> TokenSource:
|
|
257
286
|
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
258
|
-
return ClientCredentials(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
287
|
+
return ClientCredentials(
|
|
288
|
+
client_id=cfg.azure_client_id,
|
|
289
|
+
client_secret=cfg.azure_client_secret,
|
|
290
|
+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
291
|
+
endpoint_params={"resource": resource},
|
|
292
|
+
use_params=True,
|
|
293
|
+
)
|
|
263
294
|
|
|
264
295
|
_ensure_host_present(cfg, token_source_for)
|
|
265
296
|
cfg.load_azure_tenant_id()
|
|
@@ -268,7 +299,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
|
268
299
|
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
|
|
269
300
|
|
|
270
301
|
def refreshed_headers() -> Dict[str, str]:
|
|
271
|
-
headers = {
|
|
302
|
+
headers = {
|
|
303
|
+
"Authorization": f"Bearer {inner.token().access_token}",
|
|
304
|
+
}
|
|
272
305
|
add_workspace_id_header(cfg, headers)
|
|
273
306
|
add_sp_management_token(cloud, headers)
|
|
274
307
|
return headers
|
|
@@ -279,9 +312,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
|
279
312
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
280
313
|
|
|
281
314
|
|
|
282
|
-
@oauth_credentials_strategy(
|
|
283
|
-
def github_oidc_azure(cfg:
|
|
284
|
-
if
|
|
315
|
+
@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
|
|
316
|
+
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
317
|
+
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
|
|
285
318
|
# not in GitHub actions
|
|
286
319
|
return None
|
|
287
320
|
|
|
@@ -291,7 +324,7 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
291
324
|
return None
|
|
292
325
|
|
|
293
326
|
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
|
|
294
|
-
headers = {
|
|
327
|
+
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
|
|
295
328
|
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
|
|
296
329
|
response = requests.get(endpoint, headers=headers)
|
|
297
330
|
if not response.ok:
|
|
@@ -299,30 +332,34 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
299
332
|
|
|
300
333
|
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
|
|
301
334
|
response_json = response.json()
|
|
302
|
-
if
|
|
335
|
+
if "value" not in response_json:
|
|
303
336
|
return None
|
|
304
337
|
|
|
305
|
-
logger.info(
|
|
338
|
+
logger.info(
|
|
339
|
+
"Configured AAD token for GitHub Actions OIDC (%s)",
|
|
340
|
+
cfg.azure_client_id,
|
|
341
|
+
)
|
|
306
342
|
params = {
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
343
|
+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
|
344
|
+
"resource": cfg.effective_azure_login_app_id,
|
|
345
|
+
"client_assertion": response_json["value"],
|
|
310
346
|
}
|
|
311
347
|
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
312
348
|
if not cfg.azure_tenant_id:
|
|
313
349
|
# detect Azure AD Tenant ID if it's not specified directly
|
|
314
350
|
token_endpoint = cfg.oidc_endpoints.token_endpoint
|
|
315
|
-
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint,
|
|
351
|
+
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
|
|
316
352
|
inner = ClientCredentials(
|
|
317
353
|
client_id=cfg.azure_client_id,
|
|
318
|
-
client_secret="",
|
|
354
|
+
client_secret="", # we have no (rotatable) secrets in OIDC flow
|
|
319
355
|
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
320
356
|
endpoint_params=params,
|
|
321
|
-
use_params=True
|
|
357
|
+
use_params=True,
|
|
358
|
+
)
|
|
322
359
|
|
|
323
360
|
def refreshed_headers() -> Dict[str, str]:
|
|
324
361
|
token = inner.token()
|
|
325
|
-
return {
|
|
362
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
326
363
|
|
|
327
364
|
def token() -> Token:
|
|
328
365
|
return inner.token()
|
|
@@ -330,29 +367,32 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
330
367
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
331
368
|
|
|
332
369
|
|
|
333
|
-
GcpScopes = [
|
|
370
|
+
GcpScopes = [
|
|
371
|
+
"https://www.googleapis.com/auth/cloud-platform",
|
|
372
|
+
"https://www.googleapis.com/auth/compute",
|
|
373
|
+
]
|
|
334
374
|
|
|
335
375
|
|
|
336
|
-
@oauth_credentials_strategy(
|
|
337
|
-
def google_credentials(cfg:
|
|
376
|
+
@oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
|
|
377
|
+
def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
338
378
|
if not cfg.is_gcp:
|
|
339
379
|
return None
|
|
340
380
|
# Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
|
|
341
381
|
# Obtain the id token by providing the json file path and target audience.
|
|
342
|
-
if
|
|
382
|
+
if os.path.isfile(cfg.google_credentials):
|
|
343
383
|
with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
|
|
344
384
|
account_info = json.load(json_file)
|
|
345
385
|
else:
|
|
346
386
|
# If the file doesn't exist, assume that the config is the actual JSON content.
|
|
347
387
|
account_info = json.loads(cfg.google_credentials)
|
|
348
388
|
|
|
349
|
-
credentials = service_account.IDTokenCredentials.from_service_account_info(
|
|
350
|
-
|
|
389
|
+
credentials = service_account.IDTokenCredentials.from_service_account_info(
|
|
390
|
+
info=account_info, target_audience=cfg.host
|
|
391
|
+
)
|
|
351
392
|
|
|
352
393
|
request = Request()
|
|
353
394
|
|
|
354
|
-
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
|
|
355
|
-
scopes=GcpScopes)
|
|
395
|
+
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes)
|
|
356
396
|
|
|
357
397
|
def token() -> Token:
|
|
358
398
|
credentials.refresh(request)
|
|
@@ -360,7 +400,7 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
360
400
|
|
|
361
401
|
def refreshed_headers() -> Dict[str, str]:
|
|
362
402
|
credentials.refresh(request)
|
|
363
|
-
headers = {
|
|
403
|
+
headers = {"Authorization": f"Bearer {credentials.token}"}
|
|
364
404
|
if cfg.is_account_client:
|
|
365
405
|
gcp_credentials.refresh(request)
|
|
366
406
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
|
|
@@ -369,24 +409,29 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
369
409
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
370
410
|
|
|
371
411
|
|
|
372
|
-
@oauth_credentials_strategy(
|
|
373
|
-
def google_id(cfg:
|
|
412
|
+
@oauth_credentials_strategy("google-id", ["host", "google_service_account"])
|
|
413
|
+
def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
374
414
|
if not cfg.is_gcp:
|
|
375
415
|
return None
|
|
376
416
|
credentials, _project_id = google.auth.default()
|
|
377
417
|
|
|
378
418
|
# Create the impersonated credential.
|
|
379
|
-
target_credentials = impersonated_credentials.Credentials(
|
|
380
|
-
|
|
381
|
-
|
|
419
|
+
target_credentials = impersonated_credentials.Credentials(
|
|
420
|
+
source_credentials=credentials,
|
|
421
|
+
target_principal=cfg.google_service_account,
|
|
422
|
+
target_scopes=[],
|
|
423
|
+
)
|
|
382
424
|
|
|
383
425
|
# Set the impersonated credential, target audience and token options.
|
|
384
|
-
id_creds = impersonated_credentials.IDTokenCredentials(
|
|
385
|
-
|
|
386
|
-
|
|
426
|
+
id_creds = impersonated_credentials.IDTokenCredentials(
|
|
427
|
+
target_credentials, target_audience=cfg.host, include_email=True
|
|
428
|
+
)
|
|
387
429
|
|
|
388
430
|
gcp_impersonated_credentials = impersonated_credentials.Credentials(
|
|
389
|
-
source_credentials=credentials,
|
|
431
|
+
source_credentials=credentials,
|
|
432
|
+
target_principal=cfg.google_service_account,
|
|
433
|
+
target_scopes=GcpScopes,
|
|
434
|
+
)
|
|
390
435
|
|
|
391
436
|
request = Request()
|
|
392
437
|
|
|
@@ -396,7 +441,7 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
396
441
|
|
|
397
442
|
def refreshed_headers() -> Dict[str, str]:
|
|
398
443
|
id_creds.refresh(request)
|
|
399
|
-
headers = {
|
|
444
|
+
headers = {"Authorization": f"Bearer {id_creds.token}"}
|
|
400
445
|
if cfg.is_account_client:
|
|
401
446
|
gcp_impersonated_credentials.refresh(request)
|
|
402
447
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
|
|
@@ -407,7 +452,13 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
407
452
|
|
|
408
453
|
class CliTokenSource(Refreshable):
|
|
409
454
|
|
|
410
|
-
def __init__(
|
|
455
|
+
def __init__(
|
|
456
|
+
self,
|
|
457
|
+
cmd: List[str],
|
|
458
|
+
token_type_field: str,
|
|
459
|
+
access_token_field: str,
|
|
460
|
+
expiry_field: str,
|
|
461
|
+
):
|
|
411
462
|
super().__init__()
|
|
412
463
|
self._cmd = cmd
|
|
413
464
|
self._token_type_field = token_type_field
|
|
@@ -430,52 +481,74 @@ class CliTokenSource(Refreshable):
|
|
|
430
481
|
out = _run_subprocess(self._cmd, capture_output=True, check=True)
|
|
431
482
|
it = json.loads(out.stdout.decode())
|
|
432
483
|
expires_on = self._parse_expiry(it[self._expiry_field])
|
|
433
|
-
return Token(
|
|
434
|
-
|
|
435
|
-
|
|
484
|
+
return Token(
|
|
485
|
+
access_token=it[self._access_token_field],
|
|
486
|
+
token_type=it[self._token_type_field],
|
|
487
|
+
expiry=expires_on,
|
|
488
|
+
)
|
|
436
489
|
except ValueError as e:
|
|
437
490
|
raise ValueError(f"cannot unmarshal CLI result: {e}")
|
|
438
491
|
except subprocess.CalledProcessError as e:
|
|
439
492
|
stdout = e.stdout.decode().strip()
|
|
440
493
|
stderr = e.stderr.decode().strip()
|
|
441
494
|
message = stdout or stderr
|
|
442
|
-
raise IOError(f
|
|
495
|
+
raise IOError(f"cannot get access token: {message}") from e
|
|
443
496
|
|
|
444
497
|
|
|
445
|
-
def _run_subprocess(
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
498
|
+
def _run_subprocess(
|
|
499
|
+
popenargs,
|
|
500
|
+
input=None,
|
|
501
|
+
capture_output=True,
|
|
502
|
+
timeout=None,
|
|
503
|
+
check=False,
|
|
504
|
+
**kwargs,
|
|
505
|
+
) -> subprocess.CompletedProcess:
|
|
451
506
|
"""Runs subprocess with given arguments.
|
|
452
|
-
This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
|
|
453
|
-
|
|
507
|
+
This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
|
|
508
|
+
"""
|
|
509
|
+
kwargs["shell"] = sys.platform.startswith("win")
|
|
454
510
|
# windows requires shell=True to be able to execute 'az login' or other commands
|
|
455
511
|
# cannot use shell=True all the time, as it breaks macOS
|
|
456
512
|
logging.debug(f'Running command: {" ".join(popenargs)}')
|
|
457
|
-
return subprocess.run(
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
513
|
+
return subprocess.run(
|
|
514
|
+
popenargs,
|
|
515
|
+
input=input,
|
|
516
|
+
capture_output=capture_output,
|
|
517
|
+
timeout=timeout,
|
|
518
|
+
check=check,
|
|
519
|
+
**kwargs,
|
|
520
|
+
)
|
|
463
521
|
|
|
464
522
|
|
|
465
523
|
class AzureCliTokenSource(CliTokenSource):
|
|
466
|
-
"""
|
|
467
|
-
|
|
468
|
-
def __init__(
|
|
469
|
-
|
|
524
|
+
"""Obtain the token granted by `az login` CLI command"""
|
|
525
|
+
|
|
526
|
+
def __init__(
|
|
527
|
+
self,
|
|
528
|
+
resource: str,
|
|
529
|
+
subscription: Optional[str] = None,
|
|
530
|
+
tenant: Optional[str] = None,
|
|
531
|
+
):
|
|
532
|
+
cmd = [
|
|
533
|
+
"az",
|
|
534
|
+
"account",
|
|
535
|
+
"get-access-token",
|
|
536
|
+
"--resource",
|
|
537
|
+
resource,
|
|
538
|
+
"--output",
|
|
539
|
+
"json",
|
|
540
|
+
]
|
|
470
541
|
if subscription is not None:
|
|
471
542
|
cmd.append("--subscription")
|
|
472
543
|
cmd.append(subscription)
|
|
473
544
|
if tenant and not self.__is_cli_using_managed_identity():
|
|
474
545
|
cmd.extend(["--tenant", tenant])
|
|
475
|
-
super().__init__(
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
546
|
+
super().__init__(
|
|
547
|
+
cmd=cmd,
|
|
548
|
+
token_type_field="tokenType",
|
|
549
|
+
access_token_field="accessToken",
|
|
550
|
+
expiry_field="expiresOn",
|
|
551
|
+
)
|
|
479
552
|
|
|
480
553
|
@staticmethod
|
|
481
554
|
def __is_cli_using_managed_identity() -> bool:
|
|
@@ -488,7 +561,8 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
488
561
|
if user is None:
|
|
489
562
|
return False
|
|
490
563
|
return user.get("type") == "servicePrincipal" and user.get("name") in [
|
|
491
|
-
|
|
564
|
+
"systemAssignedIdentity",
|
|
565
|
+
"userAssignedIdentity",
|
|
492
566
|
]
|
|
493
567
|
except subprocess.CalledProcessError as e:
|
|
494
568
|
logger.debug("Failed to get account information from Azure CLI", exc_info=e)
|
|
@@ -511,15 +585,13 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
511
585
|
guaranteed to be unique within a tenant and should be used only for display purposes.
|
|
512
586
|
- 'upn' - The username of the user.
|
|
513
587
|
"""
|
|
514
|
-
return
|
|
588
|
+
return "upn" in self.token().jwt_claims()
|
|
515
589
|
|
|
516
590
|
@staticmethod
|
|
517
|
-
def for_resource(cfg:
|
|
591
|
+
def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource":
|
|
518
592
|
subscription = AzureCliTokenSource.get_subscription(cfg)
|
|
519
593
|
if subscription is not None:
|
|
520
|
-
token_source = AzureCliTokenSource(resource,
|
|
521
|
-
subscription=subscription,
|
|
522
|
-
tenant=cfg.azure_tenant_id)
|
|
594
|
+
token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
|
|
523
595
|
try:
|
|
524
596
|
# This will fail if the user has access to the workspace, but not to the subscription
|
|
525
597
|
# itself.
|
|
@@ -534,32 +606,32 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
534
606
|
return token_source
|
|
535
607
|
|
|
536
608
|
@staticmethod
|
|
537
|
-
def get_subscription(cfg:
|
|
609
|
+
def get_subscription(cfg: "Config") -> Optional[str]:
|
|
538
610
|
resource = cfg.azure_workspace_resource_id
|
|
539
611
|
if resource is None or resource == "":
|
|
540
612
|
return None
|
|
541
|
-
components = resource.split(
|
|
613
|
+
components = resource.split("/")
|
|
542
614
|
if len(components) < 3:
|
|
543
615
|
logger.warning("Invalid azure workspace resource ID")
|
|
544
616
|
return None
|
|
545
617
|
return components[2]
|
|
546
618
|
|
|
547
619
|
|
|
548
|
-
@credentials_strategy(
|
|
549
|
-
def azure_cli(cfg:
|
|
550
|
-
"""
|
|
620
|
+
@credentials_strategy("azure-cli", ["is_azure"])
|
|
621
|
+
def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
622
|
+
"""Adds refreshed OAuth token granted by `az login` command to every request."""
|
|
551
623
|
cfg.load_azure_tenant_id()
|
|
552
624
|
token_source = None
|
|
553
625
|
mgmt_token_source = None
|
|
554
626
|
try:
|
|
555
627
|
token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
|
|
556
628
|
except FileNotFoundError:
|
|
557
|
-
doc =
|
|
558
|
-
logger.debug(f
|
|
629
|
+
doc = "https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest"
|
|
630
|
+
logger.debug(f"Most likely Azure CLI is not installed. See {doc} for details")
|
|
559
631
|
return None
|
|
560
632
|
except OSError as e:
|
|
561
|
-
logger.debug(
|
|
562
|
-
logger.debug(
|
|
633
|
+
logger.debug("skipping Azure CLI auth", exc_info=e)
|
|
634
|
+
logger.debug("This may happen if you are attempting to login to a dev or staging workspace")
|
|
563
635
|
return None
|
|
564
636
|
|
|
565
637
|
if not token_source.is_human_user():
|
|
@@ -567,7 +639,10 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
567
639
|
management_endpoint = cfg.arm_environment.service_management_endpoint
|
|
568
640
|
mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
|
|
569
641
|
except Exception as e:
|
|
570
|
-
logger.debug(
|
|
642
|
+
logger.debug(
|
|
643
|
+
f"Not including service management token in headers",
|
|
644
|
+
exc_info=e,
|
|
645
|
+
)
|
|
571
646
|
mgmt_token_source = None
|
|
572
647
|
|
|
573
648
|
_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
|
|
@@ -575,7 +650,7 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
575
650
|
|
|
576
651
|
def inner() -> Dict[str, str]:
|
|
577
652
|
token = token_source.token()
|
|
578
|
-
headers = {
|
|
653
|
+
headers = {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
579
654
|
add_workspace_id_header(cfg, headers)
|
|
580
655
|
if mgmt_token_source:
|
|
581
656
|
add_sp_management_token(mgmt_token_source, headers)
|
|
@@ -585,12 +660,12 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
585
660
|
|
|
586
661
|
|
|
587
662
|
class DatabricksCliTokenSource(CliTokenSource):
|
|
588
|
-
"""
|
|
663
|
+
"""Obtain the token granted by `databricks auth login` CLI command"""
|
|
589
664
|
|
|
590
|
-
def __init__(self, cfg:
|
|
591
|
-
args = [
|
|
665
|
+
def __init__(self, cfg: "Config"):
|
|
666
|
+
args = ["auth", "token", "--host", cfg.host]
|
|
592
667
|
if cfg.is_account_client:
|
|
593
|
-
args += [
|
|
668
|
+
args += ["--account-id", cfg.account_id]
|
|
594
669
|
|
|
595
670
|
cli_path = cfg.databricks_cli_path
|
|
596
671
|
|
|
@@ -610,10 +685,12 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
610
685
|
elif cli_path.count("/") == 0:
|
|
611
686
|
cli_path = self.__class__._find_executable(cli_path)
|
|
612
687
|
|
|
613
|
-
super().__init__(
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
688
|
+
super().__init__(
|
|
689
|
+
cmd=[cli_path, *args],
|
|
690
|
+
token_type_field="token_type",
|
|
691
|
+
access_token_field="access_token",
|
|
692
|
+
expiry_field="expiry",
|
|
693
|
+
)
|
|
617
694
|
|
|
618
695
|
@staticmethod
|
|
619
696
|
def _find_executable(name) -> str:
|
|
@@ -635,8 +712,8 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
635
712
|
raise err
|
|
636
713
|
|
|
637
714
|
|
|
638
|
-
@oauth_credentials_strategy(
|
|
639
|
-
def databricks_cli(cfg:
|
|
715
|
+
@oauth_credentials_strategy("databricks-cli", ["host"])
|
|
716
|
+
def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
640
717
|
try:
|
|
641
718
|
token_source = DatabricksCliTokenSource(cfg)
|
|
642
719
|
except FileNotFoundError as e:
|
|
@@ -646,8 +723,8 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
646
723
|
try:
|
|
647
724
|
token_source.token()
|
|
648
725
|
except IOError as e:
|
|
649
|
-
if
|
|
650
|
-
logger.debug(f
|
|
726
|
+
if "databricks OAuth is not" in str(e):
|
|
727
|
+
logger.debug(f"OAuth not configured or not available: {e}")
|
|
651
728
|
return None
|
|
652
729
|
raise e
|
|
653
730
|
|
|
@@ -655,7 +732,7 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
655
732
|
|
|
656
733
|
def inner() -> Dict[str, str]:
|
|
657
734
|
token = token_source.token()
|
|
658
|
-
return {
|
|
735
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
659
736
|
|
|
660
737
|
def token() -> Token:
|
|
661
738
|
return token_source.token()
|
|
@@ -664,13 +741,14 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
664
741
|
|
|
665
742
|
|
|
666
743
|
class MetadataServiceTokenSource(Refreshable):
|
|
667
|
-
"""
|
|
744
|
+
"""Obtain the token granted by Databricks Metadata Service"""
|
|
745
|
+
|
|
668
746
|
METADATA_SERVICE_VERSION = "1"
|
|
669
747
|
METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
|
|
670
748
|
METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
|
|
671
|
-
_metadata_service_timeout = 10
|
|
749
|
+
_metadata_service_timeout = 10 # seconds
|
|
672
750
|
|
|
673
|
-
def __init__(self, cfg:
|
|
751
|
+
def __init__(self, cfg: "Config"):
|
|
674
752
|
super().__init__()
|
|
675
753
|
self.url = cfg.metadata_service_url
|
|
676
754
|
self.host = cfg.host
|
|
@@ -681,13 +759,14 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
681
759
|
timeout=self._metadata_service_timeout,
|
|
682
760
|
headers={
|
|
683
761
|
self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
|
|
684
|
-
self.METADATA_SERVICE_HOST_HEADER: self.host
|
|
762
|
+
self.METADATA_SERVICE_HOST_HEADER: self.host,
|
|
685
763
|
},
|
|
686
764
|
proxies={
|
|
687
765
|
# Explicitly exclude localhost from being proxied. This is necessary
|
|
688
766
|
# for Metadata URLs which typically point to localhost.
|
|
689
767
|
"no_proxy": "localhost,127.0.0.1"
|
|
690
|
-
}
|
|
768
|
+
},
|
|
769
|
+
)
|
|
691
770
|
json_resp: dict[str, Union[str, float]] = resp.json()
|
|
692
771
|
access_token = json_resp.get("access_token", None)
|
|
693
772
|
if access_token is None:
|
|
@@ -705,9 +784,9 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
705
784
|
return Token(access_token=access_token, token_type=token_type, expiry=expiry)
|
|
706
785
|
|
|
707
786
|
|
|
708
|
-
@credentials_strategy(
|
|
709
|
-
def metadata_service(cfg:
|
|
710
|
-
"""
|
|
787
|
+
@credentials_strategy("metadata-service", ["host", "metadata_service_url"])
|
|
788
|
+
def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
789
|
+
"""Adds refreshed token granted by Databricks Metadata Service to every request."""
|
|
711
790
|
|
|
712
791
|
token_source = MetadataServiceTokenSource(cfg)
|
|
713
792
|
token_source.token()
|
|
@@ -715,74 +794,92 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
715
794
|
|
|
716
795
|
def inner() -> Dict[str, str]:
|
|
717
796
|
token = token_source.token()
|
|
718
|
-
return {
|
|
797
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
719
798
|
|
|
720
799
|
return inner
|
|
721
800
|
|
|
722
801
|
|
|
723
802
|
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
|
|
724
803
|
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
|
|
725
|
-
class ModelServingAuthProvider
|
|
804
|
+
class ModelServingAuthProvider:
|
|
805
|
+
USER_CREDENTIALS = "user_credentials"
|
|
806
|
+
|
|
726
807
|
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
|
|
727
808
|
|
|
728
|
-
def __init__(self):
|
|
809
|
+
def __init__(self, credential_type: Optional[str]):
|
|
729
810
|
self.expiry_time = -1
|
|
730
811
|
self.current_token = None
|
|
731
|
-
self.refresh_duration = 300
|
|
812
|
+
self.refresh_duration = 300 # 300 Seconds
|
|
813
|
+
self.credential_type = credential_type
|
|
732
814
|
|
|
733
|
-
def should_fetch_model_serving_environment_oauth(
|
|
815
|
+
def should_fetch_model_serving_environment_oauth() -> bool:
|
|
734
816
|
"""
|
|
735
817
|
Check whether this is the model serving environment
|
|
736
818
|
Additionally check if the oauth token file path exists
|
|
737
819
|
"""
|
|
738
820
|
|
|
739
|
-
is_in_model_serving_env = (
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
821
|
+
is_in_model_serving_env = (
|
|
822
|
+
os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
|
|
823
|
+
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
|
|
824
|
+
or "false"
|
|
825
|
+
)
|
|
826
|
+
return is_in_model_serving_env == "true" and os.path.isfile(
|
|
827
|
+
ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH
|
|
828
|
+
)
|
|
743
829
|
|
|
744
|
-
def
|
|
830
|
+
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
|
|
745
831
|
# Use Cached value if it is valid
|
|
746
832
|
if self.current_token is not None and self.expiry_time > time.time():
|
|
747
833
|
return self.current_token
|
|
748
834
|
|
|
749
835
|
try:
|
|
750
|
-
with open(
|
|
836
|
+
with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
|
|
751
837
|
oauth_dict = json.load(f)
|
|
752
838
|
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
|
|
753
839
|
self.expiry_time = time.time() + self.refresh_duration
|
|
754
840
|
except Exception as e:
|
|
755
841
|
# sleep and retry in case of any race conditions with OAuth refreshing
|
|
756
842
|
if should_retry:
|
|
757
|
-
logger.warning(
|
|
758
|
-
|
|
843
|
+
logger.warning(
|
|
844
|
+
"Unable to read oauth token on first attmept in Model Serving Environment",
|
|
845
|
+
exc_info=e,
|
|
846
|
+
)
|
|
759
847
|
time.sleep(0.5)
|
|
760
|
-
return self.
|
|
848
|
+
return self._get_model_dependency_oauth_token(should_retry=False)
|
|
761
849
|
else:
|
|
762
850
|
raise RuntimeError(
|
|
763
851
|
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
|
|
764
852
|
) from e
|
|
765
853
|
return self.current_token
|
|
766
854
|
|
|
855
|
+
def _get_invokers_token(self):
|
|
856
|
+
main_thread = threading.main_thread()
|
|
857
|
+
thread_data = main_thread.__dict__
|
|
858
|
+
invokers_token = None
|
|
859
|
+
if "invokers_token" in thread_data:
|
|
860
|
+
invokers_token = thread_data["invokers_token"]
|
|
861
|
+
|
|
862
|
+
if invokers_token is None:
|
|
863
|
+
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")
|
|
864
|
+
|
|
865
|
+
return invokers_token
|
|
866
|
+
|
|
767
867
|
def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
|
|
768
|
-
if not
|
|
868
|
+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
769
869
|
return None
|
|
770
870
|
|
|
771
871
|
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
|
|
772
|
-
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
|
|
773
|
-
"DB_MODEL_SERVING_HOST_URL")
|
|
774
|
-
token = self.get_model_dependency_oauth_token()
|
|
872
|
+
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL")
|
|
775
873
|
|
|
776
|
-
|
|
874
|
+
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
|
|
875
|
+
return (host, self._get_invokers_token())
|
|
876
|
+
else:
|
|
877
|
+
return (host, self._get_model_dependency_oauth_token())
|
|
777
878
|
|
|
778
879
|
|
|
779
|
-
|
|
780
|
-
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
880
|
+
def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
|
|
781
881
|
try:
|
|
782
|
-
model_serving_auth_provider = ModelServingAuthProvider()
|
|
783
|
-
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
|
|
784
|
-
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
|
|
785
|
-
return None
|
|
882
|
+
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
|
|
786
883
|
host, token = model_serving_auth_provider.get_databricks_host_token()
|
|
787
884
|
if token is None:
|
|
788
885
|
raise ValueError(
|
|
@@ -791,9 +888,11 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
791
888
|
if cfg.host is None:
|
|
792
889
|
cfg.host = host
|
|
793
890
|
except Exception as e:
|
|
794
|
-
logger.warning(
|
|
891
|
+
logger.warning(
|
|
892
|
+
"Unable to get auth from Databricks Model Serving Environment",
|
|
893
|
+
exc_info=e,
|
|
894
|
+
)
|
|
795
895
|
return None
|
|
796
|
-
|
|
797
896
|
logger.info("Using Databricks Model Serving Authentication")
|
|
798
897
|
|
|
799
898
|
def inner() -> Dict[str, str]:
|
|
@@ -804,21 +903,40 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
804
903
|
return inner
|
|
805
904
|
|
|
806
905
|
|
|
906
|
+
@credentials_strategy("model-serving", [])
|
|
907
|
+
def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
908
|
+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
909
|
+
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
|
|
910
|
+
return None
|
|
911
|
+
|
|
912
|
+
return model_serving_auth_visitor(cfg)
|
|
913
|
+
|
|
914
|
+
|
|
807
915
|
class DefaultCredentials:
|
|
808
|
-
"""
|
|
916
|
+
"""Select the first applicable credential provider from the chain"""
|
|
809
917
|
|
|
810
918
|
def __init__(self) -> None:
|
|
811
|
-
self._auth_type =
|
|
919
|
+
self._auth_type = "default"
|
|
812
920
|
self._auth_providers = [
|
|
813
|
-
pat_auth,
|
|
814
|
-
|
|
815
|
-
|
|
921
|
+
pat_auth,
|
|
922
|
+
basic_auth,
|
|
923
|
+
metadata_service,
|
|
924
|
+
oauth_service_principal,
|
|
925
|
+
azure_service_principal,
|
|
926
|
+
github_oidc_azure,
|
|
927
|
+
azure_cli,
|
|
928
|
+
external_browser,
|
|
929
|
+
databricks_cli,
|
|
930
|
+
runtime_native_auth,
|
|
931
|
+
google_credentials,
|
|
932
|
+
google_id,
|
|
933
|
+
model_serving_auth,
|
|
816
934
|
]
|
|
817
935
|
|
|
818
936
|
def auth_type(self) -> str:
|
|
819
937
|
return self._auth_type
|
|
820
938
|
|
|
821
|
-
def oauth_token(self, cfg:
|
|
939
|
+
def oauth_token(self, cfg: "Config") -> Token:
|
|
822
940
|
for provider in self._auth_providers:
|
|
823
941
|
auth_type = provider.auth_type()
|
|
824
942
|
if auth_type != self._auth_type:
|
|
@@ -826,14 +944,14 @@ class DefaultCredentials:
|
|
|
826
944
|
continue
|
|
827
945
|
return provider.oauth_token(cfg)
|
|
828
946
|
|
|
829
|
-
def __call__(self, cfg:
|
|
947
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider:
|
|
830
948
|
for provider in self._auth_providers:
|
|
831
949
|
auth_type = provider.auth_type()
|
|
832
950
|
if cfg.auth_type and auth_type != cfg.auth_type:
|
|
833
951
|
# ignore other auth types if one is explicitly enforced
|
|
834
952
|
logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
|
|
835
953
|
continue
|
|
836
|
-
logger.debug(f
|
|
954
|
+
logger.debug(f"Attempting to configure auth: {auth_type}")
|
|
837
955
|
try:
|
|
838
956
|
header_factory = provider(cfg)
|
|
839
957
|
if not header_factory:
|
|
@@ -841,8 +959,40 @@ class DefaultCredentials:
|
|
|
841
959
|
self._auth_type = auth_type
|
|
842
960
|
return header_factory
|
|
843
961
|
except Exception as e:
|
|
844
|
-
raise ValueError(f
|
|
962
|
+
raise ValueError(f"{auth_type}: {e}") from e
|
|
845
963
|
auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
|
|
846
964
|
raise ValueError(
|
|
847
|
-
f
|
|
965
|
+
f"cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method."
|
|
848
966
|
)
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
class ModelServingUserCredentials(CredentialsStrategy):
|
|
970
|
+
"""
|
|
971
|
+
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
|
|
972
|
+
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
|
|
973
|
+
In any other environments, the class defaults to the DefaultCredentialStrategy.
|
|
974
|
+
To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
|
|
975
|
+
|
|
976
|
+
invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
|
|
977
|
+
"""
|
|
978
|
+
|
|
979
|
+
def __init__(self):
|
|
980
|
+
self.credential_type = ModelServingAuthProvider.USER_CREDENTIALS
|
|
981
|
+
self.default_credentials = DefaultCredentials()
|
|
982
|
+
|
|
983
|
+
def auth_type(self):
|
|
984
|
+
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
985
|
+
return "model_serving_" + self.credential_type
|
|
986
|
+
else:
|
|
987
|
+
return self.default_credentials.auth_type()
|
|
988
|
+
|
|
989
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider:
|
|
990
|
+
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
991
|
+
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
|
|
992
|
+
if not header_factory:
|
|
993
|
+
raise ValueError(
|
|
994
|
+
f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
|
|
995
|
+
)
|
|
996
|
+
return header_factory
|
|
997
|
+
else:
|
|
998
|
+
return self.default_credentials(cfg)
|