databricks-sdk 0.44.1__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +123 -115
- databricks/sdk/_base_client.py +112 -88
- databricks/sdk/_property.py +12 -7
- databricks/sdk/_widgets/__init__.py +13 -2
- databricks/sdk/_widgets/default_widgets_utils.py +21 -15
- databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
- databricks/sdk/azure.py +8 -6
- databricks/sdk/casing.py +5 -5
- databricks/sdk/config.py +152 -99
- databricks/sdk/core.py +57 -47
- databricks/sdk/credentials_provider.py +300 -205
- databricks/sdk/data_plane.py +86 -3
- databricks/sdk/dbutils.py +123 -87
- databricks/sdk/environments.py +52 -35
- databricks/sdk/errors/base.py +61 -35
- databricks/sdk/errors/customizer.py +3 -3
- databricks/sdk/errors/deserializer.py +38 -25
- databricks/sdk/errors/details.py +417 -0
- databricks/sdk/errors/mapper.py +1 -1
- databricks/sdk/errors/overrides.py +27 -24
- databricks/sdk/errors/parser.py +26 -14
- databricks/sdk/errors/platform.py +10 -10
- databricks/sdk/errors/private_link.py +24 -24
- databricks/sdk/logger/round_trip_logger.py +28 -20
- databricks/sdk/mixins/compute.py +90 -60
- databricks/sdk/mixins/files.py +815 -145
- databricks/sdk/mixins/jobs.py +191 -16
- databricks/sdk/mixins/open_ai_client.py +26 -20
- databricks/sdk/mixins/workspace.py +45 -34
- databricks/sdk/oauth.py +372 -196
- databricks/sdk/retries.py +14 -12
- databricks/sdk/runtime/__init__.py +34 -17
- databricks/sdk/runtime/dbutils_stub.py +52 -39
- databricks/sdk/service/_internal.py +12 -7
- databricks/sdk/service/apps.py +618 -418
- databricks/sdk/service/billing.py +827 -604
- databricks/sdk/service/catalog.py +6552 -4474
- databricks/sdk/service/cleanrooms.py +550 -388
- databricks/sdk/service/compute.py +5241 -3531
- databricks/sdk/service/dashboards.py +1313 -923
- databricks/sdk/service/files.py +442 -309
- databricks/sdk/service/iam.py +2115 -1483
- databricks/sdk/service/jobs.py +4151 -2588
- databricks/sdk/service/marketplace.py +2210 -1517
- databricks/sdk/service/ml.py +3364 -2255
- databricks/sdk/service/oauth2.py +922 -584
- databricks/sdk/service/pipelines.py +1865 -1203
- databricks/sdk/service/provisioning.py +1435 -1029
- databricks/sdk/service/serving.py +2040 -1278
- databricks/sdk/service/settings.py +2846 -1929
- databricks/sdk/service/sharing.py +2201 -877
- databricks/sdk/service/sql.py +4650 -3103
- databricks/sdk/service/vectorsearch.py +816 -550
- databricks/sdk/service/workspace.py +1330 -906
- databricks/sdk/useragent.py +36 -22
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
- databricks_sdk-0.45.0.dist-info/RECORD +70 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
- databricks_sdk-0.44.1.dist-info/RECORD +0 -69
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
|
@@ -26,13 +26,17 @@ from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
|
|
|
26
26
|
|
|
27
27
|
CredentialsProvider = Callable[[], Dict[str, str]]
|
|
28
28
|
|
|
29
|
-
logger = logging.getLogger(
|
|
29
|
+
logger = logging.getLogger("databricks.sdk")
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class OAuthCredentialsProvider:
|
|
33
|
-
"""
|
|
33
|
+
"""OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens."""
|
|
34
34
|
|
|
35
|
-
def __init__(
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
credentials_provider: CredentialsProvider,
|
|
38
|
+
token_provider: Callable[[], Token],
|
|
39
|
+
):
|
|
36
40
|
self._credentials_provider = credentials_provider
|
|
37
41
|
self._token_provider = token_provider
|
|
38
42
|
|
|
@@ -44,45 +48,49 @@ class OAuthCredentialsProvider:
|
|
|
44
48
|
|
|
45
49
|
|
|
46
50
|
class CredentialsStrategy(abc.ABC):
|
|
47
|
-
"""
|
|
48
|
-
|
|
51
|
+
"""CredentialsProvider is the protocol (call-side interface)
|
|
52
|
+
for authenticating requests to Databricks REST APIs"""
|
|
49
53
|
|
|
50
54
|
@abc.abstractmethod
|
|
51
|
-
def auth_type(self) -> str:
|
|
52
|
-
...
|
|
55
|
+
def auth_type(self) -> str: ...
|
|
53
56
|
|
|
54
57
|
@abc.abstractmethod
|
|
55
|
-
def __call__(self, cfg:
|
|
56
|
-
...
|
|
58
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider: ...
|
|
57
59
|
|
|
58
60
|
|
|
59
61
|
class OauthCredentialsStrategy(CredentialsStrategy):
|
|
60
|
-
"""
|
|
62
|
+
"""OauthCredentialsProvider is a CredentialsProvider which
|
|
61
63
|
supports Oauth tokens"""
|
|
62
64
|
|
|
63
|
-
def __init__(
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
auth_type: str,
|
|
68
|
+
headers_provider: Callable[["Config"], OAuthCredentialsProvider],
|
|
69
|
+
):
|
|
64
70
|
self._headers_provider = headers_provider
|
|
65
71
|
self._auth_type = auth_type
|
|
66
72
|
|
|
67
73
|
def auth_type(self) -> str:
|
|
68
74
|
return self._auth_type
|
|
69
75
|
|
|
70
|
-
def __call__(self, cfg:
|
|
76
|
+
def __call__(self, cfg: "Config") -> OAuthCredentialsProvider:
|
|
71
77
|
return self._headers_provider(cfg)
|
|
72
78
|
|
|
73
|
-
def oauth_token(self, cfg:
|
|
79
|
+
def oauth_token(self, cfg: "Config") -> Token:
|
|
74
80
|
return self._headers_provider(cfg).oauth_token()
|
|
75
81
|
|
|
76
82
|
|
|
77
83
|
def credentials_strategy(name: str, require: List[str]):
|
|
78
|
-
"""
|
|
84
|
+
"""Given the function that receives a Config and returns RequestVisitor,
|
|
79
85
|
create CredentialsProvider with a given name and required configuration
|
|
80
|
-
attribute names to be present for this function to be called.
|
|
86
|
+
attribute names to be present for this function to be called."""
|
|
81
87
|
|
|
82
|
-
def inner(
|
|
88
|
+
def inner(
|
|
89
|
+
func: Callable[["Config"], CredentialsProvider],
|
|
90
|
+
) -> CredentialsStrategy:
|
|
83
91
|
|
|
84
92
|
@functools.wraps(func)
|
|
85
|
-
def wrapper(cfg:
|
|
93
|
+
def wrapper(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
86
94
|
for attr in require:
|
|
87
95
|
getattr(cfg, attr)
|
|
88
96
|
if not getattr(cfg, attr):
|
|
@@ -96,14 +104,16 @@ def credentials_strategy(name: str, require: List[str]):
|
|
|
96
104
|
|
|
97
105
|
|
|
98
106
|
def oauth_credentials_strategy(name: str, require: List[str]):
|
|
99
|
-
"""
|
|
107
|
+
"""Given the function that receives a Config and returns an OauthHeaderFactory,
|
|
100
108
|
create an OauthCredentialsProvider with a given name and required configuration
|
|
101
|
-
attribute names to be present for this function to be called.
|
|
109
|
+
attribute names to be present for this function to be called."""
|
|
102
110
|
|
|
103
|
-
def inner(
|
|
111
|
+
def inner(
|
|
112
|
+
func: Callable[["Config"], OAuthCredentialsProvider],
|
|
113
|
+
) -> OauthCredentialsStrategy:
|
|
104
114
|
|
|
105
115
|
@functools.wraps(func)
|
|
106
|
-
def wrapper(cfg:
|
|
116
|
+
def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]:
|
|
107
117
|
for attr in require:
|
|
108
118
|
if not getattr(cfg, attr):
|
|
109
119
|
return None
|
|
@@ -114,11 +124,11 @@ def oauth_credentials_strategy(name: str, require: List[str]):
|
|
|
114
124
|
return inner
|
|
115
125
|
|
|
116
126
|
|
|
117
|
-
@credentials_strategy(
|
|
118
|
-
def basic_auth(cfg:
|
|
119
|
-
"""
|
|
120
|
-
encoded = base64.b64encode(f
|
|
121
|
-
static_credentials = {
|
|
127
|
+
@credentials_strategy("basic", ["host", "username", "password"])
|
|
128
|
+
def basic_auth(cfg: "Config") -> CredentialsProvider:
|
|
129
|
+
"""Given username and password, add base64-encoded Basic credentials"""
|
|
130
|
+
encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode()
|
|
131
|
+
static_credentials = {"Authorization": f"Basic {encoded}"}
|
|
122
132
|
|
|
123
133
|
def inner() -> Dict[str, str]:
|
|
124
134
|
return static_credentials
|
|
@@ -126,10 +136,10 @@ def basic_auth(cfg: 'Config') -> CredentialsProvider:
|
|
|
126
136
|
return inner
|
|
127
137
|
|
|
128
138
|
|
|
129
|
-
@credentials_strategy(
|
|
130
|
-
def pat_auth(cfg:
|
|
131
|
-
"""
|
|
132
|
-
static_credentials = {
|
|
139
|
+
@credentials_strategy("pat", ["host", "token"])
|
|
140
|
+
def pat_auth(cfg: "Config") -> CredentialsProvider:
|
|
141
|
+
"""Adds Databricks Personal Access Token to every request"""
|
|
142
|
+
static_credentials = {"Authorization": f"Bearer {cfg.token}"}
|
|
133
143
|
|
|
134
144
|
def inner() -> Dict[str, str]:
|
|
135
145
|
return static_credentials
|
|
@@ -137,9 +147,9 @@ def pat_auth(cfg: 'Config') -> CredentialsProvider:
|
|
|
137
147
|
return inner
|
|
138
148
|
|
|
139
149
|
|
|
140
|
-
@credentials_strategy(
|
|
141
|
-
def runtime_native_auth(cfg:
|
|
142
|
-
if
|
|
150
|
+
@credentials_strategy("runtime", [])
|
|
151
|
+
def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
152
|
+
if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
|
|
143
153
|
return None
|
|
144
154
|
|
|
145
155
|
# This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
|
|
@@ -148,36 +158,44 @@ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
148
158
|
from databricks.sdk.runtime import (init_runtime_legacy_auth,
|
|
149
159
|
init_runtime_native_auth,
|
|
150
160
|
init_runtime_repl_auth)
|
|
151
|
-
|
|
161
|
+
|
|
162
|
+
for init in [
|
|
163
|
+
init_runtime_native_auth,
|
|
164
|
+
init_runtime_repl_auth,
|
|
165
|
+
init_runtime_legacy_auth,
|
|
166
|
+
]:
|
|
152
167
|
if init is None:
|
|
153
168
|
continue
|
|
154
169
|
host, inner = init()
|
|
155
170
|
if host is None:
|
|
156
|
-
logger.debug(f
|
|
171
|
+
logger.debug(f"[{init.__name__}] no host detected")
|
|
157
172
|
continue
|
|
158
173
|
cfg.host = host
|
|
159
|
-
logger.debug(f
|
|
174
|
+
logger.debug(f"[{init.__name__}] runtime native auth configured")
|
|
160
175
|
return inner
|
|
161
176
|
return None
|
|
162
177
|
|
|
163
178
|
|
|
164
|
-
@oauth_credentials_strategy(
|
|
165
|
-
def oauth_service_principal(cfg:
|
|
166
|
-
"""
|
|
167
|
-
if /oidc/.well-known/oauth-authorization-server is available on the given host.
|
|
179
|
+
@oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"])
|
|
180
|
+
def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
181
|
+
"""Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
|
|
182
|
+
if /oidc/.well-known/oauth-authorization-server is available on the given host.
|
|
183
|
+
"""
|
|
168
184
|
oidc = cfg.oidc_endpoints
|
|
169
185
|
if oidc is None:
|
|
170
186
|
return None
|
|
171
187
|
|
|
172
|
-
token_source = ClientCredentials(
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
188
|
+
token_source = ClientCredentials(
|
|
189
|
+
client_id=cfg.client_id,
|
|
190
|
+
client_secret=cfg.client_secret,
|
|
191
|
+
token_url=oidc.token_endpoint,
|
|
192
|
+
scopes=["all-apis"],
|
|
193
|
+
use_header=True,
|
|
194
|
+
)
|
|
177
195
|
|
|
178
196
|
def inner() -> Dict[str, str]:
|
|
179
197
|
token = token_source.token()
|
|
180
|
-
return {
|
|
198
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
181
199
|
|
|
182
200
|
def token() -> Token:
|
|
183
201
|
return token_source.token()
|
|
@@ -185,9 +203,9 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
185
203
|
return OAuthCredentialsProvider(inner, token)
|
|
186
204
|
|
|
187
205
|
|
|
188
|
-
@credentials_strategy(
|
|
189
|
-
def external_browser(cfg:
|
|
190
|
-
if cfg.auth_type !=
|
|
206
|
+
@credentials_strategy("external-browser", ["host", "auth_type"])
|
|
207
|
+
def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
208
|
+
if cfg.auth_type != "external-browser":
|
|
191
209
|
return None
|
|
192
210
|
|
|
193
211
|
client_id, client_secret = None, None
|
|
@@ -198,17 +216,19 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
198
216
|
client_id = cfg.azure_client
|
|
199
217
|
client_secret = cfg.azure_client_secret
|
|
200
218
|
if not client_id:
|
|
201
|
-
client_id =
|
|
219
|
+
client_id = "databricks-cli"
|
|
202
220
|
|
|
203
221
|
# Load cached credentials from disk if they exist. Note that these are
|
|
204
222
|
# local to the Python SDK and not reused by other SDKs.
|
|
205
223
|
oidc_endpoints = cfg.oidc_endpoints
|
|
206
|
-
redirect_url =
|
|
207
|
-
token_cache = TokenCache(
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
224
|
+
redirect_url = "http://localhost:8020"
|
|
225
|
+
token_cache = TokenCache(
|
|
226
|
+
host=cfg.host,
|
|
227
|
+
oidc_endpoints=oidc_endpoints,
|
|
228
|
+
client_id=client_id,
|
|
229
|
+
client_secret=client_secret,
|
|
230
|
+
redirect_url=redirect_url,
|
|
231
|
+
)
|
|
212
232
|
credentials = token_cache.load()
|
|
213
233
|
if credentials:
|
|
214
234
|
try:
|
|
@@ -219,12 +239,14 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
219
239
|
return credentials(cfg)
|
|
220
240
|
# TODO: We should ideally use more specific exceptions.
|
|
221
241
|
except Exception as e:
|
|
222
|
-
logger.warning(f
|
|
223
|
-
|
|
224
|
-
oauth_client = OAuthClient(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
242
|
+
logger.warning(f"Failed to refresh cached token: {e}. Initiating new OAuth login flow")
|
|
243
|
+
|
|
244
|
+
oauth_client = OAuthClient(
|
|
245
|
+
oidc_endpoints=oidc_endpoints,
|
|
246
|
+
client_id=client_id,
|
|
247
|
+
redirect_url=redirect_url,
|
|
248
|
+
client_secret=client_secret,
|
|
249
|
+
)
|
|
228
250
|
consent = oauth_client.initiate_consent()
|
|
229
251
|
if not consent:
|
|
230
252
|
return None
|
|
@@ -234,33 +256,41 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
234
256
|
return credentials(cfg)
|
|
235
257
|
|
|
236
258
|
|
|
237
|
-
def _ensure_host_present(cfg:
|
|
238
|
-
"""
|
|
259
|
+
def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], TokenSource]):
|
|
260
|
+
"""Resolves Azure Databricks workspace URL from ARM Resource ID"""
|
|
239
261
|
if cfg.host:
|
|
240
262
|
return
|
|
241
263
|
if not cfg.azure_workspace_resource_id:
|
|
242
264
|
return
|
|
243
265
|
arm = cfg.arm_environment.resource_manager_endpoint
|
|
244
266
|
token = token_source_for(arm).token()
|
|
245
|
-
resp = requests.get(
|
|
246
|
-
|
|
267
|
+
resp = requests.get(
|
|
268
|
+
f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
|
|
269
|
+
headers={"Authorization": f"Bearer {token.access_token}"},
|
|
270
|
+
)
|
|
247
271
|
if not resp.ok:
|
|
248
272
|
raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
|
|
249
273
|
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
|
|
250
274
|
|
|
251
275
|
|
|
252
|
-
@oauth_credentials_strategy(
|
|
253
|
-
|
|
254
|
-
"""
|
|
255
|
-
|
|
276
|
+
@oauth_credentials_strategy(
|
|
277
|
+
"azure-client-secret",
|
|
278
|
+
["is_azure", "azure_client_id", "azure_client_secret"],
|
|
279
|
+
)
|
|
280
|
+
def azure_service_principal(cfg: "Config") -> CredentialsProvider:
|
|
281
|
+
"""Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
|
|
282
|
+
to every request, while automatically resolving different Azure environment endpoints.
|
|
283
|
+
"""
|
|
256
284
|
|
|
257
285
|
def token_source_for(resource: str) -> TokenSource:
|
|
258
286
|
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
259
|
-
return ClientCredentials(
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
287
|
+
return ClientCredentials(
|
|
288
|
+
client_id=cfg.azure_client_id,
|
|
289
|
+
client_secret=cfg.azure_client_secret,
|
|
290
|
+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
291
|
+
endpoint_params={"resource": resource},
|
|
292
|
+
use_params=True,
|
|
293
|
+
)
|
|
264
294
|
|
|
265
295
|
_ensure_host_present(cfg, token_source_for)
|
|
266
296
|
cfg.load_azure_tenant_id()
|
|
@@ -269,7 +299,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
|
269
299
|
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
|
|
270
300
|
|
|
271
301
|
def refreshed_headers() -> Dict[str, str]:
|
|
272
|
-
headers = {
|
|
302
|
+
headers = {
|
|
303
|
+
"Authorization": f"Bearer {inner.token().access_token}",
|
|
304
|
+
}
|
|
273
305
|
add_workspace_id_header(cfg, headers)
|
|
274
306
|
add_sp_management_token(cloud, headers)
|
|
275
307
|
return headers
|
|
@@ -280,9 +312,9 @@ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
|
280
312
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
281
313
|
|
|
282
314
|
|
|
283
|
-
@oauth_credentials_strategy(
|
|
284
|
-
def github_oidc_azure(cfg:
|
|
285
|
-
if
|
|
315
|
+
@oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"])
|
|
316
|
+
def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
317
|
+
if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ:
|
|
286
318
|
# not in GitHub actions
|
|
287
319
|
return None
|
|
288
320
|
|
|
@@ -292,7 +324,7 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
292
324
|
return None
|
|
293
325
|
|
|
294
326
|
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
|
|
295
|
-
headers = {
|
|
327
|
+
headers = {"Authorization": f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
|
|
296
328
|
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
|
|
297
329
|
response = requests.get(endpoint, headers=headers)
|
|
298
330
|
if not response.ok:
|
|
@@ -300,30 +332,34 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
300
332
|
|
|
301
333
|
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
|
|
302
334
|
response_json = response.json()
|
|
303
|
-
if
|
|
335
|
+
if "value" not in response_json:
|
|
304
336
|
return None
|
|
305
337
|
|
|
306
|
-
logger.info(
|
|
338
|
+
logger.info(
|
|
339
|
+
"Configured AAD token for GitHub Actions OIDC (%s)",
|
|
340
|
+
cfg.azure_client_id,
|
|
341
|
+
)
|
|
307
342
|
params = {
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
343
|
+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
|
344
|
+
"resource": cfg.effective_azure_login_app_id,
|
|
345
|
+
"client_assertion": response_json["value"],
|
|
311
346
|
}
|
|
312
347
|
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
313
348
|
if not cfg.azure_tenant_id:
|
|
314
349
|
# detect Azure AD Tenant ID if it's not specified directly
|
|
315
350
|
token_endpoint = cfg.oidc_endpoints.token_endpoint
|
|
316
|
-
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint,
|
|
351
|
+
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]
|
|
317
352
|
inner = ClientCredentials(
|
|
318
353
|
client_id=cfg.azure_client_id,
|
|
319
|
-
client_secret="",
|
|
354
|
+
client_secret="", # we have no (rotatable) secrets in OIDC flow
|
|
320
355
|
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
321
356
|
endpoint_params=params,
|
|
322
|
-
use_params=True
|
|
357
|
+
use_params=True,
|
|
358
|
+
)
|
|
323
359
|
|
|
324
360
|
def refreshed_headers() -> Dict[str, str]:
|
|
325
361
|
token = inner.token()
|
|
326
|
-
return {
|
|
362
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
327
363
|
|
|
328
364
|
def token() -> Token:
|
|
329
365
|
return inner.token()
|
|
@@ -331,29 +367,32 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
331
367
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
332
368
|
|
|
333
369
|
|
|
334
|
-
GcpScopes = [
|
|
370
|
+
GcpScopes = [
|
|
371
|
+
"https://www.googleapis.com/auth/cloud-platform",
|
|
372
|
+
"https://www.googleapis.com/auth/compute",
|
|
373
|
+
]
|
|
335
374
|
|
|
336
375
|
|
|
337
|
-
@oauth_credentials_strategy(
|
|
338
|
-
def google_credentials(cfg:
|
|
376
|
+
@oauth_credentials_strategy("google-credentials", ["host", "google_credentials"])
|
|
377
|
+
def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
339
378
|
if not cfg.is_gcp:
|
|
340
379
|
return None
|
|
341
380
|
# Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
|
|
342
381
|
# Obtain the id token by providing the json file path and target audience.
|
|
343
|
-
if
|
|
382
|
+
if os.path.isfile(cfg.google_credentials):
|
|
344
383
|
with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
|
|
345
384
|
account_info = json.load(json_file)
|
|
346
385
|
else:
|
|
347
386
|
# If the file doesn't exist, assume that the config is the actual JSON content.
|
|
348
387
|
account_info = json.loads(cfg.google_credentials)
|
|
349
388
|
|
|
350
|
-
credentials = service_account.IDTokenCredentials.from_service_account_info(
|
|
351
|
-
|
|
389
|
+
credentials = service_account.IDTokenCredentials.from_service_account_info(
|
|
390
|
+
info=account_info, target_audience=cfg.host
|
|
391
|
+
)
|
|
352
392
|
|
|
353
393
|
request = Request()
|
|
354
394
|
|
|
355
|
-
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
|
|
356
|
-
scopes=GcpScopes)
|
|
395
|
+
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, scopes=GcpScopes)
|
|
357
396
|
|
|
358
397
|
def token() -> Token:
|
|
359
398
|
credentials.refresh(request)
|
|
@@ -361,7 +400,7 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
361
400
|
|
|
362
401
|
def refreshed_headers() -> Dict[str, str]:
|
|
363
402
|
credentials.refresh(request)
|
|
364
|
-
headers = {
|
|
403
|
+
headers = {"Authorization": f"Bearer {credentials.token}"}
|
|
365
404
|
if cfg.is_account_client:
|
|
366
405
|
gcp_credentials.refresh(request)
|
|
367
406
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
|
|
@@ -370,24 +409,29 @@ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
370
409
|
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
371
410
|
|
|
372
411
|
|
|
373
|
-
@oauth_credentials_strategy(
|
|
374
|
-
def google_id(cfg:
|
|
412
|
+
@oauth_credentials_strategy("google-id", ["host", "google_service_account"])
|
|
413
|
+
def google_id(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
375
414
|
if not cfg.is_gcp:
|
|
376
415
|
return None
|
|
377
416
|
credentials, _project_id = google.auth.default()
|
|
378
417
|
|
|
379
418
|
# Create the impersonated credential.
|
|
380
|
-
target_credentials = impersonated_credentials.Credentials(
|
|
381
|
-
|
|
382
|
-
|
|
419
|
+
target_credentials = impersonated_credentials.Credentials(
|
|
420
|
+
source_credentials=credentials,
|
|
421
|
+
target_principal=cfg.google_service_account,
|
|
422
|
+
target_scopes=[],
|
|
423
|
+
)
|
|
383
424
|
|
|
384
425
|
# Set the impersonated credential, target audience and token options.
|
|
385
|
-
id_creds = impersonated_credentials.IDTokenCredentials(
|
|
386
|
-
|
|
387
|
-
|
|
426
|
+
id_creds = impersonated_credentials.IDTokenCredentials(
|
|
427
|
+
target_credentials, target_audience=cfg.host, include_email=True
|
|
428
|
+
)
|
|
388
429
|
|
|
389
430
|
gcp_impersonated_credentials = impersonated_credentials.Credentials(
|
|
390
|
-
source_credentials=credentials,
|
|
431
|
+
source_credentials=credentials,
|
|
432
|
+
target_principal=cfg.google_service_account,
|
|
433
|
+
target_scopes=GcpScopes,
|
|
434
|
+
)
|
|
391
435
|
|
|
392
436
|
request = Request()
|
|
393
437
|
|
|
@@ -397,7 +441,7 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
397
441
|
|
|
398
442
|
def refreshed_headers() -> Dict[str, str]:
|
|
399
443
|
id_creds.refresh(request)
|
|
400
|
-
headers = {
|
|
444
|
+
headers = {"Authorization": f"Bearer {id_creds.token}"}
|
|
401
445
|
if cfg.is_account_client:
|
|
402
446
|
gcp_impersonated_credentials.refresh(request)
|
|
403
447
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
|
|
@@ -408,7 +452,13 @@ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
408
452
|
|
|
409
453
|
class CliTokenSource(Refreshable):
|
|
410
454
|
|
|
411
|
-
def __init__(
|
|
455
|
+
def __init__(
|
|
456
|
+
self,
|
|
457
|
+
cmd: List[str],
|
|
458
|
+
token_type_field: str,
|
|
459
|
+
access_token_field: str,
|
|
460
|
+
expiry_field: str,
|
|
461
|
+
):
|
|
412
462
|
super().__init__()
|
|
413
463
|
self._cmd = cmd
|
|
414
464
|
self._token_type_field = token_type_field
|
|
@@ -431,52 +481,74 @@ class CliTokenSource(Refreshable):
|
|
|
431
481
|
out = _run_subprocess(self._cmd, capture_output=True, check=True)
|
|
432
482
|
it = json.loads(out.stdout.decode())
|
|
433
483
|
expires_on = self._parse_expiry(it[self._expiry_field])
|
|
434
|
-
return Token(
|
|
435
|
-
|
|
436
|
-
|
|
484
|
+
return Token(
|
|
485
|
+
access_token=it[self._access_token_field],
|
|
486
|
+
token_type=it[self._token_type_field],
|
|
487
|
+
expiry=expires_on,
|
|
488
|
+
)
|
|
437
489
|
except ValueError as e:
|
|
438
490
|
raise ValueError(f"cannot unmarshal CLI result: {e}")
|
|
439
491
|
except subprocess.CalledProcessError as e:
|
|
440
492
|
stdout = e.stdout.decode().strip()
|
|
441
493
|
stderr = e.stderr.decode().strip()
|
|
442
494
|
message = stdout or stderr
|
|
443
|
-
raise IOError(f
|
|
495
|
+
raise IOError(f"cannot get access token: {message}") from e
|
|
444
496
|
|
|
445
497
|
|
|
446
|
-
def _run_subprocess(
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
498
|
+
def _run_subprocess(
|
|
499
|
+
popenargs,
|
|
500
|
+
input=None,
|
|
501
|
+
capture_output=True,
|
|
502
|
+
timeout=None,
|
|
503
|
+
check=False,
|
|
504
|
+
**kwargs,
|
|
505
|
+
) -> subprocess.CompletedProcess:
|
|
452
506
|
"""Runs subprocess with given arguments.
|
|
453
|
-
This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
|
|
454
|
-
|
|
507
|
+
This handles OS-specific modifications that need to be made to the invocation of subprocess.run.
|
|
508
|
+
"""
|
|
509
|
+
kwargs["shell"] = sys.platform.startswith("win")
|
|
455
510
|
# windows requires shell=True to be able to execute 'az login' or other commands
|
|
456
511
|
# cannot use shell=True all the time, as it breaks macOS
|
|
457
512
|
logging.debug(f'Running command: {" ".join(popenargs)}')
|
|
458
|
-
return subprocess.run(
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
513
|
+
return subprocess.run(
|
|
514
|
+
popenargs,
|
|
515
|
+
input=input,
|
|
516
|
+
capture_output=capture_output,
|
|
517
|
+
timeout=timeout,
|
|
518
|
+
check=check,
|
|
519
|
+
**kwargs,
|
|
520
|
+
)
|
|
464
521
|
|
|
465
522
|
|
|
466
523
|
class AzureCliTokenSource(CliTokenSource):
|
|
467
|
-
"""
|
|
468
|
-
|
|
469
|
-
def __init__(
|
|
470
|
-
|
|
524
|
+
"""Obtain the token granted by `az login` CLI command"""
|
|
525
|
+
|
|
526
|
+
def __init__(
|
|
527
|
+
self,
|
|
528
|
+
resource: str,
|
|
529
|
+
subscription: Optional[str] = None,
|
|
530
|
+
tenant: Optional[str] = None,
|
|
531
|
+
):
|
|
532
|
+
cmd = [
|
|
533
|
+
"az",
|
|
534
|
+
"account",
|
|
535
|
+
"get-access-token",
|
|
536
|
+
"--resource",
|
|
537
|
+
resource,
|
|
538
|
+
"--output",
|
|
539
|
+
"json",
|
|
540
|
+
]
|
|
471
541
|
if subscription is not None:
|
|
472
542
|
cmd.append("--subscription")
|
|
473
543
|
cmd.append(subscription)
|
|
474
544
|
if tenant and not self.__is_cli_using_managed_identity():
|
|
475
545
|
cmd.extend(["--tenant", tenant])
|
|
476
|
-
super().__init__(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
546
|
+
super().__init__(
|
|
547
|
+
cmd=cmd,
|
|
548
|
+
token_type_field="tokenType",
|
|
549
|
+
access_token_field="accessToken",
|
|
550
|
+
expiry_field="expiresOn",
|
|
551
|
+
)
|
|
480
552
|
|
|
481
553
|
@staticmethod
|
|
482
554
|
def __is_cli_using_managed_identity() -> bool:
|
|
@@ -489,7 +561,8 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
489
561
|
if user is None:
|
|
490
562
|
return False
|
|
491
563
|
return user.get("type") == "servicePrincipal" and user.get("name") in [
|
|
492
|
-
|
|
564
|
+
"systemAssignedIdentity",
|
|
565
|
+
"userAssignedIdentity",
|
|
493
566
|
]
|
|
494
567
|
except subprocess.CalledProcessError as e:
|
|
495
568
|
logger.debug("Failed to get account information from Azure CLI", exc_info=e)
|
|
@@ -512,15 +585,13 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
512
585
|
guaranteed to be unique within a tenant and should be used only for display purposes.
|
|
513
586
|
- 'upn' - The username of the user.
|
|
514
587
|
"""
|
|
515
|
-
return
|
|
588
|
+
return "upn" in self.token().jwt_claims()
|
|
516
589
|
|
|
517
590
|
@staticmethod
|
|
518
|
-
def for_resource(cfg:
|
|
591
|
+
def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource":
|
|
519
592
|
subscription = AzureCliTokenSource.get_subscription(cfg)
|
|
520
593
|
if subscription is not None:
|
|
521
|
-
token_source = AzureCliTokenSource(resource,
|
|
522
|
-
subscription=subscription,
|
|
523
|
-
tenant=cfg.azure_tenant_id)
|
|
594
|
+
token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
|
|
524
595
|
try:
|
|
525
596
|
# This will fail if the user has access to the workspace, but not to the subscription
|
|
526
597
|
# itself.
|
|
@@ -535,32 +606,32 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
535
606
|
return token_source
|
|
536
607
|
|
|
537
608
|
@staticmethod
|
|
538
|
-
def get_subscription(cfg:
|
|
609
|
+
def get_subscription(cfg: "Config") -> Optional[str]:
|
|
539
610
|
resource = cfg.azure_workspace_resource_id
|
|
540
611
|
if resource is None or resource == "":
|
|
541
612
|
return None
|
|
542
|
-
components = resource.split(
|
|
613
|
+
components = resource.split("/")
|
|
543
614
|
if len(components) < 3:
|
|
544
615
|
logger.warning("Invalid azure workspace resource ID")
|
|
545
616
|
return None
|
|
546
617
|
return components[2]
|
|
547
618
|
|
|
548
619
|
|
|
549
|
-
@credentials_strategy(
|
|
550
|
-
def azure_cli(cfg:
|
|
551
|
-
"""
|
|
620
|
+
@credentials_strategy("azure-cli", ["is_azure"])
|
|
621
|
+
def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
622
|
+
"""Adds refreshed OAuth token granted by `az login` command to every request."""
|
|
552
623
|
cfg.load_azure_tenant_id()
|
|
553
624
|
token_source = None
|
|
554
625
|
mgmt_token_source = None
|
|
555
626
|
try:
|
|
556
627
|
token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
|
|
557
628
|
except FileNotFoundError:
|
|
558
|
-
doc =
|
|
559
|
-
logger.debug(f
|
|
629
|
+
doc = "https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest"
|
|
630
|
+
logger.debug(f"Most likely Azure CLI is not installed. See {doc} for details")
|
|
560
631
|
return None
|
|
561
632
|
except OSError as e:
|
|
562
|
-
logger.debug(
|
|
563
|
-
logger.debug(
|
|
633
|
+
logger.debug("skipping Azure CLI auth", exc_info=e)
|
|
634
|
+
logger.debug("This may happen if you are attempting to login to a dev or staging workspace")
|
|
564
635
|
return None
|
|
565
636
|
|
|
566
637
|
if not token_source.is_human_user():
|
|
@@ -568,7 +639,10 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
568
639
|
management_endpoint = cfg.arm_environment.service_management_endpoint
|
|
569
640
|
mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
|
|
570
641
|
except Exception as e:
|
|
571
|
-
logger.debug(
|
|
642
|
+
logger.debug(
|
|
643
|
+
f"Not including service management token in headers",
|
|
644
|
+
exc_info=e,
|
|
645
|
+
)
|
|
572
646
|
mgmt_token_source = None
|
|
573
647
|
|
|
574
648
|
_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
|
|
@@ -576,7 +650,7 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
576
650
|
|
|
577
651
|
def inner() -> Dict[str, str]:
|
|
578
652
|
token = token_source.token()
|
|
579
|
-
headers = {
|
|
653
|
+
headers = {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
580
654
|
add_workspace_id_header(cfg, headers)
|
|
581
655
|
if mgmt_token_source:
|
|
582
656
|
add_sp_management_token(mgmt_token_source, headers)
|
|
@@ -586,12 +660,12 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
586
660
|
|
|
587
661
|
|
|
588
662
|
class DatabricksCliTokenSource(CliTokenSource):
|
|
589
|
-
"""
|
|
663
|
+
"""Obtain the token granted by `databricks auth login` CLI command"""
|
|
590
664
|
|
|
591
|
-
def __init__(self, cfg:
|
|
592
|
-
args = [
|
|
665
|
+
def __init__(self, cfg: "Config"):
|
|
666
|
+
args = ["auth", "token", "--host", cfg.host]
|
|
593
667
|
if cfg.is_account_client:
|
|
594
|
-
args += [
|
|
668
|
+
args += ["--account-id", cfg.account_id]
|
|
595
669
|
|
|
596
670
|
cli_path = cfg.databricks_cli_path
|
|
597
671
|
|
|
@@ -611,10 +685,12 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
611
685
|
elif cli_path.count("/") == 0:
|
|
612
686
|
cli_path = self.__class__._find_executable(cli_path)
|
|
613
687
|
|
|
614
|
-
super().__init__(
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
688
|
+
super().__init__(
|
|
689
|
+
cmd=[cli_path, *args],
|
|
690
|
+
token_type_field="token_type",
|
|
691
|
+
access_token_field="access_token",
|
|
692
|
+
expiry_field="expiry",
|
|
693
|
+
)
|
|
618
694
|
|
|
619
695
|
@staticmethod
|
|
620
696
|
def _find_executable(name) -> str:
|
|
@@ -636,8 +712,8 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
636
712
|
raise err
|
|
637
713
|
|
|
638
714
|
|
|
639
|
-
@oauth_credentials_strategy(
|
|
640
|
-
def databricks_cli(cfg:
|
|
715
|
+
@oauth_credentials_strategy("databricks-cli", ["host"])
|
|
716
|
+
def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
641
717
|
try:
|
|
642
718
|
token_source = DatabricksCliTokenSource(cfg)
|
|
643
719
|
except FileNotFoundError as e:
|
|
@@ -647,8 +723,8 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
647
723
|
try:
|
|
648
724
|
token_source.token()
|
|
649
725
|
except IOError as e:
|
|
650
|
-
if
|
|
651
|
-
logger.debug(f
|
|
726
|
+
if "databricks OAuth is not" in str(e):
|
|
727
|
+
logger.debug(f"OAuth not configured or not available: {e}")
|
|
652
728
|
return None
|
|
653
729
|
raise e
|
|
654
730
|
|
|
@@ -656,7 +732,7 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
656
732
|
|
|
657
733
|
def inner() -> Dict[str, str]:
|
|
658
734
|
token = token_source.token()
|
|
659
|
-
return {
|
|
735
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
660
736
|
|
|
661
737
|
def token() -> Token:
|
|
662
738
|
return token_source.token()
|
|
@@ -665,13 +741,14 @@ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
665
741
|
|
|
666
742
|
|
|
667
743
|
class MetadataServiceTokenSource(Refreshable):
|
|
668
|
-
"""
|
|
744
|
+
"""Obtain the token granted by Databricks Metadata Service"""
|
|
745
|
+
|
|
669
746
|
METADATA_SERVICE_VERSION = "1"
|
|
670
747
|
METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
|
|
671
748
|
METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
|
|
672
|
-
_metadata_service_timeout = 10
|
|
749
|
+
_metadata_service_timeout = 10 # seconds
|
|
673
750
|
|
|
674
|
-
def __init__(self, cfg:
|
|
751
|
+
def __init__(self, cfg: "Config"):
|
|
675
752
|
super().__init__()
|
|
676
753
|
self.url = cfg.metadata_service_url
|
|
677
754
|
self.host = cfg.host
|
|
@@ -682,13 +759,14 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
682
759
|
timeout=self._metadata_service_timeout,
|
|
683
760
|
headers={
|
|
684
761
|
self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
|
|
685
|
-
self.METADATA_SERVICE_HOST_HEADER: self.host
|
|
762
|
+
self.METADATA_SERVICE_HOST_HEADER: self.host,
|
|
686
763
|
},
|
|
687
764
|
proxies={
|
|
688
765
|
# Explicitly exclude localhost from being proxied. This is necessary
|
|
689
766
|
# for Metadata URLs which typically point to localhost.
|
|
690
767
|
"no_proxy": "localhost,127.0.0.1"
|
|
691
|
-
}
|
|
768
|
+
},
|
|
769
|
+
)
|
|
692
770
|
json_resp: dict[str, Union[str, float]] = resp.json()
|
|
693
771
|
access_token = json_resp.get("access_token", None)
|
|
694
772
|
if access_token is None:
|
|
@@ -706,9 +784,9 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
706
784
|
return Token(access_token=access_token, token_type=token_type, expiry=expiry)
|
|
707
785
|
|
|
708
786
|
|
|
709
|
-
@credentials_strategy(
|
|
710
|
-
def metadata_service(cfg:
|
|
711
|
-
"""
|
|
787
|
+
@credentials_strategy("metadata-service", ["host", "metadata_service_url"])
|
|
788
|
+
def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
789
|
+
"""Adds refreshed token granted by Databricks Metadata Service to every request."""
|
|
712
790
|
|
|
713
791
|
token_source = MetadataServiceTokenSource(cfg)
|
|
714
792
|
token_source.token()
|
|
@@ -716,14 +794,14 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
716
794
|
|
|
717
795
|
def inner() -> Dict[str, str]:
|
|
718
796
|
token = token_source.token()
|
|
719
|
-
return {
|
|
797
|
+
return {"Authorization": f"{token.token_type} {token.access_token}"}
|
|
720
798
|
|
|
721
799
|
return inner
|
|
722
800
|
|
|
723
801
|
|
|
724
802
|
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
|
|
725
803
|
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
|
|
726
|
-
class ModelServingAuthProvider
|
|
804
|
+
class ModelServingAuthProvider:
|
|
727
805
|
USER_CREDENTIALS = "user_credentials"
|
|
728
806
|
|
|
729
807
|
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
|
|
@@ -731,7 +809,7 @@ class ModelServingAuthProvider():
|
|
|
731
809
|
def __init__(self, credential_type: Optional[str]):
|
|
732
810
|
self.expiry_time = -1
|
|
733
811
|
self.current_token = None
|
|
734
|
-
self.refresh_duration = 300
|
|
812
|
+
self.refresh_duration = 300 # 300 Seconds
|
|
735
813
|
self.credential_type = credential_type
|
|
736
814
|
|
|
737
815
|
def should_fetch_model_serving_environment_oauth() -> bool:
|
|
@@ -740,10 +818,14 @@ class ModelServingAuthProvider():
|
|
|
740
818
|
Additionally check if the oauth token file path exists
|
|
741
819
|
"""
|
|
742
820
|
|
|
743
|
-
is_in_model_serving_env = (
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
821
|
+
is_in_model_serving_env = (
|
|
822
|
+
os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
|
|
823
|
+
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV")
|
|
824
|
+
or "false"
|
|
825
|
+
)
|
|
826
|
+
return is_in_model_serving_env == "true" and os.path.isfile(
|
|
827
|
+
ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH
|
|
828
|
+
)
|
|
747
829
|
|
|
748
830
|
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
|
|
749
831
|
# Use Cached value if it is valid
|
|
@@ -758,8 +840,10 @@ class ModelServingAuthProvider():
|
|
|
758
840
|
except Exception as e:
|
|
759
841
|
# sleep and retry in case of any race conditions with OAuth refreshing
|
|
760
842
|
if should_retry:
|
|
761
|
-
logger.warning(
|
|
762
|
-
|
|
843
|
+
logger.warning(
|
|
844
|
+
"Unable to read oauth token on first attmept in Model Serving Environment",
|
|
845
|
+
exc_info=e,
|
|
846
|
+
)
|
|
763
847
|
time.sleep(0.5)
|
|
764
848
|
return self._get_model_dependency_oauth_token(should_retry=False)
|
|
765
849
|
else:
|
|
@@ -769,8 +853,8 @@ class ModelServingAuthProvider():
|
|
|
769
853
|
return self.current_token
|
|
770
854
|
|
|
771
855
|
def _get_invokers_token(self):
|
|
772
|
-
|
|
773
|
-
thread_data =
|
|
856
|
+
main_thread = threading.main_thread()
|
|
857
|
+
thread_data = main_thread.__dict__
|
|
774
858
|
invokers_token = None
|
|
775
859
|
if "invokers_token" in thread_data:
|
|
776
860
|
invokers_token = thread_data["invokers_token"]
|
|
@@ -785,8 +869,7 @@ class ModelServingAuthProvider():
|
|
|
785
869
|
return None
|
|
786
870
|
|
|
787
871
|
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
|
|
788
|
-
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
|
|
789
|
-
"DB_MODEL_SERVING_HOST_URL")
|
|
872
|
+
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL")
|
|
790
873
|
|
|
791
874
|
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
|
|
792
875
|
return (host, self._get_invokers_token())
|
|
@@ -794,8 +877,7 @@ class ModelServingAuthProvider():
|
|
|
794
877
|
return (host, self._get_model_dependency_oauth_token())
|
|
795
878
|
|
|
796
879
|
|
|
797
|
-
def model_serving_auth_visitor(cfg:
|
|
798
|
-
credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
|
|
880
|
+
def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
|
|
799
881
|
try:
|
|
800
882
|
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
|
|
801
883
|
host, token = model_serving_auth_provider.get_databricks_host_token()
|
|
@@ -806,7 +888,10 @@ def model_serving_auth_visitor(cfg: 'Config',
|
|
|
806
888
|
if cfg.host is None:
|
|
807
889
|
cfg.host = host
|
|
808
890
|
except Exception as e:
|
|
809
|
-
logger.warning(
|
|
891
|
+
logger.warning(
|
|
892
|
+
"Unable to get auth from Databricks Model Serving Environment",
|
|
893
|
+
exc_info=e,
|
|
894
|
+
)
|
|
810
895
|
return None
|
|
811
896
|
logger.info("Using Databricks Model Serving Authentication")
|
|
812
897
|
|
|
@@ -818,8 +903,8 @@ def model_serving_auth_visitor(cfg: 'Config',
|
|
|
818
903
|
return inner
|
|
819
904
|
|
|
820
905
|
|
|
821
|
-
@credentials_strategy(
|
|
822
|
-
def model_serving_auth(cfg:
|
|
906
|
+
@credentials_strategy("model-serving", [])
|
|
907
|
+
def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]:
|
|
823
908
|
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
824
909
|
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
|
|
825
910
|
return None
|
|
@@ -828,20 +913,30 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
|
828
913
|
|
|
829
914
|
|
|
830
915
|
class DefaultCredentials:
|
|
831
|
-
"""
|
|
916
|
+
"""Select the first applicable credential provider from the chain"""
|
|
832
917
|
|
|
833
918
|
def __init__(self) -> None:
|
|
834
|
-
self._auth_type =
|
|
919
|
+
self._auth_type = "default"
|
|
835
920
|
self._auth_providers = [
|
|
836
|
-
pat_auth,
|
|
837
|
-
|
|
838
|
-
|
|
921
|
+
pat_auth,
|
|
922
|
+
basic_auth,
|
|
923
|
+
metadata_service,
|
|
924
|
+
oauth_service_principal,
|
|
925
|
+
azure_service_principal,
|
|
926
|
+
github_oidc_azure,
|
|
927
|
+
azure_cli,
|
|
928
|
+
external_browser,
|
|
929
|
+
databricks_cli,
|
|
930
|
+
runtime_native_auth,
|
|
931
|
+
google_credentials,
|
|
932
|
+
google_id,
|
|
933
|
+
model_serving_auth,
|
|
839
934
|
]
|
|
840
935
|
|
|
841
936
|
def auth_type(self) -> str:
|
|
842
937
|
return self._auth_type
|
|
843
938
|
|
|
844
|
-
def oauth_token(self, cfg:
|
|
939
|
+
def oauth_token(self, cfg: "Config") -> Token:
|
|
845
940
|
for provider in self._auth_providers:
|
|
846
941
|
auth_type = provider.auth_type()
|
|
847
942
|
if auth_type != self._auth_type:
|
|
@@ -849,14 +944,14 @@ class DefaultCredentials:
|
|
|
849
944
|
continue
|
|
850
945
|
return provider.oauth_token(cfg)
|
|
851
946
|
|
|
852
|
-
def __call__(self, cfg:
|
|
947
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider:
|
|
853
948
|
for provider in self._auth_providers:
|
|
854
949
|
auth_type = provider.auth_type()
|
|
855
950
|
if cfg.auth_type and auth_type != cfg.auth_type:
|
|
856
951
|
# ignore other auth types if one is explicitly enforced
|
|
857
952
|
logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
|
|
858
953
|
continue
|
|
859
|
-
logger.debug(f
|
|
954
|
+
logger.debug(f"Attempting to configure auth: {auth_type}")
|
|
860
955
|
try:
|
|
861
956
|
header_factory = provider(cfg)
|
|
862
957
|
if not header_factory:
|
|
@@ -864,18 +959,18 @@ class DefaultCredentials:
|
|
|
864
959
|
self._auth_type = auth_type
|
|
865
960
|
return header_factory
|
|
866
961
|
except Exception as e:
|
|
867
|
-
raise ValueError(f
|
|
962
|
+
raise ValueError(f"{auth_type}: {e}") from e
|
|
868
963
|
auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
|
|
869
964
|
raise ValueError(
|
|
870
|
-
f
|
|
965
|
+
f"cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method."
|
|
871
966
|
)
|
|
872
967
|
|
|
873
968
|
|
|
874
969
|
class ModelServingUserCredentials(CredentialsStrategy):
|
|
875
970
|
"""
|
|
876
|
-
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
|
|
877
|
-
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
|
|
878
|
-
In any other environments, the class defaults to the DefaultCredentialStrategy.
|
|
971
|
+
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
|
|
972
|
+
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
|
|
973
|
+
In any other environments, the class defaults to the DefaultCredentialStrategy.
|
|
879
974
|
To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
|
|
880
975
|
|
|
881
976
|
invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
|
|
@@ -891,7 +986,7 @@ class ModelServingUserCredentials(CredentialsStrategy):
|
|
|
891
986
|
else:
|
|
892
987
|
return self.default_credentials.auth_type()
|
|
893
988
|
|
|
894
|
-
def __call__(self, cfg:
|
|
989
|
+
def __call__(self, cfg: "Config") -> CredentialsProvider:
|
|
895
990
|
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
|
|
896
991
|
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
|
|
897
992
|
if not header_factory:
|