zenml-nightly 0.61.0.dev20240714__py3-none-any.whl → 0.62.0.dev20240719__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. README.md +1 -1
  2. RELEASE_NOTES.md +40 -0
  3. zenml/VERSION +1 -1
  4. zenml/__init__.py +2 -0
  5. zenml/cli/stack.py +87 -228
  6. zenml/cli/stack_components.py +5 -3
  7. zenml/constants.py +2 -0
  8. zenml/entrypoints/entrypoint.py +3 -1
  9. zenml/integrations/__init__.py +1 -0
  10. zenml/integrations/constants.py +1 -0
  11. zenml/integrations/databricks/__init__.py +52 -0
  12. zenml/integrations/databricks/flavors/__init__.py +30 -0
  13. zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py +118 -0
  14. zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +147 -0
  15. zenml/integrations/databricks/model_deployers/__init__.py +20 -0
  16. zenml/integrations/databricks/model_deployers/databricks_model_deployer.py +249 -0
  17. zenml/integrations/databricks/orchestrators/__init__.py +20 -0
  18. zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +498 -0
  19. zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py +97 -0
  20. zenml/integrations/databricks/services/__init__.py +19 -0
  21. zenml/integrations/databricks/services/databricks_deployment.py +407 -0
  22. zenml/integrations/databricks/utils/__init__.py +14 -0
  23. zenml/integrations/databricks/utils/databricks_utils.py +87 -0
  24. zenml/integrations/great_expectations/data_validators/ge_data_validator.py +12 -8
  25. zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +88 -3
  26. zenml/integrations/huggingface/steps/accelerate_runner.py +1 -7
  27. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +7 -0
  28. zenml/integrations/kubernetes/pod_settings.py +2 -0
  29. zenml/integrations/lightgbm/__init__.py +1 -0
  30. zenml/integrations/mlflow/__init__.py +1 -1
  31. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +6 -2
  32. zenml/integrations/mlflow/services/mlflow_deployment.py +1 -1
  33. zenml/integrations/skypilot_lambda/__init__.py +1 -1
  34. zenml/materializers/built_in_materializer.py +1 -1
  35. zenml/materializers/cloudpickle_materializer.py +1 -1
  36. zenml/model/model.py +1 -1
  37. zenml/models/v2/core/component.py +29 -0
  38. zenml/models/v2/misc/full_stack.py +32 -0
  39. zenml/orchestrators/__init__.py +4 -0
  40. zenml/orchestrators/wheeled_orchestrator.py +147 -0
  41. zenml/service_connectors/service_connector_utils.py +349 -0
  42. zenml/stack_deployments/gcp_stack_deployment.py +2 -4
  43. zenml/steps/base_step.py +7 -5
  44. zenml/utils/function_utils.py +1 -1
  45. zenml/utils/pipeline_docker_image_builder.py +8 -0
  46. zenml/zen_server/dashboard/assets/{404-DpJaNHKF.js → 404-B_YdvmwS.js} +1 -1
  47. zenml/zen_server/dashboard/assets/{@reactflow-DJfzkHO1.js → @reactflow-l_1hUr1S.js} +1 -1
  48. zenml/zen_server/dashboard/assets/{AwarenessChannel-BYDLT2xC.js → AwarenessChannel-CFg5iX4Z.js} +1 -1
  49. zenml/zen_server/dashboard/assets/{CodeSnippet-BkOuRmyq.js → CodeSnippet-Dvkx_82E.js} +1 -1
  50. zenml/zen_server/dashboard/assets/CollapsibleCard-opiuBHHc.js +1 -0
  51. zenml/zen_server/dashboard/assets/{Commands-ZvWR1BRs.js → Commands-DoN1xrEq.js} +1 -1
  52. zenml/zen_server/dashboard/assets/{CopyButton-DVwLkafa.js → CopyButton-Cr7xYEPb.js} +1 -1
  53. zenml/zen_server/dashboard/assets/{CsvVizualization-C2IiqX4I.js → CsvVizualization-Ck-nZ43m.js} +3 -3
  54. zenml/zen_server/dashboard/assets/{Error-CqX0VqW_.js → Error-kLtljEOM.js} +1 -1
  55. zenml/zen_server/dashboard/assets/{ExecutionStatus-BoLUXR9t.js → ExecutionStatus-DguLLgTK.js} +1 -1
  56. zenml/zen_server/dashboard/assets/{Helpbox-LFydyVwh.js → Helpbox-BXUMP21n.js} +1 -1
  57. zenml/zen_server/dashboard/assets/{Infobox-DnENC0sh.js → Infobox-DSt0O-dm.js} +1 -1
  58. zenml/zen_server/dashboard/assets/{InlineAvatar-CbJtYr0t.js → InlineAvatar-xsrsIGE-.js} +1 -1
  59. zenml/zen_server/dashboard/assets/Pagination-C6X-mifw.js +1 -0
  60. zenml/zen_server/dashboard/assets/{SetPassword-BYBdbQDo.js → SetPassword-BXGTWiwj.js} +1 -1
  61. zenml/zen_server/dashboard/assets/{SuccessStep-Nx743hll.js → SuccessStep-DZC60t0x.js} +1 -1
  62. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-DF9gSzE0.js → UpdatePasswordSchemas-DGvwFWO1.js} +1 -1
  63. zenml/zen_server/dashboard/assets/{chevron-right-double-BiEMg7rd.js → chevron-right-double-CZBOf6JM.js} +1 -1
  64. zenml/zen_server/dashboard/assets/cloud-only-C_yFCAkP.js +1 -0
  65. zenml/zen_server/dashboard/assets/index-BczVOqUf.js +55 -0
  66. zenml/zen_server/dashboard/assets/index-EpMIKgrI.css +1 -0
  67. zenml/zen_server/dashboard/assets/{login-mutation-BUnVASxp.js → login-mutation-CrHrndTI.js} +1 -1
  68. zenml/zen_server/dashboard/assets/logs-D8k8BVFf.js +1 -0
  69. zenml/zen_server/dashboard/assets/{not-found-B4VnX8gK.js → not-found-DYa4pC-C.js} +1 -1
  70. zenml/zen_server/dashboard/assets/{package-CsUhPmou.js → package-B3fWP-Dh.js} +1 -1
  71. zenml/zen_server/dashboard/assets/page-1h_sD1jz.js +1 -0
  72. zenml/zen_server/dashboard/assets/{page-Sxn82W-5.js → page-1iL8aMqs.js} +1 -1
  73. zenml/zen_server/dashboard/assets/{page-DMOYZppS.js → page-2grKx_MY.js} +1 -1
  74. zenml/zen_server/dashboard/assets/page-5NCOHOsy.js +1 -0
  75. zenml/zen_server/dashboard/assets/{page-JyfeDUfu.js → page-8a4UMKXZ.js} +1 -1
  76. zenml/zen_server/dashboard/assets/{page-Bx6o0ARS.js → page-B6h3iaHJ.js} +1 -1
  77. zenml/zen_server/dashboard/assets/page-BDns21Iz.js +1 -0
  78. zenml/zen_server/dashboard/assets/{page-3efNCDeb.js → page-BhgCDInH.js} +2 -2
  79. zenml/zen_server/dashboard/assets/{page-DKlIdAe5.js → page-Bi-wtWiO.js} +2 -2
  80. zenml/zen_server/dashboard/assets/{page-7zTHbhhI.js → page-BkeAAYwp.js} +1 -1
  81. zenml/zen_server/dashboard/assets/{page-CRTJ0UuR.js → page-BkuQDIf-.js} +1 -1
  82. zenml/zen_server/dashboard/assets/page-BnaevhnB.js +1 -0
  83. zenml/zen_server/dashboard/assets/{page-BEs6jK71.js → page-Bq0YxkLV.js} +1 -1
  84. zenml/zen_server/dashboard/assets/page-Bs2F4eoD.js +2 -0
  85. zenml/zen_server/dashboard/assets/{page-CUZIGO-3.js → page-C6-UGEbH.js} +1 -1
  86. zenml/zen_server/dashboard/assets/{page-Xu8JEjSU.js → page-CCNRIt_f.js} +1 -1
  87. zenml/zen_server/dashboard/assets/{page-DvCvroOM.js → page-CHNxpz3n.js} +1 -1
  88. zenml/zen_server/dashboard/assets/{page-BpSqIf4B.js → page-DgorQFqi.js} +1 -1
  89. zenml/zen_server/dashboard/assets/page-K8ebxVIs.js +1 -0
  90. zenml/zen_server/dashboard/assets/{page-Cx67M0QT.js → page-MFQyIJd3.js} +1 -1
  91. zenml/zen_server/dashboard/assets/page-TgCF0P_U.js +1 -0
  92. zenml/zen_server/dashboard/assets/page-ZnCEe-eK.js +9 -0
  93. zenml/zen_server/dashboard/assets/{page-Dc_7KMQE.js → page-uA5prJGY.js} +1 -1
  94. zenml/zen_server/dashboard/assets/persist-D7HJNBWx.js +1 -0
  95. zenml/zen_server/dashboard/assets/plus-C8WOyCzt.js +1 -0
  96. zenml/zen_server/dashboard/assets/stack-detail-query-Cficsl6d.js +1 -0
  97. zenml/zen_server/dashboard/assets/update-server-settings-mutation-7d8xi1tS.js +1 -0
  98. zenml/zen_server/dashboard/assets/{url-DuQMeqYA.js → url-D7mAQGUM.js} +1 -1
  99. zenml/zen_server/dashboard/index.html +4 -4
  100. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  101. zenml/zen_server/dashboard_legacy/index.html +1 -1
  102. zenml/zen_server/dashboard_legacy/{precache-manifest.c8c57fb0d2132b1d3c2119e776b7dfb3.js → precache-manifest.12246c7548e71e2c4438e496360de80c.js} +4 -4
  103. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  104. zenml/zen_server/dashboard_legacy/static/js/main.3b27024b.chunk.js +2 -0
  105. zenml/zen_server/dashboard_legacy/static/js/{main.382439a7.chunk.js.map → main.3b27024b.chunk.js.map} +1 -1
  106. zenml/zen_server/deploy/helm/Chart.yaml +1 -1
  107. zenml/zen_server/deploy/helm/README.md +2 -2
  108. zenml/zen_server/routers/service_connectors_endpoints.py +57 -0
  109. zenml/zen_stores/migrations/versions/0.62.0_release.py +23 -0
  110. zenml/zen_stores/rest_zen_store.py +4 -0
  111. zenml/zen_stores/schemas/component_schemas.py +14 -0
  112. {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/METADATA +2 -2
  113. {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/RECORD +116 -98
  114. zenml/zen_server/dashboard/assets/Pagination-DEbVUupy.js +0 -1
  115. zenml/zen_server/dashboard/assets/chevron-down-D_ZlKMqH.js +0 -1
  116. zenml/zen_server/dashboard/assets/cloud-only-DVbIeckv.js +0 -1
  117. zenml/zen_server/dashboard/assets/index-C_CrU4vI.js +0 -1
  118. zenml/zen_server/dashboard/assets/index-DK1ynKjA.js +0 -55
  119. zenml/zen_server/dashboard/assets/index-inApY3KQ.css +0 -1
  120. zenml/zen_server/dashboard/assets/page-C43QGHTt.js +0 -9
  121. zenml/zen_server/dashboard/assets/page-CR0OG7ss.js +0 -1
  122. zenml/zen_server/dashboard/assets/page-CaopxiU1.js +0 -1
  123. zenml/zen_server/dashboard/assets/page-D7Z399xy.js +0 -1
  124. zenml/zen_server/dashboard/assets/page-D93kd7Xj.js +0 -1
  125. zenml/zen_server/dashboard/assets/page-DMsSn3dv.js +0 -2
  126. zenml/zen_server/dashboard/assets/page-Hus2pr9T.js +0 -1
  127. zenml/zen_server/dashboard/assets/page-TKXERe16.js +0 -1
  128. zenml/zen_server/dashboard/assets/plus-DOeLmm7C.js +0 -1
  129. zenml/zen_server/dashboard/assets/update-server-settings-mutation-CR8e3Sir.js +0 -1
  130. zenml/zen_server/dashboard_legacy/static/js/main.382439a7.chunk.js +0 -2
  131. {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/LICENSE +0 -0
  132. {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/WHEEL +0 -0
  133. {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,407 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Implementation of the Databricks Deployment service."""
15
+
16
+ import time
17
+ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import requests
22
+ from databricks.sdk import WorkspaceClient as DatabricksClient
23
+ from databricks.sdk.service.serving import (
24
+ EndpointCoreConfigInput,
25
+ EndpointStateConfigUpdate,
26
+ EndpointStateReady,
27
+ EndpointTag,
28
+ ServedModelInput,
29
+ ServingEndpointDetailed,
30
+ )
31
+ from pydantic import Field
32
+
33
+ from zenml.client import Client
34
+ from zenml.integrations.databricks.flavors.databricks_model_deployer_flavor import (
35
+ DatabricksBaseConfig,
36
+ )
37
+ from zenml.integrations.databricks.utils.databricks_utils import (
38
+ sanitize_labels,
39
+ )
40
+ from zenml.logger import get_logger
41
+ from zenml.services import ServiceState, ServiceStatus, ServiceType
42
+ from zenml.services.service import BaseDeploymentService, ServiceConfig
43
+
44
+ logger = get_logger(__name__)
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from numpy.typing import NDArray
49
+
50
+ POLLING_TIMEOUT = 1200
51
+ UUID_SLICE_LENGTH: int = 8
52
+
53
+
54
+ class DatabricksDeploymentConfig(DatabricksBaseConfig, ServiceConfig):
55
+ """Databricks service configurations."""
56
+
57
+ model_uri: Optional[str] = Field(
58
+ None,
59
+ description="URI of the model to deploy. This can be a local path or a cloud storage path.",
60
+ )
61
+ host: Optional[str] = Field(
62
+ None, description="Databricks host URL for the deployment."
63
+ )
64
+
65
+ def get_databricks_deployment_labels(self) -> Dict[str, str]:
66
+ """Generate labels for the Databricks deployment from the service configuration.
67
+
68
+ These labels are attached to the Databricks deployment resource
69
+ and may be used as label selectors in lookup operations.
70
+
71
+ Returns:
72
+ The labels for the Databricks deployment.
73
+ """
74
+ labels = {}
75
+ if self.pipeline_name:
76
+ labels["zenml_pipeline_name"] = self.pipeline_name
77
+ if self.pipeline_step_name:
78
+ labels["zenml_pipeline_step_name"] = self.pipeline_step_name
79
+ if self.model_name:
80
+ labels["zenml_model_name"] = self.model_name
81
+ if self.model_uri:
82
+ labels["zenml_model_uri"] = self.model_uri
83
+ sanitize_labels(labels)
84
+ return labels
85
+
86
+
87
+ class DatabricksServiceStatus(ServiceStatus):
88
+ """Databricks service status."""
89
+
90
+
91
+ class DatabricksDeploymentService(BaseDeploymentService):
92
+ """Databricks model deployment service.
93
+
94
+ Attributes:
95
+ SERVICE_TYPE: a service type descriptor with information describing
96
+ the Databricks deployment service class
97
+ config: service configuration
98
+ """
99
+
100
+ SERVICE_TYPE = ServiceType(
101
+ name="databricks-deployment",
102
+ type="model-serving",
103
+ flavor="databricks",
104
+ description="Databricks inference endpoint prediction service",
105
+ )
106
+ config: DatabricksDeploymentConfig
107
+ status: DatabricksServiceStatus = Field(
108
+ default_factory=lambda: DatabricksServiceStatus()
109
+ )
110
+
111
+ def __init__(self, config: DatabricksDeploymentConfig, **attrs: Any):
112
+ """Initialize the Databricks deployment service.
113
+
114
+ Args:
115
+ config: service configuration
116
+ attrs: additional attributes to set on the service
117
+ """
118
+ super().__init__(config=config, **attrs)
119
+
120
+ def get_client_id_and_secret(self) -> Tuple[str, str, str]:
121
+ """Get the Databricks client id and secret.
122
+
123
+ Raises:
124
+ ValueError: If client id and secret are not found.
125
+
126
+ Returns:
127
+ Databricks client id and secret.
128
+ """
129
+ client = Client()
130
+ client_id = None
131
+ client_secret = None
132
+ host = None
133
+ from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
134
+ DatabricksModelDeployer,
135
+ )
136
+
137
+ model_deployer = client.active_stack.model_deployer
138
+ if not isinstance(model_deployer, DatabricksModelDeployer):
139
+ raise ValueError(
140
+ "DatabricksModelDeployer is not active in the stack."
141
+ )
142
+ host = model_deployer.config.host
143
+ self.config.host = host
144
+ if model_deployer.config.secret_name:
145
+ secret = client.get_secret(model_deployer.config.secret_name)
146
+ client_id = secret.secret_values["client_id"]
147
+ client_secret = secret.secret_values["client_secret"]
148
+
149
+ else:
150
+ client_id = model_deployer.config.client_id
151
+ client_secret = model_deployer.config.client_secret
152
+ if not client_id:
153
+ raise ValueError("Client id not found.")
154
+ if not client_secret:
155
+ raise ValueError("Client secret not found.")
156
+ if not host:
157
+ raise ValueError("Host not found.")
158
+ return host, client_id, client_secret
159
+
160
+ def _get_databricks_deployment_labels(self) -> Dict[str, str]:
161
+ """Generate the labels for the Databricks deployment from the service configuration.
162
+
163
+ Returns:
164
+ The labels for the Databricks deployment.
165
+ """
166
+ labels = self.config.get_databricks_deployment_labels()
167
+ labels["zenml_service_uuid"] = str(self.uuid)
168
+ sanitize_labels(labels)
169
+ return labels
170
+
171
+ @property
172
+ def databricks_client(self) -> DatabricksClient:
173
+ """Get the deployed Databricks inference endpoint.
174
+
175
+ Returns:
176
+ databricks inference endpoint.
177
+ """
178
+ return DatabricksClient(
179
+ host=self.get_client_id_and_secret()[0],
180
+ client_id=self.get_client_id_and_secret()[1],
181
+ client_secret=self.get_client_id_and_secret()[2],
182
+ )
183
+
184
+ @property
185
+ def databricks_endpoint(self) -> ServingEndpointDetailed:
186
+ """Get the deployed Hugging Face inference endpoint.
187
+
188
+ Returns:
189
+ Databricks inference endpoint.
190
+ """
191
+ return self.databricks_client.serving_endpoints.get(
192
+ name=self._generate_an_endpoint_name(),
193
+ )
194
+
195
+ @property
196
+ def prediction_url(self) -> Optional[str]:
197
+ """The prediction URI exposed by the prediction service.
198
+
199
+ Returns:
200
+ The prediction URI exposed by the prediction service, or None if
201
+ the service is not yet ready.
202
+ """
203
+ return f"{self.config.host}/serving-endpoints/{self._generate_an_endpoint_name()}/invocations"
204
+
205
+ def provision(self) -> None:
206
+ """Provision or update remote Databricks deployment instance."""
207
+ from databricks.sdk.service.serving import (
208
+ ServedModelInputWorkloadSize,
209
+ ServedModelInputWorkloadType,
210
+ )
211
+
212
+ tags = []
213
+ for key, value in self._get_databricks_deployment_labels().items():
214
+ tags.append(EndpointTag(key=key, value=value))
215
+ # Attempt to create and wait for the inference endpoint
216
+ served_model = ServedModelInput(
217
+ model_name=self.config.model_name,
218
+ model_version=self.config.model_version,
219
+ scale_to_zero_enabled=self.config.scale_to_zero_enabled,
220
+ workload_type=ServedModelInputWorkloadType(
221
+ self.config.workload_type
222
+ ),
223
+ workload_size=ServedModelInputWorkloadSize(
224
+ self.config.workload_size
225
+ ),
226
+ )
227
+
228
+ databricks_endpoint = (
229
+ self.databricks_client.serving_endpoints.create_and_wait(
230
+ name=self._generate_an_endpoint_name(),
231
+ config=EndpointCoreConfigInput(
232
+ served_models=[served_model],
233
+ ),
234
+ tags=tags,
235
+ )
236
+ )
237
+ # Check if the endpoint URL is available after provisioning
238
+ if databricks_endpoint.endpoint_url:
239
+ logger.info(
240
+ f"Databricks inference endpoint successfully deployed and available. Endpoint URL: {databricks_endpoint.endpoint_url}"
241
+ )
242
+ else:
243
+ logger.error(
244
+ "Failed to start Databricks inference endpoint service: No URL available, please check the Databricks console for more details."
245
+ )
246
+
247
+ def check_status(self) -> Tuple[ServiceState, str]:
248
+ """Check the the current operational state of the Databricks deployment.
249
+
250
+ Returns:
251
+ The operational state of the Databricks deployment and a message
252
+ providing additional information about that state (e.g. a
253
+ description of the error, if one is encountered).
254
+ """
255
+ try:
256
+ status = self.databricks_endpoint.state or None
257
+ if (
258
+ status
259
+ and status.ready
260
+ and status.ready == EndpointStateReady.READY
261
+ ):
262
+ return (ServiceState.ACTIVE, "")
263
+ elif (
264
+ status
265
+ and status.config_update
266
+ and status.config_update
267
+ == EndpointStateConfigUpdate.UPDATE_FAILED
268
+ ):
269
+ return (
270
+ ServiceState.ERROR,
271
+ "Databricks Inference Endpoint deployment update failed",
272
+ )
273
+ elif (
274
+ status
275
+ and status.config_update
276
+ and status.config_update
277
+ == EndpointStateConfigUpdate.IN_PROGRESS
278
+ ):
279
+ return (ServiceState.PENDING_STARTUP, "")
280
+ return (ServiceState.PENDING_STARTUP, "")
281
+ except Exception as e:
282
+ return (
283
+ ServiceState.INACTIVE,
284
+ f"Databricks Inference Endpoint deployment is inactive or not found: {e}",
285
+ )
286
+
287
+ def deprovision(self, force: bool = False) -> None:
288
+ """Deprovision the remote Databricks deployment instance.
289
+
290
+ Args:
291
+ force: if True, the remote deployment instance will be
292
+ forcefully deprovisioned.
293
+ """
294
+ try:
295
+ self.databricks_client.serving_endpoints.delete(
296
+ name=self._generate_an_endpoint_name()
297
+ )
298
+ except Exception:
299
+ logger.error(
300
+ "Databricks Inference Endpoint is deleted or cannot be found."
301
+ )
302
+
303
+ def predict(
304
+ self, request: Union["NDArray[Any]", pd.DataFrame]
305
+ ) -> "NDArray[Any]":
306
+ """Make a prediction using the service.
307
+
308
+ Args:
309
+ request: The input data for the prediction.
310
+
311
+ Returns:
312
+ The prediction result.
313
+
314
+ Raises:
315
+ Exception: if the service is not running
316
+ ValueError: if the endpoint secret name is not provided.
317
+ """
318
+ if not self.is_running:
319
+ raise Exception(
320
+ "Databricks endpoint inference service is not running. "
321
+ "Please start the service before making predictions."
322
+ )
323
+ if self.prediction_url is not None:
324
+ if not self.config.endpoint_secret_name:
325
+ raise ValueError(
326
+ "No endpoint secret name is provided for prediction."
327
+ )
328
+ databricks_token = Client().get_secret(
329
+ self.config.endpoint_secret_name
330
+ )
331
+ if not databricks_token.secret_values["token"]:
332
+ raise ValueError("No databricks token found.")
333
+ headers = {
334
+ "Authorization": f"Bearer {databricks_token.secret_values['token']}",
335
+ "Content-Type": "application/json",
336
+ }
337
+ if isinstance(request, pd.DataFrame):
338
+ response = requests.post( # nosec
339
+ self.prediction_url,
340
+ json={"instances": request.to_dict("records")},
341
+ headers=headers,
342
+ )
343
+ else:
344
+ response = requests.post( # nosec
345
+ self.prediction_url,
346
+ json={"instances": request.tolist()},
347
+ headers=headers,
348
+ )
349
+ else:
350
+ raise ValueError("No endpoint known for prediction.")
351
+ response.raise_for_status()
352
+
353
+ return np.array(response.json()["predictions"])
354
+
355
+ def get_logs(
356
+ self, follow: bool = False, tail: Optional[int] = None
357
+ ) -> Generator[str, bool, None]:
358
+ """Retrieve the service logs.
359
+
360
+ Args:
361
+ follow: if True, the logs will be streamed as they are written
362
+ tail: only retrieve the last NUM lines of log output.
363
+
364
+ Yields:
365
+ A generator that can be accessed to get the service logs.
366
+ """
367
+ logger.info(
368
+ "Databricks Endpoints provides access to the logs of your Endpoints through the UI in the `Logs` tab of your Endpoint"
369
+ )
370
+
371
+ def log_generator() -> Generator[str, bool, None]:
372
+ last_log_count = 0
373
+ while True:
374
+ logs = self.databricks_client.serving_endpoints.logs(
375
+ name=self._generate_an_endpoint_name(),
376
+ served_model_name=self.config.model_name,
377
+ )
378
+
379
+ log_lines = logs.logs.split("\n")
380
+
381
+ # Apply tail if specified and it's the first iteration
382
+ if tail is not None and last_log_count == 0:
383
+ log_lines = log_lines[-tail:]
384
+
385
+ # Yield only new lines
386
+ for line in log_lines[last_log_count:]:
387
+ yield line
388
+
389
+ last_log_count = len(log_lines)
390
+
391
+ if not follow:
392
+ break
393
+
394
+ # Add a small delay to avoid excessive API calls
395
+ time.sleep(1)
396
+
397
+ yield from log_generator()
398
+
399
+ def _generate_an_endpoint_name(self) -> str:
400
+ """Generate a unique name for the Databricks Inference Endpoint.
401
+
402
+ Returns:
403
+ A unique name for the Databricks Inference Endpoint.
404
+ """
405
+ return (
406
+ f"{self.config.service_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}"
407
+ )
@@ -0,0 +1,14 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Utilities for Databricks integration."""
@@ -0,0 +1,87 @@
1
+ # Copyright (c) ZenML GmbH 2023. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Databricks utilities."""
15
+
16
+ import re
17
+ from typing import Dict, List, Optional
18
+
19
+ from databricks.sdk.service.compute import Library, PythonPyPiLibrary
20
+ from databricks.sdk.service.jobs import PythonWheelTask, TaskDependency
21
+ from databricks.sdk.service.jobs import Task as DatabricksTask
22
+
23
+ from zenml import __version__
24
+
25
+
26
+ def convert_step_to_task(
27
+ task_name: str,
28
+ command: str,
29
+ arguments: List[str],
30
+ libraries: Optional[List[str]] = None,
31
+ depends_on: Optional[List[str]] = None,
32
+ zenml_project_wheel: Optional[str] = None,
33
+ job_cluster_key: Optional[str] = None,
34
+ ) -> DatabricksTask:
35
+ """Convert a ZenML step to a Databricks task.
36
+
37
+ Args:
38
+ task_name: Name of the task.
39
+ command: Command to run.
40
+ arguments: Arguments to pass to the command.
41
+ libraries: List of libraries to install.
42
+ depends_on: List of tasks to depend on.
43
+ zenml_project_wheel: Path to the ZenML project wheel.
44
+ job_cluster_key: ID of the Databricks job_cluster_key.
45
+
46
+ Returns:
47
+ Databricks task.
48
+ """
49
+ db_libraries = []
50
+ if libraries:
51
+ for library in libraries:
52
+ db_libraries.append(Library(pypi=PythonPyPiLibrary(library)))
53
+ db_libraries.append(Library(whl=zenml_project_wheel))
54
+ db_libraries.append(
55
+ Library(pypi=PythonPyPiLibrary(f"zenml=={__version__}"))
56
+ )
57
+ return DatabricksTask(
58
+ task_key=task_name,
59
+ job_cluster_key=job_cluster_key,
60
+ libraries=db_libraries,
61
+ python_wheel_task=PythonWheelTask(
62
+ package_name="zenml",
63
+ entry_point=command,
64
+ parameters=arguments,
65
+ ),
66
+ depends_on=[TaskDependency(task) for task in depends_on]
67
+ if depends_on
68
+ else None,
69
+ )
70
+
71
+
72
+ def sanitize_labels(labels: Dict[str, str]) -> None:
73
+ """Update the label values to be valid Kubernetes labels.
74
+
75
+ See:
76
+ https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
77
+
78
+ Args:
79
+ labels: the labels to sanitize.
80
+ """
81
+ for key, value in labels.items():
82
+ # Kubernetes labels must be alphanumeric, no longer than
83
+ # 63 characters, and must begin and end with an alphanumeric
84
+ # character ([a-z0-9A-Z])
85
+ labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
86
+ "-_."
87
+ )
@@ -284,14 +284,18 @@ class GreatExpectationsDataValidator(BaseDataValidator):
284
284
  store_name=store_name,
285
285
  store_config=store_config,
286
286
  )
287
- for site_name, site_config in zenml_context_config[
288
- "data_docs_sites"
289
- ].items():
290
- self._context.config.data_docs_sites[site_name] = (
291
- site_config
292
- )
293
-
294
- if self.config.configure_local_docs:
287
+ if self._context.config.data_docs_sites is not None:
288
+ for site_name, site_config in zenml_context_config[
289
+ "data_docs_sites"
290
+ ].items():
291
+ self._context.config.data_docs_sites[site_name] = (
292
+ site_config
293
+ )
294
+
295
+ if (
296
+ self.config.configure_local_docs
297
+ and self._context.config.data_docs_sites is not None
298
+ ):
295
299
  client = Client()
296
300
  artifact_store = client.active_stack.artifact_store
297
301
  if artifact_store.flavor != "local":
@@ -1,4 +1,4 @@
1
- # Copyright (c) ZenML GmbH 2021. All Rights Reserved.
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,12 +16,21 @@
16
16
  import os
17
17
  from collections import defaultdict
18
18
  from tempfile import TemporaryDirectory, mkdtemp
19
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, Union
19
+ from typing import (
20
+ TYPE_CHECKING,
21
+ Any,
22
+ ClassVar,
23
+ Dict,
24
+ Optional,
25
+ Tuple,
26
+ Type,
27
+ Union,
28
+ )
20
29
 
21
30
  from datasets import Dataset, load_from_disk
22
31
  from datasets.dataset_dict import DatasetDict
23
32
 
24
- from zenml.enums import ArtifactType
33
+ from zenml.enums import ArtifactType, VisualizationType
25
34
  from zenml.io import fileio
26
35
  from zenml.materializers.base_materializer import BaseMaterializer
27
36
  from zenml.materializers.pandas_materializer import PandasMaterializer
@@ -33,6 +42,31 @@ if TYPE_CHECKING:
33
42
  DEFAULT_DATASET_DIR = "hf_datasets"
34
43
 
35
44
 
45
+ def extract_repo_name(checksum_str: str) -> Optional[str]:
46
+ """Extracts the repo name from the checksum string.
47
+
48
+ An example of a checksum_str is:
49
+ "hf://datasets/nyu-mll/glue@bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c/mrpc/train-00000-of-00001.parquet"
50
+ and the expected output is "nyu-mll/glue".
51
+
52
+ Args:
53
+ checksum_str: The checksum_str to extract the repo name from.
54
+
55
+ Returns:
56
+ str: The extracted repo name.
57
+ """
58
+ dataset = None
59
+ try:
60
+ parts = checksum_str.split("/")
61
+ if len(parts) >= 4:
62
+ # Case: nyu-mll/glue
63
+ dataset = f"{parts[3]}/{parts[4].split('@')[0]}"
64
+ except Exception: # pylint: disable=broad-except
65
+ pass
66
+
67
+ return dataset
68
+
69
+
36
70
  class HFDatasetMaterializer(BaseMaterializer):
37
71
  """Materializer to read data to and from huggingface datasets."""
38
72
 
@@ -103,3 +137,54 @@ class HFDatasetMaterializer(BaseMaterializer):
103
137
  metadata[key][dataset_name] = value
104
138
  return dict(metadata)
105
139
  raise ValueError(f"Unsupported type {type(ds)}")
140
+
141
+ def save_visualizations(
142
+ self, ds: Union[Dataset, DatasetDict]
143
+ ) -> Dict[str, VisualizationType]:
144
+ """Save visualizations for the dataset.
145
+
146
+ Args:
147
+ ds: The Dataset or DatasetDict to visualize.
148
+
149
+ Returns:
150
+ A dictionary mapping visualization paths to their types.
151
+
152
+ Raises:
153
+ ValueError: If the given object is not a `Dataset` or `DatasetDict`.
154
+ """
155
+ visualizations = {}
156
+
157
+ if isinstance(ds, Dataset):
158
+ datasets = {"default": ds}
159
+ elif isinstance(ds, DatasetDict):
160
+ datasets = ds
161
+ else:
162
+ raise ValueError(f"Unsupported type {type(ds)}")
163
+
164
+ for name, dataset in datasets.items():
165
+ # Generate a unique identifier for the dataset
166
+ if dataset.info.download_checksums:
167
+ dataset_id = extract_repo_name(
168
+ [x for x in dataset.info.download_checksums.keys()][0]
169
+ )
170
+ if dataset_id:
171
+ # Create the iframe HTML
172
+ html = f"""
173
+ <iframe
174
+ src="https://huggingface.co/datasets/{dataset_id}/embed/viewer"
175
+ frameborder="0"
176
+ width="100%"
177
+ height="560px"
178
+ ></iframe>
179
+ """
180
+
181
+ # Save the HTML to a file
182
+ visualization_path = os.path.join(
183
+ self.uri, f"{name}_viewer.html"
184
+ )
185
+ with fileio.open(visualization_path, "w") as f:
186
+ f.write(html)
187
+
188
+ visualizations[visualization_path] = VisualizationType.HTML
189
+
190
+ return visualizations
@@ -111,15 +111,9 @@ def run_with_accelerate(
111
111
  if isinstance(v, bool):
112
112
  if v:
113
113
  commands.append(f"--{k}")
114
- elif isinstance(v, str):
115
- commands += [f"--{k}", '"{v}"']
116
114
  elif type(v) in (list, tuple, set):
117
115
  for each in v:
118
- commands.append(f"--{k}")
119
- if isinstance(each, str):
120
- commands.append(f'"{each}"')
121
- else:
122
- commands.append(f"{each}")
116
+ commands += [f"--{k}", f"{each}"]
123
117
  else:
124
118
  commands += [f"--{k}", f"{v}"]
125
119
 
@@ -139,10 +139,17 @@ def build_pod_manifest(
139
139
  ],
140
140
  security_context=security_context,
141
141
  )
142
+ image_pull_secrets = []
143
+ if pod_settings:
144
+ image_pull_secrets = [
145
+ k8s_client.V1LocalObjectReference(name=name)
146
+ for name in pod_settings.image_pull_secrets
147
+ ]
142
148
 
143
149
  pod_spec = k8s_client.V1PodSpec(
144
150
  containers=[container_spec],
145
151
  restart_policy="Never",
152
+ image_pull_secrets=image_pull_secrets,
146
153
  )
147
154
 
148
155
  if service_account_name is not None: