databricks-sdk 0.27.1__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (32) hide show
  1. databricks/sdk/__init__.py +16 -12
  2. databricks/sdk/azure.py +0 -27
  3. databricks/sdk/config.py +71 -19
  4. databricks/sdk/core.py +27 -0
  5. databricks/sdk/credentials_provider.py +121 -44
  6. databricks/sdk/dbutils.py +81 -3
  7. databricks/sdk/environments.py +34 -1
  8. databricks/sdk/errors/__init__.py +1 -0
  9. databricks/sdk/errors/mapper.py +4 -0
  10. databricks/sdk/errors/private_link.py +60 -0
  11. databricks/sdk/oauth.py +8 -6
  12. databricks/sdk/service/catalog.py +774 -632
  13. databricks/sdk/service/compute.py +91 -116
  14. databricks/sdk/service/dashboards.py +707 -2
  15. databricks/sdk/service/jobs.py +126 -163
  16. databricks/sdk/service/marketplace.py +145 -31
  17. databricks/sdk/service/oauth2.py +22 -0
  18. databricks/sdk/service/pipelines.py +119 -4
  19. databricks/sdk/service/serving.py +217 -64
  20. databricks/sdk/service/settings.py +1 -0
  21. databricks/sdk/service/sharing.py +36 -2
  22. databricks/sdk/service/sql.py +103 -24
  23. databricks/sdk/service/vectorsearch.py +263 -1
  24. databricks/sdk/service/workspace.py +8 -4
  25. databricks/sdk/version.py +1 -1
  26. {databricks_sdk-0.27.1.dist-info → databricks_sdk-0.29.0.dist-info}/METADATA +2 -1
  27. databricks_sdk-0.29.0.dist-info/RECORD +57 -0
  28. databricks_sdk-0.27.1.dist-info/RECORD +0 -56
  29. {databricks_sdk-0.27.1.dist-info → databricks_sdk-0.29.0.dist-info}/LICENSE +0 -0
  30. {databricks_sdk-0.27.1.dist-info → databricks_sdk-0.29.0.dist-info}/NOTICE +0 -0
  31. {databricks_sdk-0.27.1.dist-info → databricks_sdk-0.29.0.dist-info}/WHEEL +0 -0
  32. {databricks_sdk-0.27.1.dist-info → databricks_sdk-0.29.0.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,26 @@ from .azure import add_sp_management_token, add_workspace_id_header
22
22
  from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
23
23
  TokenCache, TokenSource)
24
24
 
25
- HeaderFactory = Callable[[], Dict[str, str]]
25
+ CredentialsProvider = Callable[[], Dict[str, str]]
26
26
 
27
27
  logger = logging.getLogger('databricks.sdk')
28
28
 
29
29
 
30
- class CredentialsProvider(abc.ABC):
30
+ class OAuthCredentialsProvider:
31
+ """ OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
32
+
33
+ def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
34
+ self._credentials_provider = credentials_provider
35
+ self._token_provider = token_provider
36
+
37
+ def __call__(self) -> Dict[str, str]:
38
+ return self._credentials_provider()
39
+
40
+ def oauth_token(self) -> Token:
41
+ return self._token_provider()
42
+
43
+
44
+ class CredentialsStrategy(abc.ABC):
31
45
  """ CredentialsProvider is the protocol (call-side interface)
32
46
  for authenticating requests to Databricks REST APIs"""
33
47
 
@@ -36,20 +50,39 @@ class CredentialsProvider(abc.ABC):
36
50
  ...
37
51
 
38
52
  @abc.abstractmethod
39
- def __call__(self, cfg: 'Config') -> HeaderFactory:
53
+ def __call__(self, cfg: 'Config') -> CredentialsProvider:
40
54
  ...
41
55
 
42
56
 
43
- def credentials_provider(name: str, require: List[str]):
57
+ class OauthCredentialsStrategy(CredentialsStrategy):
58
+ """ OauthCredentialsProvider is a CredentialsProvider which
59
+ supports Oauth tokens"""
60
+
61
+ def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
62
+ self._headers_provider = headers_provider
63
+ self._auth_type = auth_type
64
+
65
+ def auth_type(self) -> str:
66
+ return self._auth_type
67
+
68
+ def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
69
+ return self._headers_provider(cfg)
70
+
71
+ def oauth_token(self, cfg: 'Config') -> Token:
72
+ return self._headers_provider(cfg).oauth_token()
73
+
74
+
75
+ def credentials_strategy(name: str, require: List[str]):
44
76
  """ Given the function that receives a Config and returns RequestVisitor,
45
77
  create CredentialsProvider with a given name and required configuration
46
78
  attribute names to be present for this function to be called. """
47
79
 
48
- def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider:
80
+ def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
49
81
 
50
82
  @functools.wraps(func)
51
- def wrapper(cfg: 'Config') -> Optional[HeaderFactory]:
83
+ def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
52
84
  for attr in require:
85
+ getattr(cfg, attr)
53
86
  if not getattr(cfg, attr):
54
87
  return None
55
88
  return func(cfg)
@@ -60,8 +93,27 @@ def credentials_provider(name: str, require: List[str]):
60
93
  return inner
61
94
 
62
95
 
63
- @credentials_provider('basic', ['host', 'username', 'password'])
64
- def basic_auth(cfg: 'Config') -> HeaderFactory:
96
+ def oauth_credentials_strategy(name: str, require: List[str]):
97
+ """ Given the function that receives a Config and returns an OauthHeaderFactory,
98
+ create an OauthCredentialsProvider with a given name and required configuration
99
+ attribute names to be present for this function to be called. """
100
+
101
+ def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
102
+
103
+ @functools.wraps(func)
104
+ def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
105
+ for attr in require:
106
+ if not getattr(cfg, attr):
107
+ return None
108
+ return func(cfg)
109
+
110
+ return OauthCredentialsStrategy(name, wrapper)
111
+
112
+ return inner
113
+
114
+
115
+ @credentials_strategy('basic', ['host', 'username', 'password'])
116
+ def basic_auth(cfg: 'Config') -> CredentialsProvider:
65
117
  """ Given username and password, add base64-encoded Basic credentials """
66
118
  encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
67
119
  static_credentials = {'Authorization': f'Basic {encoded}'}
@@ -72,8 +124,8 @@ def basic_auth(cfg: 'Config') -> HeaderFactory:
72
124
  return inner
73
125
 
74
126
 
75
- @credentials_provider('pat', ['host', 'token'])
76
- def pat_auth(cfg: 'Config') -> HeaderFactory:
127
+ @credentials_strategy('pat', ['host', 'token'])
128
+ def pat_auth(cfg: 'Config') -> CredentialsProvider:
77
129
  """ Adds Databricks Personal Access Token to every request """
78
130
  static_credentials = {'Authorization': f'Bearer {cfg.token}'}
79
131
 
@@ -83,8 +135,8 @@ def pat_auth(cfg: 'Config') -> HeaderFactory:
83
135
  return inner
84
136
 
85
137
 
86
- @credentials_provider('runtime', [])
87
- def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
138
+ @credentials_strategy('runtime', [])
139
+ def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
88
140
  if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
89
141
  return None
90
142
 
@@ -107,8 +159,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
107
159
  return None
108
160
 
109
161
 
110
- @credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret'])
111
- def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
162
+ @oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
163
+ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
112
164
  """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
113
165
  if /oidc/.well-known/oauth-authorization-server is available on the given host. """
114
166
  oidc = cfg.oidc_endpoints
@@ -124,11 +176,14 @@ def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
124
176
  token = token_source.token()
125
177
  return {'Authorization': f'{token.token_type} {token.access_token}'}
126
178
 
127
- return inner
179
+ def token() -> Token:
180
+ return token_source.token()
181
+
182
+ return OAuthCredentialsProvider(inner, token)
128
183
 
129
184
 
130
- @credentials_provider('external-browser', ['host', 'auth_type'])
131
- def external_browser(cfg: 'Config') -> Optional[HeaderFactory]:
185
+ @credentials_strategy('external-browser', ['host', 'auth_type'])
186
+ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
132
187
  if cfg.auth_type != 'external-browser':
133
188
  return None
134
189
  if cfg.client_id:
@@ -178,9 +233,9 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
178
233
  cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
179
234
 
180
235
 
181
- @credentials_provider('azure-client-secret',
182
- ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
183
- def azure_service_principal(cfg: 'Config') -> HeaderFactory:
236
+ @oauth_credentials_strategy('azure-client-secret',
237
+ ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
238
+ def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
184
239
  """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
185
240
  to every request, while automatically resolving different Azure environment endpoints. """
186
241
 
@@ -203,11 +258,14 @@ def azure_service_principal(cfg: 'Config') -> HeaderFactory:
203
258
  add_sp_management_token(cloud, headers)
204
259
  return headers
205
260
 
206
- return refreshed_headers
261
+ def token() -> Token:
262
+ return inner.token()
263
+
264
+ return OAuthCredentialsProvider(refreshed_headers, token)
207
265
 
208
266
 
209
- @credentials_provider('github-oidc-azure', ['host', 'azure_client_id'])
210
- def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
267
+ @oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
268
+ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
211
269
  if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
212
270
  # not in GitHub actions
213
271
  return None
@@ -250,14 +308,17 @@ def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
250
308
  token = inner.token()
251
309
  return {'Authorization': f'{token.token_type} {token.access_token}'}
252
310
 
253
- return refreshed_headers
311
+ def token() -> Token:
312
+ return inner.token()
313
+
314
+ return OAuthCredentialsProvider(refreshed_headers, token)
254
315
 
255
316
 
256
317
  GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
257
318
 
258
319
 
259
- @credentials_provider('google-credentials', ['host', 'google_credentials'])
260
- def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
320
+ @oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
321
+ def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
261
322
  if not cfg.is_gcp:
262
323
  return None
263
324
  # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
@@ -277,6 +338,10 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
277
338
  gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
278
339
  scopes=GcpScopes)
279
340
 
341
+ def token() -> Token:
342
+ credentials.refresh(request)
343
+ return credentials.token
344
+
280
345
  def refreshed_headers() -> Dict[str, str]:
281
346
  credentials.refresh(request)
282
347
  headers = {'Authorization': f'Bearer {credentials.token}'}
@@ -285,11 +350,11 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
285
350
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
286
351
  return headers
287
352
 
288
- return refreshed_headers
353
+ return OAuthCredentialsProvider(refreshed_headers, token)
289
354
 
290
355
 
291
- @credentials_provider('google-id', ['host', 'google_service_account'])
292
- def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
356
+ @oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
357
+ def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
293
358
  if not cfg.is_gcp:
294
359
  return None
295
360
  credentials, _project_id = google.auth.default()
@@ -309,6 +374,10 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
309
374
 
310
375
  request = Request()
311
376
 
377
+ def token() -> Token:
378
+ id_creds.refresh(request)
379
+ return id_creds.token
380
+
312
381
  def refreshed_headers() -> Dict[str, str]:
313
382
  id_creds.refresh(request)
314
383
  headers = {'Authorization': f'Bearer {id_creds.token}'}
@@ -317,7 +386,7 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
317
386
  headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
318
387
  return headers
319
388
 
320
- return refreshed_headers
389
+ return OAuthCredentialsProvider(refreshed_headers, token)
321
390
 
322
391
 
323
392
  class CliTokenSource(Refreshable):
@@ -422,8 +491,8 @@ class AzureCliTokenSource(CliTokenSource):
422
491
  return components[2]
423
492
 
424
493
 
425
- @credentials_provider('azure-cli', ['is_azure'])
426
- def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
494
+ @credentials_strategy('azure-cli', ['is_azure'])
495
+ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
427
496
  """ Adds refreshed OAuth token granted by `az login` command to every request. """
428
497
  token_source = None
429
498
  mgmt_token_source = None
@@ -516,8 +585,8 @@ class DatabricksCliTokenSource(CliTokenSource):
516
585
  raise err
517
586
 
518
587
 
519
- @credentials_provider('databricks-cli', ['host'])
520
- def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
588
+ @oauth_credentials_strategy('databricks-cli', ['host'])
589
+ def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
521
590
  try:
522
591
  token_source = DatabricksCliTokenSource(cfg)
523
592
  except FileNotFoundError as e:
@@ -538,7 +607,7 @@ def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
538
607
  token = token_source.token()
539
608
  return {'Authorization': f'{token.token_type} {token.access_token}'}
540
609
 
541
- return inner
610
+ return OAuthCredentialsProvider(inner, token_source.token)
542
611
 
543
612
 
544
613
  class MetadataServiceTokenSource(Refreshable):
@@ -577,8 +646,8 @@ class MetadataServiceTokenSource(Refreshable):
577
646
  return Token(access_token=access_token, token_type=token_type, expiry=expiry)
578
647
 
579
648
 
580
- @credentials_provider('metadata-service', ['host', 'metadata_service_url'])
581
- def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]:
649
+ @credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
650
+ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
582
651
  """ Adds refreshed token granted by Databricks Metadata Service to every request. """
583
652
 
584
653
  token_source = MetadataServiceTokenSource(cfg)
@@ -597,17 +666,25 @@ class DefaultCredentials:
597
666
 
598
667
  def __init__(self) -> None:
599
668
  self._auth_type = 'default'
600
-
601
- def auth_type(self) -> str:
602
- return self._auth_type
603
-
604
- def __call__(self, cfg: 'Config') -> HeaderFactory:
605
- auth_providers = [
669
+ self._auth_providers = [
606
670
  pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
607
671
  github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
608
672
  google_credentials, google_id
609
673
  ]
610
- for provider in auth_providers:
674
+
675
+ def auth_type(self) -> str:
676
+ return self._auth_type
677
+
678
+ def oauth_token(self, cfg: 'Config') -> Token:
679
+ for provider in self._auth_providers:
680
+ auth_type = provider.auth_type()
681
+ if auth_type != self._auth_type:
682
+ # ignore other auth types if they don't match the selected one
683
+ continue
684
+ return provider.oauth_token(cfg)
685
+
686
+ def __call__(self, cfg: 'Config') -> CredentialsProvider:
687
+ for provider in self._auth_providers:
611
688
  auth_type = provider.auth_type()
612
689
  if cfg.auth_type and auth_type != cfg.auth_type:
613
690
  # ignore other auth types if one is explicitly enforced
databricks/sdk/dbutils.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import base64
2
2
  import json
3
3
  import logging
4
- import os.path
4
+ import os
5
5
  import threading
6
6
  from collections import namedtuple
7
- from typing import Callable, Dict, List
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional
8
9
 
9
10
  from .core import ApiClient, Config, DatabricksError
10
11
  from .mixins import compute as compute_ext
@@ -240,6 +241,76 @@ class RemoteDbUtils:
240
241
  name=util)
241
242
 
242
243
 
244
+ @dataclass
245
+ class OverrideResult:
246
+ result: Any
247
+
248
+
249
+ def get_local_notebook_path():
250
+ value = os.getenv("DATABRICKS_SOURCE_FILE")
251
+ if value is None:
252
+ raise ValueError(
253
+ "Getting the current notebook path is only supported when running a notebook using the `Databricks Connect: Run as File` or `Databricks Connect: Debug as File` commands in the Databricks extension for VS Code. To bypass this error, set environment variable `DATABRICKS_SOURCE_FILE` to the desired notebook path."
254
+ )
255
+
256
+ return value
257
+
258
+
259
+ class _OverrideProxyUtil:
260
+
261
+ @classmethod
262
+ def new(cls, path: str):
263
+ if len(cls.__get_matching_overrides(path)) > 0:
264
+ return _OverrideProxyUtil(path)
265
+ return None
266
+
267
+ def __init__(self, name: str):
268
+ self._name = name
269
+
270
+ # These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
271
+ # are sent to remote either. This could lead to unintentional breakage.
272
+ # Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
273
+ # This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
274
+ # are being proxied to remote dbutils currently.
275
+ proxy_override_paths = {
276
+ 'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
277
+ get_local_notebook_path,
278
+ }
279
+
280
+ @classmethod
281
+ def __get_matching_overrides(cls, path: str):
282
+ return [x for x in cls.proxy_override_paths.keys() if x.startswith(path)]
283
+
284
+ def __run_override(self, path: str) -> Optional[OverrideResult]:
285
+ overrides = self.__get_matching_overrides(path)
286
+ if len(overrides) == 1 and overrides[0] == path:
287
+ return OverrideResult(self.proxy_override_paths[overrides[0]]())
288
+
289
+ if len(overrides) > 0:
290
+ return OverrideResult(_OverrideProxyUtil(name=path))
291
+
292
+ return None
293
+
294
+ def __call__(self, *args, **kwds) -> Any:
295
+ if len(args) != 0 or len(kwds) != 0:
296
+ raise TypeError(
297
+ f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()")
298
+
299
+ callable_path = f"{self._name}()"
300
+ result = self.__run_override(callable_path)
301
+ if result:
302
+ return result.result
303
+
304
+ raise TypeError(f"{self._name} is not callable")
305
+
306
+ def __getattr__(self, method: str) -> Any:
307
+ result = self.__run_override(f"{self._name}.{method}")
308
+ if result:
309
+ return result.result
310
+
311
+ raise AttributeError(f"module {self._name} has no attribute {method}")
312
+
313
+
243
314
  class _ProxyUtil:
244
315
  """Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""
245
316
 
@@ -250,7 +321,14 @@ class _ProxyUtil:
250
321
  self._context_factory = context_factory
251
322
  self._name = name
252
323
 
253
- def __getattr__(self, method: str) -> '_ProxyCall':
324
+ def __call__(self):
325
+ raise NotImplementedError(f"dbutils.{self._name} is not callable")
326
+
327
+ def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
328
+ override = _OverrideProxyUtil.new(f"{self._name}.{method}")
329
+ if override:
330
+ return override
331
+
254
332
  return _ProxyCall(command_execution=self._commands,
255
333
  cluster_id=self._cluster_id,
256
334
  context_factory=self._context_factory,
@@ -2,7 +2,31 @@ from dataclasses import dataclass
2
2
  from enum import Enum
3
3
  from typing import Optional
4
4
 
5
- from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment
5
+
6
+ @dataclass
7
+ class AzureEnvironment:
8
+ name: str
9
+ service_management_endpoint: str
10
+ resource_manager_endpoint: str
11
+ active_directory_endpoint: str
12
+
13
+
14
+ ARM_DATABRICKS_RESOURCE_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
15
+
16
+ ENVIRONMENTS = dict(
17
+ PUBLIC=AzureEnvironment(name="PUBLIC",
18
+ service_management_endpoint="https://management.core.windows.net/",
19
+ resource_manager_endpoint="https://management.azure.com/",
20
+ active_directory_endpoint="https://login.microsoftonline.com/"),
21
+ USGOVERNMENT=AzureEnvironment(name="USGOVERNMENT",
22
+ service_management_endpoint="https://management.core.usgovcloudapi.net/",
23
+ resource_manager_endpoint="https://management.usgovcloudapi.net/",
24
+ active_directory_endpoint="https://login.microsoftonline.us/"),
25
+ CHINA=AzureEnvironment(name="CHINA",
26
+ service_management_endpoint="https://management.core.chinacloudapi.cn/",
27
+ resource_manager_endpoint="https://management.chinacloudapi.cn/",
28
+ active_directory_endpoint="https://login.chinacloudapi.cn/"),
29
+ )
6
30
 
7
31
 
8
32
  class Cloud(Enum):
@@ -70,3 +94,12 @@ ALL_ENVS = [
70
94
  DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"),
71
95
  DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com")
72
96
  ]
97
+
98
+
99
+ def get_environment_for_hostname(hostname: str) -> DatabricksEnvironment:
100
+ if not hostname:
101
+ return DEFAULT_ENVIRONMENT
102
+ for env in ALL_ENVS:
103
+ if hostname.endswith(env.dns_zone):
104
+ return env
105
+ return DEFAULT_ENVIRONMENT
@@ -1,4 +1,5 @@
1
1
  from .base import DatabricksError, ErrorDetail
2
2
  from .mapper import error_mapper
3
3
  from .platform import *
4
+ from .private_link import PrivateLinkValidationError
4
5
  from .sdk import *
@@ -4,6 +4,8 @@ from databricks.sdk.errors import platform
4
4
  from databricks.sdk.errors.base import DatabricksError
5
5
 
6
6
  from .overrides import _ALL_OVERRIDES
7
+ from .private_link import (_get_private_link_validation_error,
8
+ _is_private_link_redirect)
7
9
 
8
10
 
9
11
  def error_mapper(response: requests.Response, raw: dict) -> DatabricksError:
@@ -21,6 +23,8 @@ def error_mapper(response: requests.Response, raw: dict) -> DatabricksError:
21
23
  # where there's a default exception class per HTTP status code, and we do
22
24
  # rely on Databricks platform exception mapper to do the right thing.
23
25
  return platform.STATUS_CODE_MAPPING[status_code](**raw)
26
+ if _is_private_link_redirect(response):
27
+ return _get_private_link_validation_error(response.url)
24
28
 
25
29
  # backwards-compatible error creation for cases like using older versions of
26
30
  # the SDK on way never releases of the platform.
@@ -0,0 +1,60 @@
1
+ from dataclasses import dataclass
2
+ from urllib import parse
3
+
4
+ import requests
5
+
6
+ from ..environments import Cloud, get_environment_for_hostname
7
+ from .platform import PermissionDenied
8
+
9
+
10
+ @dataclass
11
+ class _PrivateLinkInfo:
12
+ serviceName: str
13
+ endpointName: str
14
+ referencePage: str
15
+
16
+ def error_message(self):
17
+ return (
18
+ f'The requested workspace has {self.serviceName} enabled and is not accessible from the current network. '
19
+ f'Ensure that {self.serviceName} is properly configured and that your device has access to the '
20
+ f'{self.endpointName}. For more information, see {self.referencePage}.')
21
+
22
+
23
+ _private_link_info_map = {
24
+ Cloud.AWS:
25
+ _PrivateLinkInfo(serviceName='AWS PrivateLink',
26
+ endpointName='AWS VPC endpoint',
27
+ referencePage='https://docs.databricks.com/en/security/network/classic/privatelink.html',
28
+ ),
29
+ Cloud.AZURE:
30
+ _PrivateLinkInfo(
31
+ serviceName='Azure Private Link',
32
+ endpointName='Azure Private Link endpoint',
33
+ referencePage='https://learn.microsoft.com/en-us/azure/databricks/security/network/classic/private-link-standard#authentication-troubleshooting',
34
+ ),
35
+ Cloud.GCP:
36
+ _PrivateLinkInfo(
37
+ serviceName='Private Service Connect',
38
+ endpointName='GCP VPC endpoint',
39
+ referencePage='https://docs.gcp.databricks.com/en/security/network/classic/private-service-connect.html',
40
+ )
41
+ }
42
+
43
+
44
+ class PrivateLinkValidationError(PermissionDenied):
45
+ """Raised when a user tries to access a Private Link-enabled workspace, but the user's network does not have access
46
+ to the workspace."""
47
+
48
+
49
+ def _is_private_link_redirect(resp: requests.Response) -> bool:
50
+ parsed = parse.urlparse(resp.url)
51
+ return parsed.path == '/login.html' and 'error=private-link-validation-error' in parsed.query
52
+
53
+
54
+ def _get_private_link_validation_error(url: str) -> _PrivateLinkInfo:
55
+ parsed = parse.urlparse(url)
56
+ env = get_environment_for_hostname(parsed.hostname)
57
+ return PrivateLinkValidationError(message=_private_link_info_map[env.cloud].error_message(),
58
+ error_code='PRIVATE_LINK_VALIDATION_ERROR',
59
+ status_code=403,
60
+ )
databricks/sdk/oauth.py CHANGED
@@ -21,6 +21,10 @@ import requests.auth
21
21
  # See https://stackoverflow.com/a/75466778/277035 for more info
22
22
  NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327'
23
23
 
24
+ URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
25
+ JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
26
+ OIDC_TOKEN_PATH = "/oidc/v1/token"
27
+
24
28
  logger = logging.getLogger(__name__)
25
29
 
26
30
 
@@ -358,18 +362,15 @@ class OAuthClient:
358
362
  client_secret: str = None):
359
363
  # TODO: is it a circular dependency?..
360
364
  from .core import Config
361
- from .credentials_provider import credentials_provider
365
+ from .credentials_provider import credentials_strategy
362
366
 
363
- @credentials_provider('noop', [])
367
+ @credentials_strategy('noop', [])
364
368
  def noop_credentials(_: any):
365
369
  return lambda: {}
366
370
 
367
- config = Config(host=host, credentials_provider=noop_credentials)
371
+ config = Config(host=host, credentials_strategy=noop_credentials)
368
372
  if not scopes:
369
373
  scopes = ['all-apis']
370
- if config.is_azure:
371
- # Azure AD only supports full access to Azure Databricks.
372
- scopes = [f'{config.effective_azure_login_app_id}/user_impersonation', 'offline_access']
373
374
  oidc = config.oidc_endpoints
374
375
  if not oidc:
375
376
  raise ValueError(f'{host} does not support OAuth')
@@ -381,6 +382,7 @@ class OAuthClient:
381
382
  self.token_url = oidc.token_endpoint
382
383
  self.is_aws = config.is_aws
383
384
  self.is_azure = config.is_azure
385
+ self.is_gcp = config.is_gcp
384
386
 
385
387
  self._auth_url = oidc.authorization_endpoint
386
388
  self._scopes = scopes