databricks-sdk 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +41 -5
- databricks/sdk/azure.py +17 -7
- databricks/sdk/clock.py +49 -0
- databricks/sdk/config.py +459 -0
- databricks/sdk/core.py +7 -1026
- databricks/sdk/credentials_provider.py +628 -0
- databricks/sdk/environments.py +72 -0
- databricks/sdk/errors/__init__.py +1 -1
- databricks/sdk/errors/mapper.py +5 -5
- databricks/sdk/mixins/workspace.py +3 -3
- databricks/sdk/oauth.py +2 -1
- databricks/sdk/retries.py +9 -5
- databricks/sdk/service/_internal.py +1 -1
- databricks/sdk/service/catalog.py +946 -82
- databricks/sdk/service/compute.py +106 -41
- databricks/sdk/service/files.py +145 -31
- databricks/sdk/service/iam.py +44 -40
- databricks/sdk/service/jobs.py +199 -20
- databricks/sdk/service/ml.py +33 -42
- databricks/sdk/service/oauth2.py +3 -4
- databricks/sdk/service/pipelines.py +51 -31
- databricks/sdk/service/serving.py +1 -2
- databricks/sdk/service/settings.py +377 -72
- databricks/sdk/service/sharing.py +3 -4
- databricks/sdk/service/sql.py +27 -19
- databricks/sdk/service/vectorsearch.py +13 -17
- databricks/sdk/service/workspace.py +20 -11
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.17.0.dist-info → databricks_sdk-0.19.0.dist-info}/METADATA +4 -4
- databricks_sdk-0.19.0.dist-info/RECORD +53 -0
- databricks_sdk-0.17.0.dist-info/RECORD +0 -49
- /databricks/sdk/errors/{mapping.py → platform.py} +0 -0
- {databricks_sdk-0.17.0.dist-info → databricks_sdk-0.19.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.17.0.dist-info → databricks_sdk-0.19.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.17.0.dist-info → databricks_sdk-0.19.0.dist-info}/WHEEL +0 -0
- {databricks_sdk-0.17.0.dist-info → databricks_sdk-0.19.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,628 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import base64
|
|
3
|
+
import functools
|
|
4
|
+
import io
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import pathlib
|
|
9
|
+
import platform
|
|
10
|
+
import subprocess
|
|
11
|
+
import sys
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import Callable, Dict, List, Optional, Union
|
|
14
|
+
|
|
15
|
+
import google.auth
|
|
16
|
+
import requests
|
|
17
|
+
from google.auth import impersonated_credentials
|
|
18
|
+
from google.auth.transport.requests import Request
|
|
19
|
+
from google.oauth2 import service_account
|
|
20
|
+
|
|
21
|
+
from .azure import add_sp_management_token, add_workspace_id_header
|
|
22
|
+
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
|
|
23
|
+
TokenCache, TokenSource)
|
|
24
|
+
|
|
25
|
+
HeaderFactory = Callable[[], Dict[str, str]]
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger('databricks.sdk')
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CredentialsProvider(abc.ABC):
|
|
31
|
+
""" CredentialsProvider is the protocol (call-side interface)
|
|
32
|
+
for authenticating requests to Databricks REST APIs"""
|
|
33
|
+
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def auth_type(self) -> str:
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def __call__(self, cfg: 'Config') -> HeaderFactory:
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def credentials_provider(name: str, require: List[str]):
|
|
44
|
+
""" Given the function that receives a Config and returns RequestVisitor,
|
|
45
|
+
create CredentialsProvider with a given name and required configuration
|
|
46
|
+
attribute names to be present for this function to be called. """
|
|
47
|
+
|
|
48
|
+
def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider:
|
|
49
|
+
|
|
50
|
+
@functools.wraps(func)
|
|
51
|
+
def wrapper(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
52
|
+
for attr in require:
|
|
53
|
+
if not getattr(cfg, attr):
|
|
54
|
+
return None
|
|
55
|
+
return func(cfg)
|
|
56
|
+
|
|
57
|
+
wrapper.auth_type = lambda: name
|
|
58
|
+
return wrapper
|
|
59
|
+
|
|
60
|
+
return inner
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@credentials_provider('basic', ['host', 'username', 'password'])
|
|
64
|
+
def basic_auth(cfg: 'Config') -> HeaderFactory:
|
|
65
|
+
""" Given username and password, add base64-encoded Basic credentials """
|
|
66
|
+
encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
|
|
67
|
+
static_credentials = {'Authorization': f'Basic {encoded}'}
|
|
68
|
+
|
|
69
|
+
def inner() -> Dict[str, str]:
|
|
70
|
+
return static_credentials
|
|
71
|
+
|
|
72
|
+
return inner
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@credentials_provider('pat', ['host', 'token'])
|
|
76
|
+
def pat_auth(cfg: 'Config') -> HeaderFactory:
|
|
77
|
+
""" Adds Databricks Personal Access Token to every request """
|
|
78
|
+
static_credentials = {'Authorization': f'Bearer {cfg.token}'}
|
|
79
|
+
|
|
80
|
+
def inner() -> Dict[str, str]:
|
|
81
|
+
return static_credentials
|
|
82
|
+
|
|
83
|
+
return inner
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@credentials_provider('runtime', [])
|
|
87
|
+
def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
88
|
+
if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
# This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check
|
|
92
|
+
# above, so that we are not throwing import errors when not in
|
|
93
|
+
# runtime and no config variables are set.
|
|
94
|
+
from databricks.sdk.runtime import (init_runtime_legacy_auth,
|
|
95
|
+
init_runtime_native_auth,
|
|
96
|
+
init_runtime_repl_auth)
|
|
97
|
+
for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]:
|
|
98
|
+
if init is None:
|
|
99
|
+
continue
|
|
100
|
+
host, inner = init()
|
|
101
|
+
if host is None:
|
|
102
|
+
logger.debug(f'[{init.__name__}] no host detected')
|
|
103
|
+
continue
|
|
104
|
+
cfg.host = host
|
|
105
|
+
logger.debug(f'[{init.__name__}] runtime native auth configured')
|
|
106
|
+
return inner
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret'])
|
|
111
|
+
def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
112
|
+
""" Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
|
|
113
|
+
if /oidc/.well-known/oauth-authorization-server is available on the given host. """
|
|
114
|
+
oidc = cfg.oidc_endpoints
|
|
115
|
+
if oidc is None:
|
|
116
|
+
return None
|
|
117
|
+
token_source = ClientCredentials(client_id=cfg.client_id,
|
|
118
|
+
client_secret=cfg.client_secret,
|
|
119
|
+
token_url=oidc.token_endpoint,
|
|
120
|
+
scopes=["all-apis"],
|
|
121
|
+
use_header=True)
|
|
122
|
+
|
|
123
|
+
def inner() -> Dict[str, str]:
|
|
124
|
+
token = token_source.token()
|
|
125
|
+
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
126
|
+
|
|
127
|
+
return inner
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@credentials_provider('external-browser', ['host', 'auth_type'])
|
|
131
|
+
def external_browser(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
132
|
+
if cfg.auth_type != 'external-browser':
|
|
133
|
+
return None
|
|
134
|
+
if cfg.client_id:
|
|
135
|
+
client_id = cfg.client_id
|
|
136
|
+
elif cfg.is_aws:
|
|
137
|
+
client_id = 'databricks-cli'
|
|
138
|
+
elif cfg.is_azure:
|
|
139
|
+
# Use Azure AD app for cases when Azure CLI is not available on the machine.
|
|
140
|
+
# App has to be registered as Single-page multi-tenant to support PKCE
|
|
141
|
+
# TODO: temporary app ID, change it later.
|
|
142
|
+
client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
|
|
143
|
+
else:
|
|
144
|
+
raise ValueError(f'local browser SSO is not supported')
|
|
145
|
+
oauth_client = OAuthClient(host=cfg.host,
|
|
146
|
+
client_id=client_id,
|
|
147
|
+
redirect_url='http://localhost:8020',
|
|
148
|
+
client_secret=cfg.client_secret)
|
|
149
|
+
|
|
150
|
+
# Load cached credentials from disk if they exist.
|
|
151
|
+
# Note that these are local to the Python SDK and not reused by other SDKs.
|
|
152
|
+
token_cache = TokenCache(oauth_client)
|
|
153
|
+
credentials = token_cache.load()
|
|
154
|
+
if credentials:
|
|
155
|
+
# Force a refresh in case the loaded credentials are expired.
|
|
156
|
+
credentials.token()
|
|
157
|
+
else:
|
|
158
|
+
consent = oauth_client.initiate_consent()
|
|
159
|
+
if not consent:
|
|
160
|
+
return None
|
|
161
|
+
credentials = consent.launch_external_browser()
|
|
162
|
+
token_cache.save(credentials)
|
|
163
|
+
return credentials(cfg)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]):
|
|
167
|
+
""" Resolves Azure Databricks workspace URL from ARM Resource ID """
|
|
168
|
+
if cfg.host:
|
|
169
|
+
return
|
|
170
|
+
if not cfg.azure_workspace_resource_id:
|
|
171
|
+
return
|
|
172
|
+
arm = cfg.arm_environment.resource_manager_endpoint
|
|
173
|
+
token = token_source_for(arm).token()
|
|
174
|
+
resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01",
|
|
175
|
+
headers={"Authorization": f"Bearer {token.access_token}"})
|
|
176
|
+
if not resp.ok:
|
|
177
|
+
raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}")
|
|
178
|
+
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@credentials_provider('azure-client-secret',
|
|
182
|
+
['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
|
|
183
|
+
def azure_service_principal(cfg: 'Config') -> HeaderFactory:
|
|
184
|
+
""" Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
|
|
185
|
+
to every request, while automatically resolving different Azure environment endpoints. """
|
|
186
|
+
|
|
187
|
+
def token_source_for(resource: str) -> TokenSource:
|
|
188
|
+
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
189
|
+
return ClientCredentials(client_id=cfg.azure_client_id,
|
|
190
|
+
client_secret=cfg.azure_client_secret,
|
|
191
|
+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
192
|
+
endpoint_params={"resource": resource},
|
|
193
|
+
use_params=True)
|
|
194
|
+
|
|
195
|
+
_ensure_host_present(cfg, token_source_for)
|
|
196
|
+
logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
|
|
197
|
+
inner = token_source_for(cfg.effective_azure_login_app_id)
|
|
198
|
+
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
|
|
199
|
+
|
|
200
|
+
def refreshed_headers() -> Dict[str, str]:
|
|
201
|
+
headers = {'Authorization': f"Bearer {inner.token().access_token}", }
|
|
202
|
+
add_workspace_id_header(cfg, headers)
|
|
203
|
+
add_sp_management_token(cloud, headers)
|
|
204
|
+
return headers
|
|
205
|
+
|
|
206
|
+
return refreshed_headers
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@credentials_provider('github-oidc-azure', ['host', 'azure_client_id'])
|
|
210
|
+
def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
211
|
+
if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
|
|
212
|
+
# not in GitHub actions
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
# Client ID is the minimal thing we need, as otherwise we get AADSTS700016: Application with
|
|
216
|
+
# identifier 'https://token.actions.githubusercontent.com' was not found in the directory '...'.
|
|
217
|
+
if not cfg.is_azure:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers
|
|
221
|
+
headers = {'Authorization': f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}
|
|
222
|
+
endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange"
|
|
223
|
+
response = requests.get(endpoint, headers=headers)
|
|
224
|
+
if not response.ok:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
# get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name
|
|
228
|
+
response_json = response.json()
|
|
229
|
+
if 'value' not in response_json:
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
logger.info("Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id)
|
|
233
|
+
params = {
|
|
234
|
+
'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
|
|
235
|
+
'resource': cfg.effective_azure_login_app_id,
|
|
236
|
+
'client_assertion': response_json['value'],
|
|
237
|
+
}
|
|
238
|
+
aad_endpoint = cfg.arm_environment.active_directory_endpoint
|
|
239
|
+
if not cfg.azure_tenant_id:
|
|
240
|
+
# detect Azure AD Tenant ID if it's not specified directly
|
|
241
|
+
token_endpoint = cfg.oidc_endpoints.token_endpoint
|
|
242
|
+
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
|
|
243
|
+
inner = ClientCredentials(client_id=cfg.azure_client_id,
|
|
244
|
+
client_secret="", # we have no (rotatable) secrets in OIDC flow
|
|
245
|
+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
|
|
246
|
+
endpoint_params=params,
|
|
247
|
+
use_params=True)
|
|
248
|
+
|
|
249
|
+
def refreshed_headers() -> Dict[str, str]:
|
|
250
|
+
token = inner.token()
|
|
251
|
+
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
252
|
+
|
|
253
|
+
return refreshed_headers
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@credentials_provider('google-credentials', ['host', 'google_credentials'])
|
|
260
|
+
def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
261
|
+
if not cfg.is_gcp:
|
|
262
|
+
return None
|
|
263
|
+
# Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
|
|
264
|
+
# Obtain the id token by providing the json file path and target audience.
|
|
265
|
+
if (os.path.isfile(cfg.google_credentials)):
|
|
266
|
+
with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file:
|
|
267
|
+
account_info = json.load(json_file)
|
|
268
|
+
else:
|
|
269
|
+
# If the file doesn't exist, assume that the config is the actual JSON content.
|
|
270
|
+
account_info = json.loads(cfg.google_credentials)
|
|
271
|
+
|
|
272
|
+
credentials = service_account.IDTokenCredentials.from_service_account_info(info=account_info,
|
|
273
|
+
target_audience=cfg.host)
|
|
274
|
+
|
|
275
|
+
request = Request()
|
|
276
|
+
|
|
277
|
+
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
|
|
278
|
+
scopes=GcpScopes)
|
|
279
|
+
|
|
280
|
+
def refreshed_headers() -> Dict[str, str]:
|
|
281
|
+
credentials.refresh(request)
|
|
282
|
+
headers = {'Authorization': f'Bearer {credentials.token}'}
|
|
283
|
+
if cfg.is_account_client:
|
|
284
|
+
gcp_credentials.refresh(request)
|
|
285
|
+
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
|
|
286
|
+
return headers
|
|
287
|
+
|
|
288
|
+
return refreshed_headers
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@credentials_provider('google-id', ['host', 'google_service_account'])
|
|
292
|
+
def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
293
|
+
if not cfg.is_gcp:
|
|
294
|
+
return None
|
|
295
|
+
credentials, _project_id = google.auth.default()
|
|
296
|
+
|
|
297
|
+
# Create the impersonated credential.
|
|
298
|
+
target_credentials = impersonated_credentials.Credentials(source_credentials=credentials,
|
|
299
|
+
target_principal=cfg.google_service_account,
|
|
300
|
+
target_scopes=[])
|
|
301
|
+
|
|
302
|
+
# Set the impersonated credential, target audience and token options.
|
|
303
|
+
id_creds = impersonated_credentials.IDTokenCredentials(target_credentials,
|
|
304
|
+
target_audience=cfg.host,
|
|
305
|
+
include_email=True)
|
|
306
|
+
|
|
307
|
+
gcp_impersonated_credentials = impersonated_credentials.Credentials(
|
|
308
|
+
source_credentials=credentials, target_principal=cfg.google_service_account, target_scopes=GcpScopes)
|
|
309
|
+
|
|
310
|
+
request = Request()
|
|
311
|
+
|
|
312
|
+
def refreshed_headers() -> Dict[str, str]:
|
|
313
|
+
id_creds.refresh(request)
|
|
314
|
+
headers = {'Authorization': f'Bearer {id_creds.token}'}
|
|
315
|
+
if cfg.is_account_client:
|
|
316
|
+
gcp_impersonated_credentials.refresh(request)
|
|
317
|
+
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
|
|
318
|
+
return headers
|
|
319
|
+
|
|
320
|
+
return refreshed_headers
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class CliTokenSource(Refreshable):
|
|
324
|
+
|
|
325
|
+
def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str):
|
|
326
|
+
super().__init__()
|
|
327
|
+
self._cmd = cmd
|
|
328
|
+
self._token_type_field = token_type_field
|
|
329
|
+
self._access_token_field = access_token_field
|
|
330
|
+
self._expiry_field = expiry_field
|
|
331
|
+
|
|
332
|
+
@staticmethod
|
|
333
|
+
def _parse_expiry(expiry: str) -> datetime:
|
|
334
|
+
expiry = expiry.rstrip("Z").split(".")[0]
|
|
335
|
+
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"):
|
|
336
|
+
try:
|
|
337
|
+
return datetime.strptime(expiry, fmt)
|
|
338
|
+
except ValueError as e:
|
|
339
|
+
last_e = e
|
|
340
|
+
if last_e:
|
|
341
|
+
raise last_e
|
|
342
|
+
|
|
343
|
+
def refresh(self) -> Token:
|
|
344
|
+
try:
|
|
345
|
+
is_windows = sys.platform.startswith('win')
|
|
346
|
+
# windows requires shell=True to be able to execute 'az login' or other commands
|
|
347
|
+
# cannot use shell=True all the time, as it breaks macOS
|
|
348
|
+
out = subprocess.run(self._cmd, capture_output=True, check=True, shell=is_windows)
|
|
349
|
+
it = json.loads(out.stdout.decode())
|
|
350
|
+
expires_on = self._parse_expiry(it[self._expiry_field])
|
|
351
|
+
return Token(access_token=it[self._access_token_field],
|
|
352
|
+
token_type=it[self._token_type_field],
|
|
353
|
+
expiry=expires_on)
|
|
354
|
+
except ValueError as e:
|
|
355
|
+
raise ValueError(f"cannot unmarshal CLI result: {e}")
|
|
356
|
+
except subprocess.CalledProcessError as e:
|
|
357
|
+
stdout = e.stdout.decode().strip()
|
|
358
|
+
stderr = e.stderr.decode().strip()
|
|
359
|
+
message = stdout or stderr
|
|
360
|
+
raise IOError(f'cannot get access token: {message}') from e
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class AzureCliTokenSource(CliTokenSource):
|
|
364
|
+
""" Obtain the token granted by `az login` CLI command """
|
|
365
|
+
|
|
366
|
+
def __init__(self, resource: str, subscription: str = ""):
|
|
367
|
+
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
|
|
368
|
+
if subscription != "":
|
|
369
|
+
cmd.append("--subscription")
|
|
370
|
+
cmd.append(subscription)
|
|
371
|
+
super().__init__(cmd=cmd,
|
|
372
|
+
token_type_field='tokenType',
|
|
373
|
+
access_token_field='accessToken',
|
|
374
|
+
expiry_field='expiresOn')
|
|
375
|
+
|
|
376
|
+
def is_human_user(self) -> bool:
|
|
377
|
+
"""The UPN claim is the username of the user, but not the Service Principal.
|
|
378
|
+
|
|
379
|
+
Azure CLI can be authenticated by both human users (`az login`) and service principals. In case of service
|
|
380
|
+
principals, it can be either OIDC from GitHub or login with a password:
|
|
381
|
+
|
|
382
|
+
~ $ az login --service-principal --user $clientID --password $clientSecret --tenant $tenantID
|
|
383
|
+
|
|
384
|
+
Human users get more claims:
|
|
385
|
+
- 'amr' - how the subject of the token was authenticated
|
|
386
|
+
- 'name', 'family_name', 'given_name' - human-readable values that identifies the subject of the token
|
|
387
|
+
- 'scp' with `user_impersonation` value, that shows the set of scopes exposed by your application for which
|
|
388
|
+
the client application has requested (and received) consent
|
|
389
|
+
- 'unique_name' - a human-readable value that identifies the subject of the token. This value is not
|
|
390
|
+
guaranteed to be unique within a tenant and should be used only for display purposes.
|
|
391
|
+
- 'upn' - The username of the user.
|
|
392
|
+
"""
|
|
393
|
+
return 'upn' in self.token().jwt_claims()
|
|
394
|
+
|
|
395
|
+
@staticmethod
|
|
396
|
+
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
|
|
397
|
+
subscription = AzureCliTokenSource.get_subscription(cfg)
|
|
398
|
+
if subscription != "":
|
|
399
|
+
token_source = AzureCliTokenSource(resource, subscription)
|
|
400
|
+
try:
|
|
401
|
+
# This will fail if the user has access to the workspace, but not to the subscription
|
|
402
|
+
# itself.
|
|
403
|
+
# In such case, we fall back to not using the subscription.
|
|
404
|
+
token_source.token()
|
|
405
|
+
return token_source
|
|
406
|
+
except OSError:
|
|
407
|
+
logger.warning("Failed to get token for subscription. Using resource only token.")
|
|
408
|
+
|
|
409
|
+
token_source = AzureCliTokenSource(resource)
|
|
410
|
+
token_source.token()
|
|
411
|
+
return token_source
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def get_subscription(cfg: 'Config') -> str:
|
|
415
|
+
resource = cfg.azure_workspace_resource_id
|
|
416
|
+
if resource is None or resource == "":
|
|
417
|
+
return ""
|
|
418
|
+
components = resource.split('/')
|
|
419
|
+
if len(components) < 3:
|
|
420
|
+
logger.warning("Invalid azure workspace resource ID")
|
|
421
|
+
return ""
|
|
422
|
+
return components[2]
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@credentials_provider('azure-cli', ['is_azure'])
|
|
426
|
+
def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
427
|
+
""" Adds refreshed OAuth token granted by `az login` command to every request. """
|
|
428
|
+
token_source = None
|
|
429
|
+
mgmt_token_source = None
|
|
430
|
+
try:
|
|
431
|
+
token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id)
|
|
432
|
+
except FileNotFoundError:
|
|
433
|
+
doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
|
|
434
|
+
logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
|
|
435
|
+
return None
|
|
436
|
+
except OSError as e:
|
|
437
|
+
logger.debug('skipping Azure CLI auth', exc_info=e)
|
|
438
|
+
logger.debug('This may happen if you are attempting to login to a dev or staging workspace')
|
|
439
|
+
return None
|
|
440
|
+
|
|
441
|
+
if not token_source.is_human_user():
|
|
442
|
+
try:
|
|
443
|
+
management_endpoint = cfg.arm_environment.service_management_endpoint
|
|
444
|
+
mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint)
|
|
445
|
+
except Exception as e:
|
|
446
|
+
logger.debug(f'Not including service management token in headers', exc_info=e)
|
|
447
|
+
mgmt_token_source = None
|
|
448
|
+
|
|
449
|
+
_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
|
|
450
|
+
logger.info("Using Azure CLI authentication with AAD tokens")
|
|
451
|
+
if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "":
|
|
452
|
+
logger.warning(
|
|
453
|
+
"azure_workspace_resource_id field not provided. "
|
|
454
|
+
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
def inner() -> Dict[str, str]:
|
|
458
|
+
token = token_source.token()
|
|
459
|
+
headers = {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
460
|
+
add_workspace_id_header(cfg, headers)
|
|
461
|
+
if mgmt_token_source:
|
|
462
|
+
add_sp_management_token(mgmt_token_source, headers)
|
|
463
|
+
return headers
|
|
464
|
+
|
|
465
|
+
return inner
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class DatabricksCliTokenSource(CliTokenSource):
|
|
469
|
+
""" Obtain the token granted by `databricks auth login` CLI command """
|
|
470
|
+
|
|
471
|
+
def __init__(self, cfg: 'Config'):
|
|
472
|
+
args = ['auth', 'token', '--host', cfg.host]
|
|
473
|
+
if cfg.is_account_client:
|
|
474
|
+
args += ['--account-id', cfg.account_id]
|
|
475
|
+
|
|
476
|
+
cli_path = cfg.databricks_cli_path
|
|
477
|
+
|
|
478
|
+
# If the path is not specified look for "databricks" / "databricks.exe" in PATH.
|
|
479
|
+
if not cli_path:
|
|
480
|
+
try:
|
|
481
|
+
# Try to find "databricks" in PATH
|
|
482
|
+
cli_path = self.__class__._find_executable("databricks")
|
|
483
|
+
except FileNotFoundError as e:
|
|
484
|
+
# If "databricks" is not found, try to find "databricks.exe" in PATH (Windows)
|
|
485
|
+
if platform.system() == "Windows":
|
|
486
|
+
cli_path = self.__class__._find_executable("databricks.exe")
|
|
487
|
+
else:
|
|
488
|
+
raise e
|
|
489
|
+
|
|
490
|
+
# If the path is unqualified, look it up in PATH.
|
|
491
|
+
elif cli_path.count("/") == 0:
|
|
492
|
+
cli_path = self.__class__._find_executable(cli_path)
|
|
493
|
+
|
|
494
|
+
super().__init__(cmd=[cli_path, *args],
|
|
495
|
+
token_type_field='token_type',
|
|
496
|
+
access_token_field='access_token',
|
|
497
|
+
expiry_field='expiry')
|
|
498
|
+
|
|
499
|
+
@staticmethod
|
|
500
|
+
def _find_executable(name) -> str:
|
|
501
|
+
err = FileNotFoundError("Most likely the Databricks CLI is not installed")
|
|
502
|
+
for dir in os.getenv("PATH", default="").split(os.path.pathsep):
|
|
503
|
+
path = pathlib.Path(dir).joinpath(name).resolve()
|
|
504
|
+
if not path.is_file():
|
|
505
|
+
continue
|
|
506
|
+
|
|
507
|
+
# The new Databricks CLI is a single binary with size > 1MB.
|
|
508
|
+
# We use the size as a signal to determine which Databricks CLI is installed.
|
|
509
|
+
stat = path.stat()
|
|
510
|
+
if stat.st_size < (1024 * 1024):
|
|
511
|
+
err = FileNotFoundError("Databricks CLI version <0.100.0 detected")
|
|
512
|
+
continue
|
|
513
|
+
|
|
514
|
+
return str(path)
|
|
515
|
+
|
|
516
|
+
raise err
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@credentials_provider('databricks-cli', ['host'])
|
|
520
|
+
def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
521
|
+
try:
|
|
522
|
+
token_source = DatabricksCliTokenSource(cfg)
|
|
523
|
+
except FileNotFoundError as e:
|
|
524
|
+
logger.debug(e)
|
|
525
|
+
return None
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
token_source.token()
|
|
529
|
+
except IOError as e:
|
|
530
|
+
if 'databricks OAuth is not' in str(e):
|
|
531
|
+
logger.debug(f'OAuth not configured or not available: {e}')
|
|
532
|
+
return None
|
|
533
|
+
raise e
|
|
534
|
+
|
|
535
|
+
logger.info("Using Databricks CLI authentication")
|
|
536
|
+
|
|
537
|
+
def inner() -> Dict[str, str]:
|
|
538
|
+
token = token_source.token()
|
|
539
|
+
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
540
|
+
|
|
541
|
+
return inner
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class MetadataServiceTokenSource(Refreshable):
|
|
545
|
+
""" Obtain the token granted by Databricks Metadata Service """
|
|
546
|
+
METADATA_SERVICE_VERSION = "1"
|
|
547
|
+
METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version"
|
|
548
|
+
METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host"
|
|
549
|
+
_metadata_service_timeout = 10 # seconds
|
|
550
|
+
|
|
551
|
+
def __init__(self, cfg: 'Config'):
|
|
552
|
+
super().__init__()
|
|
553
|
+
self.url = cfg.metadata_service_url
|
|
554
|
+
self.host = cfg.host
|
|
555
|
+
|
|
556
|
+
def refresh(self) -> Token:
|
|
557
|
+
resp = requests.get(self.url,
|
|
558
|
+
timeout=self._metadata_service_timeout,
|
|
559
|
+
headers={
|
|
560
|
+
self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
|
|
561
|
+
self.METADATA_SERVICE_HOST_HEADER: self.host
|
|
562
|
+
})
|
|
563
|
+
json_resp: dict[str, Union[str, float]] = resp.json()
|
|
564
|
+
access_token = json_resp.get("access_token", None)
|
|
565
|
+
if access_token is None:
|
|
566
|
+
raise ValueError("Metadata Service returned empty token")
|
|
567
|
+
token_type = json_resp.get("token_type", None)
|
|
568
|
+
if token_type is None:
|
|
569
|
+
raise ValueError("Metadata Service returned empty token type")
|
|
570
|
+
if json_resp["expires_on"] in ["", None]:
|
|
571
|
+
raise ValueError("Metadata Service returned invalid expiry")
|
|
572
|
+
try:
|
|
573
|
+
expiry = datetime.fromtimestamp(json_resp["expires_on"])
|
|
574
|
+
except:
|
|
575
|
+
raise ValueError("Metadata Service returned invalid expiry")
|
|
576
|
+
|
|
577
|
+
return Token(access_token=access_token, token_type=token_type, expiry=expiry)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@credentials_provider('metadata-service', ['host', 'metadata_service_url'])
|
|
581
|
+
def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
582
|
+
""" Adds refreshed token granted by Databricks Metadata Service to every request. """
|
|
583
|
+
|
|
584
|
+
token_source = MetadataServiceTokenSource(cfg)
|
|
585
|
+
token_source.token()
|
|
586
|
+
logger.info("Using Databricks Metadata Service authentication")
|
|
587
|
+
|
|
588
|
+
def inner() -> Dict[str, str]:
|
|
589
|
+
token = token_source.token()
|
|
590
|
+
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
591
|
+
|
|
592
|
+
return inner
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class DefaultCredentials:
|
|
596
|
+
""" Select the first applicable credential provider from the chain """
|
|
597
|
+
|
|
598
|
+
def __init__(self) -> None:
|
|
599
|
+
self._auth_type = 'default'
|
|
600
|
+
|
|
601
|
+
def auth_type(self) -> str:
|
|
602
|
+
return self._auth_type
|
|
603
|
+
|
|
604
|
+
def __call__(self, cfg: 'Config') -> HeaderFactory:
|
|
605
|
+
auth_providers = [
|
|
606
|
+
pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
|
|
607
|
+
github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
|
|
608
|
+
google_credentials, google_id
|
|
609
|
+
]
|
|
610
|
+
for provider in auth_providers:
|
|
611
|
+
auth_type = provider.auth_type()
|
|
612
|
+
if cfg.auth_type and auth_type != cfg.auth_type:
|
|
613
|
+
# ignore other auth types if one is explicitly enforced
|
|
614
|
+
logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred")
|
|
615
|
+
continue
|
|
616
|
+
logger.debug(f'Attempting to configure auth: {auth_type}')
|
|
617
|
+
try:
|
|
618
|
+
header_factory = provider(cfg)
|
|
619
|
+
if not header_factory:
|
|
620
|
+
continue
|
|
621
|
+
self._auth_type = auth_type
|
|
622
|
+
return header_factory
|
|
623
|
+
except Exception as e:
|
|
624
|
+
raise ValueError(f'{auth_type}: {e}') from e
|
|
625
|
+
auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication"
|
|
626
|
+
raise ValueError(
|
|
627
|
+
f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
|
|
628
|
+
)
|