databricks-sdk 0.44.1__py3-none-any.whl → 0.46.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of databricks-sdk might be problematic. Click here for more details.
- databricks/sdk/__init__.py +135 -116
- databricks/sdk/_base_client.py +112 -88
- databricks/sdk/_property.py +12 -7
- databricks/sdk/_widgets/__init__.py +13 -2
- databricks/sdk/_widgets/default_widgets_utils.py +21 -15
- databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
- databricks/sdk/azure.py +8 -6
- databricks/sdk/casing.py +5 -5
- databricks/sdk/config.py +156 -99
- databricks/sdk/core.py +57 -47
- databricks/sdk/credentials_provider.py +306 -206
- databricks/sdk/data_plane.py +75 -50
- databricks/sdk/dbutils.py +123 -87
- databricks/sdk/environments.py +52 -35
- databricks/sdk/errors/base.py +61 -35
- databricks/sdk/errors/customizer.py +3 -3
- databricks/sdk/errors/deserializer.py +38 -25
- databricks/sdk/errors/details.py +417 -0
- databricks/sdk/errors/mapper.py +1 -1
- databricks/sdk/errors/overrides.py +27 -24
- databricks/sdk/errors/parser.py +26 -14
- databricks/sdk/errors/platform.py +10 -10
- databricks/sdk/errors/private_link.py +24 -24
- databricks/sdk/logger/round_trip_logger.py +28 -20
- databricks/sdk/mixins/compute.py +90 -60
- databricks/sdk/mixins/files.py +815 -145
- databricks/sdk/mixins/jobs.py +191 -16
- databricks/sdk/mixins/open_ai_client.py +26 -20
- databricks/sdk/mixins/workspace.py +45 -34
- databricks/sdk/oauth.py +379 -198
- databricks/sdk/retries.py +14 -12
- databricks/sdk/runtime/__init__.py +34 -17
- databricks/sdk/runtime/dbutils_stub.py +52 -39
- databricks/sdk/service/_internal.py +12 -7
- databricks/sdk/service/apps.py +618 -418
- databricks/sdk/service/billing.py +827 -604
- databricks/sdk/service/catalog.py +6552 -4474
- databricks/sdk/service/cleanrooms.py +550 -388
- databricks/sdk/service/compute.py +5263 -3536
- databricks/sdk/service/dashboards.py +1331 -924
- databricks/sdk/service/files.py +446 -309
- databricks/sdk/service/iam.py +2115 -1483
- databricks/sdk/service/jobs.py +4151 -2588
- databricks/sdk/service/marketplace.py +2210 -1517
- databricks/sdk/service/ml.py +3839 -2256
- databricks/sdk/service/oauth2.py +910 -584
- databricks/sdk/service/pipelines.py +1865 -1203
- databricks/sdk/service/provisioning.py +1435 -1029
- databricks/sdk/service/serving.py +2060 -1290
- databricks/sdk/service/settings.py +2846 -1929
- databricks/sdk/service/sharing.py +2201 -877
- databricks/sdk/service/sql.py +4650 -3103
- databricks/sdk/service/vectorsearch.py +816 -550
- databricks/sdk/service/workspace.py +1330 -906
- databricks/sdk/useragent.py +36 -22
- databricks/sdk/version.py +1 -1
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/METADATA +31 -31
- databricks_sdk-0.46.0.dist-info/RECORD +70 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/WHEEL +1 -1
- databricks_sdk-0.44.1.dist-info/RECORD +0 -69
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/LICENSE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/NOTICE +0 -0
- {databricks_sdk-0.44.1.dist-info → databricks_sdk-0.46.0.dist-info}/top_level.txt +0 -0
databricks/sdk/data_plane.py
CHANGED
|
@@ -1,65 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import threading
|
|
2
4
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Callable,
|
|
5
|
+
from typing import Callable, Optional
|
|
6
|
+
from urllib import parse
|
|
4
7
|
|
|
8
|
+
from databricks.sdk import oauth
|
|
5
9
|
from databricks.sdk.oauth import Token
|
|
6
10
|
|
|
11
|
+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
|
|
12
|
+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
13
|
+
OIDC_TOKEN_PATH = "/oidc/v1/token"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DataPlaneTokenSource:
|
|
17
|
+
"""
|
|
18
|
+
EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
|
|
22
|
+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True):
|
|
23
|
+
self._cpts = cpts
|
|
24
|
+
self._token_exchange_host = token_exchange_host
|
|
25
|
+
self._token_sources = {}
|
|
26
|
+
self._disable_async = disable_async
|
|
27
|
+
self._lock = threading.Lock()
|
|
28
|
+
|
|
29
|
+
def token(self, endpoint, auth_details):
|
|
30
|
+
key = f"{endpoint}:{auth_details}"
|
|
31
|
+
|
|
32
|
+
# First, try to read without acquiring the lock to avoid contention.
|
|
33
|
+
# Reads are atomic, so this is safe.
|
|
34
|
+
token_source = self._token_sources.get(key)
|
|
35
|
+
if token_source:
|
|
36
|
+
return token_source.token()
|
|
37
|
+
|
|
38
|
+
# If token_source is not found, acquire the lock and check again.
|
|
39
|
+
with self._lock:
|
|
40
|
+
# Another thread might have created it while we were waiting for the lock.
|
|
41
|
+
token_source = self._token_sources.get(key)
|
|
42
|
+
if not token_source:
|
|
43
|
+
token_source = DataPlaneEndpointTokenSource(
|
|
44
|
+
self._token_exchange_host, self._cpts, auth_details, self._disable_async
|
|
45
|
+
)
|
|
46
|
+
self._token_sources[key] = token_source
|
|
47
|
+
|
|
48
|
+
return token_source.token()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DataPlaneEndpointTokenSource(oauth.Refreshable):
|
|
52
|
+
"""
|
|
53
|
+
EXPERIMENTAL A token source for a specific DataPlane endpoint.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool):
|
|
57
|
+
super().__init__(disable_async=disable_async)
|
|
58
|
+
self._auth_details = auth_details
|
|
59
|
+
self._cpts = cpts
|
|
60
|
+
self._token_exchange_host = token_exchange_host
|
|
61
|
+
|
|
62
|
+
def refresh(self) -> Token:
|
|
63
|
+
control_plane_token = self._cpts()
|
|
64
|
+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
|
|
65
|
+
params = parse.urlencode(
|
|
66
|
+
{
|
|
67
|
+
"grant_type": JWT_BEARER_GRANT_TYPE,
|
|
68
|
+
"authorization_details": self._auth_details,
|
|
69
|
+
"assertion": control_plane_token.access_token,
|
|
70
|
+
}
|
|
71
|
+
)
|
|
72
|
+
return oauth.retrieve_token(
|
|
73
|
+
client_id="",
|
|
74
|
+
client_secret="",
|
|
75
|
+
token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
|
|
76
|
+
params=params,
|
|
77
|
+
headers=headers,
|
|
78
|
+
)
|
|
79
|
+
|
|
7
80
|
|
|
8
81
|
@dataclass
|
|
9
82
|
class DataPlaneDetails:
|
|
10
83
|
"""
|
|
11
84
|
Contains details required to query a DataPlane endpoint.
|
|
12
85
|
"""
|
|
86
|
+
|
|
13
87
|
endpoint_url: str
|
|
14
88
|
"""URL used to query the endpoint through the DataPlane."""
|
|
15
89
|
token: Token
|
|
16
90
|
"""Token to query the DataPlane endpoint."""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class DataPlaneService:
|
|
20
|
-
"""Helper class to fetch and manage DataPlane details."""
|
|
21
|
-
from .service.serving import DataPlaneInfo
|
|
22
|
-
|
|
23
|
-
def __init__(self):
|
|
24
|
-
self._data_plane_info = {}
|
|
25
|
-
self._tokens = {}
|
|
26
|
-
self._lock = threading.Lock()
|
|
27
|
-
|
|
28
|
-
def get_data_plane_details(self, method: str, params: List[str], info_getter: Callable[[], DataPlaneInfo],
|
|
29
|
-
refresh: Callable[[str], Token]):
|
|
30
|
-
"""Get and cache information required to query a Data Plane endpoint using the provided methods.
|
|
31
|
-
|
|
32
|
-
Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid.
|
|
33
|
-
If not, it uses the provided functions to fetch the details.
|
|
34
|
-
|
|
35
|
-
:param method: method name. Used to construct a unique key for the cache.
|
|
36
|
-
:param params: path params used in the "get" operation which uniquely determine the object. Used to construct a unique key for the cache.
|
|
37
|
-
:param info_getter: function which returns the DataPlaneInfo. It will only be called if the information is not already present in the cache.
|
|
38
|
-
:param refresh: function to refresh the token. It will only be called if the token is missing or expired.
|
|
39
|
-
"""
|
|
40
|
-
all_elements = params.copy()
|
|
41
|
-
all_elements.insert(0, method)
|
|
42
|
-
map_key = "/".join(all_elements)
|
|
43
|
-
info = self._data_plane_info.get(map_key)
|
|
44
|
-
if not info:
|
|
45
|
-
self._lock.acquire()
|
|
46
|
-
try:
|
|
47
|
-
info = self._data_plane_info.get(map_key)
|
|
48
|
-
if not info:
|
|
49
|
-
info = info_getter()
|
|
50
|
-
self._data_plane_info[map_key] = info
|
|
51
|
-
finally:
|
|
52
|
-
self._lock.release()
|
|
53
|
-
|
|
54
|
-
token = self._tokens.get(map_key)
|
|
55
|
-
if not token or not token.valid:
|
|
56
|
-
self._lock.acquire()
|
|
57
|
-
token = self._tokens.get(map_key)
|
|
58
|
-
try:
|
|
59
|
-
if not token or not token.valid:
|
|
60
|
-
token = refresh(info.authorization_details)
|
|
61
|
-
self._tokens[map_key] = token
|
|
62
|
-
finally:
|
|
63
|
-
self._lock.release()
|
|
64
|
-
|
|
65
|
-
return DataPlaneDetails(endpoint_url=info.endpoint_url, token=token)
|
databricks/sdk/dbutils.py
CHANGED
|
@@ -12,123 +12,136 @@ from .mixins import compute as compute_ext
|
|
|
12
12
|
from .mixins import files as dbfs_ext
|
|
13
13
|
from .service import compute, workspace
|
|
14
14
|
|
|
15
|
-
_LOG = logging.getLogger(
|
|
15
|
+
_LOG = logging.getLogger("databricks.sdk")
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class FileInfo(namedtuple(
|
|
18
|
+
class FileInfo(namedtuple("FileInfo", ["path", "name", "size", "modificationTime"])):
|
|
19
19
|
pass
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class MountInfo(namedtuple(
|
|
22
|
+
class MountInfo(namedtuple("MountInfo", ["mountPoint", "source", "encryptionType"])):
|
|
23
23
|
pass
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class SecretScope(namedtuple(
|
|
26
|
+
class SecretScope(namedtuple("SecretScope", ["name"])):
|
|
27
27
|
|
|
28
28
|
def getName(self):
|
|
29
29
|
return self.name
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
class SecretMetadata(namedtuple(
|
|
32
|
+
class SecretMetadata(namedtuple("SecretMetadata", ["key"])):
|
|
33
33
|
pass
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class _FsUtil:
|
|
37
|
-
"""
|
|
37
|
+
"""Manipulates the Databricks filesystem (DBFS)"""
|
|
38
38
|
|
|
39
|
-
def __init__(
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
dbfs_ext: dbfs_ext.DbfsExt,
|
|
42
|
+
proxy_factory: Callable[[str], "_ProxyUtil"],
|
|
43
|
+
):
|
|
40
44
|
self._dbfs = dbfs_ext
|
|
41
45
|
self._proxy_factory = proxy_factory
|
|
42
46
|
|
|
43
47
|
def cp(self, from_: str, to: str, recurse: bool = False) -> bool:
|
|
44
|
-
"""Copies a file or directory, possibly across FileSystems
|
|
48
|
+
"""Copies a file or directory, possibly across FileSystems"""
|
|
45
49
|
self._dbfs.copy(from_, to, recursive=recurse)
|
|
46
50
|
return True
|
|
47
51
|
|
|
48
52
|
def head(self, file: str, maxBytes: int = 65536) -> str:
|
|
49
|
-
"""Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8
|
|
53
|
+
"""Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8"""
|
|
50
54
|
with self._dbfs.download(file) as f:
|
|
51
|
-
return f.read(maxBytes).decode(
|
|
55
|
+
return f.read(maxBytes).decode("utf8")
|
|
52
56
|
|
|
53
57
|
def ls(self, dir: str) -> List[FileInfo]:
|
|
54
|
-
"""Lists the contents of a directory
|
|
58
|
+
"""Lists the contents of a directory"""
|
|
55
59
|
return [
|
|
56
|
-
FileInfo(
|
|
60
|
+
FileInfo(
|
|
61
|
+
f.path,
|
|
62
|
+
os.path.basename(f.path),
|
|
63
|
+
f.file_size,
|
|
64
|
+
f.modification_time,
|
|
65
|
+
)
|
|
57
66
|
for f in self._dbfs.list(dir)
|
|
58
67
|
]
|
|
59
68
|
|
|
60
69
|
def mkdirs(self, dir: str) -> bool:
|
|
61
|
-
"""Creates the given directory if it does not exist, also creating any necessary parent directories
|
|
70
|
+
"""Creates the given directory if it does not exist, also creating any necessary parent directories"""
|
|
62
71
|
self._dbfs.mkdirs(dir)
|
|
63
72
|
return True
|
|
64
73
|
|
|
65
74
|
def mv(self, from_: str, to: str, recurse: bool = False) -> bool:
|
|
66
|
-
"""Moves a file or directory, possibly across FileSystems
|
|
75
|
+
"""Moves a file or directory, possibly across FileSystems"""
|
|
67
76
|
self._dbfs.move_(from_, to, recursive=recurse, overwrite=True)
|
|
68
77
|
return True
|
|
69
78
|
|
|
70
79
|
def put(self, file: str, contents: str, overwrite: bool = False) -> bool:
|
|
71
|
-
"""Writes the given String out to a file, encoded in UTF-8
|
|
80
|
+
"""Writes the given String out to a file, encoded in UTF-8"""
|
|
72
81
|
with self._dbfs.open(file, write=True, overwrite=overwrite) as f:
|
|
73
|
-
f.write(contents.encode(
|
|
82
|
+
f.write(contents.encode("utf8"))
|
|
74
83
|
return True
|
|
75
84
|
|
|
76
85
|
def rm(self, dir: str, recurse: bool = False) -> bool:
|
|
77
|
-
"""Removes a file or directory
|
|
86
|
+
"""Removes a file or directory"""
|
|
78
87
|
self._dbfs.delete(dir, recursive=recurse)
|
|
79
88
|
return True
|
|
80
89
|
|
|
81
|
-
def mount(
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
90
|
+
def mount(
|
|
91
|
+
self,
|
|
92
|
+
source: str,
|
|
93
|
+
mount_point: str,
|
|
94
|
+
encryption_type: str = None,
|
|
95
|
+
owner: str = None,
|
|
96
|
+
extra_configs: Dict[str, str] = None,
|
|
97
|
+
) -> bool:
|
|
87
98
|
"""Mounts the given source directory into DBFS at the given mount point"""
|
|
88
|
-
fs = self._proxy_factory(
|
|
99
|
+
fs = self._proxy_factory("fs")
|
|
89
100
|
kwargs = {}
|
|
90
101
|
if encryption_type:
|
|
91
|
-
kwargs[
|
|
102
|
+
kwargs["encryption_type"] = encryption_type
|
|
92
103
|
if owner:
|
|
93
|
-
kwargs[
|
|
104
|
+
kwargs["owner"] = owner
|
|
94
105
|
if extra_configs:
|
|
95
|
-
kwargs[
|
|
106
|
+
kwargs["extra_configs"] = extra_configs
|
|
96
107
|
return fs.mount(source, mount_point, **kwargs)
|
|
97
108
|
|
|
98
109
|
def unmount(self, mount_point: str) -> bool:
|
|
99
110
|
"""Deletes a DBFS mount point"""
|
|
100
|
-
fs = self._proxy_factory(
|
|
111
|
+
fs = self._proxy_factory("fs")
|
|
101
112
|
return fs.unmount(mount_point)
|
|
102
113
|
|
|
103
|
-
def updateMount(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
114
|
+
def updateMount(
|
|
115
|
+
self,
|
|
116
|
+
source: str,
|
|
117
|
+
mount_point: str,
|
|
118
|
+
encryption_type: str = None,
|
|
119
|
+
owner: str = None,
|
|
120
|
+
extra_configs: Dict[str, str] = None,
|
|
121
|
+
) -> bool:
|
|
122
|
+
"""Similar to mount(), but updates an existing mount point (if present) instead of creating a new one"""
|
|
123
|
+
fs = self._proxy_factory("fs")
|
|
111
124
|
kwargs = {}
|
|
112
125
|
if encryption_type:
|
|
113
|
-
kwargs[
|
|
126
|
+
kwargs["encryption_type"] = encryption_type
|
|
114
127
|
if owner:
|
|
115
|
-
kwargs[
|
|
128
|
+
kwargs["owner"] = owner
|
|
116
129
|
if extra_configs:
|
|
117
|
-
kwargs[
|
|
130
|
+
kwargs["extra_configs"] = extra_configs
|
|
118
131
|
return fs.updateMount(source, mount_point, **kwargs)
|
|
119
132
|
|
|
120
133
|
def mounts(self) -> List[MountInfo]:
|
|
121
|
-
"""
|
|
134
|
+
"""Displays information about what is mounted within DBFS"""
|
|
122
135
|
result = []
|
|
123
|
-
fs = self._proxy_factory(
|
|
136
|
+
fs = self._proxy_factory("fs")
|
|
124
137
|
for info in fs.mounts():
|
|
125
138
|
result.append(MountInfo(info[0], info[1], info[2]))
|
|
126
139
|
return result
|
|
127
140
|
|
|
128
141
|
def refreshMounts(self) -> bool:
|
|
129
|
-
"""
|
|
130
|
-
ensuring they receive the most recent information
|
|
131
|
-
fs = self._proxy_factory(
|
|
142
|
+
"""Forces all machines in this cluster to refresh their mount cache,
|
|
143
|
+
ensuring they receive the most recent information"""
|
|
144
|
+
fs = self._proxy_factory("fs")
|
|
132
145
|
return fs.refreshMounts()
|
|
133
146
|
|
|
134
147
|
|
|
@@ -136,13 +149,13 @@ class _SecretsUtil:
|
|
|
136
149
|
"""Remote equivalent of secrets util"""
|
|
137
150
|
|
|
138
151
|
def __init__(self, secrets_api: workspace.SecretsAPI):
|
|
139
|
-
self._api = secrets_api
|
|
152
|
+
self._api = secrets_api # nolint
|
|
140
153
|
|
|
141
154
|
def getBytes(self, scope: str, key: str) -> bytes:
|
|
142
155
|
"""Gets the bytes representation of a secret value for the specified scope and key."""
|
|
143
|
-
query = {
|
|
144
|
-
raw = self._api._api.do(
|
|
145
|
-
return base64.b64decode(raw[
|
|
156
|
+
query = {"scope": scope, "key": key}
|
|
157
|
+
raw = self._api._api.do("GET", "/api/2.0/secrets/get", query=query)
|
|
158
|
+
return base64.b64decode(raw["value"])
|
|
146
159
|
|
|
147
160
|
def get(self, scope: str, key: str) -> str:
|
|
148
161
|
"""Gets the string representation of a secret value for the specified secrets scope and key."""
|
|
@@ -169,13 +182,19 @@ class _JobsUtil:
|
|
|
169
182
|
class _TaskValuesUtil:
|
|
170
183
|
"""Remote equivalent of task values util"""
|
|
171
184
|
|
|
172
|
-
def get(
|
|
185
|
+
def get(
|
|
186
|
+
self,
|
|
187
|
+
taskKey: str,
|
|
188
|
+
key: str,
|
|
189
|
+
default: any = None,
|
|
190
|
+
debugValue: any = None,
|
|
191
|
+
) -> None:
|
|
173
192
|
"""
|
|
174
193
|
Returns `debugValue` if present, throws an error otherwise as this implementation is always run outside of a job run
|
|
175
194
|
"""
|
|
176
195
|
if debugValue is None:
|
|
177
196
|
raise TypeError(
|
|
178
|
-
|
|
197
|
+
"Must pass debugValue when calling get outside of a job context. debugValue cannot be None."
|
|
179
198
|
)
|
|
180
199
|
return debugValue
|
|
181
200
|
|
|
@@ -190,7 +209,7 @@ class _JobsUtil:
|
|
|
190
209
|
|
|
191
210
|
class RemoteDbUtils:
|
|
192
211
|
|
|
193
|
-
def __init__(self, config:
|
|
212
|
+
def __init__(self, config: "Config" = None):
|
|
194
213
|
self._config = Config() if not config else config
|
|
195
214
|
self._client = ApiClient(self._config)
|
|
196
215
|
self._clusters = compute_ext.ClustersExt(self._client)
|
|
@@ -211,6 +230,7 @@ class RemoteDbUtils:
|
|
|
211
230
|
def widgets(self):
|
|
212
231
|
if self._widgets is None:
|
|
213
232
|
from ._widgets import widget_impl
|
|
233
|
+
|
|
214
234
|
self._widgets = widget_impl()
|
|
215
235
|
|
|
216
236
|
return self._widgets
|
|
@@ -219,7 +239,7 @@ class RemoteDbUtils:
|
|
|
219
239
|
def _cluster_id(self) -> str:
|
|
220
240
|
cluster_id = self._config.cluster_id
|
|
221
241
|
if not cluster_id:
|
|
222
|
-
message =
|
|
242
|
+
message = "cluster_id is required in the configuration"
|
|
223
243
|
raise ValueError(self._config.wrap_debug_info(message))
|
|
224
244
|
return cluster_id
|
|
225
245
|
|
|
@@ -230,15 +250,16 @@ class RemoteDbUtils:
|
|
|
230
250
|
if self._ctx:
|
|
231
251
|
return self._ctx
|
|
232
252
|
self._clusters.ensure_cluster_is_running(self._cluster_id)
|
|
233
|
-
self._ctx = self._commands.create(cluster_id=self._cluster_id,
|
|
234
|
-
language=compute.Language.PYTHON).result()
|
|
253
|
+
self._ctx = self._commands.create(cluster_id=self._cluster_id, language=compute.Language.PYTHON).result()
|
|
235
254
|
return self._ctx
|
|
236
255
|
|
|
237
|
-
def __getattr__(self, util) ->
|
|
238
|
-
return _ProxyUtil(
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
256
|
+
def __getattr__(self, util) -> "_ProxyUtil":
|
|
257
|
+
return _ProxyUtil(
|
|
258
|
+
command_execution=self._commands,
|
|
259
|
+
context_factory=self._running_command_context,
|
|
260
|
+
cluster_id=self._cluster_id,
|
|
261
|
+
name=util,
|
|
262
|
+
)
|
|
242
263
|
|
|
243
264
|
|
|
244
265
|
@dataclass
|
|
@@ -273,8 +294,7 @@ class _OverrideProxyUtil:
|
|
|
273
294
|
# This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
|
|
274
295
|
# are being proxied to remote dbutils currently.
|
|
275
296
|
proxy_override_paths = {
|
|
276
|
-
|
|
277
|
-
get_local_notebook_path,
|
|
297
|
+
"notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()": get_local_notebook_path,
|
|
278
298
|
}
|
|
279
299
|
|
|
280
300
|
@classmethod
|
|
@@ -294,7 +314,8 @@ class _OverrideProxyUtil:
|
|
|
294
314
|
def __call__(self, *args, **kwds) -> Any:
|
|
295
315
|
if len(args) != 0 or len(kwds) != 0:
|
|
296
316
|
raise TypeError(
|
|
297
|
-
f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()"
|
|
317
|
+
f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()"
|
|
318
|
+
)
|
|
298
319
|
|
|
299
320
|
callable_path = f"{self._name}()"
|
|
300
321
|
result = self.__run_override(callable_path)
|
|
@@ -314,8 +335,14 @@ class _OverrideProxyUtil:
|
|
|
314
335
|
class _ProxyUtil:
|
|
315
336
|
"""Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""
|
|
316
337
|
|
|
317
|
-
def __init__(
|
|
318
|
-
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
*,
|
|
341
|
+
command_execution: compute.CommandExecutionAPI,
|
|
342
|
+
context_factory: Callable[[], compute.ContextStatusResponse],
|
|
343
|
+
cluster_id: str,
|
|
344
|
+
name: str,
|
|
345
|
+
):
|
|
319
346
|
self._commands = command_execution
|
|
320
347
|
self._cluster_id = cluster_id
|
|
321
348
|
self._context_factory = context_factory
|
|
@@ -324,16 +351,18 @@ class _ProxyUtil:
|
|
|
324
351
|
def __call__(self):
|
|
325
352
|
raise NotImplementedError(f"dbutils.{self._name} is not callable")
|
|
326
353
|
|
|
327
|
-
def __getattr__(self, method: str) ->
|
|
354
|
+
def __getattr__(self, method: str) -> "_ProxyCall | _ProxyUtil | _OverrideProxyUtil":
|
|
328
355
|
override = _OverrideProxyUtil.new(f"{self._name}.{method}")
|
|
329
356
|
if override:
|
|
330
357
|
return override
|
|
331
358
|
|
|
332
|
-
return _ProxyCall(
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
359
|
+
return _ProxyCall(
|
|
360
|
+
command_execution=self._commands,
|
|
361
|
+
cluster_id=self._cluster_id,
|
|
362
|
+
context_factory=self._context_factory,
|
|
363
|
+
util=self._name,
|
|
364
|
+
method=method,
|
|
365
|
+
)
|
|
337
366
|
|
|
338
367
|
|
|
339
368
|
import html
|
|
@@ -342,29 +371,34 @@ import re
|
|
|
342
371
|
|
|
343
372
|
class _ProxyCall:
|
|
344
373
|
|
|
345
|
-
def __init__(
|
|
346
|
-
|
|
347
|
-
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
*,
|
|
377
|
+
command_execution: compute.CommandExecutionAPI,
|
|
378
|
+
context_factory: Callable[[], compute.ContextStatusResponse],
|
|
379
|
+
cluster_id: str,
|
|
380
|
+
util: str,
|
|
381
|
+
method: str,
|
|
382
|
+
):
|
|
348
383
|
self._commands = command_execution
|
|
349
384
|
self._cluster_id = cluster_id
|
|
350
385
|
self._context_factory = context_factory
|
|
351
386
|
self._util = util
|
|
352
387
|
self._method = method
|
|
353
388
|
|
|
354
|
-
_out_re = re.compile(r
|
|
355
|
-
_tag_re = re.compile(r
|
|
356
|
-
_exception_re = re.compile(r
|
|
357
|
-
_execution_error_re = re.compile(
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
_ascii_escape_re = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]')
|
|
389
|
+
_out_re = re.compile(r"Out\[[\d\s]+]:\s")
|
|
390
|
+
_tag_re = re.compile(r"<[^>]*>")
|
|
391
|
+
_exception_re = re.compile(r".*Exception:\s+(.*)")
|
|
392
|
+
_execution_error_re = re.compile(r"ExecutionError: ([\s\S]*)\n(StatusCode=[0-9]*)\n(StatusDescription=.*)\n")
|
|
393
|
+
_error_message_re = re.compile(r"ErrorMessage=(.+)\n")
|
|
394
|
+
_ascii_escape_re = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
|
|
361
395
|
|
|
362
396
|
def _is_failed(self, results: compute.Results) -> bool:
|
|
363
397
|
return results.result_type == compute.ResultType.ERROR
|
|
364
398
|
|
|
365
399
|
def _text(self, results: compute.Results) -> str:
|
|
366
400
|
if results.result_type != compute.ResultType.TEXT:
|
|
367
|
-
return
|
|
401
|
+
return ""
|
|
368
402
|
return self._out_re.sub("", str(results.data))
|
|
369
403
|
|
|
370
404
|
def _raise_if_failed(self, results: compute.Results):
|
|
@@ -399,17 +433,19 @@ class _ProxyCall:
|
|
|
399
433
|
|
|
400
434
|
def __call__(self, *args, **kwargs):
|
|
401
435
|
raw = json.dumps((args, kwargs))
|
|
402
|
-
code = f
|
|
436
|
+
code = f"""
|
|
403
437
|
import json
|
|
404
438
|
(args, kwargs) = json.loads('{raw}')
|
|
405
439
|
result = dbutils.{self._util}.{self._method}(*args, **kwargs)
|
|
406
440
|
dbutils.notebook.exit(json.dumps(result))
|
|
407
|
-
|
|
441
|
+
"""
|
|
408
442
|
ctx = self._context_factory()
|
|
409
|
-
result = self._commands.execute(
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
443
|
+
result = self._commands.execute(
|
|
444
|
+
cluster_id=self._cluster_id,
|
|
445
|
+
language=compute.Language.PYTHON,
|
|
446
|
+
context_id=ctx.id,
|
|
447
|
+
command=code,
|
|
448
|
+
).result()
|
|
413
449
|
if result.status == compute.CommandStatus.FINISHED:
|
|
414
450
|
self._raise_if_failed(result.results)
|
|
415
451
|
raw = result.results.data
|
databricks/sdk/environments.py
CHANGED
|
@@ -14,18 +14,24 @@ class AzureEnvironment:
|
|
|
14
14
|
ARM_DATABRICKS_RESOURCE_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
|
|
15
15
|
|
|
16
16
|
ENVIRONMENTS = dict(
|
|
17
|
-
PUBLIC=AzureEnvironment(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
17
|
+
PUBLIC=AzureEnvironment(
|
|
18
|
+
name="PUBLIC",
|
|
19
|
+
service_management_endpoint="https://management.core.windows.net/",
|
|
20
|
+
resource_manager_endpoint="https://management.azure.com/",
|
|
21
|
+
active_directory_endpoint="https://login.microsoftonline.com/",
|
|
22
|
+
),
|
|
23
|
+
USGOVERNMENT=AzureEnvironment(
|
|
24
|
+
name="USGOVERNMENT",
|
|
25
|
+
service_management_endpoint="https://management.core.usgovcloudapi.net/",
|
|
26
|
+
resource_manager_endpoint="https://management.usgovcloudapi.net/",
|
|
27
|
+
active_directory_endpoint="https://login.microsoftonline.us/",
|
|
28
|
+
),
|
|
29
|
+
CHINA=AzureEnvironment(
|
|
30
|
+
name="CHINA",
|
|
31
|
+
service_management_endpoint="https://management.core.chinacloudapi.cn/",
|
|
32
|
+
resource_manager_endpoint="https://management.chinacloudapi.cn/",
|
|
33
|
+
active_directory_endpoint="https://login.chinacloudapi.cn/",
|
|
34
|
+
),
|
|
29
35
|
)
|
|
30
36
|
|
|
31
37
|
|
|
@@ -69,34 +75,45 @@ DEFAULT_ENVIRONMENT = DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.com")
|
|
|
69
75
|
ALL_ENVS = [
|
|
70
76
|
DatabricksEnvironment(Cloud.AWS, ".dev.databricks.com"),
|
|
71
77
|
DatabricksEnvironment(Cloud.AWS, ".staging.cloud.databricks.com"),
|
|
72
|
-
DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"),
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
78
|
+
DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"),
|
|
79
|
+
DEFAULT_ENVIRONMENT,
|
|
80
|
+
DatabricksEnvironment(
|
|
81
|
+
Cloud.AZURE,
|
|
82
|
+
".dev.azuredatabricks.net",
|
|
83
|
+
azure_application_id="62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
|
|
84
|
+
azure_environment=ENVIRONMENTS["PUBLIC"],
|
|
85
|
+
),
|
|
86
|
+
DatabricksEnvironment(
|
|
87
|
+
Cloud.AZURE,
|
|
88
|
+
".staging.azuredatabricks.net",
|
|
89
|
+
azure_application_id="4a67d088-db5c-48f1-9ff2-0aace800ae68",
|
|
90
|
+
azure_environment=ENVIRONMENTS["PUBLIC"],
|
|
91
|
+
),
|
|
92
|
+
DatabricksEnvironment(
|
|
93
|
+
Cloud.AZURE,
|
|
94
|
+
".azuredatabricks.net",
|
|
95
|
+
azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
|
|
96
|
+
azure_environment=ENVIRONMENTS["PUBLIC"],
|
|
97
|
+
),
|
|
98
|
+
DatabricksEnvironment(
|
|
99
|
+
Cloud.AZURE,
|
|
100
|
+
".databricks.azure.us",
|
|
101
|
+
azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
|
|
102
|
+
azure_environment=ENVIRONMENTS["USGOVERNMENT"],
|
|
103
|
+
),
|
|
104
|
+
DatabricksEnvironment(
|
|
105
|
+
Cloud.AZURE,
|
|
106
|
+
".databricks.azure.cn",
|
|
107
|
+
azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
|
|
108
|
+
azure_environment=ENVIRONMENTS["CHINA"],
|
|
109
|
+
),
|
|
93
110
|
DatabricksEnvironment(Cloud.GCP, ".dev.gcp.databricks.com"),
|
|
94
111
|
DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"),
|
|
95
|
-
DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com")
|
|
112
|
+
DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com"),
|
|
96
113
|
]
|
|
97
114
|
|
|
98
115
|
|
|
99
|
-
def get_environment_for_hostname(hostname: str) -> DatabricksEnvironment:
|
|
116
|
+
def get_environment_for_hostname(hostname: Optional[str]) -> DatabricksEnvironment:
|
|
100
117
|
if not hostname:
|
|
101
118
|
return DEFAULT_ENVIRONMENT
|
|
102
119
|
for env in ALL_ENVS:
|