databricks-sdk 0.34.0__tar.gz → 0.36.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of databricks-sdk might be problematic. Click here for more details.

Files changed (96) hide show
  1. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/NOTICE +14 -0
  2. {databricks_sdk-0.34.0/databricks_sdk.egg-info → databricks_sdk-0.36.0}/PKG-INFO +8 -1
  3. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/__init__.py +3 -2
  4. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/_base_client.py +20 -0
  5. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/config.py +10 -34
  6. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/credentials_provider.py +18 -13
  7. databricks_sdk-0.36.0/databricks/sdk/mixins/open_ai_client.py +52 -0
  8. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/oauth.py +179 -51
  9. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/apps.py +1 -1
  10. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/catalog.py +12 -3
  11. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/dashboards.py +8 -1
  12. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/jobs.py +52 -1
  13. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/pipelines.py +53 -3
  14. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/sql.py +20 -0
  15. databricks_sdk-0.36.0/databricks/sdk/version.py +1 -0
  16. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0/databricks_sdk.egg-info}/PKG-INFO +8 -1
  17. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks_sdk.egg-info/SOURCES.txt +2 -0
  18. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks_sdk.egg-info/requires.txt +12 -0
  19. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/setup.py +4 -2
  20. databricks_sdk-0.36.0/tests/test_oauth.py +126 -0
  21. databricks_sdk-0.36.0/tests/test_open_ai_mixin.py +30 -0
  22. databricks_sdk-0.34.0/databricks/sdk/version.py +0 -1
  23. databricks_sdk-0.34.0/tests/test_oauth.py +0 -29
  24. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/LICENSE +0 -0
  25. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/README.md +0 -0
  26. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/__init__.py +0 -0
  27. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/_property.py +0 -0
  28. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/_widgets/__init__.py +0 -0
  29. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/_widgets/default_widgets_utils.py +0 -0
  30. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/_widgets/ipywidgets_utils.py +0 -0
  31. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/azure.py +0 -0
  32. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/casing.py +0 -0
  33. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/clock.py +0 -0
  34. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/core.py +0 -0
  35. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/data_plane.py +0 -0
  36. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/dbutils.py +0 -0
  37. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/environments.py +0 -0
  38. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/__init__.py +0 -0
  39. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/base.py +0 -0
  40. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/customizer.py +0 -0
  41. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/deserializer.py +0 -0
  42. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/mapper.py +0 -0
  43. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/overrides.py +0 -0
  44. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/parser.py +0 -0
  45. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/platform.py +0 -0
  46. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/private_link.py +0 -0
  47. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/errors/sdk.py +0 -0
  48. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/logger/__init__.py +0 -0
  49. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/logger/round_trip_logger.py +0 -0
  50. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/mixins/__init__.py +0 -0
  51. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/mixins/compute.py +0 -0
  52. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/mixins/files.py +0 -0
  53. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/mixins/workspace.py +0 -0
  54. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/py.typed +0 -0
  55. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/retries.py +0 -0
  56. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/runtime/__init__.py +0 -0
  57. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/runtime/dbutils_stub.py +0 -0
  58. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/__init__.py +0 -0
  59. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/_internal.py +0 -0
  60. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/billing.py +0 -0
  61. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/compute.py +0 -0
  62. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/files.py +0 -0
  63. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/iam.py +0 -0
  64. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/marketplace.py +0 -0
  65. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/ml.py +0 -0
  66. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/oauth2.py +0 -0
  67. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/provisioning.py +0 -0
  68. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/serving.py +0 -0
  69. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/settings.py +0 -0
  70. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/sharing.py +0 -0
  71. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/vectorsearch.py +0 -0
  72. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/service/workspace.py +0 -0
  73. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks/sdk/useragent.py +0 -0
  74. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks_sdk.egg-info/dependency_links.txt +0 -0
  75. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/databricks_sdk.egg-info/top_level.txt +0 -0
  76. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/setup.cfg +0 -0
  77. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_auth.py +0 -0
  78. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_auth_manual_tests.py +0 -0
  79. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_base_client.py +0 -0
  80. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_client.py +0 -0
  81. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_compute_mixins.py +0 -0
  82. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_config.py +0 -0
  83. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_core.py +0 -0
  84. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_data_plane.py +0 -0
  85. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_dbfs_mixins.py +0 -0
  86. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_dbutils.py +0 -0
  87. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_environments.py +0 -0
  88. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_errors.py +0 -0
  89. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_init_file.py +0 -0
  90. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_internal.py +0 -0
  91. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_jobs.py +0 -0
  92. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_metadata_service_auth.py +0 -0
  93. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_misc.py +0 -0
  94. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_model_serving_auth.py +0 -0
  95. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_retries.py +0 -0
  96. {databricks_sdk-0.34.0 → databricks_sdk-0.36.0}/tests/test_user_agent.py +0 -0
@@ -12,8 +12,22 @@ googleapis/google-auth-library-python - https://github.com/googleapis/google-aut
12
12
  Copyright google-auth-library-python authors
13
13
  License - https://github.com/googleapis/google-auth-library-python/blob/main/LICENSE
14
14
 
15
+ openai/openai-python - https://github.com/openai/openai-python
16
+ Copyright 2024 OpenAI
17
+ License - https://github.com/openai/openai-python/blob/main/LICENSE
18
+
15
19
  This software contains code from the following open source projects, licensed under the BSD (3-clause) license.
16
20
 
17
21
  x/oauth2 - https://cs.opensource.google/go/x/oauth2/+/master:oauth2.go
18
22
  Copyright 2014 The Go Authors. All rights reserved.
19
23
  License - https://cs.opensource.google/go/x/oauth2/+/master:LICENSE
24
+
25
+ encode/httpx - https://github.com/encode/httpx
26
+ Copyright 2019, Encode OSS Ltd
27
+ License - https://github.com/encode/httpx/blob/master/LICENSE.md
28
+
29
+ This software contains code from the following open source projects, licensed under the MIT license:
30
+
31
+ langchain-ai/langchain - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai
32
+ Copyright 2023 LangChain, Inc.
33
+ License - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/LICENSE
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: databricks-sdk
3
- Version: 0.34.0
3
+ Version: 0.36.0
4
4
  Summary: Databricks SDK for Python (Beta)
5
5
  Home-page: https://databricks-sdk-py.readthedocs.io
6
6
  Author: Serge Smertin
@@ -40,9 +40,16 @@ Requires-Dist: requests-mock; extra == "dev"
40
40
  Requires-Dist: pyfakefs; extra == "dev"
41
41
  Requires-Dist: databricks-connect; extra == "dev"
42
42
  Requires-Dist: pytest-rerunfailures; extra == "dev"
43
+ Requires-Dist: openai; extra == "dev"
44
+ Requires-Dist: langchain-openai; python_version > "3.7" and extra == "dev"
45
+ Requires-Dist: httpx; extra == "dev"
43
46
  Provides-Extra: notebook
44
47
  Requires-Dist: ipython<9,>=8; extra == "notebook"
45
48
  Requires-Dist: ipywidgets<9,>=8; extra == "notebook"
49
+ Provides-Extra: openai
50
+ Requires-Dist: openai; extra == "openai"
51
+ Requires-Dist: langchain-openai; python_version > "3.7" and extra == "openai"
52
+ Requires-Dist: httpx; extra == "openai"
46
53
 
47
54
  # Databricks SDK for Python (Beta)
48
55
 
@@ -6,6 +6,7 @@ from databricks.sdk import azure
6
6
  from databricks.sdk.credentials_provider import CredentialsStrategy
7
7
  from databricks.sdk.mixins.compute import ClustersExt
8
8
  from databricks.sdk.mixins.files import DbfsExt
9
+ from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt
9
10
  from databricks.sdk.mixins.workspace import WorkspaceExt
10
11
  from databricks.sdk.service.apps import AppsAPI
11
12
  from databricks.sdk.service.billing import (BillableUsageAPI, BudgetsAPI,
@@ -175,7 +176,7 @@ class WorkspaceClient:
175
176
  self._config = config.copy()
176
177
  self._dbutils = _make_dbutils(self._config)
177
178
  self._api_client = client.ApiClient(self._config)
178
- serving_endpoints = ServingEndpointsAPI(self._api_client)
179
+ serving_endpoints = ServingEndpointsExt(self._api_client)
179
180
  self._account_access_control_proxy = AccountAccessControlProxyAPI(self._api_client)
180
181
  self._alerts = AlertsAPI(self._api_client)
181
182
  self._alerts_legacy = AlertsLegacyAPI(self._api_client)
@@ -637,7 +638,7 @@ class WorkspaceClient:
637
638
  return self._service_principals
638
639
 
639
640
  @property
640
- def serving_endpoints(self) -> ServingEndpointsAPI:
641
+ def serving_endpoints(self) -> ServingEndpointsExt:
641
642
  """The Serving Endpoints API allows you to create, update, and delete model serving endpoints."""
642
643
  return self._serving_endpoints
643
644
 
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import urllib.parse
2
3
  from datetime import timedelta
3
4
  from types import TracebackType
4
5
  from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
@@ -17,6 +18,25 @@ from .retries import retried
17
18
  logger = logging.getLogger('databricks.sdk')
18
19
 
19
20
 
21
+ def _fix_host_if_needed(host: Optional[str]) -> Optional[str]:
22
+ if not host:
23
+ return host
24
+
25
+ # Add a default scheme if it's missing
26
+ if '://' not in host:
27
+ host = 'https://' + host
28
+
29
+ o = urllib.parse.urlparse(host)
30
+ # remove trailing slash
31
+ path = o.path.rstrip('/')
32
+ # remove port if 443
33
+ netloc = o.netloc
34
+ if o.port == 443:
35
+ netloc = netloc.split(':')[0]
36
+
37
+ return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
38
+
39
+
20
40
  class _BaseClient:
21
41
 
22
42
  def __init__(self,
@@ -10,11 +10,14 @@ from typing import Dict, Iterable, Optional
10
10
  import requests
11
11
 
12
12
  from . import useragent
13
+ from ._base_client import _fix_host_if_needed
13
14
  from .clock import Clock, RealClock
14
15
  from .credentials_provider import CredentialsStrategy, DefaultCredentials
15
16
  from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
16
17
  DatabricksEnvironment, get_environment_for_hostname)
17
- from .oauth import OidcEndpoints, Token
18
+ from .oauth import (OidcEndpoints, Token, get_account_endpoints,
19
+ get_azure_entra_id_workspace_endpoints,
20
+ get_workspace_endpoints)
18
21
 
19
22
  logger = logging.getLogger('databricks.sdk')
20
23
 
@@ -254,24 +257,10 @@ class Config:
254
257
  if not self.host:
255
258
  return None
256
259
  if self.is_azure and self.azure_client_id:
257
- # Retrieve authorize endpoint to retrieve token endpoint after
258
- res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
259
- real_auth_url = res.headers.get('location')
260
- if not real_auth_url:
261
- return None
262
- return OidcEndpoints(authorization_endpoint=real_auth_url,
263
- token_endpoint=real_auth_url.replace('/authorize', '/token'))
260
+ return get_azure_entra_id_workspace_endpoints(self.host)
264
261
  if self.is_account_client and self.account_id:
265
- prefix = f'{self.host}/oidc/accounts/{self.account_id}'
266
- return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize',
267
- token_endpoint=f'{prefix}/v1/token')
268
- oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server'
269
- res = requests.get(oidc)
270
- if res.status_code != 200:
271
- return None
272
- auth_metadata = res.json()
273
- return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'),
274
- token_endpoint=auth_metadata.get('token_endpoint'))
262
+ return get_account_endpoints(self.host, self.account_id)
263
+ return get_workspace_endpoints(self.host)
275
264
 
276
265
  def debug_string(self) -> str:
277
266
  """ Returns log-friendly representation of configured attributes """
@@ -346,22 +335,9 @@ class Config:
346
335
  return cls._attributes
347
336
 
348
337
  def _fix_host_if_needed(self):
349
- if not self.host:
350
- return
351
-
352
- # Add a default scheme if it's missing
353
- if '://' not in self.host:
354
- self.host = 'https://' + self.host
355
-
356
- o = urllib.parse.urlparse(self.host)
357
- # remove trailing slash
358
- path = o.path.rstrip('/')
359
- # remove port if 443
360
- netloc = o.netloc
361
- if o.port == 443:
362
- netloc = netloc.split(':')[0]
363
-
364
- self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
338
+ updated_host = _fix_host_if_needed(self.host)
339
+ if updated_host:
340
+ self.host = updated_host
365
341
 
366
342
  def load_azure_tenant_id(self):
367
343
  """[Internal] Load the Azure tenant ID from the Azure Databricks login page.
@@ -187,30 +187,35 @@ def oauth_service_principal(cfg: 'Config') -> Optional[CredentialsProvider]:
187
187
  def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
188
188
  if cfg.auth_type != 'external-browser':
189
189
  return None
190
+ client_id, client_secret = None, None
190
191
  if cfg.client_id:
191
192
  client_id = cfg.client_id
192
- elif cfg.is_aws:
193
+ client_secret = cfg.client_secret
194
+ elif cfg.azure_client_id:
195
+ client_id = cfg.azure_client
196
+ client_secret = cfg.azure_client_secret
197
+
198
+ if not client_id:
193
199
  client_id = 'databricks-cli'
194
- elif cfg.is_azure:
195
- # Use Azure AD app for cases when Azure CLI is not available on the machine.
196
- # App has to be registered as Single-page multi-tenant to support PKCE
197
- # TODO: temporary app ID, change it later.
198
- client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
199
- else:
200
- raise ValueError(f'local browser SSO is not supported')
201
- oauth_client = OAuthClient(host=cfg.host,
202
- client_id=client_id,
203
- redirect_url='http://localhost:8020',
204
- client_secret=cfg.client_secret)
205
200
 
206
201
  # Load cached credentials from disk if they exist.
207
202
  # Note that these are local to the Python SDK and not reused by other SDKs.
208
- token_cache = TokenCache(oauth_client)
203
+ oidc_endpoints = cfg.oidc_endpoints
204
+ redirect_url = 'http://localhost:8020'
205
+ token_cache = TokenCache(host=cfg.host,
206
+ oidc_endpoints=oidc_endpoints,
207
+ client_id=client_id,
208
+ client_secret=client_secret,
209
+ redirect_url=redirect_url)
209
210
  credentials = token_cache.load()
210
211
  if credentials:
211
212
  # Force a refresh in case the loaded credentials are expired.
212
213
  credentials.token()
213
214
  else:
215
+ oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
216
+ client_id=client_id,
217
+ redirect_url=redirect_url,
218
+ client_secret=client_secret)
214
219
  consent = oauth_client.initiate_consent()
215
220
  if not consent:
216
221
  return None
@@ -0,0 +1,52 @@
1
+ from databricks.sdk.service.serving import ServingEndpointsAPI
2
+
3
+
4
+ class ServingEndpointsExt(ServingEndpointsAPI):
5
+
6
+ # Using the HTTP Client to pass in the databricks authorization
7
+ # This method will be called on every invocation, so when using with model serving will always get the refreshed token
8
+ def _get_authorized_http_client(self):
9
+ import httpx
10
+
11
+ class BearerAuth(httpx.Auth):
12
+
13
+ def __init__(self, get_headers_func):
14
+ self.get_headers_func = get_headers_func
15
+
16
+ def auth_flow(self, request: httpx.Request) -> httpx.Request:
17
+ auth_headers = self.get_headers_func()
18
+ request.headers["Authorization"] = auth_headers["Authorization"]
19
+ yield request
20
+
21
+ databricks_token_auth = BearerAuth(self._api._cfg.authenticate)
22
+
23
+ # Create an HTTP client with Bearer Token authentication
24
+ http_client = httpx.Client(auth=databricks_token_auth)
25
+ return http_client
26
+
27
+ def get_open_ai_client(self):
28
+ try:
29
+ from openai import OpenAI
30
+ except Exception:
31
+ raise ImportError(
32
+ "Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`"
33
+ )
34
+
35
+ return OpenAI(
36
+ base_url=self._api._cfg.host + "/serving-endpoints",
37
+ api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
38
+ http_client=self._get_authorized_http_client())
39
+
40
+ def get_langchain_chat_open_ai_client(self, model):
41
+ try:
42
+ from langchain_openai import ChatOpenAI
43
+ except Exception:
44
+ raise ImportError(
45
+ "Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7"
46
+ )
47
+
48
+ return ChatOpenAI(
49
+ model=model,
50
+ openai_api_base=self._api._cfg.host + "/serving-endpoints",
51
+ api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
52
+ http_client=self._get_authorized_http_client())
@@ -17,6 +17,8 @@ from typing import Any, Dict, List, Optional
17
17
  import requests
18
18
  import requests.auth
19
19
 
20
+ from ._base_client import _BaseClient, _fix_host_if_needed
21
+
20
22
  # Error code for PKCE flow in Azure Active Directory, that gets additional retry.
21
23
  # See https://stackoverflow.com/a/75466778/277035 for more info
22
24
  NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327'
@@ -46,8 +48,24 @@ class IgnoreNetrcAuth(requests.auth.AuthBase):
46
48
 
47
49
  @dataclass
48
50
  class OidcEndpoints:
51
+ """
52
+ The endpoints used for OAuth-based authentication in Databricks.
53
+ """
54
+
49
55
  authorization_endpoint: str # ../v1/authorize
56
+ """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for
57
+ the user to login and authorize the client for user-to-machine (U2M) flows."""
58
+
50
59
  token_endpoint: str # ../v1/token
60
+ """The token endpoint for the OAuth flow."""
61
+
62
+ @staticmethod
63
+ def from_dict(d: dict) -> 'OidcEndpoints':
64
+ return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'),
65
+ token_endpoint=d.get('token_endpoint'))
66
+
67
+ def as_dict(self) -> dict:
68
+ return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint}
51
69
 
52
70
 
53
71
  @dataclass
@@ -220,18 +238,76 @@ class _OAuthCallback(BaseHTTPRequestHandler):
220
238
  self.wfile.write(b'You can close this tab.')
221
239
 
222
240
 
241
+ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
242
+ """
243
+ Get the OIDC endpoints for a given account.
244
+ :param host: The Databricks account host.
245
+ :param account_id: The account ID.
246
+ :return: The account's OIDC endpoints.
247
+ """
248
+ host = _fix_host_if_needed(host)
249
+ oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server'
250
+ resp = client.do('GET', oidc)
251
+ return OidcEndpoints.from_dict(resp)
252
+
253
+
254
+ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
255
+ """
256
+ Get the OIDC endpoints for a given workspace.
257
+ :param host: The Databricks workspace host.
258
+ :return: The workspace's OIDC endpoints.
259
+ """
260
+ host = _fix_host_if_needed(host)
261
+ oidc = f'{host}/oidc/.well-known/oauth-authorization-server'
262
+ resp = client.do('GET', oidc)
263
+ return OidcEndpoints.from_dict(resp)
264
+
265
+
266
+ def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]:
267
+ """
268
+ Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
269
+ using an application registered in Azure Entra ID.
270
+ :param host: The Databricks workspace host.
271
+ :return: The OIDC endpoints for the workspace's Azure Entra ID tenant.
272
+ """
273
+ # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
274
+ host = _fix_host_if_needed(host)
275
+ res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
276
+ real_auth_url = res.headers.get('location')
277
+ if not real_auth_url:
278
+ return None
279
+ return OidcEndpoints(authorization_endpoint=real_auth_url,
280
+ token_endpoint=real_auth_url.replace('/authorize', '/token'))
281
+
282
+
223
283
  class SessionCredentials(Refreshable):
224
284
 
225
- def __init__(self, client: 'OAuthClient', token: Token):
226
- self._client = client
285
+ def __init__(self,
286
+ token: Token,
287
+ token_endpoint: str,
288
+ client_id: str,
289
+ client_secret: str = None,
290
+ redirect_url: str = None):
291
+ self._token_endpoint = token_endpoint
292
+ self._client_id = client_id
293
+ self._client_secret = client_secret
294
+ self._redirect_url = redirect_url
227
295
  super().__init__(token)
228
296
 
229
297
  def as_dict(self) -> dict:
230
298
  return {'token': self._token.as_dict()}
231
299
 
232
300
  @staticmethod
233
- def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials':
234
- return SessionCredentials(client=client, token=Token.from_dict(raw['token']))
301
+ def from_dict(raw: dict,
302
+ token_endpoint: str,
303
+ client_id: str,
304
+ client_secret: str = None,
305
+ redirect_url: str = None) -> 'SessionCredentials':
306
+ return SessionCredentials(token=Token.from_dict(raw['token']),
307
+ token_endpoint=token_endpoint,
308
+ client_id=client_id,
309
+ client_secret=client_secret,
310
+ redirect_url=redirect_url)
235
311
 
236
312
  def auth_type(self):
237
313
  """Implementing CredentialsProvider protocol"""
@@ -252,13 +328,13 @@ class SessionCredentials(Refreshable):
252
328
  raise ValueError('oauth2: token expired and refresh token is not set')
253
329
  params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token}
254
330
  headers = {}
255
- if 'microsoft' in self._client.token_url:
331
+ if 'microsoft' in self._token_endpoint:
256
332
  # Tokens issued for the 'Single-Page Application' client-type may
257
333
  # only be redeemed via cross-origin requests
258
- headers = {'Origin': self._client.redirect_url}
259
- return retrieve_token(client_id=self._client.client_id,
260
- client_secret=self._client.client_secret,
261
- token_url=self._client.token_url,
334
+ headers = {'Origin': self._redirect_url}
335
+ return retrieve_token(client_id=self._client_id,
336
+ client_secret=self._client_secret,
337
+ token_url=self._token_endpoint,
262
338
  params=params,
263
339
  use_params=True,
264
340
  headers=headers)
@@ -266,27 +342,53 @@ class SessionCredentials(Refreshable):
266
342
 
267
343
  class Consent:
268
344
 
269
- def __init__(self, client: 'OAuthClient', state: str, verifier: str, auth_url: str = None) -> None:
270
- self.auth_url = auth_url
271
-
345
+ def __init__(self,
346
+ state: str,
347
+ verifier: str,
348
+ authorization_url: str,
349
+ redirect_url: str,
350
+ token_endpoint: str,
351
+ client_id: str,
352
+ client_secret: str = None) -> None:
272
353
  self._verifier = verifier
273
354
  self._state = state
274
- self._client = client
355
+ self._authorization_url = authorization_url
356
+ self._redirect_url = redirect_url
357
+ self._token_endpoint = token_endpoint
358
+ self._client_id = client_id
359
+ self._client_secret = client_secret
275
360
 
276
361
  def as_dict(self) -> dict:
277
- return {'state': self._state, 'verifier': self._verifier}
362
+ return {
363
+ 'state': self._state,
364
+ 'verifier': self._verifier,
365
+ 'authorization_url': self._authorization_url,
366
+ 'redirect_url': self._redirect_url,
367
+ 'token_endpoint': self._token_endpoint,
368
+ 'client_id': self._client_id,
369
+ }
370
+
371
+ @property
372
+ def authorization_url(self) -> str:
373
+ return self._authorization_url
278
374
 
279
375
  @staticmethod
280
- def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent':
281
- return Consent(client, raw['state'], raw['verifier'])
376
+ def from_dict(raw: dict, client_secret: str = None) -> 'Consent':
377
+ return Consent(raw['state'],
378
+ raw['verifier'],
379
+ authorization_url=raw['authorization_url'],
380
+ redirect_url=raw['redirect_url'],
381
+ token_endpoint=raw['token_endpoint'],
382
+ client_id=raw['client_id'],
383
+ client_secret=client_secret)
282
384
 
283
385
  def launch_external_browser(self) -> SessionCredentials:
284
- redirect_url = urllib.parse.urlparse(self._client.redirect_url)
386
+ redirect_url = urllib.parse.urlparse(self._redirect_url)
285
387
  if redirect_url.hostname not in ('localhost', '127.0.0.1'):
286
388
  raise ValueError(f'cannot listen on {redirect_url.hostname}')
287
389
  feedback = []
288
- logger.info(f'Opening {self.auth_url} in a browser')
289
- webbrowser.open_new(self.auth_url)
390
+ logger.info(f'Opening {self._authorization_url} in a browser')
391
+ webbrowser.open_new(self._authorization_url)
290
392
  port = redirect_url.port
291
393
  handler_factory = functools.partial(_OAuthCallback, feedback)
292
394
  with HTTPServer(("localhost", port), handler_factory) as httpd:
@@ -308,7 +410,7 @@ class Consent:
308
410
  if self._state != state:
309
411
  raise ValueError('state mismatch')
310
412
  params = {
311
- 'redirect_uri': self._client.redirect_url,
413
+ 'redirect_uri': self._redirect_url,
312
414
  'grant_type': 'authorization_code',
313
415
  'code_verifier': self._verifier,
314
416
  'code': code
@@ -316,19 +418,20 @@ class Consent:
316
418
  headers = {}
317
419
  while True:
318
420
  try:
319
- token = retrieve_token(client_id=self._client.client_id,
320
- client_secret=self._client.client_secret,
321
- token_url=self._client.token_url,
421
+ token = retrieve_token(client_id=self._client_id,
422
+ client_secret=self._client_secret,
423
+ token_url=self._token_endpoint,
322
424
  params=params,
323
425
  headers=headers,
324
426
  use_params=True)
325
- return SessionCredentials(self._client, token)
427
+ return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret,
428
+ self._redirect_url)
326
429
  except ValueError as e:
327
430
  if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
328
431
  # Retry in cases of 'Single-Page Application' client-type with
329
432
  # 'Origin' header equal to client's redirect URL.
330
- headers['Origin'] = self._client.redirect_url
331
- msg = f'Retrying OAuth token exchange with {self._client.redirect_url} origin'
433
+ headers['Origin'] = self._redirect_url
434
+ msg = f'Retrying OAuth token exchange with {self._redirect_url} origin'
332
435
  logger.debug(msg)
333
436
  continue
334
437
  raise e
@@ -354,13 +457,28 @@ class OAuthClient:
354
457
  """
355
458
 
356
459
  def __init__(self,
357
- host: str,
358
- client_id: str,
460
+ oidc_endpoints: OidcEndpoints,
359
461
  redirect_url: str,
360
- *,
462
+ client_id: str,
361
463
  scopes: List[str] = None,
362
464
  client_secret: str = None):
363
- # TODO: is it a circular dependency?..
465
+
466
+ if not scopes:
467
+ scopes = ['all-apis']
468
+
469
+ self.redirect_url = redirect_url
470
+ self._client_id = client_id
471
+ self._client_secret = client_secret
472
+ self._oidc_endpoints = oidc_endpoints
473
+ self._scopes = scopes
474
+
475
+ @staticmethod
476
+ def from_host(host: str,
477
+ client_id: str,
478
+ redirect_url: str,
479
+ *,
480
+ scopes: List[str] = None,
481
+ client_secret: str = None) -> 'OAuthClient':
364
482
  from .core import Config
365
483
  from .credentials_provider import credentials_strategy
366
484
 
@@ -374,18 +492,7 @@ class OAuthClient:
374
492
  oidc = config.oidc_endpoints
375
493
  if not oidc:
376
494
  raise ValueError(f'{host} does not support OAuth')
377
-
378
- self.host = host
379
- self.redirect_url = redirect_url
380
- self.client_id = client_id
381
- self.client_secret = client_secret
382
- self.token_url = oidc.token_endpoint
383
- self.is_aws = config.is_aws
384
- self.is_azure = config.is_azure
385
- self.is_gcp = config.is_gcp
386
-
387
- self._auth_url = oidc.authorization_endpoint
388
- self._scopes = scopes
495
+ return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret)
389
496
 
390
497
  def initiate_consent(self) -> Consent:
391
498
  state = secrets.token_urlsafe(16)
@@ -397,18 +504,24 @@ class OAuthClient:
397
504
 
398
505
  params = {
399
506
  'response_type': 'code',
400
- 'client_id': self.client_id,
507
+ 'client_id': self._client_id,
401
508
  'redirect_uri': self.redirect_url,
402
509
  'scope': ' '.join(self._scopes),
403
510
  'state': state,
404
511
  'code_challenge': challenge,
405
512
  'code_challenge_method': 'S256'
406
513
  }
407
- url = f'{self._auth_url}?{urllib.parse.urlencode(params)}'
408
- return Consent(self, state, verifier, auth_url=url)
514
+ auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}'
515
+ return Consent(state,
516
+ verifier,
517
+ authorization_url=auth_url,
518
+ redirect_url=self.redirect_url,
519
+ token_endpoint=self._oidc_endpoints.token_endpoint,
520
+ client_id=self._client_id,
521
+ client_secret=self._client_secret)
409
522
 
410
523
  def __repr__(self) -> str:
411
- return f'<OAuthClient {self.host} client_id={self.client_id}>'
524
+ return f'<OAuthClient client_id={self._client_id} token_url={self._oidc_endpoints.token_endpoint} auth_url={self._oidc_endpoints.authorization_endpoint}>'
412
525
 
413
526
 
414
527
  @dataclass
@@ -448,17 +561,28 @@ class ClientCredentials(Refreshable):
448
561
  use_header=self.use_header)
449
562
 
450
563
 
451
- class TokenCache():
564
+ class TokenCache:
452
565
  BASE_PATH = "~/.config/databricks-sdk-py/oauth"
453
566
 
454
- def __init__(self, client: OAuthClient) -> None:
455
- self.client = client
567
+ def __init__(self,
568
+ host: str,
569
+ oidc_endpoints: OidcEndpoints,
570
+ client_id: str,
571
+ redirect_url: str = None,
572
+ client_secret: str = None,
573
+ scopes: List[str] = None) -> None:
574
+ self._host = host
575
+ self._client_id = client_id
576
+ self._oidc_endpoints = oidc_endpoints
577
+ self._redirect_url = redirect_url
578
+ self._client_secret = client_secret
579
+ self._scopes = scopes or []
456
580
 
457
581
  @property
458
582
  def filename(self) -> str:
459
583
  # Include host, client_id, and scopes in the cache filename to make it unique.
460
584
  hash = hashlib.sha256()
461
- for chunk in [self.client.host, self.client.client_id, ",".join(self.client._scopes), ]:
585
+ for chunk in [self._host, self._client_id, ",".join(self._scopes), ]:
462
586
  hash.update(chunk.encode('utf-8'))
463
587
  return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))
464
588
 
@@ -472,7 +596,11 @@ class TokenCache():
472
596
  try:
473
597
  with open(self.filename, 'r') as f:
474
598
  raw = json.load(f)
475
- return SessionCredentials.from_dict(self.client, raw)
599
+ return SessionCredentials.from_dict(raw,
600
+ token_endpoint=self._oidc_endpoints.token_endpoint,
601
+ client_id=self._client_id,
602
+ client_secret=self._client_secret,
603
+ redirect_url=self._redirect_url)
476
604
  except Exception:
477
605
  return None
478
606
 
@@ -787,7 +787,7 @@ class AppsAPI:
787
787
  callback: Optional[Callable[[App], None]] = None) -> App:
788
788
  deadline = time.time() + timeout.total_seconds()
789
789
  target_states = (ComputeState.ACTIVE, )
790
- failure_states = (ComputeState.ERROR, )
790
+ failure_states = (ComputeState.ERROR, ComputeState.STOPPED, )
791
791
  status_message = 'polling...'
792
792
  attempt = 1
793
793
  while time.time() < deadline: