databricks-sdk 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

@@ -1,7 +1,7 @@
1
1
  import databricks.sdk.core as client
2
2
  import databricks.sdk.dbutils as dbutils
3
3
  from databricks.sdk import azure
4
- from databricks.sdk.credentials_provider import CredentialsProvider
4
+ from databricks.sdk.credentials_provider import CredentialsStrategy
5
5
  from databricks.sdk.mixins.compute import ClustersExt
6
6
  from databricks.sdk.mixins.files import DbfsExt
7
7
  from databricks.sdk.mixins.workspace import WorkspaceExt
@@ -131,7 +131,8 @@ class WorkspaceClient:
131
131
  debug_headers: bool = None,
132
132
  product="unknown",
133
133
  product_version="0.0.0",
134
- credentials_provider: CredentialsProvider = None,
134
+ credentials_strategy: CredentialsStrategy = None,
135
+ credentials_provider: CredentialsStrategy = None,
135
136
  config: client.Config = None):
136
137
  if not config:
137
138
  config = client.Config(host=host,
@@ -152,6 +153,7 @@ class WorkspaceClient:
152
153
  cluster_id=cluster_id,
153
154
  google_credentials=google_credentials,
154
155
  google_service_account=google_service_account,
156
+ credentials_strategy=credentials_strategy,
155
157
  credentials_provider=credentials_provider,
156
158
  debug_truncate_bytes=debug_truncate_bytes,
157
159
  debug_headers=debug_headers,
@@ -700,7 +702,8 @@ class AccountClient:
700
702
  debug_headers: bool = None,
701
703
  product="unknown",
702
704
  product_version="0.0.0",
703
- credentials_provider: CredentialsProvider = None,
705
+ credentials_strategy: CredentialsStrategy = None,
706
+ credentials_provider: CredentialsStrategy = None,
704
707
  config: client.Config = None):
705
708
  if not config:
706
709
  config = client.Config(host=host,
@@ -721,6 +724,7 @@ class AccountClient:
721
724
  cluster_id=cluster_id,
722
725
  google_credentials=google_credentials,
723
726
  google_service_account=google_service_account,
727
+ credentials_strategy=credentials_strategy,
724
728
  credentials_provider=credentials_provider,
725
729
  debug_truncate_bytes=debug_truncate_bytes,
726
730
  debug_headers=debug_headers,
databricks/sdk/config.py CHANGED
@@ -6,15 +6,15 @@ import pathlib
6
6
  import platform
7
7
  import sys
8
8
  import urllib.parse
9
- from typing import Dict, Iterable, Optional
9
+ from typing import Dict, Iterable, List, Optional, Tuple
10
10
 
11
11
  import requests
12
12
 
13
13
  from .clock import Clock, RealClock
14
- from .credentials_provider import CredentialsProvider, DefaultCredentials
14
+ from .credentials_provider import CredentialsStrategy, DefaultCredentials
15
15
  from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
16
16
  DatabricksEnvironment, get_environment_for_hostname)
17
- from .oauth import OidcEndpoints
17
+ from .oauth import OidcEndpoints, Token
18
18
  from .version import __version__
19
19
 
20
20
  logger = logging.getLogger('databricks.sdk')
@@ -44,6 +44,32 @@ class ConfigAttribute:
44
44
  return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"
45
45
 
46
46
 
47
+ _DEFAULT_PRODUCT_NAME = 'unknown'
48
+ _DEFAULT_PRODUCT_VERSION = '0.0.0'
49
+ _STATIC_USER_AGENT: Tuple[str, str, List[str]] = (_DEFAULT_PRODUCT_NAME, _DEFAULT_PRODUCT_VERSION, [])
50
+
51
+
52
+ def with_product(product: str, product_version: str):
53
+ """[INTERNAL API] Change the product name and version used in the User-Agent header."""
54
+ global _STATIC_USER_AGENT
55
+ prev_product, prev_version, prev_other_info = _STATIC_USER_AGENT
56
+ logger.debug(f'Changing product from {prev_product}/{prev_version} to {product}/{product_version}')
57
+ _STATIC_USER_AGENT = product, product_version, prev_other_info
58
+
59
+
60
+ def with_user_agent_extra(key: str, value: str):
61
+ """[INTERNAL API] Add extra metadata to the User-Agent header when developing a library."""
62
+ global _STATIC_USER_AGENT
63
+ product_name, product_version, other_info = _STATIC_USER_AGENT
64
+ for item in other_info:
65
+ if item.startswith(f"{key}/"):
66
+ # ensure that we don't have duplicates
67
+ other_info.remove(item)
68
+ break
69
+ other_info.append(f"{key}/{value}")
70
+ _STATIC_USER_AGENT = product_name, product_version, other_info
71
+
72
+
47
73
  class Config:
48
74
  host: str = ConfigAttribute(env='DATABRICKS_HOST')
49
75
  account_id: str = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID')
@@ -66,6 +92,7 @@ class Config:
66
92
  auth_type: str = ConfigAttribute(env='DATABRICKS_AUTH_TYPE')
67
93
  cluster_id: str = ConfigAttribute(env='DATABRICKS_CLUSTER_ID')
68
94
  warehouse_id: str = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID')
95
+ serverless_compute_id: str = ConfigAttribute(env='DATABRICKS_SERVERLESS_COMPUTE_ID')
69
96
  skip_verify: bool = ConfigAttribute()
70
97
  http_timeout_seconds: float = ConfigAttribute()
71
98
  debug_truncate_bytes: int = ConfigAttribute(env='DATABRICKS_DEBUG_TRUNCATE_BYTES')
@@ -81,15 +108,34 @@ class Config:
81
108
 
82
109
  def __init__(self,
83
110
  *,
84
- credentials_provider: CredentialsProvider = None,
85
- product="unknown",
86
- product_version="0.0.0",
111
+ # Deprecated. Use credentials_strategy instead.
112
+ credentials_provider: CredentialsStrategy = None,
113
+ credentials_strategy: CredentialsStrategy = None,
114
+ product=_DEFAULT_PRODUCT_NAME,
115
+ product_version=_DEFAULT_PRODUCT_VERSION,
87
116
  clock: Clock = None,
88
117
  **kwargs):
89
118
  self._header_factory = None
90
119
  self._inner = {}
120
+ # as in SDK for Go, pull information from global static user agent context,
121
+ # so that we can track additional metadata for mid-stream libraries, as well
122
+ # as for cases, when the downstream product is used as a library and is not
123
+ # configured with a proper product name and version.
124
+ static_product, static_version, _ = _STATIC_USER_AGENT
125
+ if product == _DEFAULT_PRODUCT_NAME:
126
+ product = static_product
127
+ if product_version == _DEFAULT_PRODUCT_VERSION:
128
+ product_version = static_version
91
129
  self._user_agent_other_info = []
92
- self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials()
130
+ if credentials_strategy and credentials_provider:
131
+ raise ValueError(
132
+ "When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
133
+ if credentials_provider:
134
+ logger.warning(
135
+ "parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead.")
136
+ self._credentials_strategy = next(
137
+ s for s in [credentials_strategy, credentials_provider,
138
+ DefaultCredentials()] if s is not None)
93
139
  if 'databricks_environment' in kwargs:
94
140
  self.databricks_environment = kwargs['databricks_environment']
95
141
  del kwargs['databricks_environment']
@@ -107,6 +153,9 @@ class Config:
107
153
  message = self.wrap_debug_info(str(e))
108
154
  raise ValueError(message) from e
109
155
 
156
+ def oauth_token(self) -> Token:
157
+ return self._credentials_strategy.oauth_token(self)
158
+
110
159
  def wrap_debug_info(self, message: str) -> str:
111
160
  debug_string = self.debug_string()
112
161
  if debug_string:
@@ -220,6 +269,12 @@ class Config:
220
269
  ]
221
270
  if len(self._user_agent_other_info) > 0:
222
271
  ua.append(' '.join(self._user_agent_other_info))
272
+ # as in SDK for Go, pull information from global static user agent context,
273
+ # so that we can track additional metadata for mid-stream libraries. this value
274
+ # is shared across all instances of Config objects intentionally.
275
+ _, _, static_info = _STATIC_USER_AGENT
276
+ if len(static_info) > 0:
277
+ ua.append(' '.join(static_info))
223
278
  if len(self._upstream_user_agent) > 0:
224
279
  ua.append(self._upstream_user_agent)
225
280
  if 'DATABRICKS_RUNTIME_VERSION' in os.environ:
@@ -436,12 +491,12 @@ class Config:
436
491
 
437
492
  def init_auth(self):
438
493
  try:
439
- self._header_factory = self._credentials_provider(self)
440
- self.auth_type = self._credentials_provider.auth_type()
494
+ self._header_factory = self._credentials_strategy(self)
495
+ self.auth_type = self._credentials_strategy.auth_type()
441
496
  if not self._header_factory:
442
497
  raise ValueError('not configured')
443
498
  except ValueError as e:
444
- raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e
499
+ raise ValueError(f'{self._credentials_strategy.auth_type()} auth: {e}') from e
445
500
 
446
501
  def __repr__(self):
447
502
  return f'<{self.debug_string()}>'
databricks/sdk/core.py CHANGED
@@ -4,6 +4,7 @@ from datetime import timedelta
4
4
  from json import JSONDecodeError
5
5
  from types import TracebackType
6
6
  from typing import Any, BinaryIO, Iterator, Type
7
+ from urllib.parse import urlencode
7
8
 
8
9
  from requests.adapters import HTTPAdapter
9
10
 
@@ -13,12 +14,17 @@ from .config import *
13
14
  from .credentials_provider import *
14
15
  from .errors import DatabricksError, error_mapper
15
16
  from .errors.private_link import _is_private_link_redirect
17
+ from .oauth import retrieve_token
16
18
  from .retries import retried
17
19
 
18
20
  __all__ = ['Config', 'DatabricksError']
19
21
 
20
22
  logger = logging.getLogger('databricks.sdk')
21
23
 
24
+ URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
25
+ JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
26
+ OIDC_TOKEN_PATH = "/oidc/v1/token"
27
+
22
28
 
23
29
  class ApiClient:
24
30
  _cfg: Config
@@ -109,6 +115,22 @@ class ApiClient:
109
115
  flattened = dict(flatten_dict(with_fixed_bools))
110
116
  return flattened
111
117
 
118
+ def get_oauth_token(self, auth_details: str) -> Token:
119
+ if not self._cfg.auth_type:
120
+ self._cfg.authenticate()
121
+ original_token = self._cfg.oauth_token()
122
+ headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
123
+ params = urlencode({
124
+ "grant_type": JWT_BEARER_GRANT_TYPE,
125
+ "authorization_details": auth_details,
126
+ "assertion": original_token.access_token
127
+ })
128
+ return retrieve_token(client_id=self._cfg.client_id,
129
+ client_secret=self._cfg.client_secret,
130
+ token_url=self._cfg.host + OIDC_TOKEN_PATH,
131
+ params=params,
132
+ headers=headers)
133
+
112
134
  def do(self,
113
135
  method: str,
114
136
  path: str,
@@ -22,12 +22,26 @@ from .azure import add_sp_management_token, add_workspace_id_header
22
22
  from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
23
23
  TokenCache, TokenSource)
24
24
 
25
- HeaderFactory = Callable[[], Dict[str, str]]
25
+ CredentialsProvider = Callable[[], Dict[str, str]]
26
26
 
27
27
  logger = logging.getLogger('databricks.sdk')
28
28
 
29
29
 
30
- class CredentialsProvider(abc.ABC):
30
+ class OAuthCredentialsProvider:
31
+ """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
32
+
33
+ def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
34
+ self._credentials_provider = credentials_provider
35
+ self._token_provider = token_provider
36
+
37
+ def __call__(self) -> Dict[str, str]:
38
+ return self._credentials_provider()
39
+
40
+ def oauth_token(self) -> Token:
41
+ return self._token_provider()
42
+
43
+
44
+ class CredentialsStrategy(abc.ABC):
31
45
  """ CredentialsProvider is the protocol (call-side interface)
32
46
  for authenticating requests to Databricks REST APIs"""
33
47
 
@@ -36,20 +50,39 @@ class CredentialsProvider(abc.ABC):
36
50
  ...
37
51
 
38
52
  @abc.abstractmethod
39
- def __call__(self, cfg: 'Config') -> HeaderFactory:
53
+ def __call__(self, cfg: 'Config') -> CredentialsProvider:
40
54
  ...
41
55
 
42
56
 
43
- def credentials_provider(name: str, require: List[str]):
57
+ class OauthCredentialsStrategy(CredentialsStrategy):
58
+ """ OauthCredentialsProvider is a CredentialsProvider which
59
+ supports Oauth tokens"""
60
+
61
+ def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
62
+ self._headers_provider = headers_provider
63
+ self._auth_type = auth_type
64
+
65
+ def auth_type(self) -> str:
66
+ return self._auth_type
67
+
68
+ def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
69
+ return self._headers_provider(cfg)
70
+
71
+ def oauth_token(self, cfg: 'Config') -> Token:
72
+ return self._headers_provider(cfg).oauth_token()
73
+
74
+
75
+ def credentials_strategy(name: str, require: List[str]):
44
76
  """ Given the function that receives a Config and returns RequestVisitor,
45
77
  create CredentialsProvider with a given name and required configuration
46
78
  attribute names to be present for this function to be called. """
47
79
 
48
- def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider:
80
+ def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
49
81
 
50
82
  @functools.wraps(func)
51
- def wrapper(cfg: 'Config') -> Optional[HeaderFactory]:
83
+ def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
52
84
  for attr in require:
85
+ getattr(cfg, attr)
53
86
  if not getattr(cfg, attr):
54
87
  return None
55
88
  return func(cfg)
@@ -60,8 +93,27 @@ def credentials_provider(name: str, require: List[str]):
60
93
  return inner
61
94
 
62
95
 
63
- @credentials_provider('basic', ['host', 'username', 'password'])
64
- def basic_auth(cfg: 'Config') -> HeaderFactory:
96
+ def oauth_credentials_strategy(name: str, require: List[str]):
97
+ """ Given the function that receives a Config and returns an OauthHeaderFactory,
98
+ create an OauthCredentialsProvider with a given name and required configuration
99
+ attribute names to be present for this function to be called. """
100
+
101
+ def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
102
+
103
+ @functools.wraps(func)
104
+ def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
105
+ for attr in require:
106
+ if not getattr(cfg, attr):
107
+ return None
108
+ return func(cfg)
109
+
110
+ return OauthCredentialsStrategy(name, wrapper)
111
+
112
+ return inner
113
+
114
+
115
+ @credentials_strategy('basic', ['host', 'username', 'password'])
116
+ def basic_auth(cfg: 'Config') -> CredentialsProvider:
65
117
  """ Given username and password, add base64-encoded Basic credentials """
66
118
  encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
67
119
  static_credentials = {'Authorization': f'Basic {encoded}'}
@@ -72,8 +124,8 @@ def basic_auth(cfg: 'Config') -> HeaderFactory:
72
124
  return inner
73
125
 
74
126
 
75
- @credentials_provider('pat', ['host', 'token'])
76
- def pat_auth(cfg: 'Config') -> HeaderFactory:
127
+ @credentials_strategy('pat', ['host', 'token'])
128
+ def pat_auth(cfg: 'Config') -> CredentialsProvider:
77
129
  """ Adds Databricks Personal Access Token to every request """
78
130
  static_credentials = {'Authorization': f'Bearer {cfg.token}'}
79
131
 
@@ -83,8 +135,8 @@ def pat_auth(cfg: 'Config') -> HeaderFactory:
83
135
  return inner
84
136
 
85
137
 
86
- @credentials_provider('runtime', [])
87
- def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
138
+ @credentials_strategy('runtime', [])
139
+ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
88
140
  if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
89
141
  return None
90
142
 
@@ -107,8 +159,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
107
159
  return None
108
160
 
109
161
 
110
- @credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret'])
111
- def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
162
+ @oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
163
+ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
112
164
  """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
113
165
  if /oidc/.well-known/oauth-authorization-server is available on the given host. """
114
166
  oidc = cfg.oidc_endpoints
@@ -124,11 +176,14 @@ def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
124
176
  token = token_source.token()
125
177
  return {'Authorization': f'{token.token_type} {token.access_token}'}
126
178
 
127
- return inner
179
+ def token() -> Token:
180
+ return token_source.token()
181
+
182
+ return OAuthCredentialsProvider(inner, token)
128
183
 
129
184
 
130
- @credentials_provider('external-browser', ['host', 'auth_type'])
131
- def external_browser(cfg: 'Config') -> Optional[HeaderFactory]:
185
+ @credentials_strategy('external-browser', ['host', 'auth_type'])
186
+ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
132
187
  if cfg.auth_type != 'external-browser':
133
188
  return None
134
189
  if cfg.client_id:
@@ -178,9 +233,9 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
178
233
  cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
179
234
 
180
235
 
181
- @credentials_provider('azure-client-secret',
182
- ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
183
- def azure_service_principal(cfg: 'Config') -> HeaderFactory:
236
+ @oauth_credentials_strategy('azure-client-secret',
237
+ ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
238
+ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
184
239
  """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
185
240
  to every request, while automatically resolving different Azure environment endpoints. """
186
241
 
@@ -203,11 +258,14 @@ def azure_service_principal(cfg: 'Config') -> HeaderFactory:
203
258
  add_sp_management_token(cloud, headers)
204
259
  return headers
205
260
 
206
- return refreshed_headers
261
+ def token() -> Token:
262
+ return inner.token()
263
+
264
+ return OAuthCredentialsProvider(refreshed_headers, token)
207
265
 
208
266
 
209
- @credentials_provider('github-oidc-azure', ['host', 'azure_client_id'])
210
- def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
267
+ @oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
268
+ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
211
269
  if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
212
270
  # not in GitHub actions
213
271
  return None
@@ -250,14 +308,17 @@ def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
250
308
  token = inner.token()
251
309
  return {'Authorization': f'{token.token_type} {token.access_token}'}
252
310
 
253
- return refreshed_headers
311
+ def token() -> Token:
312
+ return inner.token()
313
+
314
+ return OAuthCredentialsProvider(refreshed_headers, token)
254
315
 
255
316
 
256
317
  GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
257
318
 
258
319
 
259
- @credentials_provider('google-credentials', ['host', 'google_credentials'])
260
- def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
320
+ @oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
321
+ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
261
322
  if not cfg.is_gcp:
262
323
  return None
263
324
  # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
@@ -277,6 +338,10 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
277
338
  gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
278
339
  scopes=GcpScopes)
279
340
 
341
+ def token() -> Token:
342
+ credentials.refresh(request)
343
+ return credentials.token
344
+
280
345
  def refreshed_headers() -> Dict[str, str]:
281
346
  credentials.refresh(request)
282
347
  headers = {'Authorization': f'Bearer {credentials.token}'}
@@ -285,11 +350,11 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
285
350
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
286
351
  return headers
287
352
 
288
- return refreshed_headers
353
+ return OAuthCredentialsProvider(refreshed_headers, token)
289
354
 
290
355
 
291
- @credentials_provider('google-id', ['host', 'google_service_account'])
292
- def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
356
+ @oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
357
+ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
293
358
  if not cfg.is_gcp:
294
359
  return None
295
360
  credentials, _project_id = google.auth.default()
@@ -309,6 +374,10 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
309
374
 
310
375
  request = Request()
311
376
 
377
+ def token() -> Token:
378
+ id_creds.refresh(request)
379
+ return id_creds.token
380
+
312
381
  def refreshed_headers() -> Dict[str, str]:
313
382
  id_creds.refresh(request)
314
383
  headers = {'Authorization': f'Bearer {id_creds.token}'}
@@ -317,7 +386,7 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
317
386
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
318
387
  return headers
319
388
 
320
- return refreshed_headers
389
+ return OAuthCredentialsProvider(refreshed_headers, token)
321
390
 
322
391
 
323
392
  class CliTokenSource(Refreshable):
@@ -422,8 +491,8 @@ class AzureCliTokenSource(CliTokenSource):
422
491
  return components[2]
423
492
 
424
493
 
425
- @credentials_provider('azure-cli', ['is_azure'])
426
- def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
494
+ @credentials_strategy('azure-cli', ['is_azure'])
495
+ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
427
496
  """ Adds refreshed OAuth token granted by `az login` command to every request. """
428
497
  token_source = None
429
498
  mgmt_token_source = None
@@ -516,8 +585,8 @@ class DatabricksCliTokenSource(CliTokenSource):
516
585
  raise err
517
586
 
518
587
 
519
- @credentials_provider('databricks-cli', ['host'])
520
- def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
588
+ @oauth_credentials_strategy('databricks-cli', ['host'])
589
+ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
521
590
  try:
522
591
  token_source = DatabricksCliTokenSource(cfg)
523
592
  except FileNotFoundError as e:
@@ -538,7 +607,7 @@ def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
538
607
  token = token_source.token()
539
608
  return {'Authorization': f'{token.token_type} {token.access_token}'}
540
609
 
541
- return inner
610
+ return OAuthCredentialsProvider(inner, token_source.token)
542
611
 
543
612
 
544
613
  class MetadataServiceTokenSource(Refreshable):
@@ -577,8 +646,8 @@ class MetadataServiceTokenSource(Refreshable):
577
646
  return Token(access_token=access_token, token_type=token_type, expiry=expiry)
578
647
 
579
648
 
580
- @credentials_provider('metadata-service', ['host', 'metadata_service_url'])
581
- def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]:
649
+ @credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
650
+ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
582
651
  """ Adds refreshed token granted by Databricks Metadata Service to every request. """
583
652
 
584
653
  token_source = MetadataServiceTokenSource(cfg)
@@ -597,17 +666,25 @@ class DefaultCredentials:
597
666
 
598
667
  def __init__(self) -> None:
599
668
  self._auth_type = 'default'
600
-
601
- def auth_type(self) -> str:
602
- return self._auth_type
603
-
604
- def __call__(self, cfg: 'Config') -> HeaderFactory:
605
- auth_providers = [
669
+ self._auth_providers = [
606
670
  pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
607
671
  github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
608
672
  google_credentials, google_id
609
673
  ]
610
- for provider in auth_providers:
674
+
675
+ def auth_type(self) -> str:
676
+ return self._auth_type
677
+
678
+ def oauth_token(self, cfg: 'Config') -> Token:
679
+ for provider in self._auth_providers:
680
+ auth_type = provider.auth_type()
681
+ if auth_type != self._auth_type:
682
+ # ignore other auth types if they don't match the selected one
683
+ continue
684
+ return provider.oauth_token(cfg)
685
+
686
+ def __call__(self, cfg: 'Config') -> CredentialsProvider:
687
+ for provider in self._auth_providers:
611
688
  auth_type = provider.auth_type()
612
689
  if cfg.auth_type and auth_type != cfg.auth_type:
613
690
  # ignore other auth types if one is explicitly enforced
databricks/sdk/dbutils.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import base64
2
2
  import json
3
3
  import logging
4
- import os.path
4
+ import os
5
5
  import threading
6
6
  from collections import namedtuple
7
- from typing import Callable, Dict, List
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional
8
9
 
9
10
  from .core import ApiClient, Config, DatabricksError
10
11
  from .mixins import compute as compute_ext
@@ -240,6 +241,76 @@ class RemoteDbUtils:
240
241
  name=util)
241
242
 
242
243
 
244
+ @dataclass
245
+ class OverrideResult:
246
+ result: Any
247
+
248
+
249
+ def get_local_notebook_path():
250
+ value = os.getenv("DATABRICKS_SOURCE_FILE")
251
+ if value is None:
252
+ raise ValueError(
253
+ "Getting the current notebook path is only supported when running a notebook using the `Databricks Connect: Run as File` or `Databricks Connect: Debug as File` commands in the Databricks extension for VS Code. To bypass this error, set environment variable `DATABRICKS_SOURCE_FILE` to the desired notebook path."
254
+ )
255
+
256
+ return value
257
+
258
+
259
+ class _OverrideProxyUtil:
260
+
261
+ @classmethod
262
+ def new(cls, path: str):
263
+ if len(cls.__get_matching_overrides(path)) > 0:
264
+ return _OverrideProxyUtil(path)
265
+ return None
266
+
267
+ def __init__(self, name: str):
268
+ self._name = name
269
+
270
+ # These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
271
+ # are sent to remote either. This could lead to unintentional breakage.
272
+ # Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
273
+ # This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
274
+ # are being proxied to remote dbutils currently.
275
+ proxy_override_paths = {
276
+ 'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
277
+ get_local_notebook_path,
278
+ }
279
+
280
+ @classmethod
281
+ def __get_matching_overrides(cls, path: str):
282
+ return [x for x in cls.proxy_override_paths.keys() if x.startswith(path)]
283
+
284
+ def __run_override(self, path: str) -> Optional[OverrideResult]:
285
+ overrides = self.__get_matching_overrides(path)
286
+ if len(overrides) == 1 and overrides[0] == path:
287
+ return OverrideResult(self.proxy_override_paths[overrides[0]]())
288
+
289
+ if len(overrides) > 0:
290
+ return OverrideResult(_OverrideProxyUtil(name=path))
291
+
292
+ return None
293
+
294
+ def __call__(self, *args, **kwds) -> Any:
295
+ if len(args) != 0 or len(kwds) != 0:
296
+ raise TypeError(
297
+ f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()")
298
+
299
+ callable_path = f"{self._name}()"
300
+ result = self.__run_override(callable_path)
301
+ if result:
302
+ return result.result
303
+
304
+ raise TypeError(f"{self._name} is not callable")
305
+
306
+ def __getattr__(self, method: str) -> Any:
307
+ result = self.__run_override(f"{self._name}.{method}")
308
+ if result:
309
+ return result.result
310
+
311
+ raise AttributeError(f"module {self._name} has no attribute {method}")
312
+
313
+
243
314
  class _ProxyUtil:
244
315
  """Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""
245
316
 
@@ -250,7 +321,14 @@ class _ProxyUtil:
250
321
  self._context_factory = context_factory
251
322
  self._name = name
252
323
 
253
- def __getattr__(self, method: str) -> '_ProxyCall':
324
+ def __call__(self):
325
+ raise NotImplementedError(f"dbutils.{self._name} is not callable")
326
+
327
+ def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
328
+ override = _OverrideProxyUtil.new(f"{self._name}.{method}")
329
+ if override:
330
+ return override
331
+
254
332
  return _ProxyCall(command_execution=self._commands,
255
333
  cluster_id=self._cluster_id,
256
334
  context_factory=self._context_factory,