databricks-sdk 0.44.0__py3-none-any.whl → 0.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (63) hide show
  1. databricks/sdk/__init__.py +123 -115
  2. databricks/sdk/_base_client.py +112 -88
  3. databricks/sdk/_property.py +12 -7
  4. databricks/sdk/_widgets/__init__.py +13 -2
  5. databricks/sdk/_widgets/default_widgets_utils.py +21 -15
  6. databricks/sdk/_widgets/ipywidgets_utils.py +47 -24
  7. databricks/sdk/azure.py +8 -6
  8. databricks/sdk/casing.py +5 -5
  9. databricks/sdk/config.py +152 -99
  10. databricks/sdk/core.py +57 -47
  11. databricks/sdk/credentials_provider.py +360 -210
  12. databricks/sdk/data_plane.py +86 -3
  13. databricks/sdk/dbutils.py +123 -87
  14. databricks/sdk/environments.py +52 -35
  15. databricks/sdk/errors/base.py +61 -35
  16. databricks/sdk/errors/customizer.py +3 -3
  17. databricks/sdk/errors/deserializer.py +38 -25
  18. databricks/sdk/errors/details.py +417 -0
  19. databricks/sdk/errors/mapper.py +1 -1
  20. databricks/sdk/errors/overrides.py +27 -24
  21. databricks/sdk/errors/parser.py +26 -14
  22. databricks/sdk/errors/platform.py +10 -10
  23. databricks/sdk/errors/private_link.py +24 -24
  24. databricks/sdk/logger/round_trip_logger.py +28 -20
  25. databricks/sdk/mixins/compute.py +90 -60
  26. databricks/sdk/mixins/files.py +815 -145
  27. databricks/sdk/mixins/jobs.py +201 -20
  28. databricks/sdk/mixins/open_ai_client.py +26 -20
  29. databricks/sdk/mixins/workspace.py +45 -34
  30. databricks/sdk/oauth.py +372 -196
  31. databricks/sdk/retries.py +14 -12
  32. databricks/sdk/runtime/__init__.py +34 -17
  33. databricks/sdk/runtime/dbutils_stub.py +52 -39
  34. databricks/sdk/service/_internal.py +12 -7
  35. databricks/sdk/service/apps.py +618 -418
  36. databricks/sdk/service/billing.py +827 -604
  37. databricks/sdk/service/catalog.py +6552 -4474
  38. databricks/sdk/service/cleanrooms.py +550 -388
  39. databricks/sdk/service/compute.py +5241 -3531
  40. databricks/sdk/service/dashboards.py +1313 -923
  41. databricks/sdk/service/files.py +442 -309
  42. databricks/sdk/service/iam.py +2115 -1483
  43. databricks/sdk/service/jobs.py +4151 -2588
  44. databricks/sdk/service/marketplace.py +2210 -1517
  45. databricks/sdk/service/ml.py +3364 -2255
  46. databricks/sdk/service/oauth2.py +922 -584
  47. databricks/sdk/service/pipelines.py +1865 -1203
  48. databricks/sdk/service/provisioning.py +1435 -1029
  49. databricks/sdk/service/serving.py +2040 -1278
  50. databricks/sdk/service/settings.py +2846 -1929
  51. databricks/sdk/service/sharing.py +2201 -877
  52. databricks/sdk/service/sql.py +4650 -3103
  53. databricks/sdk/service/vectorsearch.py +816 -550
  54. databricks/sdk/service/workspace.py +1330 -906
  55. databricks/sdk/useragent.py +36 -22
  56. databricks/sdk/version.py +1 -1
  57. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/METADATA +31 -31
  58. databricks_sdk-0.45.0.dist-info/RECORD +70 -0
  59. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/WHEEL +1 -1
  60. databricks_sdk-0.44.0.dist-info/RECORD +0 -69
  61. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/LICENSE +0 -0
  62. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/NOTICE +0 -0
  63. {databricks_sdk-0.44.0.dist-info → databricks_sdk-0.45.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,101 @@
1
+ from __future__ import annotations
2
+
1
3
  import threading
2
4
  from dataclasses import dataclass
3
- from typing import Callable, List
5
+ from typing import Callable, List, Optional
6
+ from urllib import parse
4
7
 
8
+ from databricks.sdk import oauth
5
9
  from databricks.sdk.oauth import Token
6
10
 
11
+ URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
12
+ JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
13
+ OIDC_TOKEN_PATH = "/oidc/v1/token"
14
+
15
+
16
+ class DataPlaneTokenSource:
17
+ """
18
+ EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
19
+ """
20
+
21
+ # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22
+ def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True):
23
+ self._cpts = cpts
24
+ self._token_exchange_host = token_exchange_host
25
+ self._token_sources = {}
26
+ self._disable_async = disable_async
27
+ self._lock = threading.Lock()
28
+
29
+ def token(self, endpoint, auth_details):
30
+ key = f"{endpoint}:{auth_details}"
31
+
32
+ # First, try to read without acquiring the lock to avoid contention.
33
+ # Reads are atomic, so this is safe.
34
+ token_source = self._token_sources.get(key)
35
+ if token_source:
36
+ return token_source.token()
37
+
38
+ # If token_source is not found, acquire the lock and check again.
39
+ with self._lock:
40
+ # Another thread might have created it while we were waiting for the lock.
41
+ token_source = self._token_sources.get(key)
42
+ if not token_source:
43
+ token_source = DataPlaneEndpointTokenSource(
44
+ self._token_exchange_host, self._cpts, auth_details, self._disable_async
45
+ )
46
+ self._token_sources[key] = token_source
47
+
48
+ return token_source.token()
49
+
50
+
51
+ class DataPlaneEndpointTokenSource(oauth.Refreshable):
52
+ """
53
+ EXPERIMENTAL A token source for a specific DataPlane endpoint.
54
+ """
55
+
56
+ def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool):
57
+ super().__init__(disable_async=disable_async)
58
+ self._auth_details = auth_details
59
+ self._cpts = cpts
60
+ self._token_exchange_host = token_exchange_host
61
+
62
+ def refresh(self) -> Token:
63
+ control_plane_token = self._cpts()
64
+ headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
65
+ params = parse.urlencode(
66
+ {
67
+ "grant_type": JWT_BEARER_GRANT_TYPE,
68
+ "authorization_details": self._auth_details,
69
+ "assertion": control_plane_token.access_token,
70
+ }
71
+ )
72
+ return oauth.retrieve_token(
73
+ client_id="",
74
+ client_secret="",
75
+ token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
76
+ params=params,
77
+ headers=headers,
78
+ )
79
+
7
80
 
8
81
  @dataclass
9
82
  class DataPlaneDetails:
10
83
  """
11
84
  Contains details required to query a DataPlane endpoint.
12
85
  """
86
+
13
87
  endpoint_url: str
14
88
  """URL used to query the endpoint through the DataPlane."""
15
89
  token: Token
16
90
  """Token to query the DataPlane endpoint."""
17
91
 
18
92
 
93
+ ## Old implementation. #TODO: Remove after the new implementation is used
94
+
95
+
19
96
  class DataPlaneService:
20
97
  """Helper class to fetch and manage DataPlane details."""
98
+
21
99
  from .service.serving import DataPlaneInfo
22
100
 
23
101
  def __init__(self):
@@ -25,8 +103,13 @@ class DataPlaneService:
25
103
  self._tokens = {}
26
104
  self._lock = threading.Lock()
27
105
 
28
- def get_data_plane_details(self, method: str, params: List[str], info_getter: Callable[[], DataPlaneInfo],
29
- refresh: Callable[[str], Token]):
106
+ def get_data_plane_details(
107
+ self,
108
+ method: str,
109
+ params: List[str],
110
+ info_getter: Callable[[], DataPlaneInfo],
111
+ refresh: Callable[[str], Token],
112
+ ):
30
113
  """Get and cache information required to query a Data Plane endpoint using the provided methods.
31
114
 
32
115
  Returns a cached DataPlaneDetails if the details have already been fetched previously and are still valid.
databricks/sdk/dbutils.py CHANGED
@@ -12,123 +12,136 @@ from .mixins import compute as compute_ext
12
12
  from .mixins import files as dbfs_ext
13
13
  from .service import compute, workspace
14
14
 
15
- _LOG = logging.getLogger('databricks.sdk')
15
+ _LOG = logging.getLogger("databricks.sdk")
16
16
 
17
17
 
18
- class FileInfo(namedtuple('FileInfo', ['path', 'name', 'size', "modificationTime"])):
18
+ class FileInfo(namedtuple("FileInfo", ["path", "name", "size", "modificationTime"])):
19
19
  pass
20
20
 
21
21
 
22
- class MountInfo(namedtuple('MountInfo', ['mountPoint', 'source', 'encryptionType'])):
22
+ class MountInfo(namedtuple("MountInfo", ["mountPoint", "source", "encryptionType"])):
23
23
  pass
24
24
 
25
25
 
26
- class SecretScope(namedtuple('SecretScope', ['name'])):
26
+ class SecretScope(namedtuple("SecretScope", ["name"])):
27
27
 
28
28
  def getName(self):
29
29
  return self.name
30
30
 
31
31
 
32
- class SecretMetadata(namedtuple('SecretMetadata', ['key'])):
32
+ class SecretMetadata(namedtuple("SecretMetadata", ["key"])):
33
33
  pass
34
34
 
35
35
 
36
36
  class _FsUtil:
37
- """ Manipulates the Databricks filesystem (DBFS) """
37
+ """Manipulates the Databricks filesystem (DBFS)"""
38
38
 
39
- def __init__(self, dbfs_ext: dbfs_ext.DbfsExt, proxy_factory: Callable[[str], '_ProxyUtil']):
39
+ def __init__(
40
+ self,
41
+ dbfs_ext: dbfs_ext.DbfsExt,
42
+ proxy_factory: Callable[[str], "_ProxyUtil"],
43
+ ):
40
44
  self._dbfs = dbfs_ext
41
45
  self._proxy_factory = proxy_factory
42
46
 
43
47
  def cp(self, from_: str, to: str, recurse: bool = False) -> bool:
44
- """Copies a file or directory, possibly across FileSystems """
48
+ """Copies a file or directory, possibly across FileSystems"""
45
49
  self._dbfs.copy(from_, to, recursive=recurse)
46
50
  return True
47
51
 
48
52
  def head(self, file: str, maxBytes: int = 65536) -> str:
49
- """Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8 """
53
+ """Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8"""
50
54
  with self._dbfs.download(file) as f:
51
- return f.read(maxBytes).decode('utf8')
55
+ return f.read(maxBytes).decode("utf8")
52
56
 
53
57
  def ls(self, dir: str) -> List[FileInfo]:
54
- """Lists the contents of a directory """
58
+ """Lists the contents of a directory"""
55
59
  return [
56
- FileInfo(f.path, os.path.basename(f.path), f.file_size, f.modification_time)
60
+ FileInfo(
61
+ f.path,
62
+ os.path.basename(f.path),
63
+ f.file_size,
64
+ f.modification_time,
65
+ )
57
66
  for f in self._dbfs.list(dir)
58
67
  ]
59
68
 
60
69
  def mkdirs(self, dir: str) -> bool:
61
- """Creates the given directory if it does not exist, also creating any necessary parent directories """
70
+ """Creates the given directory if it does not exist, also creating any necessary parent directories"""
62
71
  self._dbfs.mkdirs(dir)
63
72
  return True
64
73
 
65
74
  def mv(self, from_: str, to: str, recurse: bool = False) -> bool:
66
- """Moves a file or directory, possibly across FileSystems """
75
+ """Moves a file or directory, possibly across FileSystems"""
67
76
  self._dbfs.move_(from_, to, recursive=recurse, overwrite=True)
68
77
  return True
69
78
 
70
79
  def put(self, file: str, contents: str, overwrite: bool = False) -> bool:
71
- """Writes the given String out to a file, encoded in UTF-8 """
80
+ """Writes the given String out to a file, encoded in UTF-8"""
72
81
  with self._dbfs.open(file, write=True, overwrite=overwrite) as f:
73
- f.write(contents.encode('utf8'))
82
+ f.write(contents.encode("utf8"))
74
83
  return True
75
84
 
76
85
  def rm(self, dir: str, recurse: bool = False) -> bool:
77
- """Removes a file or directory """
86
+ """Removes a file or directory"""
78
87
  self._dbfs.delete(dir, recursive=recurse)
79
88
  return True
80
89
 
81
- def mount(self,
82
- source: str,
83
- mount_point: str,
84
- encryption_type: str = None,
85
- owner: str = None,
86
- extra_configs: Dict[str, str] = None) -> bool:
90
+ def mount(
91
+ self,
92
+ source: str,
93
+ mount_point: str,
94
+ encryption_type: str = None,
95
+ owner: str = None,
96
+ extra_configs: Dict[str, str] = None,
97
+ ) -> bool:
87
98
  """Mounts the given source directory into DBFS at the given mount point"""
88
- fs = self._proxy_factory('fs')
99
+ fs = self._proxy_factory("fs")
89
100
  kwargs = {}
90
101
  if encryption_type:
91
- kwargs['encryption_type'] = encryption_type
102
+ kwargs["encryption_type"] = encryption_type
92
103
  if owner:
93
- kwargs['owner'] = owner
104
+ kwargs["owner"] = owner
94
105
  if extra_configs:
95
- kwargs['extra_configs'] = extra_configs
106
+ kwargs["extra_configs"] = extra_configs
96
107
  return fs.mount(source, mount_point, **kwargs)
97
108
 
98
109
  def unmount(self, mount_point: str) -> bool:
99
110
  """Deletes a DBFS mount point"""
100
- fs = self._proxy_factory('fs')
111
+ fs = self._proxy_factory("fs")
101
112
  return fs.unmount(mount_point)
102
113
 
103
- def updateMount(self,
104
- source: str,
105
- mount_point: str,
106
- encryption_type: str = None,
107
- owner: str = None,
108
- extra_configs: Dict[str, str] = None) -> bool:
109
- """ Similar to mount(), but updates an existing mount point (if present) instead of creating a new one """
110
- fs = self._proxy_factory('fs')
114
+ def updateMount(
115
+ self,
116
+ source: str,
117
+ mount_point: str,
118
+ encryption_type: str = None,
119
+ owner: str = None,
120
+ extra_configs: Dict[str, str] = None,
121
+ ) -> bool:
122
+ """Similar to mount(), but updates an existing mount point (if present) instead of creating a new one"""
123
+ fs = self._proxy_factory("fs")
111
124
  kwargs = {}
112
125
  if encryption_type:
113
- kwargs['encryption_type'] = encryption_type
126
+ kwargs["encryption_type"] = encryption_type
114
127
  if owner:
115
- kwargs['owner'] = owner
128
+ kwargs["owner"] = owner
116
129
  if extra_configs:
117
- kwargs['extra_configs'] = extra_configs
130
+ kwargs["extra_configs"] = extra_configs
118
131
  return fs.updateMount(source, mount_point, **kwargs)
119
132
 
120
133
  def mounts(self) -> List[MountInfo]:
121
- """ Displays information about what is mounted within DBFS """
134
+ """Displays information about what is mounted within DBFS"""
122
135
  result = []
123
- fs = self._proxy_factory('fs')
136
+ fs = self._proxy_factory("fs")
124
137
  for info in fs.mounts():
125
138
  result.append(MountInfo(info[0], info[1], info[2]))
126
139
  return result
127
140
 
128
141
  def refreshMounts(self) -> bool:
129
- """ Forces all machines in this cluster to refresh their mount cache,
130
- ensuring they receive the most recent information """
131
- fs = self._proxy_factory('fs')
142
+ """Forces all machines in this cluster to refresh their mount cache,
143
+ ensuring they receive the most recent information"""
144
+ fs = self._proxy_factory("fs")
132
145
  return fs.refreshMounts()
133
146
 
134
147
 
@@ -136,13 +149,13 @@ class _SecretsUtil:
136
149
  """Remote equivalent of secrets util"""
137
150
 
138
151
  def __init__(self, secrets_api: workspace.SecretsAPI):
139
- self._api = secrets_api # nolint
152
+ self._api = secrets_api # nolint
140
153
 
141
154
  def getBytes(self, scope: str, key: str) -> bytes:
142
155
  """Gets the bytes representation of a secret value for the specified scope and key."""
143
- query = {'scope': scope, 'key': key}
144
- raw = self._api._api.do('GET', '/api/2.0/secrets/get', query=query)
145
- return base64.b64decode(raw['value'])
156
+ query = {"scope": scope, "key": key}
157
+ raw = self._api._api.do("GET", "/api/2.0/secrets/get", query=query)
158
+ return base64.b64decode(raw["value"])
146
159
 
147
160
  def get(self, scope: str, key: str) -> str:
148
161
  """Gets the string representation of a secret value for the specified secrets scope and key."""
@@ -169,13 +182,19 @@ class _JobsUtil:
169
182
  class _TaskValuesUtil:
170
183
  """Remote equivalent of task values util"""
171
184
 
172
- def get(self, taskKey: str, key: str, default: any = None, debugValue: any = None) -> None:
185
+ def get(
186
+ self,
187
+ taskKey: str,
188
+ key: str,
189
+ default: any = None,
190
+ debugValue: any = None,
191
+ ) -> None:
173
192
  """
174
193
  Returns `debugValue` if present, throws an error otherwise as this implementation is always run outside of a job run
175
194
  """
176
195
  if debugValue is None:
177
196
  raise TypeError(
178
- 'Must pass debugValue when calling get outside of a job context. debugValue cannot be None.'
197
+ "Must pass debugValue when calling get outside of a job context. debugValue cannot be None."
179
198
  )
180
199
  return debugValue
181
200
 
@@ -190,7 +209,7 @@ class _JobsUtil:
190
209
 
191
210
  class RemoteDbUtils:
192
211
 
193
- def __init__(self, config: 'Config' = None):
212
+ def __init__(self, config: "Config" = None):
194
213
  self._config = Config() if not config else config
195
214
  self._client = ApiClient(self._config)
196
215
  self._clusters = compute_ext.ClustersExt(self._client)
@@ -211,6 +230,7 @@ class RemoteDbUtils:
211
230
  def widgets(self):
212
231
  if self._widgets is None:
213
232
  from ._widgets import widget_impl
233
+
214
234
  self._widgets = widget_impl()
215
235
 
216
236
  return self._widgets
@@ -219,7 +239,7 @@ class RemoteDbUtils:
219
239
  def _cluster_id(self) -> str:
220
240
  cluster_id = self._config.cluster_id
221
241
  if not cluster_id:
222
- message = 'cluster_id is required in the configuration'
242
+ message = "cluster_id is required in the configuration"
223
243
  raise ValueError(self._config.wrap_debug_info(message))
224
244
  return cluster_id
225
245
 
@@ -230,15 +250,16 @@ class RemoteDbUtils:
230
250
  if self._ctx:
231
251
  return self._ctx
232
252
  self._clusters.ensure_cluster_is_running(self._cluster_id)
233
- self._ctx = self._commands.create(cluster_id=self._cluster_id,
234
- language=compute.Language.PYTHON).result()
253
+ self._ctx = self._commands.create(cluster_id=self._cluster_id, language=compute.Language.PYTHON).result()
235
254
  return self._ctx
236
255
 
237
- def __getattr__(self, util) -> '_ProxyUtil':
238
- return _ProxyUtil(command_execution=self._commands,
239
- context_factory=self._running_command_context,
240
- cluster_id=self._cluster_id,
241
- name=util)
256
+ def __getattr__(self, util) -> "_ProxyUtil":
257
+ return _ProxyUtil(
258
+ command_execution=self._commands,
259
+ context_factory=self._running_command_context,
260
+ cluster_id=self._cluster_id,
261
+ name=util,
262
+ )
242
263
 
243
264
 
244
265
  @dataclass
@@ -273,8 +294,7 @@ class _OverrideProxyUtil:
273
294
  # This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
274
295
  # are being proxied to remote dbutils currently.
275
296
  proxy_override_paths = {
276
- 'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
277
- get_local_notebook_path,
297
+ "notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()": get_local_notebook_path,
278
298
  }
279
299
 
280
300
  @classmethod
@@ -294,7 +314,8 @@ class _OverrideProxyUtil:
294
314
  def __call__(self, *args, **kwds) -> Any:
295
315
  if len(args) != 0 or len(kwds) != 0:
296
316
  raise TypeError(
297
- f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()")
317
+ f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()"
318
+ )
298
319
 
299
320
  callable_path = f"{self._name}()"
300
321
  result = self.__run_override(callable_path)
@@ -314,8 +335,14 @@ class _OverrideProxyUtil:
314
335
  class _ProxyUtil:
315
336
  """Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""
316
337
 
317
- def __init__(self, *, command_execution: compute.CommandExecutionAPI,
318
- context_factory: Callable[[], compute.ContextStatusResponse], cluster_id: str, name: str):
338
+ def __init__(
339
+ self,
340
+ *,
341
+ command_execution: compute.CommandExecutionAPI,
342
+ context_factory: Callable[[], compute.ContextStatusResponse],
343
+ cluster_id: str,
344
+ name: str,
345
+ ):
319
346
  self._commands = command_execution
320
347
  self._cluster_id = cluster_id
321
348
  self._context_factory = context_factory
@@ -324,16 +351,18 @@ class _ProxyUtil:
324
351
  def __call__(self):
325
352
  raise NotImplementedError(f"dbutils.{self._name} is not callable")
326
353
 
327
- def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
354
+ def __getattr__(self, method: str) -> "_ProxyCall | _ProxyUtil | _OverrideProxyUtil":
328
355
  override = _OverrideProxyUtil.new(f"{self._name}.{method}")
329
356
  if override:
330
357
  return override
331
358
 
332
- return _ProxyCall(command_execution=self._commands,
333
- cluster_id=self._cluster_id,
334
- context_factory=self._context_factory,
335
- util=self._name,
336
- method=method)
359
+ return _ProxyCall(
360
+ command_execution=self._commands,
361
+ cluster_id=self._cluster_id,
362
+ context_factory=self._context_factory,
363
+ util=self._name,
364
+ method=method,
365
+ )
337
366
 
338
367
 
339
368
  import html
@@ -342,29 +371,34 @@ import re
342
371
 
343
372
  class _ProxyCall:
344
373
 
345
- def __init__(self, *, command_execution: compute.CommandExecutionAPI,
346
- context_factory: Callable[[], compute.ContextStatusResponse], cluster_id: str, util: str,
347
- method: str):
374
+ def __init__(
375
+ self,
376
+ *,
377
+ command_execution: compute.CommandExecutionAPI,
378
+ context_factory: Callable[[], compute.ContextStatusResponse],
379
+ cluster_id: str,
380
+ util: str,
381
+ method: str,
382
+ ):
348
383
  self._commands = command_execution
349
384
  self._cluster_id = cluster_id
350
385
  self._context_factory = context_factory
351
386
  self._util = util
352
387
  self._method = method
353
388
 
354
- _out_re = re.compile(r'Out\[[\d\s]+]:\s')
355
- _tag_re = re.compile(r'<[^>]*>')
356
- _exception_re = re.compile(r'.*Exception:\s+(.*)')
357
- _execution_error_re = re.compile(
358
- r'ExecutionError: ([\s\S]*)\n(StatusCode=[0-9]*)\n(StatusDescription=.*)\n')
359
- _error_message_re = re.compile(r'ErrorMessage=(.+)\n')
360
- _ascii_escape_re = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]')
389
+ _out_re = re.compile(r"Out\[[\d\s]+]:\s")
390
+ _tag_re = re.compile(r"<[^>]*>")
391
+ _exception_re = re.compile(r".*Exception:\s+(.*)")
392
+ _execution_error_re = re.compile(r"ExecutionError: ([\s\S]*)\n(StatusCode=[0-9]*)\n(StatusDescription=.*)\n")
393
+ _error_message_re = re.compile(r"ErrorMessage=(.+)\n")
394
+ _ascii_escape_re = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
361
395
 
362
396
  def _is_failed(self, results: compute.Results) -> bool:
363
397
  return results.result_type == compute.ResultType.ERROR
364
398
 
365
399
  def _text(self, results: compute.Results) -> str:
366
400
  if results.result_type != compute.ResultType.TEXT:
367
- return ''
401
+ return ""
368
402
  return self._out_re.sub("", str(results.data))
369
403
 
370
404
  def _raise_if_failed(self, results: compute.Results):
@@ -399,17 +433,19 @@ class _ProxyCall:
399
433
 
400
434
  def __call__(self, *args, **kwargs):
401
435
  raw = json.dumps((args, kwargs))
402
- code = f'''
436
+ code = f"""
403
437
  import json
404
438
  (args, kwargs) = json.loads('{raw}')
405
439
  result = dbutils.{self._util}.{self._method}(*args, **kwargs)
406
440
  dbutils.notebook.exit(json.dumps(result))
407
- '''
441
+ """
408
442
  ctx = self._context_factory()
409
- result = self._commands.execute(cluster_id=self._cluster_id,
410
- language=compute.Language.PYTHON,
411
- context_id=ctx.id,
412
- command=code).result()
443
+ result = self._commands.execute(
444
+ cluster_id=self._cluster_id,
445
+ language=compute.Language.PYTHON,
446
+ context_id=ctx.id,
447
+ command=code,
448
+ ).result()
413
449
  if result.status == compute.CommandStatus.FINISHED:
414
450
  self._raise_if_failed(result.results)
415
451
  raw = result.results.data
@@ -14,18 +14,24 @@ class AzureEnvironment:
14
14
  ARM_DATABRICKS_RESOURCE_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
15
15
 
16
16
  ENVIRONMENTS = dict(
17
- PUBLIC=AzureEnvironment(name="PUBLIC",
18
- service_management_endpoint="https://management.core.windows.net/",
19
- resource_manager_endpoint="https://management.azure.com/",
20
- active_directory_endpoint="https://login.microsoftonline.com/"),
21
- USGOVERNMENT=AzureEnvironment(name="USGOVERNMENT",
22
- service_management_endpoint="https://management.core.usgovcloudapi.net/",
23
- resource_manager_endpoint="https://management.usgovcloudapi.net/",
24
- active_directory_endpoint="https://login.microsoftonline.us/"),
25
- CHINA=AzureEnvironment(name="CHINA",
26
- service_management_endpoint="https://management.core.chinacloudapi.cn/",
27
- resource_manager_endpoint="https://management.chinacloudapi.cn/",
28
- active_directory_endpoint="https://login.chinacloudapi.cn/"),
17
+ PUBLIC=AzureEnvironment(
18
+ name="PUBLIC",
19
+ service_management_endpoint="https://management.core.windows.net/",
20
+ resource_manager_endpoint="https://management.azure.com/",
21
+ active_directory_endpoint="https://login.microsoftonline.com/",
22
+ ),
23
+ USGOVERNMENT=AzureEnvironment(
24
+ name="USGOVERNMENT",
25
+ service_management_endpoint="https://management.core.usgovcloudapi.net/",
26
+ resource_manager_endpoint="https://management.usgovcloudapi.net/",
27
+ active_directory_endpoint="https://login.microsoftonline.us/",
28
+ ),
29
+ CHINA=AzureEnvironment(
30
+ name="CHINA",
31
+ service_management_endpoint="https://management.core.chinacloudapi.cn/",
32
+ resource_manager_endpoint="https://management.chinacloudapi.cn/",
33
+ active_directory_endpoint="https://login.chinacloudapi.cn/",
34
+ ),
29
35
  )
30
36
 
31
37
 
@@ -69,34 +75,45 @@ DEFAULT_ENVIRONMENT = DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.com")
69
75
  ALL_ENVS = [
70
76
  DatabricksEnvironment(Cloud.AWS, ".dev.databricks.com"),
71
77
  DatabricksEnvironment(Cloud.AWS, ".staging.cloud.databricks.com"),
72
- DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"), DEFAULT_ENVIRONMENT,
73
- DatabricksEnvironment(Cloud.AZURE,
74
- ".dev.azuredatabricks.net",
75
- azure_application_id="62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
76
- azure_environment=ENVIRONMENTS["PUBLIC"]),
77
- DatabricksEnvironment(Cloud.AZURE,
78
- ".staging.azuredatabricks.net",
79
- azure_application_id="4a67d088-db5c-48f1-9ff2-0aace800ae68",
80
- azure_environment=ENVIRONMENTS["PUBLIC"]),
81
- DatabricksEnvironment(Cloud.AZURE,
82
- ".azuredatabricks.net",
83
- azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
84
- azure_environment=ENVIRONMENTS["PUBLIC"]),
85
- DatabricksEnvironment(Cloud.AZURE,
86
- ".databricks.azure.us",
87
- azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
88
- azure_environment=ENVIRONMENTS["USGOVERNMENT"]),
89
- DatabricksEnvironment(Cloud.AZURE,
90
- ".databricks.azure.cn",
91
- azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
92
- azure_environment=ENVIRONMENTS["CHINA"]),
78
+ DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"),
79
+ DEFAULT_ENVIRONMENT,
80
+ DatabricksEnvironment(
81
+ Cloud.AZURE,
82
+ ".dev.azuredatabricks.net",
83
+ azure_application_id="62a912ac-b58e-4c1d-89ea-b2dbfc7358fc",
84
+ azure_environment=ENVIRONMENTS["PUBLIC"],
85
+ ),
86
+ DatabricksEnvironment(
87
+ Cloud.AZURE,
88
+ ".staging.azuredatabricks.net",
89
+ azure_application_id="4a67d088-db5c-48f1-9ff2-0aace800ae68",
90
+ azure_environment=ENVIRONMENTS["PUBLIC"],
91
+ ),
92
+ DatabricksEnvironment(
93
+ Cloud.AZURE,
94
+ ".azuredatabricks.net",
95
+ azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
96
+ azure_environment=ENVIRONMENTS["PUBLIC"],
97
+ ),
98
+ DatabricksEnvironment(
99
+ Cloud.AZURE,
100
+ ".databricks.azure.us",
101
+ azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
102
+ azure_environment=ENVIRONMENTS["USGOVERNMENT"],
103
+ ),
104
+ DatabricksEnvironment(
105
+ Cloud.AZURE,
106
+ ".databricks.azure.cn",
107
+ azure_application_id=ARM_DATABRICKS_RESOURCE_ID,
108
+ azure_environment=ENVIRONMENTS["CHINA"],
109
+ ),
93
110
  DatabricksEnvironment(Cloud.GCP, ".dev.gcp.databricks.com"),
94
111
  DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"),
95
- DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com")
112
+ DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com"),
96
113
  ]
97
114
 
98
115
 
99
- def get_environment_for_hostname(hostname: str) -> DatabricksEnvironment:
116
+ def get_environment_for_hostname(hostname: Optional[str]) -> DatabricksEnvironment:
100
117
  if not hostname:
101
118
  return DEFAULT_ENVIRONMENT
102
119
  for env in ALL_ENVS: