databricks-sdk 0.32.1__tar.gz → 0.32.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (92) hide show
  1. {databricks_sdk-0.32.1/databricks_sdk.egg-info → databricks_sdk-0.32.3}/PKG-INFO +1 -1
  2. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/__init__.py +48 -46
  3. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/config.py +3 -3
  4. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/core.py +17 -30
  5. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/credentials_provider.py +130 -11
  6. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/__init__.py +2 -2
  7. databricks_sdk-0.32.3/databricks/sdk/errors/customizer.py +50 -0
  8. databricks_sdk-0.32.3/databricks/sdk/errors/deserializer.py +106 -0
  9. databricks_sdk-0.32.3/databricks/sdk/errors/parser.py +83 -0
  10. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/logger/round_trip_logger.py +2 -1
  11. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/mixins/files.py +9 -9
  12. databricks_sdk-0.32.3/databricks/sdk/version.py +1 -0
  13. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3/databricks_sdk.egg-info}/PKG-INFO +1 -1
  14. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks_sdk.egg-info/SOURCES.txt +3 -0
  15. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_auth_manual_tests.py +12 -0
  16. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_core.py +69 -21
  17. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_dbfs_mixins.py +8 -5
  18. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_errors.py +23 -11
  19. databricks_sdk-0.32.3/tests/test_model_serving_auth.py +98 -0
  20. databricks_sdk-0.32.1/databricks/sdk/errors/parser.py +0 -147
  21. databricks_sdk-0.32.1/databricks/sdk/version.py +0 -1
  22. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/LICENSE +0 -0
  23. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/NOTICE +0 -0
  24. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/README.md +0 -0
  25. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/__init__.py +0 -0
  26. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/_property.py +0 -0
  27. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/_widgets/__init__.py +0 -0
  28. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/_widgets/default_widgets_utils.py +0 -0
  29. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/_widgets/ipywidgets_utils.py +0 -0
  30. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/azure.py +0 -0
  31. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/casing.py +0 -0
  32. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/clock.py +0 -0
  33. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/data_plane.py +0 -0
  34. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/dbutils.py +0 -0
  35. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/environments.py +0 -0
  36. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/base.py +0 -0
  37. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/mapper.py +0 -0
  38. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/overrides.py +0 -0
  39. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/platform.py +0 -0
  40. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/private_link.py +0 -0
  41. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/errors/sdk.py +0 -0
  42. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/logger/__init__.py +0 -0
  43. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/mixins/__init__.py +0 -0
  44. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/mixins/compute.py +0 -0
  45. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/mixins/workspace.py +0 -0
  46. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/oauth.py +0 -0
  47. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/py.typed +0 -0
  48. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/retries.py +0 -0
  49. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/runtime/__init__.py +0 -0
  50. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/runtime/dbutils_stub.py +0 -0
  51. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/__init__.py +0 -0
  52. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/_internal.py +0 -0
  53. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/apps.py +0 -0
  54. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/billing.py +0 -0
  55. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/catalog.py +0 -0
  56. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/compute.py +0 -0
  57. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/dashboards.py +0 -0
  58. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/files.py +0 -0
  59. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/iam.py +0 -0
  60. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/jobs.py +0 -0
  61. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/marketplace.py +0 -0
  62. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/ml.py +0 -0
  63. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/oauth2.py +0 -0
  64. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/pipelines.py +0 -0
  65. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/provisioning.py +0 -0
  66. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/serving.py +0 -0
  67. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/settings.py +0 -0
  68. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/sharing.py +0 -0
  69. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/sql.py +0 -0
  70. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/vectorsearch.py +0 -0
  71. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/service/workspace.py +0 -0
  72. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks/sdk/useragent.py +0 -0
  73. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks_sdk.egg-info/dependency_links.txt +0 -0
  74. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks_sdk.egg-info/requires.txt +0 -0
  75. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/databricks_sdk.egg-info/top_level.txt +0 -0
  76. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/setup.cfg +0 -0
  77. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/setup.py +0 -0
  78. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_auth.py +0 -0
  79. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_client.py +0 -0
  80. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_compute_mixins.py +0 -0
  81. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_config.py +0 -0
  82. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_data_plane.py +0 -0
  83. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_dbutils.py +0 -0
  84. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_environments.py +0 -0
  85. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_init_file.py +0 -0
  86. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_internal.py +0 -0
  87. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_jobs.py +0 -0
  88. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_metadata_service_auth.py +0 -0
  89. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_misc.py +0 -0
  90. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_oauth.py +0 -0
  91. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_retries.py +0 -0
  92. {databricks_sdk-0.32.1 → databricks_sdk-0.32.3}/tests/test_user_agent.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: databricks-sdk
3
- Version: 0.32.1
3
+ Version: 0.32.3
4
4
  Summary: Databricks SDK for Python (Beta)
5
5
  Home-page: https://databricks-sdk-py.readthedocs.io
6
6
  Author: Serge Smertin
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import databricks.sdk.core as client
2
4
  import databricks.sdk.dbutils as dbutils
3
5
  from databricks.sdk import azure
@@ -116,31 +118,31 @@ class WorkspaceClient:
116
118
 
117
119
  def __init__(self,
118
120
  *,
119
- host: str = None,
120
- account_id: str = None,
121
- username: str = None,
122
- password: str = None,
123
- client_id: str = None,
124
- client_secret: str = None,
125
- token: str = None,
126
- profile: str = None,
127
- config_file: str = None,
128
- azure_workspace_resource_id: str = None,
129
- azure_client_secret: str = None,
130
- azure_client_id: str = None,
131
- azure_tenant_id: str = None,
132
- azure_environment: str = None,
133
- auth_type: str = None,
134
- cluster_id: str = None,
135
- google_credentials: str = None,
136
- google_service_account: str = None,
137
- debug_truncate_bytes: int = None,
138
- debug_headers: bool = None,
121
+ host: Optional[str] = None,
122
+ account_id: Optional[str] = None,
123
+ username: Optional[str] = None,
124
+ password: Optional[str] = None,
125
+ client_id: Optional[str] = None,
126
+ client_secret: Optional[str] = None,
127
+ token: Optional[str] = None,
128
+ profile: Optional[str] = None,
129
+ config_file: Optional[str] = None,
130
+ azure_workspace_resource_id: Optional[str] = None,
131
+ azure_client_secret: Optional[str] = None,
132
+ azure_client_id: Optional[str] = None,
133
+ azure_tenant_id: Optional[str] = None,
134
+ azure_environment: Optional[str] = None,
135
+ auth_type: Optional[str] = None,
136
+ cluster_id: Optional[str] = None,
137
+ google_credentials: Optional[str] = None,
138
+ google_service_account: Optional[str] = None,
139
+ debug_truncate_bytes: Optional[int] = None,
140
+ debug_headers: Optional[bool] = None,
139
141
  product="unknown",
140
142
  product_version="0.0.0",
141
- credentials_strategy: CredentialsStrategy = None,
142
- credentials_provider: CredentialsStrategy = None,
143
- config: client.Config = None):
143
+ credentials_strategy: Optional[CredentialsStrategy] = None,
144
+ credentials_provider: Optional[CredentialsStrategy] = None,
145
+ config: Optional[client.Config] = None):
144
146
  if not config:
145
147
  config = client.Config(host=host,
146
148
  account_id=account_id,
@@ -742,31 +744,31 @@ class AccountClient:
742
744
 
743
745
  def __init__(self,
744
746
  *,
745
- host: str = None,
746
- account_id: str = None,
747
- username: str = None,
748
- password: str = None,
749
- client_id: str = None,
750
- client_secret: str = None,
751
- token: str = None,
752
- profile: str = None,
753
- config_file: str = None,
754
- azure_workspace_resource_id: str = None,
755
- azure_client_secret: str = None,
756
- azure_client_id: str = None,
757
- azure_tenant_id: str = None,
758
- azure_environment: str = None,
759
- auth_type: str = None,
760
- cluster_id: str = None,
761
- google_credentials: str = None,
762
- google_service_account: str = None,
763
- debug_truncate_bytes: int = None,
764
- debug_headers: bool = None,
747
+ host: Optional[str] = None,
748
+ account_id: Optional[str] = None,
749
+ username: Optional[str] = None,
750
+ password: Optional[str] = None,
751
+ client_id: Optional[str] = None,
752
+ client_secret: Optional[str] = None,
753
+ token: Optional[str] = None,
754
+ profile: Optional[str] = None,
755
+ config_file: Optional[str] = None,
756
+ azure_workspace_resource_id: Optional[str] = None,
757
+ azure_client_secret: Optional[str] = None,
758
+ azure_client_id: Optional[str] = None,
759
+ azure_tenant_id: Optional[str] = None,
760
+ azure_environment: Optional[str] = None,
761
+ auth_type: Optional[str] = None,
762
+ cluster_id: Optional[str] = None,
763
+ google_credentials: Optional[str] = None,
764
+ google_service_account: Optional[str] = None,
765
+ debug_truncate_bytes: Optional[int] = None,
766
+ debug_headers: Optional[bool] = None,
765
767
  product="unknown",
766
768
  product_version="0.0.0",
767
- credentials_strategy: CredentialsStrategy = None,
768
- credentials_provider: CredentialsStrategy = None,
769
- config: client.Config = None):
769
+ credentials_strategy: Optional[CredentialsStrategy] = None,
770
+ credentials_provider: Optional[CredentialsStrategy] = None,
771
+ config: Optional[client.Config] = None):
770
772
  if not config:
771
773
  config = client.Config(host=host,
772
774
  account_id=account_id,
@@ -92,11 +92,11 @@ class Config:
92
92
  def __init__(self,
93
93
  *,
94
94
  # Deprecated. Use credentials_strategy instead.
95
- credentials_provider: CredentialsStrategy = None,
96
- credentials_strategy: CredentialsStrategy = None,
95
+ credentials_provider: Optional[CredentialsStrategy] = None,
96
+ credentials_strategy: Optional[CredentialsStrategy] = None,
97
97
  product=None,
98
98
  product_version=None,
99
- clock: Clock = None,
99
+ clock: Optional[Clock] = None,
100
100
  **kwargs):
101
101
  self._header_factory = None
102
102
  self._inner = {}
@@ -10,7 +10,7 @@ from .casing import Casing
10
10
  from .config import *
11
11
  # To preserve backwards compatibility (as these definitions were previously in this module)
12
12
  from .credentials_provider import *
13
- from .errors import DatabricksError, get_api_error
13
+ from .errors import DatabricksError, _ErrorCustomizer, _Parser
14
14
  from .logger import RoundTrip
15
15
  from .oauth import retrieve_token
16
16
  from .retries import retried
@@ -71,6 +71,8 @@ class ApiClient:
71
71
  # Default to 60 seconds
72
72
  self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60
73
73
 
74
+ self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)])
75
+
74
76
  @property
75
77
  def account_id(self) -> str:
76
78
  return self._cfg.account_id
@@ -219,27 +221,6 @@ class ApiClient:
219
221
  return f'matched {substring}'
220
222
  return None
221
223
 
222
- @classmethod
223
- def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
224
- retry_after = response.headers.get("Retry-After")
225
- if retry_after is None:
226
- # 429 requests should include a `Retry-After` header, but if it's missing,
227
- # we default to 1 second.
228
- return cls._RETRY_AFTER_DEFAULT
229
- # If the request is throttled, try parse the `Retry-After` header and sleep
230
- # for the specified number of seconds. Note that this header can contain either
231
- # an integer or a RFC1123 datetime string.
232
- # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
233
- #
234
- # For simplicity, we only try to parse it as an integer, as this is what Databricks
235
- # platform returns. Otherwise, we fall back and don't sleep.
236
- try:
237
- return int(retry_after)
238
- except ValueError:
239
- logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
240
- # defaulting to 1 sleep second to make self._is_retryable() simpler
241
- return cls._RETRY_AFTER_DEFAULT
242
-
243
224
  def _perform(self,
244
225
  method: str,
245
226
  url: str,
@@ -261,15 +242,8 @@ class ApiClient:
261
242
  stream=raw,
262
243
  timeout=self._http_timeout_seconds)
263
244
  self._record_request_log(response, raw=raw or data is not None or files is not None)
264
- error = get_api_error(response)
245
+ error = self._error_parser.get_api_error(response)
265
246
  if error is not None:
266
- status_code = response.status_code
267
- is_http_unauthorized_or_forbidden = status_code in (401, 403)
268
- is_too_many_requests_or_unavailable = status_code in (429, 503)
269
- if is_http_unauthorized_or_forbidden:
270
- error.message = self._cfg.wrap_debug_info(error.message)
271
- if is_too_many_requests_or_unavailable:
272
- error.retry_after_secs = self._parse_retry_after(response)
273
247
  raise error from None
274
248
  return response
275
249
 
@@ -279,6 +253,19 @@ class ApiClient:
279
253
  logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate())
280
254
 
281
255
 
256
+ class _AddDebugErrorCustomizer(_ErrorCustomizer):
257
+ """An error customizer that adds debug information about the configuration to unauthenticated and
258
+ unauthorized errors."""
259
+
260
+ def __init__(self, cfg: Config):
261
+ self._cfg = cfg
262
+
263
+ def customize_error(self, response: requests.Response, kwargs: dict):
264
+ if response.status_code in (401, 403):
265
+ message = kwargs.get('message', 'request failed')
266
+ kwargs['message'] = self._cfg.wrap_debug_info(message)
267
+
268
+
282
269
  class StreamingResponse(BinaryIO):
283
270
  _response: requests.Response
284
271
  _buffer: bytes
@@ -9,14 +9,15 @@ import pathlib
9
9
  import platform
10
10
  import subprocess
11
11
  import sys
12
+ import time
12
13
  from datetime import datetime
13
- from typing import Callable, Dict, List, Optional, Union
14
+ from typing import Callable, Dict, List, Optional, Tuple, Union
14
15
 
15
- import google.auth
16
+ import google.auth # type: ignore
16
17
  import requests
17
- from google.auth import impersonated_credentials
18
- from google.auth.transport.requests import Request
19
- from google.oauth2 import service_account
18
+ from google.auth import impersonated_credentials # type: ignore
19
+ from google.auth.transport.requests import Request # type: ignore
20
+ from google.oauth2 import service_account # type: ignore
20
21
 
21
22
  from .azure import add_sp_management_token, add_workspace_id_header
22
23
  from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
@@ -411,10 +412,7 @@ class CliTokenSource(Refreshable):
411
412
 
412
413
  def refresh(self) -> Token:
413
414
  try:
414
- is_windows = sys.platform.startswith('win')
415
- # windows requires shell=True to be able to execute 'az login' or other commands
416
- # cannot use shell=True all the time, as it breaks macOS
417
- out = subprocess.run(self._cmd, capture_output=True, check=True, shell=is_windows)
415
+ out = _run_subprocess(self._cmd, capture_output=True, check=True)
418
416
  it = json.loads(out.stdout.decode())
419
417
  expires_on = self._parse_expiry(it[self._expiry_field])
420
418
  return Token(access_token=it[self._access_token_field],
@@ -429,6 +427,26 @@ class CliTokenSource(Refreshable):
429
427
  raise IOError(f'cannot get access token: {message}') from e
430
428
 
431
429
 
430
+ def _run_subprocess(popenargs,
431
+ input=None,
432
+ capture_output=True,
433
+ timeout=None,
434
+ check=False,
435
+ **kwargs) -> subprocess.CompletedProcess:
436
+ """Runs subprocess with given arguments.
437
+ This handles OS-specific modifications that need to be made to the invocation of subprocess.run."""
438
+ kwargs['shell'] = sys.platform.startswith('win')
439
+ # windows requires shell=True to be able to execute 'az login' or other commands
440
+ # cannot use shell=True all the time, as it breaks macOS
441
+ logging.debug(f'Running command: {" ".join(popenargs)}')
442
+ return subprocess.run(popenargs,
443
+ input=input,
444
+ capture_output=capture_output,
445
+ timeout=timeout,
446
+ check=check,
447
+ **kwargs)
448
+
449
+
432
450
  class AzureCliTokenSource(CliTokenSource):
433
451
  """ Obtain the token granted by `az login` CLI command """
434
452
 
@@ -437,13 +455,30 @@ class AzureCliTokenSource(CliTokenSource):
437
455
  if subscription is not None:
438
456
  cmd.append("--subscription")
439
457
  cmd.append(subscription)
440
- if tenant:
458
+ if tenant and not self.__is_cli_using_managed_identity():
441
459
  cmd.extend(["--tenant", tenant])
442
460
  super().__init__(cmd=cmd,
443
461
  token_type_field='tokenType',
444
462
  access_token_field='accessToken',
445
463
  expiry_field='expiresOn')
446
464
 
465
+ @staticmethod
466
+ def __is_cli_using_managed_identity() -> bool:
467
+ """Checks whether the current CLI session is authenticated using managed identity."""
468
+ try:
469
+ cmd = ["az", "account", "show", "--output", "json"]
470
+ out = _run_subprocess(cmd, capture_output=True, check=True)
471
+ account = json.loads(out.stdout.decode())
472
+ user = account.get("user")
473
+ if user is None:
474
+ return False
475
+ return user.get("type") == "servicePrincipal" and user.get("name") in [
476
+ 'systemAssignedIdentity', 'userAssignedIdentity'
477
+ ]
478
+ except subprocess.CalledProcessError as e:
479
+ logger.debug("Failed to get account information from Azure CLI", exc_info=e)
480
+ return False
481
+
447
482
  def is_human_user(self) -> bool:
448
483
  """The UPN claim is the username of the user, but not the Service Principal.
449
484
 
@@ -664,6 +699,90 @@ def metadata_service(cfg: 'Config') -> Optional[CredentialsProvider]:
664
699
  return inner
665
700
 
666
701
 
702
+ # This Code is derived from Mlflow DatabricksModelServingConfigProvider
703
+ # https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
704
+ class ModelServingAuthProvider():
705
+ _MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
706
+
707
+ def __init__(self):
708
+ self.expiry_time = -1
709
+ self.current_token = None
710
+ self.refresh_duration = 300 # 300 Seconds
711
+
712
+ def should_fetch_model_serving_environment_oauth(self) -> bool:
713
+ """
714
+ Check whether this is the model serving environment
715
+ Additionally check if the oauth token file path exists
716
+ """
717
+
718
+ is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
719
+ or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
720
+ return (is_in_model_serving_env == "true"
721
+ and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
722
+
723
+ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
724
+ # Use Cached value if it is valid
725
+ if self.current_token is not None and self.expiry_time > time.time():
726
+ return self.current_token
727
+
728
+ try:
729
+ with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
730
+ oauth_dict = json.load(f)
731
+ self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
732
+ self.expiry_time = time.time() + self.refresh_duration
733
+ except Exception as e:
734
+ # sleep and retry in case of any race conditions with OAuth refreshing
735
+ if should_retry:
736
+ logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
737
+ exc_info=e)
738
+ time.sleep(0.5)
739
+ return self.get_model_dependency_oauth_token(should_retry=False)
740
+ else:
741
+ raise RuntimeError(
742
+ "Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
743
+ ) from e
744
+ return self.current_token
745
+
746
+ def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
747
+ if not self.should_fetch_model_serving_environment_oauth():
748
+ return None
749
+
750
+ # read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
751
+ host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
752
+ "DB_MODEL_SERVING_HOST_URL")
753
+ token = self.get_model_dependency_oauth_token()
754
+
755
+ return (host, token)
756
+
757
+
758
+ @credentials_strategy('model-serving', [])
759
+ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
760
+ try:
761
+ model_serving_auth_provider = ModelServingAuthProvider()
762
+ if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
763
+ logger.debug("model-serving: Not in Databricks Model Serving, skipping")
764
+ return None
765
+ host, token = model_serving_auth_provider.get_databricks_host_token()
766
+ if token is None:
767
+ raise ValueError(
768
+ "Got malformed auth (empty token) when fetching auth implicitly available in Model Serving Environment. Please contact Databricks support"
769
+ )
770
+ if cfg.host is None:
771
+ cfg.host = host
772
+ except Exception as e:
773
+ logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
774
+ return None
775
+
776
+ logger.info("Using Databricks Model Serving Authentication")
777
+
778
+ def inner() -> Dict[str, str]:
779
+ # Call here again to get the refreshed token
780
+ _, token = model_serving_auth_provider.get_databricks_host_token()
781
+ return {"Authorization": f"Bearer {token}"}
782
+
783
+ return inner
784
+
785
+
667
786
  class DefaultCredentials:
668
787
  """ Select the first applicable credential provider from the chain """
669
788
 
@@ -672,7 +791,7 @@ class DefaultCredentials:
672
791
  self._auth_providers = [
673
792
  pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
674
793
  github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
675
- google_credentials, google_id
794
+ google_credentials, google_id, model_serving_auth
676
795
  ]
677
796
 
678
797
  def auth_type(self) -> str:
@@ -1,6 +1,6 @@
1
1
  from .base import DatabricksError, ErrorDetail
2
- from .mapper import _error_mapper
3
- from .parser import get_api_error
2
+ from .customizer import _ErrorCustomizer
3
+ from .parser import _Parser
4
4
  from .platform import *
5
5
  from .private_link import PrivateLinkValidationError
6
6
  from .sdk import *
@@ -0,0 +1,50 @@
1
+ import abc
2
+ import logging
3
+
4
+ import requests
5
+
6
+
7
+ class _ErrorCustomizer(abc.ABC):
8
+ """A customizer for errors from the Databricks REST API."""
9
+
10
+ @abc.abstractmethod
11
+ def customize_error(self, response: requests.Response, kwargs: dict):
12
+ """Customize the error constructor parameters."""
13
+
14
+
15
+ class _RetryAfterCustomizer(_ErrorCustomizer):
16
+ """An error customizer that sets the retry_after_secs parameter based on the Retry-After header."""
17
+
18
+ _DEFAULT_RETRY_AFTER_SECONDS = 1
19
+ """The default number of seconds to wait before retrying a request if the Retry-After header is missing or is not
20
+ a valid integer."""
21
+
22
+ @classmethod
23
+ def _parse_retry_after(cls, response: requests.Response) -> int:
24
+ retry_after = response.headers.get("Retry-After")
25
+ if retry_after is None:
26
+ logging.debug(
27
+ f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}'
28
+ )
29
+ # 429 requests should include a `Retry-After` header, but if it's missing,
30
+ # we default to 1 second.
31
+ return cls._DEFAULT_RETRY_AFTER_SECONDS
32
+ # If the request is throttled, try parse the `Retry-After` header and sleep
33
+ # for the specified number of seconds. Note that this header can contain either
34
+ # an integer or a RFC1123 datetime string.
35
+ # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
36
+ #
37
+ # For simplicity, we only try to parse it as an integer, as this is what Databricks
38
+ # platform returns. Otherwise, we fall back and don't sleep.
39
+ try:
40
+ return int(retry_after)
41
+ except ValueError:
42
+ logging.debug(
43
+ f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}'
44
+ )
45
+ # defaulting to 1 sleep second to make self._is_retryable() simpler
46
+ return cls._DEFAULT_RETRY_AFTER_SECONDS
47
+
48
+ def customize_error(self, response: requests.Response, kwargs: dict):
49
+ if response.status_code in (429, 503):
50
+ kwargs['retry_after_secs'] = self._parse_retry_after(response)
@@ -0,0 +1,106 @@
1
+ import abc
2
+ import json
3
+ import logging
4
+ import re
5
+ from typing import Optional
6
+
7
+ import requests
8
+
9
+
10
+ class _ErrorDeserializer(abc.ABC):
11
+ """A parser for errors from the Databricks REST API."""
12
+
13
+ @abc.abstractmethod
14
+ def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
15
+ """Parses an error from the Databricks REST API. If the error cannot be parsed, returns None."""
16
+
17
+
18
+ class _EmptyDeserializer(_ErrorDeserializer):
19
+ """A parser that handles empty responses."""
20
+
21
+ def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
22
+ if len(response_body) == 0:
23
+ return {'message': response.reason}
24
+ return None
25
+
26
+
27
+ class _StandardErrorDeserializer(_ErrorDeserializer):
28
+ """
29
+ Parses errors from the Databricks REST API using the standard error format.
30
+ """
31
+
32
+ def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
33
+ try:
34
+ payload_str = response_body.decode('utf-8')
35
+ resp = json.loads(payload_str)
36
+ except UnicodeDecodeError as e:
37
+ logging.debug('_StandardErrorParser: unable to decode response using utf-8', exc_info=e)
38
+ return None
39
+ except json.JSONDecodeError as e:
40
+ logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e)
41
+ return None
42
+ if not isinstance(resp, dict):
43
+ logging.debug('_StandardErrorParser: response is valid JSON but not a dictionary')
44
+ return None
45
+
46
+ error_args = {
47
+ 'message': resp.get('message', 'request failed'),
48
+ 'error_code': resp.get('error_code'),
49
+ 'details': resp.get('details'),
50
+ }
51
+
52
+ # Handle API 1.2-style errors
53
+ if 'error' in resp:
54
+ error_args['message'] = resp['error']
55
+
56
+ # Handle SCIM Errors
57
+ detail = resp.get('detail')
58
+ status = resp.get('status')
59
+ scim_type = resp.get('scimType')
60
+ if detail:
61
+ # Handle SCIM error message details
62
+ # @see https://tools.ietf.org/html/rfc7644#section-3.7.3
63
+ if detail == "null":
64
+ detail = "SCIM API Internal Error"
65
+ error_args['message'] = f"{scim_type} {detail}".strip(" ")
66
+ error_args['error_code'] = f"SCIM_{status}"
67
+ return error_args
68
+
69
+
70
+ class _StringErrorDeserializer(_ErrorDeserializer):
71
+ """
72
+ Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE".
73
+ """
74
+
75
+ __STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)')
76
+
77
+ def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
78
+ payload_str = response_body.decode('utf-8')
79
+ match = self.__STRING_ERROR_REGEX.match(payload_str)
80
+ if not match:
81
+ logging.debug('_StringErrorParser: unable to parse response as string')
82
+ return None
83
+ error_code, message = match.groups()
84
+ return {'error_code': error_code, 'message': message, 'status': response.status_code, }
85
+
86
+
87
+ class _HtmlErrorDeserializer(_ErrorDeserializer):
88
+ """
89
+ Parses errors from the Databricks REST API in HTML format.
90
+ """
91
+
92
+ __HTML_ERROR_REGEXES = [re.compile(r'<pre>(.*)</pre>'), re.compile(r'<title>(.*)</title>'), ]
93
+
94
+ def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
95
+ payload_str = response_body.decode('utf-8')
96
+ for regex in self.__HTML_ERROR_REGEXES:
97
+ match = regex.search(payload_str)
98
+ if match:
99
+ message = match.group(1) if match.group(1) else response.reason
100
+ return {
101
+ 'status': response.status_code,
102
+ 'message': message,
103
+ 'error_code': response.reason.upper().replace(' ', '_')
104
+ }
105
+ logging.debug('_HtmlErrorParser: no <pre> tag found in error response')
106
+ return None
@@ -0,0 +1,83 @@
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import requests
5
+
6
+ from ..logger import RoundTrip
7
+ from .base import DatabricksError
8
+ from .customizer import _ErrorCustomizer, _RetryAfterCustomizer
9
+ from .deserializer import (_EmptyDeserializer, _ErrorDeserializer,
10
+ _HtmlErrorDeserializer, _StandardErrorDeserializer,
11
+ _StringErrorDeserializer)
12
+ from .mapper import _error_mapper
13
+ from .private_link import (_get_private_link_validation_error,
14
+ _is_private_link_redirect)
15
+
16
+ # A list of _ErrorDeserializers that are tried in order to parse an API error from a response body. Most errors should
17
+ # be parsable by the _StandardErrorDeserializer, but additional parsers can be added here for specific error formats.
18
+ # The order of the parsers is not important, as the set of errors that can be parsed by each parser should be disjoint.
19
+ _error_deserializers = [
20
+ _EmptyDeserializer(),
21
+ _StandardErrorDeserializer(),
22
+ _StringErrorDeserializer(),
23
+ _HtmlErrorDeserializer(),
24
+ ]
25
+
26
+ # A list of _ErrorCustomizers that are applied to the error arguments after they are parsed. Customizers can modify the
27
+ # error arguments in any way, including adding or removing fields. Customizers are applied in order, so later
28
+ # customizers can override the changes made by earlier customizers.
29
+ _error_customizers = [_RetryAfterCustomizer(), ]
30
+
31
+
32
+ def _unknown_error(response: requests.Response) -> str:
33
+ """A standard error message that can be shown when an API response cannot be parsed.
34
+
35
+ This error message includes a link to the issue tracker for the SDK for users to report the issue to us.
36
+ """
37
+ request_log = RoundTrip(response, debug_headers=True, debug_truncate_bytes=10 * 1024).generate()
38
+ return (
39
+ 'This is likely a bug in the Databricks SDK for Python or the underlying '
40
+ 'API. Please report this issue with the following debugging information to the SDK issue tracker at '
41
+ f'https://github.com/databricks/databricks-sdk-go/issues. Request log:```{request_log}```')
42
+
43
+
44
+ class _Parser:
45
+ """
46
+ A parser for errors from the Databricks REST API. It attempts to deserialize an error using a sequence of
47
+ deserializers, and then customizes the deserialized error using a sequence of customizers. If the error cannot be
48
+ deserialized, it returns a generic error with debugging information and instructions to report the issue to the SDK
49
+ issue tracker.
50
+ """
51
+
52
+ def __init__(self,
53
+ extra_error_parsers: List[_ErrorDeserializer] = [],
54
+ extra_error_customizers: List[_ErrorCustomizer] = []):
55
+ self._error_parsers = _error_deserializers + (extra_error_parsers
56
+ if extra_error_parsers is not None else [])
57
+ self._error_customizers = _error_customizers + (extra_error_customizers
58
+ if extra_error_customizers is not None else [])
59
+
60
+ def get_api_error(self, response: requests.Response) -> Optional[DatabricksError]:
61
+ """
62
+ Handles responses from the REST API and returns a DatabricksError if the response indicates an error.
63
+ :param response: The response from the REST API.
64
+ :return: A DatabricksError if the response indicates an error, otherwise None.
65
+ """
66
+ if not response.ok:
67
+ content = response.content
68
+ for parser in self._error_parsers:
69
+ try:
70
+ error_args = parser.deserialize_error(response, content)
71
+ if error_args:
72
+ for customizer in self._error_customizers:
73
+ customizer.customize_error(response, error_args)
74
+ return _error_mapper(response, error_args)
75
+ except Exception as e:
76
+ logging.debug(f'Error parsing response with {parser}, continuing', exc_info=e)
77
+ return _error_mapper(response,
78
+ {'message': 'unable to parse response. ' + _unknown_error(response)})
79
+
80
+ # Private link failures happen via a redirect to the login page. From a requests-perspective, the request
81
+ # is successful, but the response is not what we expect. We need to handle this case separately.
82
+ if _is_private_link_redirect(response):
83
+ return _get_private_link_validation_error(response.url)