zenml-nightly 0.68.1.dev20241105__py3-none-any.whl → 0.68.1.dev20241107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/{load_directory_materializer.py → preexisting_data_materializer.py} +8 -9
- zenml/artifacts/utils.py +121 -59
- zenml/constants.py +1 -0
- zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py +19 -31
- zenml/integrations/evidently/__init__.py +1 -1
- zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +8 -12
- zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py +17 -18
- zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py +2 -5
- zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py +17 -18
- zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py +2 -3
- zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py +8 -15
- zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py +11 -16
- zenml/integrations/pillow/materializers/pillow_image_materializer.py +17 -20
- zenml/integrations/polars/materializers/dataframe_materializer.py +26 -39
- zenml/integrations/pycaret/materializers/model_materializer.py +7 -22
- zenml/integrations/tensorflow/materializers/keras_materializer.py +11 -22
- zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py +8 -15
- zenml/integrations/vllm/services/vllm_deployment.py +16 -7
- zenml/integrations/whylogs/materializers/whylogs_materializer.py +11 -18
- zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py +11 -22
- zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py +10 -19
- zenml/materializers/base_materializer.py +68 -1
- zenml/orchestrators/step_runner.py +17 -11
- zenml/stack/flavor.py +9 -5
- zenml/steps/step_context.py +2 -0
- zenml/utils/callback_registry.py +71 -0
- zenml/zen_server/rbac/endpoint_utils.py +43 -1
- zenml/zen_server/routers/artifact_version_endpoints.py +27 -1
- zenml/zen_stores/rest_zen_store.py +52 -0
- zenml/zen_stores/sql_zen_store.py +16 -0
- zenml/zen_stores/zen_store_interface.py +13 -0
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/METADATA +1 -1
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/RECORD +37 -36
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/entry_points.txt +0 -0
@@ -28,7 +28,8 @@ from typing import (
|
|
28
28
|
)
|
29
29
|
|
30
30
|
from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
|
31
|
-
from zenml.artifacts.utils import
|
31
|
+
from zenml.artifacts.utils import _store_artifact_data_and_prepare_request
|
32
|
+
from zenml.client import Client
|
32
33
|
from zenml.config.step_configurations import StepConfiguration
|
33
34
|
from zenml.config.step_run_info import StepRunInfo
|
34
35
|
from zenml.constants import (
|
@@ -153,7 +154,7 @@ class StepRunner:
|
|
153
154
|
|
154
155
|
# Initialize the step context singleton
|
155
156
|
StepContext._clear()
|
156
|
-
StepContext(
|
157
|
+
step_context = StepContext(
|
157
158
|
pipeline_run=pipeline_run,
|
158
159
|
step_run=step_run,
|
159
160
|
output_materializers=output_materializers,
|
@@ -246,6 +247,9 @@ class StepRunner:
|
|
246
247
|
model_version=model_version,
|
247
248
|
)
|
248
249
|
finally:
|
250
|
+
step_context._cleanup_registry.execute_callbacks(
|
251
|
+
raise_on_exception=False
|
252
|
+
)
|
249
253
|
StepContext._clear() # Remove the step context singleton
|
250
254
|
|
251
255
|
# Update the status and output artifacts of the step run.
|
@@ -531,7 +535,7 @@ class StepRunner:
|
|
531
535
|
The IDs of the published output artifacts.
|
532
536
|
"""
|
533
537
|
step_context = get_step_context()
|
534
|
-
|
538
|
+
artifact_requests = []
|
535
539
|
|
536
540
|
for output_name, return_value in output_data.items():
|
537
541
|
data_type = type(return_value)
|
@@ -592,22 +596,24 @@ class StepRunner:
|
|
592
596
|
# Get full set of tags
|
593
597
|
tags = step_context.get_output_tags(output_name)
|
594
598
|
|
595
|
-
|
599
|
+
artifact_request = _store_artifact_data_and_prepare_request(
|
596
600
|
name=artifact_name,
|
597
601
|
data=return_value,
|
598
|
-
|
602
|
+
materializer_class=materializer_class,
|
599
603
|
uri=uri,
|
600
|
-
|
601
|
-
|
604
|
+
store_metadata=artifact_metadata_enabled,
|
605
|
+
store_visualizations=artifact_visualization_enabled,
|
602
606
|
has_custom_name=has_custom_name,
|
603
607
|
version=version,
|
604
608
|
tags=tags,
|
605
|
-
|
606
|
-
manual_save=False,
|
609
|
+
metadata=user_metadata,
|
607
610
|
)
|
608
|
-
|
611
|
+
artifact_requests.append(artifact_request)
|
609
612
|
|
610
|
-
|
613
|
+
responses = Client().zen_store.batch_create_artifact_versions(
|
614
|
+
artifact_requests
|
615
|
+
)
|
616
|
+
return dict(zip(output_data.keys(), responses))
|
611
617
|
|
612
618
|
def load_and_run_hook(
|
613
619
|
self,
|
zenml/stack/flavor.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
from abc import abstractmethod
|
17
17
|
from typing import Any, Dict, Optional, Type, cast
|
18
18
|
|
19
|
+
from zenml.client import Client
|
19
20
|
from zenml.enums import StackComponentType
|
20
21
|
from zenml.models import (
|
21
22
|
FlavorRequest,
|
@@ -146,9 +147,6 @@ class Flavor:
|
|
146
147
|
Returns:
|
147
148
|
The model.
|
148
149
|
"""
|
149
|
-
from zenml.client import Client
|
150
|
-
|
151
|
-
client = Client()
|
152
150
|
connector_requirements = self.service_connector_requirements
|
153
151
|
connector_type = (
|
154
152
|
connector_requirements.connector_type
|
@@ -165,10 +163,16 @@ class Flavor:
|
|
165
163
|
if connector_requirements
|
166
164
|
else None
|
167
165
|
)
|
166
|
+
user = None
|
167
|
+
workspace = None
|
168
|
+
if is_custom:
|
169
|
+
user = Client().active_user.id
|
170
|
+
workspace = Client().active_workspace.id
|
171
|
+
|
168
172
|
model_class = FlavorRequest if is_custom else InternalFlavorRequest
|
169
173
|
model = model_class(
|
170
|
-
user=
|
171
|
-
workspace=
|
174
|
+
user=user,
|
175
|
+
workspace=workspace,
|
172
176
|
name=self.name,
|
173
177
|
type=self.type,
|
174
178
|
source=source_utils.resolve(self.__class__).import_path,
|
zenml/steps/step_context.py
CHANGED
@@ -26,6 +26,7 @@ from typing import (
|
|
26
26
|
|
27
27
|
from zenml.exceptions import StepContextError
|
28
28
|
from zenml.logger import get_logger
|
29
|
+
from zenml.utils.callback_registry import CallbackRegistry
|
29
30
|
from zenml.utils.singleton import SingletonMetaClass
|
30
31
|
|
31
32
|
if TYPE_CHECKING:
|
@@ -145,6 +146,7 @@ class StepContext(metaclass=SingletonMetaClass):
|
|
145
146
|
)
|
146
147
|
for key in output_materializers.keys()
|
147
148
|
}
|
149
|
+
self._cleanup_registry = CallbackRegistry()
|
148
150
|
|
149
151
|
@property
|
150
152
|
def pipeline(self) -> "PipelineResponse":
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Callback registry implementation."""
|
15
|
+
|
16
|
+
from typing import Any, Callable, Dict, List, Tuple
|
17
|
+
|
18
|
+
from typing_extensions import ParamSpec
|
19
|
+
|
20
|
+
from zenml.logger import get_logger
|
21
|
+
|
22
|
+
P = ParamSpec("P")
|
23
|
+
|
24
|
+
logger = get_logger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class CallbackRegistry:
|
28
|
+
"""Callback registry class."""
|
29
|
+
|
30
|
+
def __init__(self) -> None:
|
31
|
+
"""Initializes the callback registry."""
|
32
|
+
self._callbacks: List[
|
33
|
+
Tuple[Callable[P, Any], Tuple[Any], Dict[str, Any]]
|
34
|
+
] = []
|
35
|
+
|
36
|
+
def register_callback(
|
37
|
+
self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
|
38
|
+
) -> None:
|
39
|
+
"""Register a callback.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
callback: The callback to register.
|
43
|
+
*args: Arguments to call the callback with.
|
44
|
+
**kwargs: Keyword arguments to call the callback with.
|
45
|
+
"""
|
46
|
+
self._callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
47
|
+
|
48
|
+
def reset(self) -> None:
|
49
|
+
"""Reset the callbacks."""
|
50
|
+
self._callbacks = []
|
51
|
+
|
52
|
+
def execute_callbacks(self, raise_on_exception: bool) -> None:
|
53
|
+
"""Execute all registered callbacks.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
raise_on_exception: If True, exceptions raised during the execution
|
57
|
+
of the callbacks will be raised. If False, a warning with the
|
58
|
+
exception will be logged instead.
|
59
|
+
|
60
|
+
Raises:
|
61
|
+
Exception: Exceptions raised in any of the callbacks if
|
62
|
+
`raise_on_exception` is set to True.
|
63
|
+
"""
|
64
|
+
for callback, args, kwargs in self._callbacks:
|
65
|
+
try:
|
66
|
+
callback(*args, **kwargs)
|
67
|
+
except Exception as e:
|
68
|
+
if raise_on_exception:
|
69
|
+
raise e
|
70
|
+
else:
|
71
|
+
logger.warning("Failed to run callback: %s", str(e))
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""High-level helper functions to write endpoints with RBAC."""
|
15
15
|
|
16
|
-
from typing import Any, Callable, TypeVar, Union
|
16
|
+
from typing import Any, Callable, List, TypeVar, Union
|
17
17
|
from uuid import UUID
|
18
18
|
|
19
19
|
from pydantic import BaseModel
|
@@ -96,6 +96,48 @@ def verify_permissions_and_create_entity(
|
|
96
96
|
return created
|
97
97
|
|
98
98
|
|
99
|
+
def verify_permissions_and_batch_create_entity(
|
100
|
+
batch: List[AnyRequest],
|
101
|
+
resource_type: ResourceType,
|
102
|
+
create_method: Callable[[List[AnyRequest]], List[AnyResponse]],
|
103
|
+
) -> List[AnyResponse]:
|
104
|
+
"""Verify permissions and create a batch of entities if authorized.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
batch: The batch to create.
|
108
|
+
resource_type: The resource type of the entities to create.
|
109
|
+
create_method: The method to create the entities.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
IllegalOperationError: If the request model has a different owner then
|
113
|
+
the currently authenticated user.
|
114
|
+
RuntimeError: If the resource type is usage-tracked.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The created entities.
|
118
|
+
"""
|
119
|
+
auth_context = get_auth_context()
|
120
|
+
assert auth_context
|
121
|
+
|
122
|
+
for request_model in batch:
|
123
|
+
if isinstance(request_model, UserScopedRequest):
|
124
|
+
if request_model.user != auth_context.user.id:
|
125
|
+
raise IllegalOperationError(
|
126
|
+
f"Not allowed to create resource '{resource_type}' for a "
|
127
|
+
"different user."
|
128
|
+
)
|
129
|
+
|
130
|
+
verify_permission(resource_type=resource_type, action=Action.CREATE)
|
131
|
+
|
132
|
+
if resource_type in REPORTABLE_RESOURCES:
|
133
|
+
raise RuntimeError(
|
134
|
+
"Batch requests are currently not possible with usage-tracked features."
|
135
|
+
)
|
136
|
+
|
137
|
+
created = create_method(batch)
|
138
|
+
return created
|
139
|
+
|
140
|
+
|
99
141
|
def verify_permissions_and_get_entity(
|
100
142
|
id: UUIDOrStr,
|
101
143
|
get_method: Callable[[UUIDOrStr], AnyResponse],
|
@@ -13,12 +13,13 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Endpoint definitions for artifact versions."""
|
15
15
|
|
16
|
+
from typing import List
|
16
17
|
from uuid import UUID
|
17
18
|
|
18
19
|
from fastapi import APIRouter, Depends, Security
|
19
20
|
|
20
21
|
from zenml.artifacts.utils import load_artifact_visualization
|
21
|
-
from zenml.constants import API, ARTIFACT_VERSIONS, VERSION_1, VISUALIZE
|
22
|
+
from zenml.constants import API, ARTIFACT_VERSIONS, BATCH, VERSION_1, VISUALIZE
|
22
23
|
from zenml.models import (
|
23
24
|
ArtifactVersionFilter,
|
24
25
|
ArtifactVersionRequest,
|
@@ -30,6 +31,7 @@ from zenml.models import (
|
|
30
31
|
from zenml.zen_server.auth import AuthContext, authorize
|
31
32
|
from zenml.zen_server.exceptions import error_response
|
32
33
|
from zenml.zen_server.rbac.endpoint_utils import (
|
34
|
+
verify_permissions_and_batch_create_entity,
|
33
35
|
verify_permissions_and_create_entity,
|
34
36
|
verify_permissions_and_delete_entity,
|
35
37
|
verify_permissions_and_get_entity,
|
@@ -118,6 +120,30 @@ def create_artifact_version(
|
|
118
120
|
)
|
119
121
|
|
120
122
|
|
123
|
+
@artifact_version_router.post(
|
124
|
+
BATCH,
|
125
|
+
responses={401: error_response, 409: error_response, 422: error_response},
|
126
|
+
)
|
127
|
+
@handle_exceptions
|
128
|
+
def batch_create_artifact_version(
|
129
|
+
artifact_versions: List[ArtifactVersionRequest],
|
130
|
+
_: AuthContext = Security(authorize),
|
131
|
+
) -> List[ArtifactVersionResponse]:
|
132
|
+
"""Create a batch of artifact versions.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
artifact_versions: The artifact versions to create.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
The created artifact versions.
|
139
|
+
"""
|
140
|
+
return verify_permissions_and_batch_create_entity(
|
141
|
+
batch=artifact_versions,
|
142
|
+
resource_type=ResourceType.ARTIFACT_VERSION,
|
143
|
+
create_method=zen_store().batch_create_artifact_versions,
|
144
|
+
)
|
145
|
+
|
146
|
+
|
121
147
|
@artifact_version_router.get(
|
122
148
|
"/{artifact_version_id}",
|
123
149
|
response_model=ArtifactVersionResponse,
|
@@ -57,6 +57,7 @@ from zenml.constants import (
|
|
57
57
|
ARTIFACT_VERSIONS,
|
58
58
|
ARTIFACT_VISUALIZATIONS,
|
59
59
|
ARTIFACTS,
|
60
|
+
BATCH,
|
60
61
|
CODE_REFERENCES,
|
61
62
|
CODE_REPOSITORIES,
|
62
63
|
CONFIG,
|
@@ -991,6 +992,23 @@ class RestZenStore(BaseZenStore):
|
|
991
992
|
route=ARTIFACT_VERSIONS,
|
992
993
|
)
|
993
994
|
|
995
|
+
def batch_create_artifact_versions(
|
996
|
+
self, artifact_versions: List[ArtifactVersionRequest]
|
997
|
+
) -> List[ArtifactVersionResponse]:
|
998
|
+
"""Creates a batch of artifact versions.
|
999
|
+
|
1000
|
+
Args:
|
1001
|
+
artifact_versions: The artifact versions to create.
|
1002
|
+
|
1003
|
+
Returns:
|
1004
|
+
The created artifact versions.
|
1005
|
+
"""
|
1006
|
+
return self._batch_create_resources(
|
1007
|
+
resources=artifact_versions,
|
1008
|
+
response_model=ArtifactVersionResponse,
|
1009
|
+
route=ARTIFACT_VERSIONS,
|
1010
|
+
)
|
1011
|
+
|
994
1012
|
def get_artifact_version(
|
995
1013
|
self, artifact_version_id: UUID, hydrate: bool = True
|
996
1014
|
) -> ArtifactVersionResponse:
|
@@ -4518,6 +4536,40 @@ class RestZenStore(BaseZenStore):
|
|
4518
4536
|
|
4519
4537
|
return response_model.model_validate(response_body)
|
4520
4538
|
|
4539
|
+
def _batch_create_resources(
|
4540
|
+
self,
|
4541
|
+
resources: List[AnyRequest],
|
4542
|
+
response_model: Type[AnyResponse],
|
4543
|
+
route: str,
|
4544
|
+
params: Optional[Dict[str, Any]] = None,
|
4545
|
+
) -> List[AnyResponse]:
|
4546
|
+
"""Create a new batch of resources.
|
4547
|
+
|
4548
|
+
Args:
|
4549
|
+
resources: The resources to create.
|
4550
|
+
response_model: The response model of an individual resource.
|
4551
|
+
route: The resource REST route to use.
|
4552
|
+
params: Optional query parameters to pass to the endpoint.
|
4553
|
+
|
4554
|
+
Returns:
|
4555
|
+
List of response models.
|
4556
|
+
"""
|
4557
|
+
json_data = [
|
4558
|
+
resource.model_dump(mode="json") for resource in resources
|
4559
|
+
]
|
4560
|
+
response = self._request(
|
4561
|
+
"POST",
|
4562
|
+
self.url + API + VERSION_1 + route + BATCH,
|
4563
|
+
json=json_data,
|
4564
|
+
params=params,
|
4565
|
+
)
|
4566
|
+
assert isinstance(response, list)
|
4567
|
+
|
4568
|
+
return [
|
4569
|
+
response_model.model_validate(model_data)
|
4570
|
+
for model_data in response
|
4571
|
+
]
|
4572
|
+
|
4521
4573
|
def _create_workspace_scoped_resource(
|
4522
4574
|
self,
|
4523
4575
|
resource: AnyWorkspaceScopedRequest,
|
@@ -2915,6 +2915,22 @@ class SqlZenStore(BaseZenStore):
|
|
2915
2915
|
include_metadata=True, include_resources=True
|
2916
2916
|
)
|
2917
2917
|
|
2918
|
+
def batch_create_artifact_versions(
|
2919
|
+
self, artifact_versions: List[ArtifactVersionRequest]
|
2920
|
+
) -> List[ArtifactVersionResponse]:
|
2921
|
+
"""Creates a batch of artifact versions.
|
2922
|
+
|
2923
|
+
Args:
|
2924
|
+
artifact_versions: The artifact versions to create.
|
2925
|
+
|
2926
|
+
Returns:
|
2927
|
+
The created artifact versions.
|
2928
|
+
"""
|
2929
|
+
return [
|
2930
|
+
self.create_artifact_version(artifact_version)
|
2931
|
+
for artifact_version in artifact_versions
|
2932
|
+
]
|
2933
|
+
|
2918
2934
|
def get_artifact_version(
|
2919
2935
|
self, artifact_version_id: UUID, hydrate: bool = True
|
2920
2936
|
) -> ArtifactVersionResponse:
|
@@ -663,6 +663,19 @@ class ZenStoreInterface(ABC):
|
|
663
663
|
The created artifact version.
|
664
664
|
"""
|
665
665
|
|
666
|
+
@abstractmethod
|
667
|
+
def batch_create_artifact_versions(
|
668
|
+
self, artifact_versions: List[ArtifactVersionRequest]
|
669
|
+
) -> List[ArtifactVersionResponse]:
|
670
|
+
"""Creates a batch of artifact versions.
|
671
|
+
|
672
|
+
Args:
|
673
|
+
artifact_versions: The artifact versions to create.
|
674
|
+
|
675
|
+
Returns:
|
676
|
+
The created artifact versions.
|
677
|
+
"""
|
678
|
+
|
666
679
|
@abstractmethod
|
667
680
|
def get_artifact_version(
|
668
681
|
self, artifact_version_id: UUID, hydrate: bool = True
|