zenml-nightly 0.68.1.dev20241105__py3-none-any.whl → 0.68.1.dev20241107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/{load_directory_materializer.py → preexisting_data_materializer.py} +8 -9
  3. zenml/artifacts/utils.py +121 -59
  4. zenml/constants.py +1 -0
  5. zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py +19 -31
  6. zenml/integrations/evidently/__init__.py +1 -1
  7. zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +8 -12
  8. zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py +17 -18
  9. zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py +2 -5
  10. zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py +17 -18
  11. zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py +2 -3
  12. zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py +8 -15
  13. zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py +11 -16
  14. zenml/integrations/pillow/materializers/pillow_image_materializer.py +17 -20
  15. zenml/integrations/polars/materializers/dataframe_materializer.py +26 -39
  16. zenml/integrations/pycaret/materializers/model_materializer.py +7 -22
  17. zenml/integrations/tensorflow/materializers/keras_materializer.py +11 -22
  18. zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py +8 -15
  19. zenml/integrations/vllm/services/vllm_deployment.py +16 -7
  20. zenml/integrations/whylogs/materializers/whylogs_materializer.py +11 -18
  21. zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py +11 -22
  22. zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py +10 -19
  23. zenml/materializers/base_materializer.py +68 -1
  24. zenml/orchestrators/step_runner.py +17 -11
  25. zenml/stack/flavor.py +9 -5
  26. zenml/steps/step_context.py +2 -0
  27. zenml/utils/callback_registry.py +71 -0
  28. zenml/zen_server/rbac/endpoint_utils.py +43 -1
  29. zenml/zen_server/routers/artifact_version_endpoints.py +27 -1
  30. zenml/zen_stores/rest_zen_store.py +52 -0
  31. zenml/zen_stores/sql_zen_store.py +16 -0
  32. zenml/zen_stores/zen_store_interface.py +13 -0
  33. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/METADATA +1 -1
  34. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/RECORD +37 -36
  35. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/LICENSE +0 -0
  36. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/WHEEL +0 -0
  37. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/entry_points.txt +0 -0
@@ -28,7 +28,8 @@ from typing import (
28
28
  )
29
29
 
30
30
  from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
31
- from zenml.artifacts.utils import save_artifact
31
+ from zenml.artifacts.utils import _store_artifact_data_and_prepare_request
32
+ from zenml.client import Client
32
33
  from zenml.config.step_configurations import StepConfiguration
33
34
  from zenml.config.step_run_info import StepRunInfo
34
35
  from zenml.constants import (
@@ -153,7 +154,7 @@ class StepRunner:
153
154
 
154
155
  # Initialize the step context singleton
155
156
  StepContext._clear()
156
- StepContext(
157
+ step_context = StepContext(
157
158
  pipeline_run=pipeline_run,
158
159
  step_run=step_run,
159
160
  output_materializers=output_materializers,
@@ -246,6 +247,9 @@ class StepRunner:
246
247
  model_version=model_version,
247
248
  )
248
249
  finally:
250
+ step_context._cleanup_registry.execute_callbacks(
251
+ raise_on_exception=False
252
+ )
249
253
  StepContext._clear() # Remove the step context singleton
250
254
 
251
255
  # Update the status and output artifacts of the step run.
@@ -531,7 +535,7 @@ class StepRunner:
531
535
  The IDs of the published output artifacts.
532
536
  """
533
537
  step_context = get_step_context()
534
- output_artifacts: Dict[str, "ArtifactVersionResponse"] = {}
538
+ artifact_requests = []
535
539
 
536
540
  for output_name, return_value in output_data.items():
537
541
  data_type = type(return_value)
@@ -592,22 +596,24 @@ class StepRunner:
592
596
  # Get full set of tags
593
597
  tags = step_context.get_output_tags(output_name)
594
598
 
595
- artifact = save_artifact(
599
+ artifact_request = _store_artifact_data_and_prepare_request(
596
600
  name=artifact_name,
597
601
  data=return_value,
598
- materializer=materializer_class,
602
+ materializer_class=materializer_class,
599
603
  uri=uri,
600
- extract_metadata=artifact_metadata_enabled,
601
- include_visualizations=artifact_visualization_enabled,
604
+ store_metadata=artifact_metadata_enabled,
605
+ store_visualizations=artifact_visualization_enabled,
602
606
  has_custom_name=has_custom_name,
603
607
  version=version,
604
608
  tags=tags,
605
- user_metadata=user_metadata,
606
- manual_save=False,
609
+ metadata=user_metadata,
607
610
  )
608
- output_artifacts[output_name] = artifact
611
+ artifact_requests.append(artifact_request)
609
612
 
610
- return output_artifacts
613
+ responses = Client().zen_store.batch_create_artifact_versions(
614
+ artifact_requests
615
+ )
616
+ return dict(zip(output_data.keys(), responses))
611
617
 
612
618
  def load_and_run_hook(
613
619
  self,
zenml/stack/flavor.py CHANGED
@@ -16,6 +16,7 @@
16
16
  from abc import abstractmethod
17
17
  from typing import Any, Dict, Optional, Type, cast
18
18
 
19
+ from zenml.client import Client
19
20
  from zenml.enums import StackComponentType
20
21
  from zenml.models import (
21
22
  FlavorRequest,
@@ -146,9 +147,6 @@ class Flavor:
146
147
  Returns:
147
148
  The model.
148
149
  """
149
- from zenml.client import Client
150
-
151
- client = Client()
152
150
  connector_requirements = self.service_connector_requirements
153
151
  connector_type = (
154
152
  connector_requirements.connector_type
@@ -165,10 +163,16 @@ class Flavor:
165
163
  if connector_requirements
166
164
  else None
167
165
  )
166
+ user = None
167
+ workspace = None
168
+ if is_custom:
169
+ user = Client().active_user.id
170
+ workspace = Client().active_workspace.id
171
+
168
172
  model_class = FlavorRequest if is_custom else InternalFlavorRequest
169
173
  model = model_class(
170
- user=client.active_user.id if is_custom else None,
171
- workspace=client.active_workspace.id if is_custom else None,
174
+ user=user,
175
+ workspace=workspace,
172
176
  name=self.name,
173
177
  type=self.type,
174
178
  source=source_utils.resolve(self.__class__).import_path,
@@ -26,6 +26,7 @@ from typing import (
26
26
 
27
27
  from zenml.exceptions import StepContextError
28
28
  from zenml.logger import get_logger
29
+ from zenml.utils.callback_registry import CallbackRegistry
29
30
  from zenml.utils.singleton import SingletonMetaClass
30
31
 
31
32
  if TYPE_CHECKING:
@@ -145,6 +146,7 @@ class StepContext(metaclass=SingletonMetaClass):
145
146
  )
146
147
  for key in output_materializers.keys()
147
148
  }
149
+ self._cleanup_registry = CallbackRegistry()
148
150
 
149
151
  @property
150
152
  def pipeline(self) -> "PipelineResponse":
@@ -0,0 +1,71 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Callback registry implementation."""
15
+
16
+ from typing import Any, Callable, Dict, List, Tuple
17
+
18
+ from typing_extensions import ParamSpec
19
+
20
+ from zenml.logger import get_logger
21
+
22
+ P = ParamSpec("P")
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class CallbackRegistry:
28
+ """Callback registry class."""
29
+
30
+ def __init__(self) -> None:
31
+ """Initializes the callback registry."""
32
+ self._callbacks: List[
33
+ Tuple[Callable[P, Any], Tuple[Any], Dict[str, Any]]
34
+ ] = []
35
+
36
+ def register_callback(
37
+ self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
38
+ ) -> None:
39
+ """Register a callback.
40
+
41
+ Args:
42
+ callback: The callback to register.
43
+ *args: Arguments to call the callback with.
44
+ **kwargs: Keyword arguments to call the callback with.
45
+ """
46
+ self._callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
47
+
48
+ def reset(self) -> None:
49
+ """Reset the callbacks."""
50
+ self._callbacks = []
51
+
52
+ def execute_callbacks(self, raise_on_exception: bool) -> None:
53
+ """Execute all registered callbacks.
54
+
55
+ Args:
56
+ raise_on_exception: If True, exceptions raised during the execution
57
+ of the callbacks will be raised. If False, a warning with the
58
+ exception will be logged instead.
59
+
60
+ Raises:
61
+ Exception: Exceptions raised in any of the callbacks if
62
+ `raise_on_exception` is set to True.
63
+ """
64
+ for callback, args, kwargs in self._callbacks:
65
+ try:
66
+ callback(*args, **kwargs)
67
+ except Exception as e:
68
+ if raise_on_exception:
69
+ raise e
70
+ else:
71
+ logger.warning("Failed to run callback: %s", str(e))
@@ -13,7 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """High-level helper functions to write endpoints with RBAC."""
15
15
 
16
- from typing import Any, Callable, TypeVar, Union
16
+ from typing import Any, Callable, List, TypeVar, Union
17
17
  from uuid import UUID
18
18
 
19
19
  from pydantic import BaseModel
@@ -96,6 +96,48 @@ def verify_permissions_and_create_entity(
96
96
  return created
97
97
 
98
98
 
99
+ def verify_permissions_and_batch_create_entity(
100
+ batch: List[AnyRequest],
101
+ resource_type: ResourceType,
102
+ create_method: Callable[[List[AnyRequest]], List[AnyResponse]],
103
+ ) -> List[AnyResponse]:
104
+ """Verify permissions and create a batch of entities if authorized.
105
+
106
+ Args:
107
+ batch: The batch to create.
108
+ resource_type: The resource type of the entities to create.
109
+ create_method: The method to create the entities.
110
+
111
+ Raises:
112
+ IllegalOperationError: If the request model has a different owner then
113
+ the currently authenticated user.
114
+ RuntimeError: If the resource type is usage-tracked.
115
+
116
+ Returns:
117
+ The created entities.
118
+ """
119
+ auth_context = get_auth_context()
120
+ assert auth_context
121
+
122
+ for request_model in batch:
123
+ if isinstance(request_model, UserScopedRequest):
124
+ if request_model.user != auth_context.user.id:
125
+ raise IllegalOperationError(
126
+ f"Not allowed to create resource '{resource_type}' for a "
127
+ "different user."
128
+ )
129
+
130
+ verify_permission(resource_type=resource_type, action=Action.CREATE)
131
+
132
+ if resource_type in REPORTABLE_RESOURCES:
133
+ raise RuntimeError(
134
+ "Batch requests are currently not possible with usage-tracked features."
135
+ )
136
+
137
+ created = create_method(batch)
138
+ return created
139
+
140
+
99
141
  def verify_permissions_and_get_entity(
100
142
  id: UUIDOrStr,
101
143
  get_method: Callable[[UUIDOrStr], AnyResponse],
@@ -13,12 +13,13 @@
13
13
  # permissions and limitations under the License.
14
14
  """Endpoint definitions for artifact versions."""
15
15
 
16
+ from typing import List
16
17
  from uuid import UUID
17
18
 
18
19
  from fastapi import APIRouter, Depends, Security
19
20
 
20
21
  from zenml.artifacts.utils import load_artifact_visualization
21
- from zenml.constants import API, ARTIFACT_VERSIONS, VERSION_1, VISUALIZE
22
+ from zenml.constants import API, ARTIFACT_VERSIONS, BATCH, VERSION_1, VISUALIZE
22
23
  from zenml.models import (
23
24
  ArtifactVersionFilter,
24
25
  ArtifactVersionRequest,
@@ -30,6 +31,7 @@ from zenml.models import (
30
31
  from zenml.zen_server.auth import AuthContext, authorize
31
32
  from zenml.zen_server.exceptions import error_response
32
33
  from zenml.zen_server.rbac.endpoint_utils import (
34
+ verify_permissions_and_batch_create_entity,
33
35
  verify_permissions_and_create_entity,
34
36
  verify_permissions_and_delete_entity,
35
37
  verify_permissions_and_get_entity,
@@ -118,6 +120,30 @@ def create_artifact_version(
118
120
  )
119
121
 
120
122
 
123
+ @artifact_version_router.post(
124
+ BATCH,
125
+ responses={401: error_response, 409: error_response, 422: error_response},
126
+ )
127
+ @handle_exceptions
128
+ def batch_create_artifact_version(
129
+ artifact_versions: List[ArtifactVersionRequest],
130
+ _: AuthContext = Security(authorize),
131
+ ) -> List[ArtifactVersionResponse]:
132
+ """Create a batch of artifact versions.
133
+
134
+ Args:
135
+ artifact_versions: The artifact versions to create.
136
+
137
+ Returns:
138
+ The created artifact versions.
139
+ """
140
+ return verify_permissions_and_batch_create_entity(
141
+ batch=artifact_versions,
142
+ resource_type=ResourceType.ARTIFACT_VERSION,
143
+ create_method=zen_store().batch_create_artifact_versions,
144
+ )
145
+
146
+
121
147
  @artifact_version_router.get(
122
148
  "/{artifact_version_id}",
123
149
  response_model=ArtifactVersionResponse,
@@ -57,6 +57,7 @@ from zenml.constants import (
57
57
  ARTIFACT_VERSIONS,
58
58
  ARTIFACT_VISUALIZATIONS,
59
59
  ARTIFACTS,
60
+ BATCH,
60
61
  CODE_REFERENCES,
61
62
  CODE_REPOSITORIES,
62
63
  CONFIG,
@@ -991,6 +992,23 @@ class RestZenStore(BaseZenStore):
991
992
  route=ARTIFACT_VERSIONS,
992
993
  )
993
994
 
995
+ def batch_create_artifact_versions(
996
+ self, artifact_versions: List[ArtifactVersionRequest]
997
+ ) -> List[ArtifactVersionResponse]:
998
+ """Creates a batch of artifact versions.
999
+
1000
+ Args:
1001
+ artifact_versions: The artifact versions to create.
1002
+
1003
+ Returns:
1004
+ The created artifact versions.
1005
+ """
1006
+ return self._batch_create_resources(
1007
+ resources=artifact_versions,
1008
+ response_model=ArtifactVersionResponse,
1009
+ route=ARTIFACT_VERSIONS,
1010
+ )
1011
+
994
1012
  def get_artifact_version(
995
1013
  self, artifact_version_id: UUID, hydrate: bool = True
996
1014
  ) -> ArtifactVersionResponse:
@@ -4518,6 +4536,40 @@ class RestZenStore(BaseZenStore):
4518
4536
 
4519
4537
  return response_model.model_validate(response_body)
4520
4538
 
4539
+ def _batch_create_resources(
4540
+ self,
4541
+ resources: List[AnyRequest],
4542
+ response_model: Type[AnyResponse],
4543
+ route: str,
4544
+ params: Optional[Dict[str, Any]] = None,
4545
+ ) -> List[AnyResponse]:
4546
+ """Create a new batch of resources.
4547
+
4548
+ Args:
4549
+ resources: The resources to create.
4550
+ response_model: The response model of an individual resource.
4551
+ route: The resource REST route to use.
4552
+ params: Optional query parameters to pass to the endpoint.
4553
+
4554
+ Returns:
4555
+ List of response models.
4556
+ """
4557
+ json_data = [
4558
+ resource.model_dump(mode="json") for resource in resources
4559
+ ]
4560
+ response = self._request(
4561
+ "POST",
4562
+ self.url + API + VERSION_1 + route + BATCH,
4563
+ json=json_data,
4564
+ params=params,
4565
+ )
4566
+ assert isinstance(response, list)
4567
+
4568
+ return [
4569
+ response_model.model_validate(model_data)
4570
+ for model_data in response
4571
+ ]
4572
+
4521
4573
  def _create_workspace_scoped_resource(
4522
4574
  self,
4523
4575
  resource: AnyWorkspaceScopedRequest,
@@ -2915,6 +2915,22 @@ class SqlZenStore(BaseZenStore):
2915
2915
  include_metadata=True, include_resources=True
2916
2916
  )
2917
2917
 
2918
+ def batch_create_artifact_versions(
2919
+ self, artifact_versions: List[ArtifactVersionRequest]
2920
+ ) -> List[ArtifactVersionResponse]:
2921
+ """Creates a batch of artifact versions.
2922
+
2923
+ Args:
2924
+ artifact_versions: The artifact versions to create.
2925
+
2926
+ Returns:
2927
+ The created artifact versions.
2928
+ """
2929
+ return [
2930
+ self.create_artifact_version(artifact_version)
2931
+ for artifact_version in artifact_versions
2932
+ ]
2933
+
2918
2934
  def get_artifact_version(
2919
2935
  self, artifact_version_id: UUID, hydrate: bool = True
2920
2936
  ) -> ArtifactVersionResponse:
@@ -663,6 +663,19 @@ class ZenStoreInterface(ABC):
663
663
  The created artifact version.
664
664
  """
665
665
 
666
+ @abstractmethod
667
+ def batch_create_artifact_versions(
668
+ self, artifact_versions: List[ArtifactVersionRequest]
669
+ ) -> List[ArtifactVersionResponse]:
670
+ """Creates a batch of artifact versions.
671
+
672
+ Args:
673
+ artifact_versions: The artifact versions to create.
674
+
675
+ Returns:
676
+ The created artifact versions.
677
+ """
678
+
666
679
  @abstractmethod
667
680
  def get_artifact_version(
668
681
  self, artifact_version_id: UUID, hydrate: bool = True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: zenml-nightly
3
- Version: 0.68.1.dev20241105
3
+ Version: 0.68.1.dev20241107
4
4
  Summary: ZenML: Write production-ready ML code.
5
5
  Home-page: https://zenml.io
6
6
  License: Apache-2.0