databricks-sdk 0.28.0__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (31) hide show
  1. databricks/sdk/__init__.py +74 -22
  2. databricks/sdk/config.py +89 -48
  3. databricks/sdk/core.py +38 -9
  4. databricks/sdk/credentials_provider.py +134 -57
  5. databricks/sdk/data_plane.py +65 -0
  6. databricks/sdk/dbutils.py +81 -3
  7. databricks/sdk/mixins/files.py +12 -4
  8. databricks/sdk/oauth.py +8 -6
  9. databricks/sdk/service/apps.py +977 -0
  10. databricks/sdk/service/billing.py +602 -218
  11. databricks/sdk/service/catalog.py +263 -62
  12. databricks/sdk/service/compute.py +515 -94
  13. databricks/sdk/service/dashboards.py +1310 -2
  14. databricks/sdk/service/iam.py +99 -88
  15. databricks/sdk/service/jobs.py +159 -166
  16. databricks/sdk/service/marketplace.py +74 -58
  17. databricks/sdk/service/oauth2.py +149 -70
  18. databricks/sdk/service/pipelines.py +73 -53
  19. databricks/sdk/service/serving.py +332 -694
  20. databricks/sdk/service/settings.py +424 -4
  21. databricks/sdk/service/sharing.py +235 -26
  22. databricks/sdk/service/sql.py +2484 -553
  23. databricks/sdk/service/vectorsearch.py +75 -0
  24. databricks/sdk/useragent.py +144 -0
  25. databricks/sdk/version.py +1 -1
  26. {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.30.0.dist-info}/METADATA +37 -16
  27. {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.30.0.dist-info}/RECORD +31 -28
  28. {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.30.0.dist-info}/WHEEL +1 -1
  29. {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.30.0.dist-info}/LICENSE +0 -0
  30. {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.30.0.dist-info}/NOTICE +0 -0
  31. {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.30.0.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,26 @@ from .azure import add_sp_management_token, add_workspace_id_header
22
22
  from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
23
23
  TokenCache, TokenSource)
24
24
 
25
- HeaderFactory = Callable[[], Dict[str, str]]
25
+ CredentialsProvider = Callable[[], Dict[str, str]]
26
26
 
27
27
  logger = logging.getLogger('databricks.sdk')
28
28
 
29
29
 
30
- class CredentialsProvider(abc.ABC):
30
+ class OAuthCredentialsProvider:
31
+ """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
32
+
33
+ def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
34
+ self._credentials_provider = credentials_provider
35
+ self._token_provider = token_provider
36
+
37
+ def __call__(self) -> Dict[str, str]:
38
+ return self._credentials_provider()
39
+
40
+ def oauth_token(self) -> Token:
41
+ return self._token_provider()
42
+
43
+
44
+ class CredentialsStrategy(abc.ABC):
31
45
  """ CredentialsProvider is the protocol (call-side interface)
32
46
  for authenticating requests to Databricks REST APIs"""
33
47
 
@@ -36,20 +50,39 @@ class CredentialsProvider(abc.ABC):
36
50
  ...
37
51
 
38
52
  @abc.abstractmethod
39
- def __call__(self, cfg: 'Config') -> HeaderFactory:
53
+ def __call__(self, cfg: 'Config') -> CredentialsProvider:
40
54
  ...
41
55
 
42
56
 
43
- def credentials_provider(name: str, require: List[str]):
57
+ class OauthCredentialsStrategy(CredentialsStrategy):
58
+ """ OauthCredentialsProvider is a CredentialsProvider which
59
+ supports Oauth tokens"""
60
+
61
+ def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
62
+ self._headers_provider = headers_provider
63
+ self._auth_type = auth_type
64
+
65
+ def auth_type(self) -> str:
66
+ return self._auth_type
67
+
68
+ def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
69
+ return self._headers_provider(cfg)
70
+
71
+ def oauth_token(self, cfg: 'Config') -> Token:
72
+ return self._headers_provider(cfg).oauth_token()
73
+
74
+
75
+ def credentials_strategy(name: str, require: List[str]):
44
76
  """ Given the function that receives a Config and returns RequestVisitor,
45
77
  create CredentialsProvider with a given name and required configuration
46
78
  attribute names to be present for this function to be called. """
47
79
 
48
- def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider:
80
+ def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
49
81
 
50
82
  @functools.wraps(func)
51
- def wrapper(cfg: 'Config') -> Optional[HeaderFactory]:
83
+ def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
52
84
  for attr in require:
85
+ getattr(cfg, attr)
53
86
  if not getattr(cfg, attr):
54
87
  return None
55
88
  return func(cfg)
@@ -60,8 +93,27 @@ def credentials_provider(name: str, require: List[str]):
60
93
  return inner
61
94
 
62
95
 
63
- @credentials_provider('basic', ['host', 'username', 'password'])
64
- def basic_auth(cfg: 'Config') -> HeaderFactory:
96
+ def oauth_credentials_strategy(name: str, require: List[str]):
97
+ """ Given the function that receives a Config and returns an OauthHeaderFactory,
98
+ create an OauthCredentialsProvider with a given name and required configuration
99
+ attribute names to be present for this function to be called. """
100
+
101
+ def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
102
+
103
+ @functools.wraps(func)
104
+ def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
105
+ for attr in require:
106
+ if not getattr(cfg, attr):
107
+ return None
108
+ return func(cfg)
109
+
110
+ return OauthCredentialsStrategy(name, wrapper)
111
+
112
+ return inner
113
+
114
+
115
+ @credentials_strategy('basic', ['host', 'username', 'password'])
116
+ def basic_auth(cfg: 'Config') -> CredentialsProvider:
65
117
  """ Given username and password, add base64-encoded Basic credentials """
66
118
  encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
67
119
  static_credentials = {'Authorization': f'Basic {encoded}'}
@@ -72,8 +124,8 @@ def basic_auth(cfg: 'Config') -> HeaderFactory:
72
124
  return inner
73
125
 
74
126
 
75
- @credentials_provider('pat', ['host', 'token'])
76
- def pat_auth(cfg: 'Config') -> HeaderFactory:
127
+ @credentials_strategy('pat', ['host', 'token'])
128
+ def pat_auth(cfg: 'Config') -> CredentialsProvider:
77
129
  """ Adds Databricks Personal Access Token to every request """
78
130
  static_credentials = {'Authorization': f'Bearer {cfg.token}'}
79
131
 
@@ -83,8 +135,8 @@ def pat_auth(cfg: 'Config') -> HeaderFactory:
83
135
  return inner
84
136
 
85
137
 
86
- @credentials_provider('runtime', [])
87
- def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
138
+ @credentials_strategy('runtime', [])
139
+ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
88
140
  if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
89
141
  return None
90
142
 
@@ -107,8 +159,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
107
159
  return None
108
160
 
109
161
 
110
- @credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret'])
111
- def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
162
+ @oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
163
+ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
112
164
  """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
113
165
  if /oidc/.well-known/oauth-authorization-server is available on the given host. """
114
166
  oidc = cfg.oidc_endpoints
@@ -124,11 +176,14 @@ def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
124
176
  token = token_source.token()
125
177
  return {'Authorization': f'{token.token_type} {token.access_token}'}
126
178
 
127
- return inner
179
+ def token() -> Token:
180
+ return token_source.token()
181
+
182
+ return OAuthCredentialsProvider(inner, token)
128
183
 
129
184
 
130
- @credentials_provider('external-browser', ['host', 'auth_type'])
131
- def external_browser(cfg: 'Config') -> Optional[HeaderFactory]:
185
+ @credentials_strategy('external-browser', ['host', 'auth_type'])
186
+ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
132
187
  if cfg.auth_type != 'external-browser':
133
188
  return None
134
189
  if cfg.client_id:
@@ -178,9 +233,8 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
178
233
  cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
179
234
 
180
235
 
181
- @credentials_provider('azure-client-secret',
182
- ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
183
- def azure_service_principal(cfg: 'Config') -> HeaderFactory:
236
+ @oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
237
+ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
184
238
  """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
185
239
  to every request, while automatically resolving different Azure environment endpoints. """
186
240
 
@@ -193,6 +247,7 @@ def azure_service_principal(cfg: 'Config') -> HeaderFactory:
193
247
  use_params=True)
194
248
 
195
249
  _ensure_host_present(cfg, token_source_for)
250
+ cfg.load_azure_tenant_id()
196
251
  logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
197
252
  inner = token_source_for(cfg.effective_azure_login_app_id)
198
253
  cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
@@ -203,11 +258,14 @@ def azure_service_principal(cfg: 'Config') -> HeaderFactory:
203
258
  add_sp_management_token(cloud, headers)
204
259
  return headers
205
260
 
206
- return refreshed_headers
261
+ def token() -> Token:
262
+ return inner.token()
207
263
 
264
+ return OAuthCredentialsProvider(refreshed_headers, token)
208
265
 
209
- @credentials_provider('github-oidc-azure', ['host', 'azure_client_id'])
210
- def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
266
+
267
+ @oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
268
+ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
211
269
  if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
212
270
  # not in GitHub actions
213
271
  return None
@@ -250,14 +308,17 @@ def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
250
308
  token = inner.token()
251
309
  return {'Authorization': f'{token.token_type} {token.access_token}'}
252
310
 
253
- return refreshed_headers
311
+ def token() -> Token:
312
+ return inner.token()
313
+
314
+ return OAuthCredentialsProvider(refreshed_headers, token)
254
315
 
255
316
 
256
317
  GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
257
318
 
258
319
 
259
- @credentials_provider('google-credentials', ['host', 'google_credentials'])
260
- def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
320
+ @oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
321
+ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
261
322
  if not cfg.is_gcp:
262
323
  return None
263
324
  # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
@@ -277,6 +338,10 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
277
338
  gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
278
339
  scopes=GcpScopes)
279
340
 
341
+ def token() -> Token:
342
+ credentials.refresh(request)
343
+ return credentials.token
344
+
280
345
  def refreshed_headers() -> Dict[str, str]:
281
346
  credentials.refresh(request)
282
347
  headers = {'Authorization': f'Bearer {credentials.token}'}
@@ -285,11 +350,11 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
285
350
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
286
351
  return headers
287
352
 
288
- return refreshed_headers
353
+ return OAuthCredentialsProvider(refreshed_headers, token)
289
354
 
290
355
 
291
- @credentials_provider('google-id', ['host', 'google_service_account'])
292
- def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
356
+ @oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
357
+ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
293
358
  if not cfg.is_gcp:
294
359
  return None
295
360
  credentials, _project_id = google.auth.default()
@@ -309,6 +374,10 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
309
374
 
310
375
  request = Request()
311
376
 
377
+ def token() -> Token:
378
+ id_creds.refresh(request)
379
+ return id_creds.token
380
+
312
381
  def refreshed_headers() -> Dict[str, str]:
313
382
  id_creds.refresh(request)
314
383
  headers = {'Authorization': f'Bearer {id_creds.token}'}
@@ -317,7 +386,7 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
317
386
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
318
387
  return headers
319
388
 
320
- return refreshed_headers
389
+ return OAuthCredentialsProvider(refreshed_headers, token)
321
390
 
322
391
 
323
392
  class CliTokenSource(Refreshable):
@@ -363,11 +432,13 @@ class CliTokenSource(Refreshable):
363
432
  class AzureCliTokenSource(CliTokenSource):
364
433
  """ Obtain the token granted by `az login` CLI command """
365
434
 
366
- def __init__(self, resource: str, subscription: str = ""):
435
+ def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
367
436
  cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
368
- if subscription != "":
437
+ if subscription is not None:
369
438
  cmd.append("--subscription")
370
439
  cmd.append(subscription)
440
+ if tenant:
441
+ cmd.extend(["--tenant", tenant])
371
442
  super().__init__(cmd=cmd,
372
443
  token_type_field='tokenType',
373
444
  access_token_field='accessToken',
@@ -395,8 +466,10 @@ class AzureCliTokenSource(CliTokenSource):
395
466
  @staticmethod
396
467
  def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
397
468
  subscription = AzureCliTokenSource.get_subscription(cfg)
398
- if subscription != "":
399
- token_source = AzureCliTokenSource(resource, subscription)
469
+ if subscription is not None:
470
+ token_source = AzureCliTokenSource(resource,
471
+ subscription=subscription,
472
+ tenant=cfg.azure_tenant_id)
400
473
  try:
401
474
  # This will fail if the user has access to the workspace, but not to the subscription
402
475
  # itself.
@@ -406,25 +479,26 @@ class AzureCliTokenSource(CliTokenSource):
406
479
  except OSError:
407
480
  logger.warning("Failed to get token for subscription. Using resource only token.")
408
481
 
409
- token_source = AzureCliTokenSource(resource)
482
+ token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id)
410
483
  token_source.token()
411
484
  return token_source
412
485
 
413
486
  @staticmethod
414
- def get_subscription(cfg: 'Config') -> str:
487
+ def get_subscription(cfg: 'Config') -> Optional[str]:
415
488
  resource = cfg.azure_workspace_resource_id
416
489
  if resource is None or resource == "":
417
- return ""
490
+ return None
418
491
  components = resource.split('/')
419
492
  if len(components) < 3:
420
493
  logger.warning("Invalid azure workspace resource ID")
421
- return ""
494
+ return None
422
495
  return components[2]
423
496
 
424
497
 
425
- @credentials_provider('azure-cli', ['is_azure'])
426
- def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
498
+ @credentials_strategy('azure-cli', ['is_azure'])
499
+ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
427
500
  """ Adds refreshed OAuth token granted by `az login` command to every request. """
501
+ cfg.load_azure_tenant_id()
428
502
  token_source = None
429
503
  mgmt_token_source = None
430
504
  try:
@@ -448,11 +522,6 @@ def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
448
522
 
449
523
  _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
450
524
  logger.info("Using Azure CLI authentication with AAD tokens")
451
- if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "":
452
- logger.warning(
453
- "azure_workspace_resource_id field not provided. "
454
- "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
455
- )
456
525
 
457
526
  def inner() -> Dict[str, str]:
458
527
  token = token_source.token()
@@ -516,8 +585,8 @@ class DatabricksCliTokenSource(CliTokenSource):
516
585
  raise err
517
586
 
518
587
 
519
- @credentials_provider('databricks-cli', ['host'])
520
- def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
588
+ @oauth_credentials_strategy('databricks-cli', ['host'])
589
+ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
521
590
  try:
522
591
  token_source = DatabricksCliTokenSource(cfg)
523
592
  except FileNotFoundError as e:
@@ -538,7 +607,7 @@ def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
538
607
  token = token_source.token()
539
608
  return {'Authorization': f'{token.token_type} {token.access_token}'}
540
609
 
541
- return inner
610
+ return OAuthCredentialsProvider(inner, token_source.token)
542
611
 
543
612
 
544
613
  class MetadataServiceTokenSource(Refreshable):
@@ -577,8 +646,8 @@ class MetadataServiceTokenSource(Refreshable):
577
646
  return Token(access_token=access_token, token_type=token_type, expiry=expiry)
578
647
 
579
648
 
580
- @credentials_provider('metadata-service', ['host', 'metadata_service_url'])
581
- def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]:
649
+ @credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
650
+ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
582
651
  """ Adds refreshed token granted by Databricks Metadata Service to every request. """
583
652
 
584
653
  token_source = MetadataServiceTokenSource(cfg)
@@ -597,17 +666,25 @@ class DefaultCredentials:
597
666
 
598
667
  def __init__(self) -> None:
599
668
  self._auth_type = 'default'
600
-
601
- def auth_type(self) -> str:
602
- return self._auth_type
603
-
604
- def __call__(self, cfg: 'Config') -> HeaderFactory:
605
- auth_providers = [
669
+ self._auth_providers = [
606
670
  pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
607
671
  github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
608
672
  google_credentials, google_id
609
673
  ]
610
- for provider in auth_providers:
674
+
675
+ def auth_type(self) -> str:
676
+ return self._auth_type
677
+
678
+ def oauth_token(self, cfg: 'Config') -> Token:
679
+ for provider in self._auth_providers:
680
+ auth_type = provider.auth_type()
681
+ if auth_type != self._auth_type:
682
+ # ignore other auth types if they don't match the selected one
683
+ continue
684
+ return provider.oauth_token(cfg)
685
+
686
+ def __call__(self, cfg: 'Config') -> CredentialsProvider:
687
+ for provider in self._auth_providers:
611
688
  auth_type = provider.auth_type()
612
689
  if cfg.auth_type and auth_type != cfg.auth_type:
613
690
  # ignore other auth types if one is explicitly enforced
@@ -0,0 +1,65 @@
1
+ import threading
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List
4
+
5
+ from databricks.sdk.oauth import Token
6
+ from databricks.sdk.service.oauth2 import DataPlaneInfo
7
+
8
+
9
+ @dataclass
10
+ class DataPlaneDetails:
11
+ """
12
+ Contains details required to query a DataPlane endpoint.
13
+ """
14
+ endpoint_url: str
15
+ """URL used to query the endpoint through the DataPlane."""
16
+ token: Token
17
+ """Token to query the DataPlane endpoint."""
18
+
19
+
20
+ class DataPlaneService:
21
+ """Helper class to fetch and manage DataPlane details."""
22
+
23
+ def __init__(self):
24
+ self._data_plane_info = {}
25
+ self._tokens = {}
26
+ self._lock = threading.Lock()
27
+
28
+ def get_data_plane_details(self, method: str, params: List[str], info_getter: Callable[[], DataPlaneInfo],
29
+ refresh: Callable[[str], Token]):
30
+ """Get and cache information required to query a Data Plane endpoint using the provided methods.
31
+
32
+ Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid.
33
+ If not, it uses the provided functions to fetch the details.
34
+
35
+ :param method: method name. Used to construct a unique key for the cache.
36
+ :param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache.
37
+ :param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache.
38
+ :param refresh: function to refresh the token. It will only be called if the token is missing or expired.
39
+ """
40
+ all_elements = params.copy()
41
+ all_elements.insert(0, method)
42
+ map_key = "/".join(all_elements)
43
+ info = self._data_plane_info.get(map_key)
44
+ if not info:
45
+ self._lock.acquire()
46
+ try:
47
+ info = self._data_plane_info.get(map_key)
48
+ if not info:
49
+ info = info_getter()
50
+ self._data_plane_info[map_key] = info
51
+ finally:
52
+ self._lock.release()
53
+
54
+ token = self._tokens.get(map_key)
55
+ if not token or not token.valid:
56
+ self._lock.acquire()
57
+ token = self._tokens.get(map_key)
58
+ try:
59
+ if not token or not token.valid:
60
+ token = refresh(info.authorization_details)
61
+ self._tokens[map_key] = token
62
+ finally:
63
+ self._lock.release()
64
+
65
+ return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token)
databricks/sdk/dbutils.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import base64
2
2
  import json
3
3
  import logging
4
- import os.path
4
+ import os
5
5
  import threading
6
6
  from collections import namedtuple
7
- from typing import Callable, Dict, List
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional
8
9
 
9
10
  from .core import ApiClient, Config, DatabricksError
10
11
  from .mixins import compute as compute_ext
@@ -240,6 +241,76 @@ class RemoteDbUtils:
240
241
  name=util)
241
242
 
242
243
 
244
+ @dataclass
245
+ class OverrideResult:
246
+ result: Any
247
+
248
+
249
+ def get_local_notebook_path():
250
+ value = os.getenv("DATABRICKS_SOURCE_FILE")
251
+ if value is None:
252
+ raise ValueError(
253
+ "Getting the current notebook path is only supported when running a notebook using the `Databricks Connect: Run as File` or `Databricks Connect: Debug as File` commands in the Databricks extension for VS Code. To bypass this error, set environment variable `DATABRICKS_SOURCE_FILE` to the desired notebook path."
254
+ )
255
+
256
+ return value
257
+
258
+
259
+ class _OverrideProxyUtil:
260
+
261
+ @classmethod
262
+ def new(cls, path: str):
263
+ if len(cls.__get_matching_overrides(path)) > 0:
264
+ return _OverrideProxyUtil(path)
265
+ return None
266
+
267
+ def __init__(self, name: str):
268
+ self._name = name
269
+
270
+ # These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
271
+ # are sent to remote either. This could lead to unintentional breakage.
272
+ # Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
273
+ # This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
274
+ # are being proxied to remote dbutils currently.
275
+ proxy_override_paths = {
276
+ 'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
277
+ get_local_notebook_path,
278
+ }
279
+
280
+ @classmethod
281
+ def __get_matching_overrides(cls, path: str):
282
+ return [x for x in cls.proxy_override_paths.keys() if x.startswith(path)]
283
+
284
+ def __run_override(self, path: str) -> Optional[OverrideResult]:
285
+ overrides = self.__get_matching_overrides(path)
286
+ if len(overrides) == 1 and overrides[0] == path:
287
+ return OverrideResult(self.proxy_override_paths[overrides[0]]())
288
+
289
+ if len(overrides) > 0:
290
+ return OverrideResult(_OverrideProxyUtil(name=path))
291
+
292
+ return None
293
+
294
+ def __call__(self, *args, **kwds) -> Any:
295
+ if len(args) != 0 or len(kwds) != 0:
296
+ raise TypeError(
297
+ f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()")
298
+
299
+ callable_path = f"{self._name}()"
300
+ result = self.__run_override(callable_path)
301
+ if result:
302
+ return result.result
303
+
304
+ raise TypeError(f"{self._name} is not callable")
305
+
306
+ def __getattr__(self, method: str) -> Any:
307
+ result = self.__run_override(f"{self._name}.{method}")
308
+ if result:
309
+ return result.result
310
+
311
+ raise AttributeError(f"module {self._name} has no attribute {method}")
312
+
313
+
243
314
  class _ProxyUtil:
244
315
  """Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""
245
316
 
@@ -250,7 +321,14 @@ class _ProxyUtil:
250
321
  self._context_factory = context_factory
251
322
  self._name = name
252
323
 
253
- def __getattr__(self, method: str) -> '_ProxyCall':
324
+ def __call__(self):
325
+ raise NotImplementedError(f"dbutils.{self._name} is not callable")
326
+
327
+ def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
328
+ override = _OverrideProxyUtil.new(f"{self._name}.{method}")
329
+ if override:
330
+ return override
331
+
254
332
  return _ProxyCall(command_execution=self._commands,
255
333
  cluster_id=self._cluster_id,
256
334
  context_factory=self._context_factory,
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import base64
4
4
  import os
5
5
  import pathlib
6
+ import platform
6
7
  import shutil
7
8
  import sys
8
9
  from abc import ABC, abstractmethod
@@ -266,8 +267,9 @@ class _VolumesIO(BinaryIO):
266
267
 
267
268
  class _Path(ABC):
268
269
 
269
- def __init__(self, path: str):
270
- self._path = pathlib.Path(str(path).replace('dbfs:', '').replace('file:', ''))
270
+ @abstractmethod
271
+ def __init__(self):
272
+ ...
271
273
 
272
274
  @property
273
275
  def is_local(self) -> bool:
@@ -327,6 +329,12 @@ class _Path(ABC):
327
329
 
328
330
  class _LocalPath(_Path):
329
331
 
332
+ def __init__(self, path: str):
333
+ if platform.system() == "Windows":
334
+ self._path = pathlib.Path(str(path).replace('file:///', '').replace('file:', ''))
335
+ else:
336
+ self._path = pathlib.Path(str(path).replace('file:', ''))
337
+
330
338
  def _is_local(self) -> bool:
331
339
  return True
332
340
 
@@ -393,7 +401,7 @@ class _LocalPath(_Path):
393
401
  class _VolumesPath(_Path):
394
402
 
395
403
  def __init__(self, api: files.FilesAPI, src: Union[str, pathlib.Path]):
396
- super().__init__(src)
404
+ self._path = pathlib.PurePosixPath(str(src).replace('dbfs:', '').replace('file:', ''))
397
405
  self._api = api
398
406
 
399
407
  def _is_local(self) -> bool:
@@ -462,7 +470,7 @@ class _VolumesPath(_Path):
462
470
  class _DbfsPath(_Path):
463
471
 
464
472
  def __init__(self, api: files.DbfsAPI, src: str):
465
- super().__init__(src)
473
+ self._path = pathlib.PurePosixPath(str(src).replace('dbfs:', '').replace('file:', ''))
466
474
  self._api = api
467
475
 
468
476
  def _is_local(self) -> bool:
databricks/sdk/oauth.py CHANGED
@@ -21,6 +21,10 @@ import requests.auth
21
21
  # See https://stackoverflow.com/a/75466778/277035 for more info
22
22
  NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327'
23
23
 
24
+ URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
25
+ JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
26
+ OIDC_TOKEN_PATH = "/oidc/v1/token"
27
+
24
28
  logger = logging.getLogger(__name__)
25
29
 
26
30
 
@@ -358,18 +362,15 @@ class OAuthClient:
358
362
  client_secret: str = None):
359
363
  # TODO: is it a circular dependency?..
360
364
  from .core import Config
361
- from .credentials_provider import credentials_provider
365
+ from .credentials_provider import credentials_strategy
362
366
 
363
- @credentials_provider('noop', [])
367
+ @credentials_strategy('noop', [])
364
368
  def noop_credentials(_: any):
365
369
  return lambda: {}
366
370
 
367
- config = Config(host=host, credentials_provider=noop_credentials)
371
+ config = Config(host=host, credentials_strategy=noop_credentials)
368
372
  if not scopes:
369
373
  scopes = ['all-apis']
370
- if config.is_azure:
371
- # Azure AD only supports full access to Azure Databricks.
372
- scopes = [f'{config.effective_azure_login_app_id}/user_impersonation', 'offline_access']
373
374
  oidc = config.oidc_endpoints
374
375
  if not oidc:
375
376
  raise ValueError(f'{host} does not support OAuth')
@@ -381,6 +382,7 @@ class OAuthClient:
381
382
  self.token_url = oidc.token_endpoint
382
383
  self.is_aws = config.is_aws
383
384
  self.is_azure = config.is_azure
385
+ self.is_gcp = config.is_gcp
384
386
 
385
387
  self._auth_url = oidc.authorization_endpoint
386
388
  self._scopes = scopes