zenml-nightly 0.61.0.dev20240714__py3-none-any.whl → 0.62.0.dev20240719__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +1 -1
- RELEASE_NOTES.md +40 -0
- zenml/VERSION +1 -1
- zenml/__init__.py +2 -0
- zenml/cli/stack.py +87 -228
- zenml/cli/stack_components.py +5 -3
- zenml/constants.py +2 -0
- zenml/entrypoints/entrypoint.py +3 -1
- zenml/integrations/__init__.py +1 -0
- zenml/integrations/constants.py +1 -0
- zenml/integrations/databricks/__init__.py +52 -0
- zenml/integrations/databricks/flavors/__init__.py +30 -0
- zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py +118 -0
- zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +147 -0
- zenml/integrations/databricks/model_deployers/__init__.py +20 -0
- zenml/integrations/databricks/model_deployers/databricks_model_deployer.py +249 -0
- zenml/integrations/databricks/orchestrators/__init__.py +20 -0
- zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +498 -0
- zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py +97 -0
- zenml/integrations/databricks/services/__init__.py +19 -0
- zenml/integrations/databricks/services/databricks_deployment.py +407 -0
- zenml/integrations/databricks/utils/__init__.py +14 -0
- zenml/integrations/databricks/utils/databricks_utils.py +87 -0
- zenml/integrations/great_expectations/data_validators/ge_data_validator.py +12 -8
- zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +88 -3
- zenml/integrations/huggingface/steps/accelerate_runner.py +1 -7
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +7 -0
- zenml/integrations/kubernetes/pod_settings.py +2 -0
- zenml/integrations/lightgbm/__init__.py +1 -0
- zenml/integrations/mlflow/__init__.py +1 -1
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +6 -2
- zenml/integrations/mlflow/services/mlflow_deployment.py +1 -1
- zenml/integrations/skypilot_lambda/__init__.py +1 -1
- zenml/materializers/built_in_materializer.py +1 -1
- zenml/materializers/cloudpickle_materializer.py +1 -1
- zenml/model/model.py +1 -1
- zenml/models/v2/core/component.py +29 -0
- zenml/models/v2/misc/full_stack.py +32 -0
- zenml/orchestrators/__init__.py +4 -0
- zenml/orchestrators/wheeled_orchestrator.py +147 -0
- zenml/service_connectors/service_connector_utils.py +349 -0
- zenml/stack_deployments/gcp_stack_deployment.py +2 -4
- zenml/steps/base_step.py +7 -5
- zenml/utils/function_utils.py +1 -1
- zenml/utils/pipeline_docker_image_builder.py +8 -0
- zenml/zen_server/dashboard/assets/{404-DpJaNHKF.js → 404-B_YdvmwS.js} +1 -1
- zenml/zen_server/dashboard/assets/{@reactflow-DJfzkHO1.js → @reactflow-l_1hUr1S.js} +1 -1
- zenml/zen_server/dashboard/assets/{AwarenessChannel-BYDLT2xC.js → AwarenessChannel-CFg5iX4Z.js} +1 -1
- zenml/zen_server/dashboard/assets/{CodeSnippet-BkOuRmyq.js → CodeSnippet-Dvkx_82E.js} +1 -1
- zenml/zen_server/dashboard/assets/CollapsibleCard-opiuBHHc.js +1 -0
- zenml/zen_server/dashboard/assets/{Commands-ZvWR1BRs.js → Commands-DoN1xrEq.js} +1 -1
- zenml/zen_server/dashboard/assets/{CopyButton-DVwLkafa.js → CopyButton-Cr7xYEPb.js} +1 -1
- zenml/zen_server/dashboard/assets/{CsvVizualization-C2IiqX4I.js → CsvVizualization-Ck-nZ43m.js} +3 -3
- zenml/zen_server/dashboard/assets/{Error-CqX0VqW_.js → Error-kLtljEOM.js} +1 -1
- zenml/zen_server/dashboard/assets/{ExecutionStatus-BoLUXR9t.js → ExecutionStatus-DguLLgTK.js} +1 -1
- zenml/zen_server/dashboard/assets/{Helpbox-LFydyVwh.js → Helpbox-BXUMP21n.js} +1 -1
- zenml/zen_server/dashboard/assets/{Infobox-DnENC0sh.js → Infobox-DSt0O-dm.js} +1 -1
- zenml/zen_server/dashboard/assets/{InlineAvatar-CbJtYr0t.js → InlineAvatar-xsrsIGE-.js} +1 -1
- zenml/zen_server/dashboard/assets/Pagination-C6X-mifw.js +1 -0
- zenml/zen_server/dashboard/assets/{SetPassword-BYBdbQDo.js → SetPassword-BXGTWiwj.js} +1 -1
- zenml/zen_server/dashboard/assets/{SuccessStep-Nx743hll.js → SuccessStep-DZC60t0x.js} +1 -1
- zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-DF9gSzE0.js → UpdatePasswordSchemas-DGvwFWO1.js} +1 -1
- zenml/zen_server/dashboard/assets/{chevron-right-double-BiEMg7rd.js → chevron-right-double-CZBOf6JM.js} +1 -1
- zenml/zen_server/dashboard/assets/cloud-only-C_yFCAkP.js +1 -0
- zenml/zen_server/dashboard/assets/index-BczVOqUf.js +55 -0
- zenml/zen_server/dashboard/assets/index-EpMIKgrI.css +1 -0
- zenml/zen_server/dashboard/assets/{login-mutation-BUnVASxp.js → login-mutation-CrHrndTI.js} +1 -1
- zenml/zen_server/dashboard/assets/logs-D8k8BVFf.js +1 -0
- zenml/zen_server/dashboard/assets/{not-found-B4VnX8gK.js → not-found-DYa4pC-C.js} +1 -1
- zenml/zen_server/dashboard/assets/{package-CsUhPmou.js → package-B3fWP-Dh.js} +1 -1
- zenml/zen_server/dashboard/assets/page-1h_sD1jz.js +1 -0
- zenml/zen_server/dashboard/assets/{page-Sxn82W-5.js → page-1iL8aMqs.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DMOYZppS.js → page-2grKx_MY.js} +1 -1
- zenml/zen_server/dashboard/assets/page-5NCOHOsy.js +1 -0
- zenml/zen_server/dashboard/assets/{page-JyfeDUfu.js → page-8a4UMKXZ.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-Bx6o0ARS.js → page-B6h3iaHJ.js} +1 -1
- zenml/zen_server/dashboard/assets/page-BDns21Iz.js +1 -0
- zenml/zen_server/dashboard/assets/{page-3efNCDeb.js → page-BhgCDInH.js} +2 -2
- zenml/zen_server/dashboard/assets/{page-DKlIdAe5.js → page-Bi-wtWiO.js} +2 -2
- zenml/zen_server/dashboard/assets/{page-7zTHbhhI.js → page-BkeAAYwp.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CRTJ0UuR.js → page-BkuQDIf-.js} +1 -1
- zenml/zen_server/dashboard/assets/page-BnaevhnB.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BEs6jK71.js → page-Bq0YxkLV.js} +1 -1
- zenml/zen_server/dashboard/assets/page-Bs2F4eoD.js +2 -0
- zenml/zen_server/dashboard/assets/{page-CUZIGO-3.js → page-C6-UGEbH.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-Xu8JEjSU.js → page-CCNRIt_f.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DvCvroOM.js → page-CHNxpz3n.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-BpSqIf4B.js → page-DgorQFqi.js} +1 -1
- zenml/zen_server/dashboard/assets/page-K8ebxVIs.js +1 -0
- zenml/zen_server/dashboard/assets/{page-Cx67M0QT.js → page-MFQyIJd3.js} +1 -1
- zenml/zen_server/dashboard/assets/page-TgCF0P_U.js +1 -0
- zenml/zen_server/dashboard/assets/page-ZnCEe-eK.js +9 -0
- zenml/zen_server/dashboard/assets/{page-Dc_7KMQE.js → page-uA5prJGY.js} +1 -1
- zenml/zen_server/dashboard/assets/persist-D7HJNBWx.js +1 -0
- zenml/zen_server/dashboard/assets/plus-C8WOyCzt.js +1 -0
- zenml/zen_server/dashboard/assets/stack-detail-query-Cficsl6d.js +1 -0
- zenml/zen_server/dashboard/assets/update-server-settings-mutation-7d8xi1tS.js +1 -0
- zenml/zen_server/dashboard/assets/{url-DuQMeqYA.js → url-D7mAQGUM.js} +1 -1
- zenml/zen_server/dashboard/index.html +4 -4
- zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
- zenml/zen_server/dashboard_legacy/index.html +1 -1
- zenml/zen_server/dashboard_legacy/{precache-manifest.c8c57fb0d2132b1d3c2119e776b7dfb3.js → precache-manifest.12246c7548e71e2c4438e496360de80c.js} +4 -4
- zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
- zenml/zen_server/dashboard_legacy/static/js/main.3b27024b.chunk.js +2 -0
- zenml/zen_server/dashboard_legacy/static/js/{main.382439a7.chunk.js.map → main.3b27024b.chunk.js.map} +1 -1
- zenml/zen_server/deploy/helm/Chart.yaml +1 -1
- zenml/zen_server/deploy/helm/README.md +2 -2
- zenml/zen_server/routers/service_connectors_endpoints.py +57 -0
- zenml/zen_stores/migrations/versions/0.62.0_release.py +23 -0
- zenml/zen_stores/rest_zen_store.py +4 -0
- zenml/zen_stores/schemas/component_schemas.py +14 -0
- {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/METADATA +2 -2
- {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/RECORD +116 -98
- zenml/zen_server/dashboard/assets/Pagination-DEbVUupy.js +0 -1
- zenml/zen_server/dashboard/assets/chevron-down-D_ZlKMqH.js +0 -1
- zenml/zen_server/dashboard/assets/cloud-only-DVbIeckv.js +0 -1
- zenml/zen_server/dashboard/assets/index-C_CrU4vI.js +0 -1
- zenml/zen_server/dashboard/assets/index-DK1ynKjA.js +0 -55
- zenml/zen_server/dashboard/assets/index-inApY3KQ.css +0 -1
- zenml/zen_server/dashboard/assets/page-C43QGHTt.js +0 -9
- zenml/zen_server/dashboard/assets/page-CR0OG7ss.js +0 -1
- zenml/zen_server/dashboard/assets/page-CaopxiU1.js +0 -1
- zenml/zen_server/dashboard/assets/page-D7Z399xy.js +0 -1
- zenml/zen_server/dashboard/assets/page-D93kd7Xj.js +0 -1
- zenml/zen_server/dashboard/assets/page-DMsSn3dv.js +0 -2
- zenml/zen_server/dashboard/assets/page-Hus2pr9T.js +0 -1
- zenml/zen_server/dashboard/assets/page-TKXERe16.js +0 -1
- zenml/zen_server/dashboard/assets/plus-DOeLmm7C.js +0 -1
- zenml/zen_server/dashboard/assets/update-server-settings-mutation-CR8e3Sir.js +0 -1
- zenml/zen_server/dashboard_legacy/static/js/main.382439a7.chunk.js +0 -2
- {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.61.0.dev20240714.dist-info → zenml_nightly-0.62.0.dev20240719.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,407 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Implementation of the Databricks Deployment service."""
|
15
|
+
|
16
|
+
import time
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
import pandas as pd
|
21
|
+
import requests
|
22
|
+
from databricks.sdk import WorkspaceClient as DatabricksClient
|
23
|
+
from databricks.sdk.service.serving import (
|
24
|
+
EndpointCoreConfigInput,
|
25
|
+
EndpointStateConfigUpdate,
|
26
|
+
EndpointStateReady,
|
27
|
+
EndpointTag,
|
28
|
+
ServedModelInput,
|
29
|
+
ServingEndpointDetailed,
|
30
|
+
)
|
31
|
+
from pydantic import Field
|
32
|
+
|
33
|
+
from zenml.client import Client
|
34
|
+
from zenml.integrations.databricks.flavors.databricks_model_deployer_flavor import (
|
35
|
+
DatabricksBaseConfig,
|
36
|
+
)
|
37
|
+
from zenml.integrations.databricks.utils.databricks_utils import (
|
38
|
+
sanitize_labels,
|
39
|
+
)
|
40
|
+
from zenml.logger import get_logger
|
41
|
+
from zenml.services import ServiceState, ServiceStatus, ServiceType
|
42
|
+
from zenml.services.service import BaseDeploymentService, ServiceConfig
|
43
|
+
|
44
|
+
logger = get_logger(__name__)
|
45
|
+
|
46
|
+
|
47
|
+
if TYPE_CHECKING:
|
48
|
+
from numpy.typing import NDArray
|
49
|
+
|
50
|
+
POLLING_TIMEOUT = 1200
|
51
|
+
UUID_SLICE_LENGTH: int = 8
|
52
|
+
|
53
|
+
|
54
|
+
class DatabricksDeploymentConfig(DatabricksBaseConfig, ServiceConfig):
|
55
|
+
"""Databricks service configurations."""
|
56
|
+
|
57
|
+
model_uri: Optional[str] = Field(
|
58
|
+
None,
|
59
|
+
description="URI of the model to deploy. This can be a local path or a cloud storage path.",
|
60
|
+
)
|
61
|
+
host: Optional[str] = Field(
|
62
|
+
None, description="Databricks host URL for the deployment."
|
63
|
+
)
|
64
|
+
|
65
|
+
def get_databricks_deployment_labels(self) -> Dict[str, str]:
|
66
|
+
"""Generate labels for the Databricks deployment from the service configuration.
|
67
|
+
|
68
|
+
These labels are attached to the Databricks deployment resource
|
69
|
+
and may be used as label selectors in lookup operations.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
The labels for the Databricks deployment.
|
73
|
+
"""
|
74
|
+
labels = {}
|
75
|
+
if self.pipeline_name:
|
76
|
+
labels["zenml_pipeline_name"] = self.pipeline_name
|
77
|
+
if self.pipeline_step_name:
|
78
|
+
labels["zenml_pipeline_step_name"] = self.pipeline_step_name
|
79
|
+
if self.model_name:
|
80
|
+
labels["zenml_model_name"] = self.model_name
|
81
|
+
if self.model_uri:
|
82
|
+
labels["zenml_model_uri"] = self.model_uri
|
83
|
+
sanitize_labels(labels)
|
84
|
+
return labels
|
85
|
+
|
86
|
+
|
87
|
+
class DatabricksServiceStatus(ServiceStatus):
|
88
|
+
"""Databricks service status."""
|
89
|
+
|
90
|
+
|
91
|
+
class DatabricksDeploymentService(BaseDeploymentService):
|
92
|
+
"""Databricks model deployment service.
|
93
|
+
|
94
|
+
Attributes:
|
95
|
+
SERVICE_TYPE: a service type descriptor with information describing
|
96
|
+
the Databricks deployment service class
|
97
|
+
config: service configuration
|
98
|
+
"""
|
99
|
+
|
100
|
+
SERVICE_TYPE = ServiceType(
|
101
|
+
name="databricks-deployment",
|
102
|
+
type="model-serving",
|
103
|
+
flavor="databricks",
|
104
|
+
description="Databricks inference endpoint prediction service",
|
105
|
+
)
|
106
|
+
config: DatabricksDeploymentConfig
|
107
|
+
status: DatabricksServiceStatus = Field(
|
108
|
+
default_factory=lambda: DatabricksServiceStatus()
|
109
|
+
)
|
110
|
+
|
111
|
+
def __init__(self, config: DatabricksDeploymentConfig, **attrs: Any):
|
112
|
+
"""Initialize the Databricks deployment service.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
config: service configuration
|
116
|
+
attrs: additional attributes to set on the service
|
117
|
+
"""
|
118
|
+
super().__init__(config=config, **attrs)
|
119
|
+
|
120
|
+
def get_client_id_and_secret(self) -> Tuple[str, str, str]:
|
121
|
+
"""Get the Databricks client id and secret.
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
ValueError: If client id and secret are not found.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Databricks client id and secret.
|
128
|
+
"""
|
129
|
+
client = Client()
|
130
|
+
client_id = None
|
131
|
+
client_secret = None
|
132
|
+
host = None
|
133
|
+
from zenml.integrations.databricks.model_deployers.databricks_model_deployer import (
|
134
|
+
DatabricksModelDeployer,
|
135
|
+
)
|
136
|
+
|
137
|
+
model_deployer = client.active_stack.model_deployer
|
138
|
+
if not isinstance(model_deployer, DatabricksModelDeployer):
|
139
|
+
raise ValueError(
|
140
|
+
"DatabricksModelDeployer is not active in the stack."
|
141
|
+
)
|
142
|
+
host = model_deployer.config.host
|
143
|
+
self.config.host = host
|
144
|
+
if model_deployer.config.secret_name:
|
145
|
+
secret = client.get_secret(model_deployer.config.secret_name)
|
146
|
+
client_id = secret.secret_values["client_id"]
|
147
|
+
client_secret = secret.secret_values["client_secret"]
|
148
|
+
|
149
|
+
else:
|
150
|
+
client_id = model_deployer.config.client_id
|
151
|
+
client_secret = model_deployer.config.client_secret
|
152
|
+
if not client_id:
|
153
|
+
raise ValueError("Client id not found.")
|
154
|
+
if not client_secret:
|
155
|
+
raise ValueError("Client secret not found.")
|
156
|
+
if not host:
|
157
|
+
raise ValueError("Host not found.")
|
158
|
+
return host, client_id, client_secret
|
159
|
+
|
160
|
+
def _get_databricks_deployment_labels(self) -> Dict[str, str]:
|
161
|
+
"""Generate the labels for the Databricks deployment from the service configuration.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
The labels for the Databricks deployment.
|
165
|
+
"""
|
166
|
+
labels = self.config.get_databricks_deployment_labels()
|
167
|
+
labels["zenml_service_uuid"] = str(self.uuid)
|
168
|
+
sanitize_labels(labels)
|
169
|
+
return labels
|
170
|
+
|
171
|
+
@property
|
172
|
+
def databricks_client(self) -> DatabricksClient:
|
173
|
+
"""Get the deployed Databricks inference endpoint.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
databricks inference endpoint.
|
177
|
+
"""
|
178
|
+
return DatabricksClient(
|
179
|
+
host=self.get_client_id_and_secret()[0],
|
180
|
+
client_id=self.get_client_id_and_secret()[1],
|
181
|
+
client_secret=self.get_client_id_and_secret()[2],
|
182
|
+
)
|
183
|
+
|
184
|
+
@property
|
185
|
+
def databricks_endpoint(self) -> ServingEndpointDetailed:
|
186
|
+
"""Get the deployed Hugging Face inference endpoint.
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
Databricks inference endpoint.
|
190
|
+
"""
|
191
|
+
return self.databricks_client.serving_endpoints.get(
|
192
|
+
name=self._generate_an_endpoint_name(),
|
193
|
+
)
|
194
|
+
|
195
|
+
@property
|
196
|
+
def prediction_url(self) -> Optional[str]:
|
197
|
+
"""The prediction URI exposed by the prediction service.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
The prediction URI exposed by the prediction service, or None if
|
201
|
+
the service is not yet ready.
|
202
|
+
"""
|
203
|
+
return f"{self.config.host}/serving-endpoints/{self._generate_an_endpoint_name()}/invocations"
|
204
|
+
|
205
|
+
def provision(self) -> None:
|
206
|
+
"""Provision or update remote Databricks deployment instance."""
|
207
|
+
from databricks.sdk.service.serving import (
|
208
|
+
ServedModelInputWorkloadSize,
|
209
|
+
ServedModelInputWorkloadType,
|
210
|
+
)
|
211
|
+
|
212
|
+
tags = []
|
213
|
+
for key, value in self._get_databricks_deployment_labels().items():
|
214
|
+
tags.append(EndpointTag(key=key, value=value))
|
215
|
+
# Attempt to create and wait for the inference endpoint
|
216
|
+
served_model = ServedModelInput(
|
217
|
+
model_name=self.config.model_name,
|
218
|
+
model_version=self.config.model_version,
|
219
|
+
scale_to_zero_enabled=self.config.scale_to_zero_enabled,
|
220
|
+
workload_type=ServedModelInputWorkloadType(
|
221
|
+
self.config.workload_type
|
222
|
+
),
|
223
|
+
workload_size=ServedModelInputWorkloadSize(
|
224
|
+
self.config.workload_size
|
225
|
+
),
|
226
|
+
)
|
227
|
+
|
228
|
+
databricks_endpoint = (
|
229
|
+
self.databricks_client.serving_endpoints.create_and_wait(
|
230
|
+
name=self._generate_an_endpoint_name(),
|
231
|
+
config=EndpointCoreConfigInput(
|
232
|
+
served_models=[served_model],
|
233
|
+
),
|
234
|
+
tags=tags,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
# Check if the endpoint URL is available after provisioning
|
238
|
+
if databricks_endpoint.endpoint_url:
|
239
|
+
logger.info(
|
240
|
+
f"Databricks inference endpoint successfully deployed and available. Endpoint URL: {databricks_endpoint.endpoint_url}"
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
logger.error(
|
244
|
+
"Failed to start Databricks inference endpoint service: No URL available, please check the Databricks console for more details."
|
245
|
+
)
|
246
|
+
|
247
|
+
def check_status(self) -> Tuple[ServiceState, str]:
|
248
|
+
"""Check the the current operational state of the Databricks deployment.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
The operational state of the Databricks deployment and a message
|
252
|
+
providing additional information about that state (e.g. a
|
253
|
+
description of the error, if one is encountered).
|
254
|
+
"""
|
255
|
+
try:
|
256
|
+
status = self.databricks_endpoint.state or None
|
257
|
+
if (
|
258
|
+
status
|
259
|
+
and status.ready
|
260
|
+
and status.ready == EndpointStateReady.READY
|
261
|
+
):
|
262
|
+
return (ServiceState.ACTIVE, "")
|
263
|
+
elif (
|
264
|
+
status
|
265
|
+
and status.config_update
|
266
|
+
and status.config_update
|
267
|
+
== EndpointStateConfigUpdate.UPDATE_FAILED
|
268
|
+
):
|
269
|
+
return (
|
270
|
+
ServiceState.ERROR,
|
271
|
+
"Databricks Inference Endpoint deployment update failed",
|
272
|
+
)
|
273
|
+
elif (
|
274
|
+
status
|
275
|
+
and status.config_update
|
276
|
+
and status.config_update
|
277
|
+
== EndpointStateConfigUpdate.IN_PROGRESS
|
278
|
+
):
|
279
|
+
return (ServiceState.PENDING_STARTUP, "")
|
280
|
+
return (ServiceState.PENDING_STARTUP, "")
|
281
|
+
except Exception as e:
|
282
|
+
return (
|
283
|
+
ServiceState.INACTIVE,
|
284
|
+
f"Databricks Inference Endpoint deployment is inactive or not found: {e}",
|
285
|
+
)
|
286
|
+
|
287
|
+
def deprovision(self, force: bool = False) -> None:
|
288
|
+
"""Deprovision the remote Databricks deployment instance.
|
289
|
+
|
290
|
+
Args:
|
291
|
+
force: if True, the remote deployment instance will be
|
292
|
+
forcefully deprovisioned.
|
293
|
+
"""
|
294
|
+
try:
|
295
|
+
self.databricks_client.serving_endpoints.delete(
|
296
|
+
name=self._generate_an_endpoint_name()
|
297
|
+
)
|
298
|
+
except Exception:
|
299
|
+
logger.error(
|
300
|
+
"Databricks Inference Endpoint is deleted or cannot be found."
|
301
|
+
)
|
302
|
+
|
303
|
+
def predict(
|
304
|
+
self, request: Union["NDArray[Any]", pd.DataFrame]
|
305
|
+
) -> "NDArray[Any]":
|
306
|
+
"""Make a prediction using the service.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
request: The input data for the prediction.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
The prediction result.
|
313
|
+
|
314
|
+
Raises:
|
315
|
+
Exception: if the service is not running
|
316
|
+
ValueError: if the endpoint secret name is not provided.
|
317
|
+
"""
|
318
|
+
if not self.is_running:
|
319
|
+
raise Exception(
|
320
|
+
"Databricks endpoint inference service is not running. "
|
321
|
+
"Please start the service before making predictions."
|
322
|
+
)
|
323
|
+
if self.prediction_url is not None:
|
324
|
+
if not self.config.endpoint_secret_name:
|
325
|
+
raise ValueError(
|
326
|
+
"No endpoint secret name is provided for prediction."
|
327
|
+
)
|
328
|
+
databricks_token = Client().get_secret(
|
329
|
+
self.config.endpoint_secret_name
|
330
|
+
)
|
331
|
+
if not databricks_token.secret_values["token"]:
|
332
|
+
raise ValueError("No databricks token found.")
|
333
|
+
headers = {
|
334
|
+
"Authorization": f"Bearer {databricks_token.secret_values['token']}",
|
335
|
+
"Content-Type": "application/json",
|
336
|
+
}
|
337
|
+
if isinstance(request, pd.DataFrame):
|
338
|
+
response = requests.post( # nosec
|
339
|
+
self.prediction_url,
|
340
|
+
json={"instances": request.to_dict("records")},
|
341
|
+
headers=headers,
|
342
|
+
)
|
343
|
+
else:
|
344
|
+
response = requests.post( # nosec
|
345
|
+
self.prediction_url,
|
346
|
+
json={"instances": request.tolist()},
|
347
|
+
headers=headers,
|
348
|
+
)
|
349
|
+
else:
|
350
|
+
raise ValueError("No endpoint known for prediction.")
|
351
|
+
response.raise_for_status()
|
352
|
+
|
353
|
+
return np.array(response.json()["predictions"])
|
354
|
+
|
355
|
+
def get_logs(
|
356
|
+
self, follow: bool = False, tail: Optional[int] = None
|
357
|
+
) -> Generator[str, bool, None]:
|
358
|
+
"""Retrieve the service logs.
|
359
|
+
|
360
|
+
Args:
|
361
|
+
follow: if True, the logs will be streamed as they are written
|
362
|
+
tail: only retrieve the last NUM lines of log output.
|
363
|
+
|
364
|
+
Yields:
|
365
|
+
A generator that can be accessed to get the service logs.
|
366
|
+
"""
|
367
|
+
logger.info(
|
368
|
+
"Databricks Endpoints provides access to the logs of your Endpoints through the UI in the `Logs` tab of your Endpoint"
|
369
|
+
)
|
370
|
+
|
371
|
+
def log_generator() -> Generator[str, bool, None]:
|
372
|
+
last_log_count = 0
|
373
|
+
while True:
|
374
|
+
logs = self.databricks_client.serving_endpoints.logs(
|
375
|
+
name=self._generate_an_endpoint_name(),
|
376
|
+
served_model_name=self.config.model_name,
|
377
|
+
)
|
378
|
+
|
379
|
+
log_lines = logs.logs.split("\n")
|
380
|
+
|
381
|
+
# Apply tail if specified and it's the first iteration
|
382
|
+
if tail is not None and last_log_count == 0:
|
383
|
+
log_lines = log_lines[-tail:]
|
384
|
+
|
385
|
+
# Yield only new lines
|
386
|
+
for line in log_lines[last_log_count:]:
|
387
|
+
yield line
|
388
|
+
|
389
|
+
last_log_count = len(log_lines)
|
390
|
+
|
391
|
+
if not follow:
|
392
|
+
break
|
393
|
+
|
394
|
+
# Add a small delay to avoid excessive API calls
|
395
|
+
time.sleep(1)
|
396
|
+
|
397
|
+
yield from log_generator()
|
398
|
+
|
399
|
+
def _generate_an_endpoint_name(self) -> str:
|
400
|
+
"""Generate a unique name for the Databricks Inference Endpoint.
|
401
|
+
|
402
|
+
Returns:
|
403
|
+
A unique name for the Databricks Inference Endpoint.
|
404
|
+
"""
|
405
|
+
return (
|
406
|
+
f"{self.config.service_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}"
|
407
|
+
)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Utilities for Databricks integration."""
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Databricks utilities."""
|
15
|
+
|
16
|
+
import re
|
17
|
+
from typing import Dict, List, Optional
|
18
|
+
|
19
|
+
from databricks.sdk.service.compute import Library, PythonPyPiLibrary
|
20
|
+
from databricks.sdk.service.jobs import PythonWheelTask, TaskDependency
|
21
|
+
from databricks.sdk.service.jobs import Task as DatabricksTask
|
22
|
+
|
23
|
+
from zenml import __version__
|
24
|
+
|
25
|
+
|
26
|
+
def convert_step_to_task(
|
27
|
+
task_name: str,
|
28
|
+
command: str,
|
29
|
+
arguments: List[str],
|
30
|
+
libraries: Optional[List[str]] = None,
|
31
|
+
depends_on: Optional[List[str]] = None,
|
32
|
+
zenml_project_wheel: Optional[str] = None,
|
33
|
+
job_cluster_key: Optional[str] = None,
|
34
|
+
) -> DatabricksTask:
|
35
|
+
"""Convert a ZenML step to a Databricks task.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
task_name: Name of the task.
|
39
|
+
command: Command to run.
|
40
|
+
arguments: Arguments to pass to the command.
|
41
|
+
libraries: List of libraries to install.
|
42
|
+
depends_on: List of tasks to depend on.
|
43
|
+
zenml_project_wheel: Path to the ZenML project wheel.
|
44
|
+
job_cluster_key: ID of the Databricks job_cluster_key.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Databricks task.
|
48
|
+
"""
|
49
|
+
db_libraries = []
|
50
|
+
if libraries:
|
51
|
+
for library in libraries:
|
52
|
+
db_libraries.append(Library(pypi=PythonPyPiLibrary(library)))
|
53
|
+
db_libraries.append(Library(whl=zenml_project_wheel))
|
54
|
+
db_libraries.append(
|
55
|
+
Library(pypi=PythonPyPiLibrary(f"zenml=={__version__}"))
|
56
|
+
)
|
57
|
+
return DatabricksTask(
|
58
|
+
task_key=task_name,
|
59
|
+
job_cluster_key=job_cluster_key,
|
60
|
+
libraries=db_libraries,
|
61
|
+
python_wheel_task=PythonWheelTask(
|
62
|
+
package_name="zenml",
|
63
|
+
entry_point=command,
|
64
|
+
parameters=arguments,
|
65
|
+
),
|
66
|
+
depends_on=[TaskDependency(task) for task in depends_on]
|
67
|
+
if depends_on
|
68
|
+
else None,
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def sanitize_labels(labels: Dict[str, str]) -> None:
|
73
|
+
"""Update the label values to be valid Kubernetes labels.
|
74
|
+
|
75
|
+
See:
|
76
|
+
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
|
77
|
+
|
78
|
+
Args:
|
79
|
+
labels: the labels to sanitize.
|
80
|
+
"""
|
81
|
+
for key, value in labels.items():
|
82
|
+
# Kubernetes labels must be alphanumeric, no longer than
|
83
|
+
# 63 characters, and must begin and end with an alphanumeric
|
84
|
+
# character ([a-z0-9A-Z])
|
85
|
+
labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
|
86
|
+
"-_."
|
87
|
+
)
|
@@ -284,14 +284,18 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
284
284
|
store_name=store_name,
|
285
285
|
store_config=store_config,
|
286
286
|
)
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
287
|
+
if self._context.config.data_docs_sites is not None:
|
288
|
+
for site_name, site_config in zenml_context_config[
|
289
|
+
"data_docs_sites"
|
290
|
+
].items():
|
291
|
+
self._context.config.data_docs_sites[site_name] = (
|
292
|
+
site_config
|
293
|
+
)
|
294
|
+
|
295
|
+
if (
|
296
|
+
self.config.configure_local_docs
|
297
|
+
and self._context.config.data_docs_sites is not None
|
298
|
+
):
|
295
299
|
client = Client()
|
296
300
|
artifact_store = client.active_stack.artifact_store
|
297
301
|
if artifact_store.flavor != "local":
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright (c) ZenML GmbH
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,12 +16,21 @@
|
|
16
16
|
import os
|
17
17
|
from collections import defaultdict
|
18
18
|
from tempfile import TemporaryDirectory, mkdtemp
|
19
|
-
from typing import
|
19
|
+
from typing import (
|
20
|
+
TYPE_CHECKING,
|
21
|
+
Any,
|
22
|
+
ClassVar,
|
23
|
+
Dict,
|
24
|
+
Optional,
|
25
|
+
Tuple,
|
26
|
+
Type,
|
27
|
+
Union,
|
28
|
+
)
|
20
29
|
|
21
30
|
from datasets import Dataset, load_from_disk
|
22
31
|
from datasets.dataset_dict import DatasetDict
|
23
32
|
|
24
|
-
from zenml.enums import ArtifactType
|
33
|
+
from zenml.enums import ArtifactType, VisualizationType
|
25
34
|
from zenml.io import fileio
|
26
35
|
from zenml.materializers.base_materializer import BaseMaterializer
|
27
36
|
from zenml.materializers.pandas_materializer import PandasMaterializer
|
@@ -33,6 +42,31 @@ if TYPE_CHECKING:
|
|
33
42
|
DEFAULT_DATASET_DIR = "hf_datasets"
|
34
43
|
|
35
44
|
|
45
|
+
def extract_repo_name(checksum_str: str) -> Optional[str]:
|
46
|
+
"""Extracts the repo name from the checksum string.
|
47
|
+
|
48
|
+
An example of a checksum_str is:
|
49
|
+
"hf://datasets/nyu-mll/glue@bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c/mrpc/train-00000-of-00001.parquet"
|
50
|
+
and the expected output is "nyu-mll/glue".
|
51
|
+
|
52
|
+
Args:
|
53
|
+
checksum_str: The checksum_str to extract the repo name from.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
str: The extracted repo name.
|
57
|
+
"""
|
58
|
+
dataset = None
|
59
|
+
try:
|
60
|
+
parts = checksum_str.split("/")
|
61
|
+
if len(parts) >= 4:
|
62
|
+
# Case: nyu-mll/glue
|
63
|
+
dataset = f"{parts[3]}/{parts[4].split('@')[0]}"
|
64
|
+
except Exception: # pylint: disable=broad-except
|
65
|
+
pass
|
66
|
+
|
67
|
+
return dataset
|
68
|
+
|
69
|
+
|
36
70
|
class HFDatasetMaterializer(BaseMaterializer):
|
37
71
|
"""Materializer to read data to and from huggingface datasets."""
|
38
72
|
|
@@ -103,3 +137,54 @@ class HFDatasetMaterializer(BaseMaterializer):
|
|
103
137
|
metadata[key][dataset_name] = value
|
104
138
|
return dict(metadata)
|
105
139
|
raise ValueError(f"Unsupported type {type(ds)}")
|
140
|
+
|
141
|
+
def save_visualizations(
|
142
|
+
self, ds: Union[Dataset, DatasetDict]
|
143
|
+
) -> Dict[str, VisualizationType]:
|
144
|
+
"""Save visualizations for the dataset.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
ds: The Dataset or DatasetDict to visualize.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
A dictionary mapping visualization paths to their types.
|
151
|
+
|
152
|
+
Raises:
|
153
|
+
ValueError: If the given object is not a `Dataset` or `DatasetDict`.
|
154
|
+
"""
|
155
|
+
visualizations = {}
|
156
|
+
|
157
|
+
if isinstance(ds, Dataset):
|
158
|
+
datasets = {"default": ds}
|
159
|
+
elif isinstance(ds, DatasetDict):
|
160
|
+
datasets = ds
|
161
|
+
else:
|
162
|
+
raise ValueError(f"Unsupported type {type(ds)}")
|
163
|
+
|
164
|
+
for name, dataset in datasets.items():
|
165
|
+
# Generate a unique identifier for the dataset
|
166
|
+
if dataset.info.download_checksums:
|
167
|
+
dataset_id = extract_repo_name(
|
168
|
+
[x for x in dataset.info.download_checksums.keys()][0]
|
169
|
+
)
|
170
|
+
if dataset_id:
|
171
|
+
# Create the iframe HTML
|
172
|
+
html = f"""
|
173
|
+
<iframe
|
174
|
+
src="https://huggingface.co/datasets/{dataset_id}/embed/viewer"
|
175
|
+
frameborder="0"
|
176
|
+
width="100%"
|
177
|
+
height="560px"
|
178
|
+
></iframe>
|
179
|
+
"""
|
180
|
+
|
181
|
+
# Save the HTML to a file
|
182
|
+
visualization_path = os.path.join(
|
183
|
+
self.uri, f"{name}_viewer.html"
|
184
|
+
)
|
185
|
+
with fileio.open(visualization_path, "w") as f:
|
186
|
+
f.write(html)
|
187
|
+
|
188
|
+
visualizations[visualization_path] = VisualizationType.HTML
|
189
|
+
|
190
|
+
return visualizations
|
@@ -111,15 +111,9 @@ def run_with_accelerate(
|
|
111
111
|
if isinstance(v, bool):
|
112
112
|
if v:
|
113
113
|
commands.append(f"--{k}")
|
114
|
-
elif isinstance(v, str):
|
115
|
-
commands += [f"--{k}", '"{v}"']
|
116
114
|
elif type(v) in (list, tuple, set):
|
117
115
|
for each in v:
|
118
|
-
commands
|
119
|
-
if isinstance(each, str):
|
120
|
-
commands.append(f'"{each}"')
|
121
|
-
else:
|
122
|
-
commands.append(f"{each}")
|
116
|
+
commands += [f"--{k}", f"{each}"]
|
123
117
|
else:
|
124
118
|
commands += [f"--{k}", f"{v}"]
|
125
119
|
|
@@ -139,10 +139,17 @@ def build_pod_manifest(
|
|
139
139
|
],
|
140
140
|
security_context=security_context,
|
141
141
|
)
|
142
|
+
image_pull_secrets = []
|
143
|
+
if pod_settings:
|
144
|
+
image_pull_secrets = [
|
145
|
+
k8s_client.V1LocalObjectReference(name=name)
|
146
|
+
for name in pod_settings.image_pull_secrets
|
147
|
+
]
|
142
148
|
|
143
149
|
pod_spec = k8s_client.V1PodSpec(
|
144
150
|
containers=[container_spec],
|
145
151
|
restart_policy="Never",
|
152
|
+
image_pull_secrets=image_pull_secrets,
|
146
153
|
)
|
147
154
|
|
148
155
|
if service_account_name is not None:
|