databricks-sdk 0.28.0__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +7 -3
- databricks/sdk/config.py +65 -10
- databricks/sdk/core.py +22 -0
- databricks/sdk/credentials_provider.py +121 -44
- databricks/sdk/dbutils.py +81 -3
- databricks/sdk/oauth.py +8 -6
- databricks/sdk/service/catalog.py +132 -28
- databricks/sdk/service/compute.py +21 -13
- databricks/sdk/service/dashboards.py +707 -2
- databricks/sdk/service/jobs.py +126 -152
- databricks/sdk/service/marketplace.py +136 -0
- databricks/sdk/service/oauth2.py +22 -0
- databricks/sdk/service/pipelines.py +1 -1
- databricks/sdk/service/serving.py +140 -55
- databricks/sdk/service/settings.py +1 -0
- databricks/sdk/service/sharing.py +0 -1
- databricks/sdk/service/sql.py +103 -23
- databricks/sdk/service/vectorsearch.py +75 -0
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.29.0.dist-info}/METADATA +2 -1
- {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.29.0.dist-info}/RECORD +25 -25
- {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.29.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.29.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.29.0.dist-info}/WHEEL +0 -0
- {databricks_sdk-0.28.0.dist-info → databricks_sdk-0.29.0.dist-info}/top_level.txt +0 -0
databricks/sdk/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import databricks.sdk.core as client
|
|
2
2
|
import databricks.sdk.dbutils as dbutils
|
|
3
3
|
from databricks.sdk import azure
|
|
4
|
-
from databricks.sdk.credentials_provider import
|
|
4
|
+
from databricks.sdk.credentials_provider import CredentialsStrategy
|
|
5
5
|
from databricks.sdk.mixins.compute import ClustersExt
|
|
6
6
|
from databricks.sdk.mixins.files import DbfsExt
|
|
7
7
|
from databricks.sdk.mixins.workspace import WorkspaceExt
|
|
@@ -131,7 +131,8 @@ class WorkspaceClient:
|
|
|
131
131
|
debug_headers: bool = None,
|
|
132
132
|
product="unknown",
|
|
133
133
|
product_version="0.0.0",
|
|
134
|
-
|
|
134
|
+
credentials_strategy: CredentialsStrategy = None,
|
|
135
|
+
credentials_provider: CredentialsStrategy = None,
|
|
135
136
|
config: client.Config = None):
|
|
136
137
|
if not config:
|
|
137
138
|
config = client.Config(host=host,
|
|
@@ -152,6 +153,7 @@ class WorkspaceClient:
|
|
|
152
153
|
cluster_id=cluster_id,
|
|
153
154
|
google_credentials=google_credentials,
|
|
154
155
|
google_service_account=google_service_account,
|
|
156
|
+
credentials_strategy=credentials_strategy,
|
|
155
157
|
credentials_provider=credentials_provider,
|
|
156
158
|
debug_truncate_bytes=debug_truncate_bytes,
|
|
157
159
|
debug_headers=debug_headers,
|
|
@@ -700,7 +702,8 @@ class AccountClient:
|
|
|
700
702
|
debug_headers: bool = None,
|
|
701
703
|
product="unknown",
|
|
702
704
|
product_version="0.0.0",
|
|
703
|
-
|
|
705
|
+
credentials_strategy: CredentialsStrategy = None,
|
|
706
|
+
credentials_provider: CredentialsStrategy = None,
|
|
704
707
|
config: client.Config = None):
|
|
705
708
|
if not config:
|
|
706
709
|
config = client.Config(host=host,
|
|
@@ -721,6 +724,7 @@ class AccountClient:
|
|
|
721
724
|
cluster_id=cluster_id,
|
|
722
725
|
google_credentials=google_credentials,
|
|
723
726
|
google_service_account=google_service_account,
|
|
727
|
+
credentials_strategy=credentials_strategy,
|
|
724
728
|
credentials_provider=credentials_provider,
|
|
725
729
|
debug_truncate_bytes=debug_truncate_bytes,
|
|
726
730
|
debug_headers=debug_headers,
|
databricks/sdk/config.py
CHANGED
|
@@ -6,15 +6,15 @@ import pathlib
|
|
|
6
6
|
import platform
|
|
7
7
|
import sys
|
|
8
8
|
import urllib.parse
|
|
9
|
-
from typing import Dict, Iterable, Optional
|
|
9
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
10
10
|
|
|
11
11
|
import requests
|
|
12
12
|
|
|
13
13
|
from .clock import Clock, RealClock
|
|
14
|
-
from .credentials_provider import
|
|
14
|
+
from .credentials_provider import CredentialsStrategy, DefaultCredentials
|
|
15
15
|
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
|
|
16
16
|
DatabricksEnvironment, get_environment_for_hostname)
|
|
17
|
-
from .oauth import OidcEndpoints
|
|
17
|
+
from .oauth import OidcEndpoints, Token
|
|
18
18
|
from .version import __version__
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger('databricks.sdk')
|
|
@@ -44,6 +44,32 @@ class ConfigAttribute:
|
|
|
44
44
|
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"
|
|
45
45
|
|
|
46
46
|
|
|
47
|
+
_DEFAULT_PRODUCT_NAME = 'unknown'
|
|
48
|
+
_DEFAULT_PRODUCT_VERSION = '0.0.0'
|
|
49
|
+
_STATIC_USER_AGENT: Tuple[str, str, List[str]] = (_DEFAULT_PRODUCT_NAME, _DEFAULT_PRODUCT_VERSION, [])
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def with_product(product: str, product_version: str):
|
|
53
|
+
"""[INTERNAL API] Change the product name and version used in the User-Agent header."""
|
|
54
|
+
global _STATIC_USER_AGENT
|
|
55
|
+
prev_product, prev_version, prev_other_info = _STATIC_USER_AGENT
|
|
56
|
+
logger.debug(f'Changing product from {prev_product}/{prev_version} to {product}/{product_version}')
|
|
57
|
+
_STATIC_USER_AGENT = product, product_version, prev_other_info
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def with_user_agent_extra(key: str, value: str):
|
|
61
|
+
"""[INTERNAL API] Add extra metadata to the User-Agent header when developing a library."""
|
|
62
|
+
global _STATIC_USER_AGENT
|
|
63
|
+
product_name, product_version, other_info = _STATIC_USER_AGENT
|
|
64
|
+
for item in other_info:
|
|
65
|
+
if item.startswith(f"{key}/"):
|
|
66
|
+
# ensure that we don't have duplicates
|
|
67
|
+
other_info.remove(item)
|
|
68
|
+
break
|
|
69
|
+
other_info.append(f"{key}/{value}")
|
|
70
|
+
_STATIC_USER_AGENT = product_name, product_version, other_info
|
|
71
|
+
|
|
72
|
+
|
|
47
73
|
class Config:
|
|
48
74
|
host: str = ConfigAttribute(env='DATABRICKS_HOST')
|
|
49
75
|
account_id: str = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID')
|
|
@@ -66,6 +92,7 @@ class Config:
|
|
|
66
92
|
auth_type: str = ConfigAttribute(env='DATABRICKS_AUTH_TYPE')
|
|
67
93
|
cluster_id: str = ConfigAttribute(env='DATABRICKS_CLUSTER_ID')
|
|
68
94
|
warehouse_id: str = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID')
|
|
95
|
+
serverless_compute_id: str = ConfigAttribute(env='DATABRICKS_SERVERLESS_COMPUTE_ID')
|
|
69
96
|
skip_verify: bool = ConfigAttribute()
|
|
70
97
|
http_timeout_seconds: float = ConfigAttribute()
|
|
71
98
|
debug_truncate_bytes: int = ConfigAttribute(env='DATABRICKS_DEBUG_TRUNCATE_BYTES')
|
|
@@ -81,15 +108,34 @@ class Config:
|
|
|
81
108
|
|
|
82
109
|
def __init__(self,
|
|
83
110
|
*,
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
111
|
+
# Deprecated. Use credentials_strategy instead.
|
|
112
|
+
credentials_provider: CredentialsStrategy = None,
|
|
113
|
+
credentials_strategy: CredentialsStrategy = None,
|
|
114
|
+
product=_DEFAULT_PRODUCT_NAME,
|
|
115
|
+
product_version=_DEFAULT_PRODUCT_VERSION,
|
|
87
116
|
clock: Clock = None,
|
|
88
117
|
**kwargs):
|
|
89
118
|
self._header_factory = None
|
|
90
119
|
self._inner = {}
|
|
120
|
+
# as in SDK for Go, pull information from global static user agent context,
|
|
121
|
+
# so that we can track additional metadata for mid-stream libraries, as well
|
|
122
|
+
# as for cases, when the downstream product is used as a library and is not
|
|
123
|
+
# configured with a proper product name and version.
|
|
124
|
+
static_product, static_version, _ = _STATIC_USER_AGENT
|
|
125
|
+
if product == _DEFAULT_PRODUCT_NAME:
|
|
126
|
+
product = static_product
|
|
127
|
+
if product_version == _DEFAULT_PRODUCT_VERSION:
|
|
128
|
+
product_version = static_version
|
|
91
129
|
self._user_agent_other_info = []
|
|
92
|
-
|
|
130
|
+
if credentials_strategy and credentials_provider:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
|
|
133
|
+
if credentials_provider:
|
|
134
|
+
logger.warning(
|
|
135
|
+
"parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead.")
|
|
136
|
+
self._credentials_strategy = next(
|
|
137
|
+
s for s in [credentials_strategy, credentials_provider,
|
|
138
|
+
DefaultCredentials()] if s is not None)
|
|
93
139
|
if 'databricks_environment' in kwargs:
|
|
94
140
|
self.databricks_environment = kwargs['databricks_environment']
|
|
95
141
|
del kwargs['databricks_environment']
|
|
@@ -107,6 +153,9 @@ class Config:
|
|
|
107
153
|
message = self.wrap_debug_info(str(e))
|
|
108
154
|
raise ValueError(message) from e
|
|
109
155
|
|
|
156
|
+
def oauth_token(self) -> Token:
|
|
157
|
+
return self._credentials_strategy.oauth_token(self)
|
|
158
|
+
|
|
110
159
|
def wrap_debug_info(self, message: str) -> str:
|
|
111
160
|
debug_string = self.debug_string()
|
|
112
161
|
if debug_string:
|
|
@@ -220,6 +269,12 @@ class Config:
|
|
|
220
269
|
]
|
|
221
270
|
if len(self._user_agent_other_info) > 0:
|
|
222
271
|
ua.append(' '.join(self._user_agent_other_info))
|
|
272
|
+
# as in SDK for Go, pull information from global static user agent context,
|
|
273
|
+
# so that we can track additional metadata for mid-stream libraries. this value
|
|
274
|
+
# is shared across all instances of Config objects intentionally.
|
|
275
|
+
_, _, static_info = _STATIC_USER_AGENT
|
|
276
|
+
if len(static_info) > 0:
|
|
277
|
+
ua.append(' '.join(static_info))
|
|
223
278
|
if len(self._upstream_user_agent) > 0:
|
|
224
279
|
ua.append(self._upstream_user_agent)
|
|
225
280
|
if 'DATABRICKS_RUNTIME_VERSION' in os.environ:
|
|
@@ -436,12 +491,12 @@ class Config:
|
|
|
436
491
|
|
|
437
492
|
def init_auth(self):
|
|
438
493
|
try:
|
|
439
|
-
self._header_factory = self.
|
|
440
|
-
self.auth_type = self.
|
|
494
|
+
self._header_factory = self._credentials_strategy(self)
|
|
495
|
+
self.auth_type = self._credentials_strategy.auth_type()
|
|
441
496
|
if not self._header_factory:
|
|
442
497
|
raise ValueError('not configured')
|
|
443
498
|
except ValueError as e:
|
|
444
|
-
raise ValueError(f'{self.
|
|
499
|
+
raise ValueError(f'{self._credentials_strategy.auth_type()} auth: {e}') from e
|
|
445
500
|
|
|
446
501
|
def __repr__(self):
|
|
447
502
|
return f'<{self.debug_string()}>'
|
databricks/sdk/core.py
CHANGED
|
@@ -4,6 +4,7 @@ from datetime import timedelta
|
|
|
4
4
|
from json import JSONDecodeError
|
|
5
5
|
from types import TracebackType
|
|
6
6
|
from typing import Any, BinaryIO, Iterator, Type
|
|
7
|
+
from urllib.parse import urlencode
|
|
7
8
|
|
|
8
9
|
from requests.adapters import HTTPAdapter
|
|
9
10
|
|
|
@@ -13,12 +14,17 @@ from .config import *
|
|
|
13
14
|
from .credentials_provider import *
|
|
14
15
|
from .errors import DatabricksError, error_mapper
|
|
15
16
|
from .errors.private_link import _is_private_link_redirect
|
|
17
|
+
from .oauth import retrieve_token
|
|
16
18
|
from .retries import retried
|
|
17
19
|
|
|
18
20
|
__all__ = ['Config', 'DatabricksError']
|
|
19
21
|
|
|
20
22
|
logger = logging.getLogger('databricks.sdk')
|
|
21
23
|
|
|
24
|
+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
|
|
25
|
+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
26
|
+
OIDC_TOKEN_PATH = "/oidc/v1/token"
|
|
27
|
+
|
|
22
28
|
|
|
23
29
|
class ApiClient:
|
|
24
30
|
_cfg: Config
|
|
@@ -109,6 +115,22 @@ class ApiClient:
|
|
|
109
115
|
flattened = dict(flatten_dict(with_fixed_bools))
|
|
110
116
|
return flattened
|
|
111
117
|
|
|
118
|
+
def get_oauth_token(self, auth_details: str) -> Token:
|
|
119
|
+
if not self._cfg.auth_type:
|
|
120
|
+
self._cfg.authenticate()
|
|
121
|
+
original_token = self._cfg.oauth_token()
|
|
122
|
+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
|
|
123
|
+
params = urlencode({
|
|
124
|
+
"grant_type": JWT_BEARER_GRANT_TYPE,
|
|
125
|
+
"authorization_details": auth_details,
|
|
126
|
+
"assertion": original_token.access_token
|
|
127
|
+
})
|
|
128
|
+
return retrieve_token(client_id=self._cfg.client_id,
|
|
129
|
+
client_secret=self._cfg.client_secret,
|
|
130
|
+
token_url=self._cfg.host + OIDC_TOKEN_PATH,
|
|
131
|
+
params=params,
|
|
132
|
+
headers=headers)
|
|
133
|
+
|
|
112
134
|
def do(self,
|
|
113
135
|
method: str,
|
|
114
136
|
path: str,
|
|
@@ -22,12 +22,26 @@ from .azure import add_sp_management_token, add_workspace_id_header
|
|
|
22
22
|
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
|
|
23
23
|
TokenCache, TokenSource)
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
CredentialsProvider = Callable[[], Dict[str, str]]
|
|
26
26
|
|
|
27
27
|
logger = logging.getLogger('databricks.sdk')
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
class
|
|
30
|
+
class OAuthCredentialsProvider:
|
|
31
|
+
""" OAuthCredentialsProvider is a type of CredentialsProvider which exposes OAuth tokens. """
|
|
32
|
+
|
|
33
|
+
def __init__(self, credentials_provider: CredentialsProvider, token_provider: Callable[[], Token]):
|
|
34
|
+
self._credentials_provider = credentials_provider
|
|
35
|
+
self._token_provider = token_provider
|
|
36
|
+
|
|
37
|
+
def __call__(self) -> Dict[str, str]:
|
|
38
|
+
return self._credentials_provider()
|
|
39
|
+
|
|
40
|
+
def oauth_token(self) -> Token:
|
|
41
|
+
return self._token_provider()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CredentialsStrategy(abc.ABC):
|
|
31
45
|
""" CredentialsProvider is the protocol (call-side interface)
|
|
32
46
|
for authenticating requests to Databricks REST APIs"""
|
|
33
47
|
|
|
@@ -36,20 +50,39 @@ class CredentialsProvider(abc.ABC):
|
|
|
36
50
|
...
|
|
37
51
|
|
|
38
52
|
@abc.abstractmethod
|
|
39
|
-
def __call__(self, cfg: 'Config') ->
|
|
53
|
+
def __call__(self, cfg: 'Config') -> CredentialsProvider:
|
|
40
54
|
...
|
|
41
55
|
|
|
42
56
|
|
|
43
|
-
|
|
57
|
+
class OauthCredentialsStrategy(CredentialsStrategy):
|
|
58
|
+
""" OauthCredentialsProvider is a CredentialsProvider which
|
|
59
|
+
supports Oauth tokens"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, auth_type: str, headers_provider: Callable[['Config'], OAuthCredentialsProvider]):
|
|
62
|
+
self._headers_provider = headers_provider
|
|
63
|
+
self._auth_type = auth_type
|
|
64
|
+
|
|
65
|
+
def auth_type(self) -> str:
|
|
66
|
+
return self._auth_type
|
|
67
|
+
|
|
68
|
+
def __call__(self, cfg: 'Config') -> OAuthCredentialsProvider:
|
|
69
|
+
return self._headers_provider(cfg)
|
|
70
|
+
|
|
71
|
+
def oauth_token(self, cfg: 'Config') -> Token:
|
|
72
|
+
return self._headers_provider(cfg).oauth_token()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def credentials_strategy(name: str, require: List[str]):
|
|
44
76
|
""" Given the function that receives a Config and returns RequestVisitor,
|
|
45
77
|
create CredentialsProvider with a given name and required configuration
|
|
46
78
|
attribute names to be present for this function to be called. """
|
|
47
79
|
|
|
48
|
-
def inner(func: Callable[['Config'],
|
|
80
|
+
def inner(func: Callable[['Config'], CredentialsProvider]) -> CredentialsStrategy:
|
|
49
81
|
|
|
50
82
|
@functools.wraps(func)
|
|
51
|
-
def wrapper(cfg: 'Config') -> Optional[
|
|
83
|
+
def wrapper(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
52
84
|
for attr in require:
|
|
85
|
+
getattr(cfg, attr)
|
|
53
86
|
if not getattr(cfg, attr):
|
|
54
87
|
return None
|
|
55
88
|
return func(cfg)
|
|
@@ -60,8 +93,27 @@ def credentials_provider(name: str, require: List[str]):
|
|
|
60
93
|
return inner
|
|
61
94
|
|
|
62
95
|
|
|
63
|
-
|
|
64
|
-
|
|
96
|
+
def oauth_credentials_strategy(name: str, require: List[str]):
|
|
97
|
+
""" Given the function that receives a Config and returns an OauthHeaderFactory,
|
|
98
|
+
create an OauthCredentialsProvider with a given name and required configuration
|
|
99
|
+
attribute names to be present for this function to be called. """
|
|
100
|
+
|
|
101
|
+
def inner(func: Callable[['Config'], OAuthCredentialsProvider]) -> OauthCredentialsStrategy:
|
|
102
|
+
|
|
103
|
+
@functools.wraps(func)
|
|
104
|
+
def wrapper(cfg: 'Config') -> Optional[OAuthCredentialsProvider]:
|
|
105
|
+
for attr in require:
|
|
106
|
+
if not getattr(cfg, attr):
|
|
107
|
+
return None
|
|
108
|
+
return func(cfg)
|
|
109
|
+
|
|
110
|
+
return OauthCredentialsStrategy(name, wrapper)
|
|
111
|
+
|
|
112
|
+
return inner
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@credentials_strategy('basic', ['host', 'username', 'password'])
|
|
116
|
+
def basic_auth(cfg: 'Config') -> CredentialsProvider:
|
|
65
117
|
""" Given username and password, add base64-encoded Basic credentials """
|
|
66
118
|
encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode()
|
|
67
119
|
static_credentials = {'Authorization': f'Basic {encoded}'}
|
|
@@ -72,8 +124,8 @@ def basic_auth(cfg: 'Config') -> HeaderFactory:
|
|
|
72
124
|
return inner
|
|
73
125
|
|
|
74
126
|
|
|
75
|
-
@
|
|
76
|
-
def pat_auth(cfg: 'Config') ->
|
|
127
|
+
@credentials_strategy('pat', ['host', 'token'])
|
|
128
|
+
def pat_auth(cfg: 'Config') -> CredentialsProvider:
|
|
77
129
|
""" Adds Databricks Personal Access Token to every request """
|
|
78
130
|
static_credentials = {'Authorization': f'Bearer {cfg.token}'}
|
|
79
131
|
|
|
@@ -83,8 +135,8 @@ def pat_auth(cfg: 'Config') -> HeaderFactory:
|
|
|
83
135
|
return inner
|
|
84
136
|
|
|
85
137
|
|
|
86
|
-
@
|
|
87
|
-
def runtime_native_auth(cfg: 'Config') -> Optional[
|
|
138
|
+
@credentials_strategy('runtime', [])
|
|
139
|
+
def runtime_native_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
88
140
|
if 'DATABRICKS_RUNTIME_VERSION' not in os.environ:
|
|
89
141
|
return None
|
|
90
142
|
|
|
@@ -107,8 +159,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
107
159
|
return None
|
|
108
160
|
|
|
109
161
|
|
|
110
|
-
@
|
|
111
|
-
def oauth_service_principal(cfg: 'Config') -> Optional[
|
|
162
|
+
@oauth_credentials_strategy('oauth-m2m', ['host', 'client_id', 'client_secret'])
|
|
163
|
+
def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
112
164
|
""" Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request,
|
|
113
165
|
if /oidc/.well-known/oauth-authorization-server is available on the given host. """
|
|
114
166
|
oidc = cfg.oidc_endpoints
|
|
@@ -124,11 +176,14 @@ def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
124
176
|
token = token_source.token()
|
|
125
177
|
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
126
178
|
|
|
127
|
-
|
|
179
|
+
def token() -> Token:
|
|
180
|
+
return token_source.token()
|
|
181
|
+
|
|
182
|
+
return OAuthCredentialsProvider(inner, token)
|
|
128
183
|
|
|
129
184
|
|
|
130
|
-
@
|
|
131
|
-
def external_browser(cfg: 'Config') -> Optional[
|
|
185
|
+
@credentials_strategy('external-browser', ['host', 'auth_type'])
|
|
186
|
+
def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
132
187
|
if cfg.auth_type != 'external-browser':
|
|
133
188
|
return None
|
|
134
189
|
if cfg.client_id:
|
|
@@ -178,9 +233,9 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
|
|
|
178
233
|
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"
|
|
179
234
|
|
|
180
235
|
|
|
181
|
-
@
|
|
182
|
-
|
|
183
|
-
def azure_service_principal(cfg: 'Config') ->
|
|
236
|
+
@oauth_credentials_strategy('azure-client-secret',
|
|
237
|
+
['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
|
|
238
|
+
def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
|
|
184
239
|
""" Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
|
|
185
240
|
to every request, while automatically resolving different Azure environment endpoints. """
|
|
186
241
|
|
|
@@ -203,11 +258,14 @@ def azure_service_principal(cfg: 'Config') -> HeaderFactory:
|
|
|
203
258
|
add_sp_management_token(cloud, headers)
|
|
204
259
|
return headers
|
|
205
260
|
|
|
206
|
-
|
|
261
|
+
def token() -> Token:
|
|
262
|
+
return inner.token()
|
|
263
|
+
|
|
264
|
+
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
207
265
|
|
|
208
266
|
|
|
209
|
-
@
|
|
210
|
-
def github_oidc_azure(cfg: 'Config') -> Optional[
|
|
267
|
+
@oauth_credentials_strategy('github-oidc-azure', ['host', 'azure_client_id'])
|
|
268
|
+
def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
211
269
|
if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ:
|
|
212
270
|
# not in GitHub actions
|
|
213
271
|
return None
|
|
@@ -250,14 +308,17 @@ def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
250
308
|
token = inner.token()
|
|
251
309
|
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
252
310
|
|
|
253
|
-
|
|
311
|
+
def token() -> Token:
|
|
312
|
+
return inner.token()
|
|
313
|
+
|
|
314
|
+
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
254
315
|
|
|
255
316
|
|
|
256
317
|
GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"]
|
|
257
318
|
|
|
258
319
|
|
|
259
|
-
@
|
|
260
|
-
def google_credentials(cfg: 'Config') -> Optional[
|
|
320
|
+
@oauth_credentials_strategy('google-credentials', ['host', 'google_credentials'])
|
|
321
|
+
def google_credentials(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
261
322
|
if not cfg.is_gcp:
|
|
262
323
|
return None
|
|
263
324
|
# Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string.
|
|
@@ -277,6 +338,10 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
277
338
|
gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info,
|
|
278
339
|
scopes=GcpScopes)
|
|
279
340
|
|
|
341
|
+
def token() -> Token:
|
|
342
|
+
credentials.refresh(request)
|
|
343
|
+
return credentials.token
|
|
344
|
+
|
|
280
345
|
def refreshed_headers() -> Dict[str, str]:
|
|
281
346
|
credentials.refresh(request)
|
|
282
347
|
headers = {'Authorization': f'Bearer {credentials.token}'}
|
|
@@ -285,11 +350,11 @@ def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
285
350
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
|
|
286
351
|
return headers
|
|
287
352
|
|
|
288
|
-
return refreshed_headers
|
|
353
|
+
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
289
354
|
|
|
290
355
|
|
|
291
|
-
@
|
|
292
|
-
def google_id(cfg: 'Config') -> Optional[
|
|
356
|
+
@oauth_credentials_strategy('google-id', ['host', 'google_service_account'])
|
|
357
|
+
def google_id(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
293
358
|
if not cfg.is_gcp:
|
|
294
359
|
return None
|
|
295
360
|
credentials, _project_id = google.auth.default()
|
|
@@ -309,6 +374,10 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
309
374
|
|
|
310
375
|
request = Request()
|
|
311
376
|
|
|
377
|
+
def token() -> Token:
|
|
378
|
+
id_creds.refresh(request)
|
|
379
|
+
return id_creds.token
|
|
380
|
+
|
|
312
381
|
def refreshed_headers() -> Dict[str, str]:
|
|
313
382
|
id_creds.refresh(request)
|
|
314
383
|
headers = {'Authorization': f'Bearer {id_creds.token}'}
|
|
@@ -317,7 +386,7 @@ def google_id(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
317
386
|
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
|
|
318
387
|
return headers
|
|
319
388
|
|
|
320
|
-
return refreshed_headers
|
|
389
|
+
return OAuthCredentialsProvider(refreshed_headers, token)
|
|
321
390
|
|
|
322
391
|
|
|
323
392
|
class CliTokenSource(Refreshable):
|
|
@@ -422,8 +491,8 @@ class AzureCliTokenSource(CliTokenSource):
|
|
|
422
491
|
return components[2]
|
|
423
492
|
|
|
424
493
|
|
|
425
|
-
@
|
|
426
|
-
def azure_cli(cfg: 'Config') -> Optional[
|
|
494
|
+
@credentials_strategy('azure-cli', ['is_azure'])
|
|
495
|
+
def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
427
496
|
""" Adds refreshed OAuth token granted by `az login` command to every request. """
|
|
428
497
|
token_source = None
|
|
429
498
|
mgmt_token_source = None
|
|
@@ -516,8 +585,8 @@ class DatabricksCliTokenSource(CliTokenSource):
|
|
|
516
585
|
raise err
|
|
517
586
|
|
|
518
587
|
|
|
519
|
-
@
|
|
520
|
-
def databricks_cli(cfg: 'Config') -> Optional[
|
|
588
|
+
@oauth_credentials_strategy('databricks-cli', ['host'])
|
|
589
|
+
def databricks_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
521
590
|
try:
|
|
522
591
|
token_source = DatabricksCliTokenSource(cfg)
|
|
523
592
|
except FileNotFoundError as e:
|
|
@@ -538,7 +607,7 @@ def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
|
|
|
538
607
|
token = token_source.token()
|
|
539
608
|
return {'Authorization': f'{token.token_type} {token.access_token}'}
|
|
540
609
|
|
|
541
|
-
return inner
|
|
610
|
+
return OAuthCredentialsProvider(inner, token_source.token)
|
|
542
611
|
|
|
543
612
|
|
|
544
613
|
class MetadataServiceTokenSource(Refreshable):
|
|
@@ -577,8 +646,8 @@ class MetadataServiceTokenSource(Refreshable):
|
|
|
577
646
|
return Token(access_token=access_token, token_type=token_type, expiry=expiry)
|
|
578
647
|
|
|
579
648
|
|
|
580
|
-
@
|
|
581
|
-
def metadata_service(cfg: 'Config') -> Optional[
|
|
649
|
+
@credentials_strategy('metadata-service', ['host', 'metadata_service_url'])
|
|
650
|
+
def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
|
|
582
651
|
""" Adds refreshed token granted by Databricks Metadata Service to every request. """
|
|
583
652
|
|
|
584
653
|
token_source = MetadataServiceTokenSource(cfg)
|
|
@@ -597,17 +666,25 @@ class DefaultCredentials:
|
|
|
597
666
|
|
|
598
667
|
def __init__(self) -> None:
|
|
599
668
|
self._auth_type = 'default'
|
|
600
|
-
|
|
601
|
-
def auth_type(self) -> str:
|
|
602
|
-
return self._auth_type
|
|
603
|
-
|
|
604
|
-
def __call__(self, cfg: 'Config') -> HeaderFactory:
|
|
605
|
-
auth_providers = [
|
|
669
|
+
self._auth_providers = [
|
|
606
670
|
pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
|
|
607
671
|
github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
|
|
608
672
|
google_credentials, google_id
|
|
609
673
|
]
|
|
610
|
-
|
|
674
|
+
|
|
675
|
+
def auth_type(self) -> str:
|
|
676
|
+
return self._auth_type
|
|
677
|
+
|
|
678
|
+
def oauth_token(self, cfg: 'Config') -> Token:
|
|
679
|
+
for provider in self._auth_providers:
|
|
680
|
+
auth_type = provider.auth_type()
|
|
681
|
+
if auth_type != self._auth_type:
|
|
682
|
+
# ignore other auth types if they don't match the selected one
|
|
683
|
+
continue
|
|
684
|
+
return provider.oauth_token(cfg)
|
|
685
|
+
|
|
686
|
+
def __call__(self, cfg: 'Config') -> CredentialsProvider:
|
|
687
|
+
for provider in self._auth_providers:
|
|
611
688
|
auth_type = provider.auth_type()
|
|
612
689
|
if cfg.auth_type and auth_type != cfg.auth_type:
|
|
613
690
|
# ignore other auth types if one is explicitly enforced
|
databricks/sdk/dbutils.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
|
-
import os
|
|
4
|
+
import os
|
|
5
5
|
import threading
|
|
6
6
|
from collections import namedtuple
|
|
7
|
-
from
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
9
|
|
|
9
10
|
from .core import ApiClient, Config, DatabricksError
|
|
10
11
|
from .mixins import compute as compute_ext
|
|
@@ -240,6 +241,76 @@ class RemoteDbUtils:
|
|
|
240
241
|
name=util)
|
|
241
242
|
|
|
242
243
|
|
|
244
|
+
@dataclass
|
|
245
|
+
class OverrideResult:
|
|
246
|
+
result: Any
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_local_notebook_path():
|
|
250
|
+
value = os.getenv("DATABRICKS_SOURCE_FILE")
|
|
251
|
+
if value is None:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
"Getting the current notebook path is only supported when running a notebook using the `Databricks Connect: Run as File` or `Databricks Connect: Debug as File` commands in the Databricks extension for VS Code. To bypass this error, set environment variable `DATABRICKS_SOURCE_FILE` to the desired notebook path."
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return value
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class _OverrideProxyUtil:
|
|
260
|
+
|
|
261
|
+
@classmethod
|
|
262
|
+
def new(cls, path: str):
|
|
263
|
+
if len(cls.__get_matching_overrides(path)) > 0:
|
|
264
|
+
return _OverrideProxyUtil(path)
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
def __init__(self, name: str):
|
|
268
|
+
self._name = name
|
|
269
|
+
|
|
270
|
+
# These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
|
|
271
|
+
# are sent to remote either. This could lead to unintentional breakage.
|
|
272
|
+
# Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
|
|
273
|
+
# This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
|
|
274
|
+
# are being proxied to remote dbutils currently.
|
|
275
|
+
proxy_override_paths = {
|
|
276
|
+
'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
|
|
277
|
+
get_local_notebook_path,
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def __get_matching_overrides(cls, path: str):
|
|
282
|
+
return [x for x in cls.proxy_override_paths.keys() if x.startswith(path)]
|
|
283
|
+
|
|
284
|
+
def __run_override(self, path: str) -> Optional[OverrideResult]:
|
|
285
|
+
overrides = self.__get_matching_overrides(path)
|
|
286
|
+
if len(overrides) == 1 and overrides[0] == path:
|
|
287
|
+
return OverrideResult(self.proxy_override_paths[overrides[0]]())
|
|
288
|
+
|
|
289
|
+
if len(overrides) > 0:
|
|
290
|
+
return OverrideResult(_OverrideProxyUtil(name=path))
|
|
291
|
+
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
def __call__(self, *args, **kwds) -> Any:
|
|
295
|
+
if len(args) != 0 or len(kwds) != 0:
|
|
296
|
+
raise TypeError(
|
|
297
|
+
f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()")
|
|
298
|
+
|
|
299
|
+
callable_path = f"{self._name}()"
|
|
300
|
+
result = self.__run_override(callable_path)
|
|
301
|
+
if result:
|
|
302
|
+
return result.result
|
|
303
|
+
|
|
304
|
+
raise TypeError(f"{self._name} is not callable")
|
|
305
|
+
|
|
306
|
+
def __getattr__(self, method: str) -> Any:
|
|
307
|
+
result = self.__run_override(f"{self._name}.{method}")
|
|
308
|
+
if result:
|
|
309
|
+
return result.result
|
|
310
|
+
|
|
311
|
+
raise AttributeError(f"module {self._name} has no attribute {method}")
|
|
312
|
+
|
|
313
|
+
|
|
243
314
|
class _ProxyUtil:
|
|
244
315
|
"""Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""
|
|
245
316
|
|
|
@@ -250,7 +321,14 @@ class _ProxyUtil:
|
|
|
250
321
|
self._context_factory = context_factory
|
|
251
322
|
self._name = name
|
|
252
323
|
|
|
253
|
-
def
|
|
324
|
+
def __call__(self):
|
|
325
|
+
raise NotImplementedError(f"dbutils.{self._name} is not callable")
|
|
326
|
+
|
|
327
|
+
def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
|
|
328
|
+
override = _OverrideProxyUtil.new(f"{self._name}.{method}")
|
|
329
|
+
if override:
|
|
330
|
+
return override
|
|
331
|
+
|
|
254
332
|
return _ProxyCall(command_execution=self._commands,
|
|
255
333
|
cluster_id=self._cluster_id,
|
|
256
334
|
context_factory=self._context_factory,
|