zenml-nightly 0.61.0.dev20240712__py3-none-any.whl → 0.62.0.dev20240727__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. README.md +2 -2
  2. RELEASE_NOTES.md +40 -0
  3. zenml/VERSION +1 -1
  4. zenml/__init__.py +2 -0
  5. zenml/cli/stack.py +114 -248
  6. zenml/cli/stack_components.py +5 -3
  7. zenml/config/pipeline_spec.py +2 -2
  8. zenml/config/step_configurations.py +3 -3
  9. zenml/constants.py +3 -0
  10. zenml/enums.py +16 -0
  11. zenml/integrations/__init__.py +1 -0
  12. zenml/integrations/azure/__init__.py +2 -2
  13. zenml/integrations/constants.py +1 -0
  14. zenml/integrations/databricks/__init__.py +52 -0
  15. zenml/integrations/databricks/flavors/__init__.py +30 -0
  16. zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py +118 -0
  17. zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +147 -0
  18. zenml/integrations/databricks/model_deployers/__init__.py +20 -0
  19. zenml/integrations/databricks/model_deployers/databricks_model_deployer.py +249 -0
  20. zenml/integrations/databricks/orchestrators/__init__.py +20 -0
  21. zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +497 -0
  22. zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py +97 -0
  23. zenml/integrations/databricks/services/__init__.py +19 -0
  24. zenml/integrations/databricks/services/databricks_deployment.py +407 -0
  25. zenml/integrations/databricks/utils/__init__.py +14 -0
  26. zenml/integrations/databricks/utils/databricks_utils.py +87 -0
  27. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +44 -28
  28. zenml/integrations/great_expectations/data_validators/ge_data_validator.py +12 -8
  29. zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +88 -3
  30. zenml/integrations/huggingface/steps/accelerate_runner.py +1 -7
  31. zenml/integrations/kubernetes/__init__.py +3 -2
  32. zenml/integrations/kubernetes/flavors/__init__.py +8 -0
  33. zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +166 -0
  34. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +1 -13
  35. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +22 -4
  36. zenml/integrations/kubernetes/pod_settings.py +4 -0
  37. zenml/integrations/kubernetes/step_operators/__init__.py +22 -0
  38. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +235 -0
  39. zenml/integrations/lightgbm/__init__.py +1 -0
  40. zenml/integrations/mlflow/__init__.py +1 -1
  41. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +6 -2
  42. zenml/integrations/mlflow/services/mlflow_deployment.py +1 -1
  43. zenml/integrations/skypilot_azure/__init__.py +1 -3
  44. zenml/integrations/skypilot_lambda/__init__.py +1 -1
  45. zenml/logging/step_logging.py +34 -35
  46. zenml/materializers/built_in_materializer.py +1 -1
  47. zenml/materializers/cloudpickle_materializer.py +1 -1
  48. zenml/model/model.py +1 -1
  49. zenml/models/v2/core/code_repository.py +2 -2
  50. zenml/models/v2/core/component.py +29 -0
  51. zenml/models/v2/core/server_settings.py +0 -20
  52. zenml/models/v2/misc/full_stack.py +32 -0
  53. zenml/models/v2/misc/stack_deployment.py +5 -0
  54. zenml/new/pipelines/run_utils.py +1 -1
  55. zenml/orchestrators/__init__.py +4 -0
  56. zenml/orchestrators/step_launcher.py +1 -0
  57. zenml/orchestrators/wheeled_orchestrator.py +147 -0
  58. zenml/service_connectors/service_connector_utils.py +408 -0
  59. zenml/stack_deployments/azure_stack_deployment.py +179 -0
  60. zenml/stack_deployments/gcp_stack_deployment.py +13 -4
  61. zenml/stack_deployments/stack_deployment.py +10 -0
  62. zenml/stack_deployments/utils.py +4 -0
  63. zenml/steps/base_step.py +7 -5
  64. zenml/utils/function_utils.py +1 -1
  65. zenml/utils/pipeline_docker_image_builder.py +8 -0
  66. zenml/utils/source_utils.py +4 -1
  67. zenml/zen_server/dashboard/assets/{404-DpJaNHKF.js → 404-B_YdvmwS.js} +1 -1
  68. zenml/zen_server/dashboard/assets/{@reactflow-DJfzkHO1.js → @reactflow-l_1hUr1S.js} +1 -1
  69. zenml/zen_server/dashboard/assets/{AwarenessChannel-BYDLT2xC.js → AwarenessChannel-CFg5iX4Z.js} +1 -1
  70. zenml/zen_server/dashboard/assets/{CodeSnippet-BkOuRmyq.js → CodeSnippet-Dvkx_82E.js} +1 -1
  71. zenml/zen_server/dashboard/assets/CollapsibleCard-opiuBHHc.js +1 -0
  72. zenml/zen_server/dashboard/assets/{Commands-ZvWR1BRs.js → Commands-DoN1xrEq.js} +1 -1
  73. zenml/zen_server/dashboard/assets/{CopyButton-DVwLkafa.js → CopyButton-Cr7xYEPb.js} +1 -1
  74. zenml/zen_server/dashboard/assets/{CsvVizualization-C2IiqX4I.js → CsvVizualization-Ck-nZ43m.js} +3 -3
  75. zenml/zen_server/dashboard/assets/{Error-CqX0VqW_.js → Error-kLtljEOM.js} +1 -1
  76. zenml/zen_server/dashboard/assets/{ExecutionStatus-BoLUXR9t.js → ExecutionStatus-DguLLgTK.js} +1 -1
  77. zenml/zen_server/dashboard/assets/{Helpbox-LFydyVwh.js → Helpbox-BXUMP21n.js} +1 -1
  78. zenml/zen_server/dashboard/assets/{Infobox-DnENC0sh.js → Infobox-DSt0O-dm.js} +1 -1
  79. zenml/zen_server/dashboard/assets/{InlineAvatar-CbJtYr0t.js → InlineAvatar-xsrsIGE-.js} +1 -1
  80. zenml/zen_server/dashboard/assets/Pagination-C6X-mifw.js +1 -0
  81. zenml/zen_server/dashboard/assets/{SetPassword-BYBdbQDo.js → SetPassword-BXGTWiwj.js} +1 -1
  82. zenml/zen_server/dashboard/assets/{SuccessStep-Nx743hll.js → SuccessStep-DZC60t0x.js} +1 -1
  83. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-DF9gSzE0.js → UpdatePasswordSchemas-DGvwFWO1.js} +1 -1
  84. zenml/zen_server/dashboard/assets/{chevron-right-double-BiEMg7rd.js → chevron-right-double-CZBOf6JM.js} +1 -1
  85. zenml/zen_server/dashboard/assets/cloud-only-C_yFCAkP.js +1 -0
  86. zenml/zen_server/dashboard/assets/index-BczVOqUf.js +55 -0
  87. zenml/zen_server/dashboard/assets/index-EpMIKgrI.css +1 -0
  88. zenml/zen_server/dashboard/assets/{login-mutation-BUnVASxp.js → login-mutation-CrHrndTI.js} +1 -1
  89. zenml/zen_server/dashboard/assets/logs-D8k8BVFf.js +1 -0
  90. zenml/zen_server/dashboard/assets/{not-found-B4VnX8gK.js → not-found-DYa4pC-C.js} +1 -1
  91. zenml/zen_server/dashboard/assets/{package-CsUhPmou.js → package-B3fWP-Dh.js} +1 -1
  92. zenml/zen_server/dashboard/assets/page-1h_sD1jz.js +1 -0
  93. zenml/zen_server/dashboard/assets/{page-Sxn82W-5.js → page-1iL8aMqs.js} +1 -1
  94. zenml/zen_server/dashboard/assets/{page-DMOYZppS.js → page-2grKx_MY.js} +1 -1
  95. zenml/zen_server/dashboard/assets/page-5NCOHOsy.js +1 -0
  96. zenml/zen_server/dashboard/assets/{page-JyfeDUfu.js → page-8a4UMKXZ.js} +1 -1
  97. zenml/zen_server/dashboard/assets/{page-Bx6o0ARS.js → page-B6h3iaHJ.js} +1 -1
  98. zenml/zen_server/dashboard/assets/page-BDns21Iz.js +1 -0
  99. zenml/zen_server/dashboard/assets/{page-3efNCDeb.js → page-BhgCDInH.js} +2 -2
  100. zenml/zen_server/dashboard/assets/{page-DKlIdAe5.js → page-Bi-wtWiO.js} +2 -2
  101. zenml/zen_server/dashboard/assets/{page-7zTHbhhI.js → page-BkeAAYwp.js} +1 -1
  102. zenml/zen_server/dashboard/assets/{page-CRTJ0UuR.js → page-BkuQDIf-.js} +1 -1
  103. zenml/zen_server/dashboard/assets/page-BnaevhnB.js +1 -0
  104. zenml/zen_server/dashboard/assets/{page-BEs6jK71.js → page-Bq0YxkLV.js} +1 -1
  105. zenml/zen_server/dashboard/assets/page-Bs2F4eoD.js +2 -0
  106. zenml/zen_server/dashboard/assets/{page-CUZIGO-3.js → page-C6-UGEbH.js} +1 -1
  107. zenml/zen_server/dashboard/assets/{page-Xu8JEjSU.js → page-CCNRIt_f.js} +1 -1
  108. zenml/zen_server/dashboard/assets/{page-DvCvroOM.js → page-CHNxpz3n.js} +1 -1
  109. zenml/zen_server/dashboard/assets/{page-BpSqIf4B.js → page-DgorQFqi.js} +1 -1
  110. zenml/zen_server/dashboard/assets/page-K8ebxVIs.js +1 -0
  111. zenml/zen_server/dashboard/assets/{page-Cx67M0QT.js → page-MFQyIJd3.js} +1 -1
  112. zenml/zen_server/dashboard/assets/page-TgCF0P_U.js +1 -0
  113. zenml/zen_server/dashboard/assets/page-ZnCEe-eK.js +9 -0
  114. zenml/zen_server/dashboard/assets/{page-Dc_7KMQE.js → page-uA5prJGY.js} +1 -1
  115. zenml/zen_server/dashboard/assets/persist-D7HJNBWx.js +1 -0
  116. zenml/zen_server/dashboard/assets/plus-C8WOyCzt.js +1 -0
  117. zenml/zen_server/dashboard/assets/stack-detail-query-Cficsl6d.js +1 -0
  118. zenml/zen_server/dashboard/assets/update-server-settings-mutation-7d8xi1tS.js +1 -0
  119. zenml/zen_server/dashboard/assets/{url-DuQMeqYA.js → url-D7mAQGUM.js} +1 -1
  120. zenml/zen_server/dashboard/index.html +4 -4
  121. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  122. zenml/zen_server/dashboard_legacy/index.html +1 -1
  123. zenml/zen_server/dashboard_legacy/{precache-manifest.c8c57fb0d2132b1d3c2119e776b7dfb3.js → precache-manifest.12246c7548e71e2c4438e496360de80c.js} +4 -4
  124. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  125. zenml/zen_server/dashboard_legacy/static/js/main.3b27024b.chunk.js +2 -0
  126. zenml/zen_server/dashboard_legacy/static/js/{main.382439a7.chunk.js.map → main.3b27024b.chunk.js.map} +1 -1
  127. zenml/zen_server/deploy/helm/Chart.yaml +1 -1
  128. zenml/zen_server/deploy/helm/README.md +2 -2
  129. zenml/zen_server/rbac/utils.py +10 -2
  130. zenml/zen_server/routers/devices_endpoints.py +4 -1
  131. zenml/zen_server/routers/server_endpoints.py +29 -2
  132. zenml/zen_server/routers/service_connectors_endpoints.py +57 -0
  133. zenml/zen_server/routers/steps_endpoints.py +2 -1
  134. zenml/zen_stores/migrations/versions/0.62.0_release.py +23 -0
  135. zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py +167 -0
  136. zenml/zen_stores/rest_zen_store.py +4 -0
  137. zenml/zen_stores/schemas/component_schemas.py +14 -0
  138. zenml/zen_stores/schemas/server_settings_schemas.py +23 -11
  139. zenml/zen_stores/sql_zen_store.py +151 -1
  140. {zenml_nightly-0.61.0.dev20240712.dist-info → zenml_nightly-0.62.0.dev20240727.dist-info}/METADATA +5 -5
  141. {zenml_nightly-0.61.0.dev20240712.dist-info → zenml_nightly-0.62.0.dev20240727.dist-info}/RECORD +144 -121
  142. zenml/zen_server/dashboard/assets/Pagination-DEbVUupy.js +0 -1
  143. zenml/zen_server/dashboard/assets/chevron-down-D_ZlKMqH.js +0 -1
  144. zenml/zen_server/dashboard/assets/cloud-only-DVbIeckv.js +0 -1
  145. zenml/zen_server/dashboard/assets/index-C_CrU4vI.js +0 -1
  146. zenml/zen_server/dashboard/assets/index-DK1ynKjA.js +0 -55
  147. zenml/zen_server/dashboard/assets/index-inApY3KQ.css +0 -1
  148. zenml/zen_server/dashboard/assets/page-C43QGHTt.js +0 -9
  149. zenml/zen_server/dashboard/assets/page-CR0OG7ss.js +0 -1
  150. zenml/zen_server/dashboard/assets/page-CaopxiU1.js +0 -1
  151. zenml/zen_server/dashboard/assets/page-D7Z399xy.js +0 -1
  152. zenml/zen_server/dashboard/assets/page-D93kd7Xj.js +0 -1
  153. zenml/zen_server/dashboard/assets/page-DMsSn3dv.js +0 -2
  154. zenml/zen_server/dashboard/assets/page-Hus2pr9T.js +0 -1
  155. zenml/zen_server/dashboard/assets/page-TKXERe16.js +0 -1
  156. zenml/zen_server/dashboard/assets/plus-DOeLmm7C.js +0 -1
  157. zenml/zen_server/dashboard/assets/update-server-settings-mutation-CR8e3Sir.js +0 -1
  158. zenml/zen_server/dashboard_legacy/static/js/main.382439a7.chunk.js +0 -2
  159. {zenml_nightly-0.61.0.dev20240712.dist-info → zenml_nightly-0.62.0.dev20240727.dist-info}/LICENSE +0 -0
  160. {zenml_nightly-0.61.0.dev20240712.dist-info → zenml_nightly-0.62.0.dev20240727.dist-info}/WHEEL +0 -0
  161. {zenml_nightly-0.61.0.dev20240712.dist-info → zenml_nightly-0.62.0.dev20240727.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,407 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Implementation of the Databricks Deployment service."""
15
+
16
+ import time
17
+ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import requests
22
+ from databricks.sdk import WorkspaceClient as DatabricksClient
23
+ from databricks.sdk.service.serving import (
24
+ EndpointCoreConfigInput,
25
+ EndpointStateConfigUpdate,
26
+ EndpointStateReady,
27
+ EndpointTag,
28
+ ServedModelInput,
29
+ ServingEndpointDetailed,
30
+ )
31
+ from pydantic import Field
32
+
33
+ from zenml.client import Client
34
+ from zenml.integrations.databricks.flavors.databricks_model_deployer_flavor import (
35
+ DatabricksBaseConfig,
36
+ )
37
+ from zenml.integrations.databricks.utils.databricks_utils import (
38
+ sanitize_labels,
39
+ )
40
+ from zenml.logger import get_logger
41
+ from zenml.services import ServiceState, ServiceStatus, ServiceType
42
+ from zenml.services.service import BaseDeploymentService, ServiceConfig
43
+
44
+ logger = get_logger(__name__)
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from numpy.typing import NDArray
49
+
50
+ POLLING_TIMEOUT = 1200
51
+ UUID_SLICE_LENGTH: int = 8
52
+
53
+
54
+ class DatabricksDeploymentConfig(DatabricksBaseConfig, ServiceConfig):
55
+ """Databricks service configurations."""
56
+
57
+ model_uri: Optional[str] = Field(
58
+ None,
59
+ description="URI of the model to deploy. This can be a local path or a cloud storage path.",
60
+ )
61
+ host: Optional[str] = Field(
62
+ None, description="Databricks host URL for the deployment."
63
+ )
64
+
65
+ def get_databricks_deployment_labels(self) -> Dict[str, str]:
66
+ """Generate labels for the Databricks deployment from the service configuration.
67
+
68
+ These labels are attached to the Databricks deployment resource
69
+ and may be used as label selectors in lookup operations.
70
+
71
+ Returns:
72
+ The labels for the Databricks deployment.
73
+ """
74
+ labels = {}
75
+ if self.pipeline_name:
76
+ labels["zenml_pipeline_name"] = self.pipeline_name
77
+ if self.pipeline_step_name:
78
+ labels["zenml_pipeline_step_name"] = self.pipeline_step_name
79
+ if self.model_name:
80
+ labels["zenml_model_name"] = self.model_name
81
+ if self.model_uri:
82
+ labels["zenml_model_uri"] = self.model_uri
83
+ sanitize_labels(labels)
84
+ return labels
85
+
86
+
87
+ class DatabricksServiceStatus(ServiceStatus):
88
+ """Databricks service status."""
89
+
90
+
91
+ class DatabricksDeploymentService(BaseDeploymentService):
92
+ """Databricks model deployment service.
93
+
94
+ Attributes:
95
+ SERVICE_TYPE: a service type descriptor with information describing
96
+ the Databricks deployment service class
97
+ config: service configuration
98
+ """
99
+
100
+ SERVICE_TYPE = ServiceType(
101
+ name="databricks-deployment",
102
+ type="model-serving",
103
+ flavor="databricks",
104
+ description="Databricks inference endpoint prediction service",
105
+ )
106
+ config: DatabricksDeploymentConfig
107
+ status: DatabricksServiceStatus = Field(
108
+ default_factory=lambda: DatabricksServiceStatus()
109
+ )
110
+
111
+ def __init__(self, config: DatabricksDeploymentConfig, **attrs: Any):
112
+ """Initialize the Databricks deployment service.
113
+
114
+ Args:
115
+ config: service configuration
116
+ attrs: additional attributes to set on the service
117
+ """
118
+ super().__init__(config=config, **attrs)
119
+
120
+ def get_client_id_and_secret(self) -> Tuple[str, str, str]:
121
+ """Get the Databricks client id and secret.
122
+
123
+ Raises:
124
+ ValueError: If client id and secret are not found.
125
+
126
+ Returns:
127
+ Databricks client id and secret.
128
+ """
129
+ client = Client()
130
+ client_id = None
131
+ client_secret = None
132
+ host = None
133
+ from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
134
+ DatabricksModelDeployer,
135
+ )
136
+
137
+ model_deployer = client.active_stack.model_deployer
138
+ if not isinstance(model_deployer, DatabricksModelDeployer):
139
+ raise ValueError(
140
+ "DatabricksModelDeployer is not active in the stack."
141
+ )
142
+ host = model_deployer.config.host
143
+ self.config.host = host
144
+ if model_deployer.config.secret_name:
145
+ secret = client.get_secret(model_deployer.config.secret_name)
146
+ client_id = secret.secret_values["client_id"]
147
+ client_secret = secret.secret_values["client_secret"]
148
+
149
+ else:
150
+ client_id = model_deployer.config.client_id
151
+ client_secret = model_deployer.config.client_secret
152
+ if not client_id:
153
+ raise ValueError("Client id not found.")
154
+ if not client_secret:
155
+ raise ValueError("Client secret not found.")
156
+ if not host:
157
+ raise ValueError("Host not found.")
158
+ return host, client_id, client_secret
159
+
160
+ def _get_databricks_deployment_labels(self) -> Dict[str, str]:
161
+ """Generate the labels for the Databricks deployment from the service configuration.
162
+
163
+ Returns:
164
+ The labels for the Databricks deployment.
165
+ """
166
+ labels = self.config.get_databricks_deployment_labels()
167
+ labels["zenml_service_uuid"] = str(self.uuid)
168
+ sanitize_labels(labels)
169
+ return labels
170
+
171
+ @property
172
+ def databricks_client(self) -> DatabricksClient:
173
+ """Get the deployed Databricks inference endpoint.
174
+
175
+ Returns:
176
+ databricks inference endpoint.
177
+ """
178
+ return DatabricksClient(
179
+ host=self.get_client_id_and_secret()[0],
180
+ client_id=self.get_client_id_and_secret()[1],
181
+ client_secret=self.get_client_id_and_secret()[2],
182
+ )
183
+
184
+ @property
185
+ def databricks_endpoint(self) -> ServingEndpointDetailed:
186
+ """Get the deployed Hugging Face inference endpoint.
187
+
188
+ Returns:
189
+ Databricks inference endpoint.
190
+ """
191
+ return self.databricks_client.serving_endpoints.get(
192
+ name=self._generate_an_endpoint_name(),
193
+ )
194
+
195
+ @property
196
+ def prediction_url(self) -> Optional[str]:
197
+ """The prediction URI exposed by the prediction service.
198
+
199
+ Returns:
200
+ The prediction URI exposed by the prediction service, or None if
201
+ the service is not yet ready.
202
+ """
203
+ return f"{self.config.host}/serving-endpoints/{self._generate_an_endpoint_name()}/invocations"
204
+
205
+ def provision(self) -> None:
206
+ """Provision or update remote Databricks deployment instance."""
207
+ from databricks.sdk.service.serving import (
208
+ ServedModelInputWorkloadSize,
209
+ ServedModelInputWorkloadType,
210
+ )
211
+
212
+ tags = []
213
+ for key, value in self._get_databricks_deployment_labels().items():
214
+ tags.append(EndpointTag(key=key, value=value))
215
+ # Attempt to create and wait for the inference endpoint
216
+ served_model = ServedModelInput(
217
+ model_name=self.config.model_name,
218
+ model_version=self.config.model_version,
219
+ scale_to_zero_enabled=self.config.scale_to_zero_enabled,
220
+ workload_type=ServedModelInputWorkloadType(
221
+ self.config.workload_type
222
+ ),
223
+ workload_size=ServedModelInputWorkloadSize(
224
+ self.config.workload_size
225
+ ),
226
+ )
227
+
228
+ databricks_endpoint = (
229
+ self.databricks_client.serving_endpoints.create_and_wait(
230
+ name=self._generate_an_endpoint_name(),
231
+ config=EndpointCoreConfigInput(
232
+ served_models=[served_model],
233
+ ),
234
+ tags=tags,
235
+ )
236
+ )
237
+ # Check if the endpoint URL is available after provisioning
238
+ if databricks_endpoint.endpoint_url:
239
+ logger.info(
240
+ f"Databricks inference endpoint successfully deployed and available. Endpoint URL: {databricks_endpoint.endpoint_url}"
241
+ )
242
+ else:
243
+ logger.error(
244
+ "Failed to start Databricks inference endpoint service: No URL available, please check the Databricks console for more details."
245
+ )
246
+
247
+ def check_status(self) -> Tuple[ServiceState, str]:
248
+ """Check the the current operational state of the Databricks deployment.
249
+
250
+ Returns:
251
+ The operational state of the Databricks deployment and a message
252
+ providing additional information about that state (e.g. a
253
+ description of the error, if one is encountered).
254
+ """
255
+ try:
256
+ status = self.databricks_endpoint.state or None
257
+ if (
258
+ status
259
+ and status.ready
260
+ and status.ready == EndpointStateReady.READY
261
+ ):
262
+ return (ServiceState.ACTIVE, "")
263
+ elif (
264
+ status
265
+ and status.config_update
266
+ and status.config_update
267
+ == EndpointStateConfigUpdate.UPDATE_FAILED
268
+ ):
269
+ return (
270
+ ServiceState.ERROR,
271
+ "Databricks Inference Endpoint deployment update failed",
272
+ )
273
+ elif (
274
+ status
275
+ and status.config_update
276
+ and status.config_update
277
+ == EndpointStateConfigUpdate.IN_PROGRESS
278
+ ):
279
+ return (ServiceState.PENDING_STARTUP, "")
280
+ return (ServiceState.PENDING_STARTUP, "")
281
+ except Exception as e:
282
+ return (
283
+ ServiceState.INACTIVE,
284
+ f"Databricks Inference Endpoint deployment is inactive or not found: {e}",
285
+ )
286
+
287
+ def deprovision(self, force: bool = False) -> None:
288
+ """Deprovision the remote Databricks deployment instance.
289
+
290
+ Args:
291
+ force: if True, the remote deployment instance will be
292
+ forcefully deprovisioned.
293
+ """
294
+ try:
295
+ self.databricks_client.serving_endpoints.delete(
296
+ name=self._generate_an_endpoint_name()
297
+ )
298
+ except Exception:
299
+ logger.error(
300
+ "Databricks Inference Endpoint is deleted or cannot be found."
301
+ )
302
+
303
+ def predict(
304
+ self, request: Union["NDArray[Any]", pd.DataFrame]
305
+ ) -> "NDArray[Any]":
306
+ """Make a prediction using the service.
307
+
308
+ Args:
309
+ request: The input data for the prediction.
310
+
311
+ Returns:
312
+ The prediction result.
313
+
314
+ Raises:
315
+ Exception: if the service is not running
316
+ ValueError: if the endpoint secret name is not provided.
317
+ """
318
+ if not self.is_running:
319
+ raise Exception(
320
+ "Databricks endpoint inference service is not running. "
321
+ "Please start the service before making predictions."
322
+ )
323
+ if self.prediction_url is not None:
324
+ if not self.config.endpoint_secret_name:
325
+ raise ValueError(
326
+ "No endpoint secret name is provided for prediction."
327
+ )
328
+ databricks_token = Client().get_secret(
329
+ self.config.endpoint_secret_name
330
+ )
331
+ if not databricks_token.secret_values["token"]:
332
+ raise ValueError("No databricks token found.")
333
+ headers = {
334
+ "Authorization": f"Bearer {databricks_token.secret_values['token']}",
335
+ "Content-Type": "application/json",
336
+ }
337
+ if isinstance(request, pd.DataFrame):
338
+ response = requests.post( # nosec
339
+ self.prediction_url,
340
+ json={"instances": request.to_dict("records")},
341
+ headers=headers,
342
+ )
343
+ else:
344
+ response = requests.post( # nosec
345
+ self.prediction_url,
346
+ json={"instances": request.tolist()},
347
+ headers=headers,
348
+ )
349
+ else:
350
+ raise ValueError("No endpoint known for prediction.")
351
+ response.raise_for_status()
352
+
353
+ return np.array(response.json()["predictions"])
354
+
355
+ def get_logs(
356
+ self, follow: bool = False, tail: Optional[int] = None
357
+ ) -> Generator[str, bool, None]:
358
+ """Retrieve the service logs.
359
+
360
+ Args:
361
+ follow: if True, the logs will be streamed as they are written
362
+ tail: only retrieve the last NUM lines of log output.
363
+
364
+ Yields:
365
+ A generator that can be accessed to get the service logs.
366
+ """
367
+ logger.info(
368
+ "Databricks Endpoints provides access to the logs of your Endpoints through the UI in the `Logs` tab of your Endpoint"
369
+ )
370
+
371
+ def log_generator() -> Generator[str, bool, None]:
372
+ last_log_count = 0
373
+ while True:
374
+ logs = self.databricks_client.serving_endpoints.logs(
375
+ name=self._generate_an_endpoint_name(),
376
+ served_model_name=self.config.model_name,
377
+ )
378
+
379
+ log_lines = logs.logs.split("\n")
380
+
381
+ # Apply tail if specified and it's the first iteration
382
+ if tail is not None and last_log_count == 0:
383
+ log_lines = log_lines[-tail:]
384
+
385
+ # Yield only new lines
386
+ for line in log_lines[last_log_count:]:
387
+ yield line
388
+
389
+ last_log_count = len(log_lines)
390
+
391
+ if not follow:
392
+ break
393
+
394
+ # Add a small delay to avoid excessive API calls
395
+ time.sleep(1)
396
+
397
+ yield from log_generator()
398
+
399
+ def _generate_an_endpoint_name(self) -> str:
400
+ """Generate a unique name for the Databricks Inference Endpoint.
401
+
402
+ Returns:
403
+ A unique name for the Databricks Inference Endpoint.
404
+ """
405
+ return (
406
+ f"{self.config.service_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}"
407
+ )
@@ -0,0 +1,14 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Utilities for Databricks integration."""
@@ -0,0 +1,87 @@
1
+ # Copyright (c) ZenML GmbH 2023. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Databricks utilities."""
15
+
16
+ import re
17
+ from typing import Dict, List, Optional
18
+
19
+ from databricks.sdk.service.compute import Library, PythonPyPiLibrary
20
+ from databricks.sdk.service.jobs import PythonWheelTask, TaskDependency
21
+ from databricks.sdk.service.jobs import Task as DatabricksTask
22
+
23
+ from zenml import __version__
24
+
25
+
26
+ def convert_step_to_task(
27
+ task_name: str,
28
+ command: str,
29
+ arguments: List[str],
30
+ libraries: Optional[List[str]] = None,
31
+ depends_on: Optional[List[str]] = None,
32
+ zenml_project_wheel: Optional[str] = None,
33
+ job_cluster_key: Optional[str] = None,
34
+ ) -> DatabricksTask:
35
+ """Convert a ZenML step to a Databricks task.
36
+
37
+ Args:
38
+ task_name: Name of the task.
39
+ command: Command to run.
40
+ arguments: Arguments to pass to the command.
41
+ libraries: List of libraries to install.
42
+ depends_on: List of tasks to depend on.
43
+ zenml_project_wheel: Path to the ZenML project wheel.
44
+ job_cluster_key: ID of the Databricks job_cluster_key.
45
+
46
+ Returns:
47
+ Databricks task.
48
+ """
49
+ db_libraries = []
50
+ if libraries:
51
+ for library in libraries:
52
+ db_libraries.append(Library(pypi=PythonPyPiLibrary(library)))
53
+ db_libraries.append(Library(whl=zenml_project_wheel))
54
+ db_libraries.append(
55
+ Library(pypi=PythonPyPiLibrary(f"zenml=={__version__}"))
56
+ )
57
+ return DatabricksTask(
58
+ task_key=task_name,
59
+ job_cluster_key=job_cluster_key,
60
+ libraries=db_libraries,
61
+ python_wheel_task=PythonWheelTask(
62
+ package_name="zenml",
63
+ entry_point=command,
64
+ parameters=arguments,
65
+ ),
66
+ depends_on=[TaskDependency(task) for task in depends_on]
67
+ if depends_on
68
+ else None,
69
+ )
70
+
71
+
72
+ def sanitize_labels(labels: Dict[str, str]) -> None:
73
+ """Update the label values to be valid Kubernetes labels.
74
+
75
+ See:
76
+ https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
77
+
78
+ Args:
79
+ labels: the labels to sanitize.
80
+ """
81
+ for key, value in labels.items():
82
+ # Kubernetes labels must be alphanumeric, no longer than
83
+ # 63 characters, and must begin and end with an alphanumeric
84
+ # character ([a-z0-9A-Z])
85
+ labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
86
+ "-_."
87
+ )
@@ -363,7 +363,6 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
363
363
  pipeline_func
364
364
  """
365
365
  step_name_to_dynamic_component: Dict[str, Any] = {}
366
- node_selector_constraint: Optional[Tuple[str, str]] = None
367
366
 
368
367
  for step_name, step in deployment.step_configurations.items():
369
368
  image = self.get_image(
@@ -410,23 +409,17 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
410
409
  "Volume mounts are set but not supported in "
411
410
  "Vertex with Kubeflow Pipelines 2.x. Ignoring..."
412
411
  )
413
-
414
- # apply pod settings
415
- if (
416
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
417
- in pod_settings.node_selectors.keys()
418
- ):
419
- node_selector_constraint = (
420
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
421
- pod_settings.node_selectors[
422
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
423
- ],
424
- )
425
- elif step_settings.node_selector_constraint:
426
- node_selector_constraint = (
427
- GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
428
- step_settings.node_selector_constraint[1],
429
- )
412
+ for key in pod_settings.node_selectors:
413
+ if (
414
+ key
415
+ != GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
416
+ ):
417
+ logger.warning(
418
+ "Vertex only allows the %s node selector, "
419
+ "ignoring the node selector %s.",
420
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
421
+ key,
422
+ )
430
423
 
431
424
  step_name_to_dynamic_component[step_name] = dynamic_component
432
425
 
@@ -460,10 +453,33 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
460
453
  )
461
454
  .after(*upstream_step_components)
462
455
  )
456
+
457
+ step_settings = cast(
458
+ VertexOrchestratorSettings, self.get_settings(step)
459
+ )
460
+ pod_settings = step_settings.pod_settings
461
+
462
+ node_selector_constraint: Optional[Tuple[str, str]] = None
463
+ if pod_settings and (
464
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
465
+ in pod_settings.node_selectors.keys()
466
+ ):
467
+ node_selector_constraint = (
468
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
469
+ pod_settings.node_selectors[
470
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
471
+ ],
472
+ )
473
+ elif step_settings.node_selector_constraint:
474
+ node_selector_constraint = (
475
+ GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
476
+ step_settings.node_selector_constraint[1],
477
+ )
478
+
463
479
  self._configure_container_resources(
464
- task,
465
- step.config.resource_settings,
466
- node_selector_constraint,
480
+ dynamic_component=task,
481
+ resource_settings=step.config.resource_settings,
482
+ node_selector_constraint=node_selector_constraint,
467
483
  )
468
484
 
469
485
  return dynamic_pipeline
@@ -731,20 +747,20 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
731
747
  )
732
748
 
733
749
  if node_selector_constraint:
734
- (constraint_label, value) = node_selector_constraint
750
+ _, value = node_selector_constraint
735
751
  if gpu_limit is not None and gpu_limit > 0:
736
752
  dynamic_component = (
737
753
  dynamic_component.set_accelerator_type(value)
738
754
  .set_accelerator_limit(gpu_limit)
739
755
  .set_gpu_limit(gpu_limit)
740
756
  )
741
- elif (
742
- constraint_label
743
- == GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL
744
- and gpu_limit == 0
745
- ):
757
+ else:
746
758
  logger.warning(
747
- "GPU limit is set to 0 but a GPU type is specified. Ignoring GPU settings."
759
+ "Accelerator type %s specified, but the GPU limit is not "
760
+ "set or set to 0. The accelerator type will be ignored. "
761
+ "To fix this warning, either remove the specified "
762
+ "accelerator type or set the `gpu_count` using the "
763
+ "ResourceSettings (https://docs.zenml.io/how-to/training-with-gpus#specify-resource-requirements-for-steps)."
748
764
  )
749
765
 
750
766
  return dynamic_component
@@ -284,14 +284,18 @@ class GreatExpectationsDataValidator(BaseDataValidator):
284
284
  store_name=store_name,
285
285
  store_config=store_config,
286
286
  )
287
- for site_name, site_config in zenml_context_config[
288
- "data_docs_sites"
289
- ].items():
290
- self._context.config.data_docs_sites[site_name] = (
291
- site_config
292
- )
293
-
294
- if self.config.configure_local_docs:
287
+ if self._context.config.data_docs_sites is not None:
288
+ for site_name, site_config in zenml_context_config[
289
+ "data_docs_sites"
290
+ ].items():
291
+ self._context.config.data_docs_sites[site_name] = (
292
+ site_config
293
+ )
294
+
295
+ if (
296
+ self.config.configure_local_docs
297
+ and self._context.config.data_docs_sites is not None
298
+ ):
295
299
  client = Client()
296
300
  artifact_store = client.active_stack.artifact_store
297
301
  if artifact_store.flavor != "local":