zenml-nightly 0.70.0.dev20241125__py3-none-any.whl → 0.71.0.dev20241220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +4 -4
- RELEASE_NOTES.md +112 -0
- zenml/VERSION +1 -1
- zenml/artifact_stores/base_artifact_store.py +2 -2
- zenml/artifacts/artifact_config.py +15 -6
- zenml/artifacts/utils.py +59 -32
- zenml/cli/__init__.py +22 -4
- zenml/cli/base.py +5 -5
- zenml/cli/login.py +26 -0
- zenml/cli/pipeline.py +111 -62
- zenml/cli/server.py +20 -20
- zenml/cli/service_connectors.py +3 -3
- zenml/cli/stack.py +0 -3
- zenml/cli/stack_components.py +0 -1
- zenml/cli/utils.py +0 -5
- zenml/client.py +62 -20
- zenml/config/compiler.py +12 -3
- zenml/config/pipeline_configurations.py +20 -0
- zenml/config/pipeline_run_configuration.py +1 -0
- zenml/config/secret_reference_mixin.py +1 -1
- zenml/config/server_config.py +4 -0
- zenml/config/step_configurations.py +21 -0
- zenml/constants.py +10 -0
- zenml/enums.py +1 -0
- zenml/image_builders/base_image_builder.py +5 -2
- zenml/image_builders/build_context.py +7 -16
- zenml/image_builders/local_image_builder.py +13 -3
- zenml/integrations/__init__.py +1 -0
- zenml/integrations/aws/__init__.py +3 -0
- zenml/integrations/aws/flavors/__init__.py +6 -0
- zenml/integrations/aws/flavors/aws_image_builder_flavor.py +146 -0
- zenml/integrations/aws/image_builders/__init__.py +20 -0
- zenml/integrations/aws/image_builders/aws_image_builder.py +307 -0
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +14 -6
- zenml/integrations/constants.py +1 -0
- zenml/integrations/feast/__init__.py +1 -1
- zenml/integrations/feast/feature_stores/feast_feature_store.py +13 -9
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +1 -1
- zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +2 -1
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +11 -0
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +46 -2
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +13 -2
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +3 -1
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +3 -2
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +3 -2
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +11 -0
- zenml/integrations/modal/__init__.py +46 -0
- zenml/integrations/modal/flavors/__init__.py +26 -0
- zenml/integrations/modal/flavors/modal_step_operator_flavor.py +125 -0
- zenml/integrations/modal/step_operators/__init__.py +22 -0
- zenml/integrations/modal/step_operators/modal_step_operator.py +242 -0
- zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py +7 -5
- zenml/integrations/neptune/experiment_trackers/run_state.py +69 -53
- zenml/integrations/registry.py +2 -2
- zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +12 -0
- zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +13 -5
- zenml/io/filesystem.py +2 -2
- zenml/io/local_filesystem.py +3 -3
- zenml/materializers/built_in_materializer.py +18 -1
- zenml/materializers/structured_string_materializer.py +8 -3
- zenml/model/model.py +23 -101
- zenml/model/utils.py +21 -17
- zenml/models/__init__.py +6 -0
- zenml/models/v2/base/filter.py +26 -30
- zenml/models/v2/base/scoped.py +258 -5
- zenml/models/v2/core/artifact_version.py +21 -29
- zenml/models/v2/core/code_repository.py +1 -12
- zenml/models/v2/core/component.py +5 -68
- zenml/models/v2/core/flavor.py +1 -11
- zenml/models/v2/core/model.py +1 -57
- zenml/models/v2/core/model_version.py +11 -36
- zenml/models/v2/core/model_version_artifact.py +11 -3
- zenml/models/v2/core/model_version_pipeline_run.py +14 -3
- zenml/models/v2/core/pipeline.py +47 -55
- zenml/models/v2/core/pipeline_build.py +67 -12
- zenml/models/v2/core/pipeline_deployment.py +0 -10
- zenml/models/v2/core/pipeline_run.py +110 -32
- zenml/models/v2/core/run_metadata.py +30 -9
- zenml/models/v2/core/run_template.py +21 -29
- zenml/models/v2/core/schedule.py +0 -10
- zenml/models/v2/core/secret.py +0 -14
- zenml/models/v2/core/service.py +9 -16
- zenml/models/v2/core/service_connector.py +0 -11
- zenml/models/v2/core/stack.py +21 -30
- zenml/models/v2/core/step_run.py +24 -18
- zenml/models/v2/core/trigger.py +19 -3
- zenml/models/v2/misc/run_metadata.py +38 -0
- zenml/orchestrators/base_orchestrator.py +13 -1
- zenml/orchestrators/input_utils.py +19 -6
- zenml/orchestrators/output_utils.py +5 -1
- zenml/orchestrators/publish_utils.py +12 -5
- zenml/orchestrators/step_launcher.py +16 -16
- zenml/orchestrators/step_run_utils.py +18 -197
- zenml/orchestrators/step_runner.py +40 -3
- zenml/orchestrators/utils.py +79 -50
- zenml/pipelines/build_utils.py +12 -0
- zenml/pipelines/pipeline_decorator.py +4 -0
- zenml/pipelines/pipeline_definition.py +26 -8
- zenml/pipelines/run_utils.py +9 -5
- zenml/service_connectors/service_connector_utils.py +3 -9
- zenml/stack/stack_component.py +1 -1
- zenml/stack_deployments/aws_stack_deployment.py +22 -0
- zenml/steps/base_step.py +11 -1
- zenml/steps/entrypoint_function_utils.py +7 -3
- zenml/steps/step_decorator.py +4 -0
- zenml/steps/utils.py +23 -7
- zenml/types.py +4 -0
- zenml/utils/archivable.py +65 -36
- zenml/utils/code_utils.py +8 -4
- zenml/utils/docker_utils.py +9 -0
- zenml/utils/metadata_utils.py +186 -153
- zenml/utils/string_utils.py +41 -16
- zenml/utils/visualization_utils.py +4 -1
- zenml/zen_server/auth.py +9 -10
- zenml/zen_server/cloud_utils.py +3 -1
- zenml/zen_server/dashboard/assets/{404-NVXKFp-x.js → 404-Cqu3EDCm.js} +1 -1
- zenml/zen_server/dashboard/assets/{@reactflow-CK0KJUen.js → @reactflow-D2Y7BWwz.js} +1 -1
- zenml/zen_server/dashboard/assets/{AlertDialogDropdownItem-DezXKmDf.js → AlertDialogDropdownItem-BHd71pVS.js} +1 -1
- zenml/zen_server/dashboard/assets/{CodeSnippet-JzR8CEtw.js → CodeSnippet-DIonwetW.js} +1 -1
- zenml/zen_server/dashboard/assets/{CollapsibleCard-DQW_ktMO.js → CollapsibleCard-CDnC97pB.js} +1 -1
- zenml/zen_server/dashboard/assets/{Commands-DL2kwkRd.js → Commands-BVEXKAOj.js} +1 -1
- zenml/zen_server/dashboard/assets/{ComponentBadge-D_g62Wv8.js → ComponentBadge-CrRvovox.js} +1 -1
- zenml/zen_server/dashboard/assets/{CopyButton-LNcWaa14.js → CopyButton-B6wGAhQv.js} +1 -1
- zenml/zen_server/dashboard/assets/{CsvVizualization-DknpE5ej.js → CsvVizualization-CjcT7LMm.js} +5 -5
- zenml/zen_server/dashboard/assets/DeleteAlertDialog-D2ELtM2W.js +1 -0
- zenml/zen_server/dashboard/assets/{DialogItem-Bxf8FuAT.js → DialogItem-DXIMhBgU.js} +1 -1
- zenml/zen_server/dashboard/assets/{Error-DYflYyps.js → Error-B8uUfTpL.js} +1 -1
- zenml/zen_server/dashboard/assets/{ExecutionStatus-C7zyIQKZ.js → ExecutionStatus-ibAdY-dG.js} +1 -1
- zenml/zen_server/dashboard/assets/{Helpbox-oYSGpLqd.js → Helpbox-BfAfhKHw.js} +1 -1
- zenml/zen_server/dashboard/assets/{Infobox-Cx4xGoXR.js → Infobox-M_SMOu96.js} +1 -1
- zenml/zen_server/dashboard/assets/{InlineAvatar-DiGOWNKF.js → InlineAvatar-DBA0a0-a.js} +1 -1
- zenml/zen_server/dashboard/assets/{NestedCollapsible-DYbgyKxK.js → NestedCollapsible-DpgmEFKw.js} +1 -1
- zenml/zen_server/dashboard/assets/{Partials-03iZf8-N.js → Partials-D_ldD9if.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProBadge-D_EB8HNo.js → ProBadge-DQbfFotM.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProCta-DqNS4v3x.js → ProCta-Bcpb4rcY.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProviderIcon-Bki2aw8w.js → ProviderIcon-BZpgPigN.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProviderRadio-8f43sPD4.js → ProviderRadio-DWPnMuQ1.js} +1 -1
- zenml/zen_server/dashboard/assets/RunSelector-DgRGaAc6.js +1 -0
- zenml/zen_server/dashboard/assets/{RunsBody-07YEO7qI.js → RunsBody-KecfSkjY.js} +1 -1
- zenml/zen_server/dashboard/assets/{SearchField-lp1KgU4e.js → SearchField-n-ILHnaP.js} +1 -1
- zenml/zen_server/dashboard/assets/{SecretTooltip-CgnbyeOx.js → SecretTooltip-B8MrX5yu.js} +1 -1
- zenml/zen_server/dashboard/assets/{SetPassword-CpP418A2.js → SetPassword-B_IVq_wg.js} +1 -1
- zenml/zen_server/dashboard/assets/StackList-TWPBYnkF.js +1 -0
- zenml/zen_server/dashboard/assets/{Tabs-BktHkCJJ.js → Tabs-Rg857zmd.js} +1 -1
- zenml/zen_server/dashboard/assets/{Tick-BlMoIlJT.js → Tick-COg4A-xo.js} +1 -1
- zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-Sc0A0pP-.js → UpdatePasswordSchemas-C6Aj3hm6.js} +1 -1
- zenml/zen_server/dashboard/assets/{UsageReason-YYduL4fj.js → UsageReason-BTLbx7w4.js} +1 -1
- zenml/zen_server/dashboard/assets/{WizardFooter-dgmizSJC.js → WizardFooter-BCAj69Vj.js} +1 -1
- zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-D-c2G6lV.js → all-pipeline-runs-query-DMXkDrV2.js} +1 -1
- zenml/zen_server/dashboard/assets/code-snippets-CqONne41.js +13 -0
- zenml/zen_server/dashboard/assets/{create-stack-DM_JPgef.js → create-stack-HfdbhLs4.js} +1 -1
- zenml/zen_server/dashboard/assets/dates-3pMLCNrD.js +1 -0
- zenml/zen_server/dashboard/assets/delete-run-DZ4hIXff.js +1 -0
- zenml/zen_server/dashboard/assets/{form-schemas-K6FYKjwa.js → form-schemas-B0AVEd9b.js} +1 -1
- zenml/zen_server/dashboard/assets/{index-BAkC7FXi.js → index-DPqSWjug.js} +1 -1
- zenml/zen_server/dashboard/assets/{index-CEV4Cvaf.js → index-DScjfBRb.js} +1 -1
- zenml/zen_server/dashboard/assets/index-DXvT1_Um.css +1 -0
- zenml/zen_server/dashboard/assets/{index-CCOPpudF.js → index-FO-p0GU7.js} +5 -5
- zenml/zen_server/dashboard/assets/{index-B1mVPYxf.js → index-I3bKUGUj.js} +1 -1
- zenml/zen_server/dashboard/assets/key-icon-aH-QIa5R.js +1 -0
- zenml/zen_server/dashboard/assets/login-command-CkqxPtV3.js +1 -0
- zenml/zen_server/dashboard/assets/{login-mutation-hf-lK87O.js → login-mutation-BQeo4wTY.js} +1 -1
- zenml/zen_server/dashboard/assets/{not-found-BGirLjU-.js → not-found-gAJ5aDdR.js} +1 -1
- zenml/zen_server/dashboard/assets/page-9Y9-gig0.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DjRJCGb3.js → page-AUwiQ14W.js} +1 -1
- zenml/zen_server/dashboard/assets/page-B6XU7yYT.js +2 -0
- zenml/zen_server/dashboard/assets/{page-C00YAkaB.js → page-BKZYc2Zv.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CdMWnQak.js → page-BU9FG4sR.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-D7S3aCbF.js → page-B_Apk3xg.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-Djikxq_S.js → page-BdowiCbr.js} +1 -1
- zenml/zen_server/dashboard/assets/page-Bg8OjTRe.js +1 -0
- zenml/zen_server/dashboard/assets/page-BxL4qD4_.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DakHVWXF.js → page-CWxT5K5J.js} +1 -1
- zenml/zen_server/dashboard/assets/page-CXuQufSe.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DLC-bNBP.js → page-CcQr8CPP.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CD-DcWoy.js → page-Ce4Hrjnr.js} +1 -1
- zenml/zen_server/dashboard/assets/page-CiYxgZP_.js +1 -0
- zenml/zen_server/dashboard/assets/page-Cldq1mpe.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BDigxVpo.js → page-D4wdonLm.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-D6uU2ax4.js → page-D8ObrbH8.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DXSTpqRD.js → page-DFuAUGt4.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CbpvrsDL.js → page-DGazBpuP.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-COXXJj1k.js → page-DO1UcqPX.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DRYXdL5o.js +1 -0
- zenml/zen_server/dashboard/assets/{page-Df-Fw0aq.js → page-DYEquBC2.js} +1 -1
- zenml/zen_server/dashboard/assets/page-Dk32IeZm.js +1 -0
- zenml/zen_server/dashboard/assets/{page-yYC9OI-E.js → page-I3nKFGie.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-6m6yHHlE.js → page-M0w-n6vn.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-Vcxara9U.js → page-R5dx3xGF.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-BR68V0V1.js → page-bT5pOvcB.js} +1 -1
- zenml/zen_server/dashboard/assets/page-hUqK889I.js +6 -0
- zenml/zen_server/dashboard/assets/{page-CjGdWY13.js → page-h_Stveon.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-D01JhjQB.js → page-r8XK5vR7.js} +1 -1
- zenml/zen_server/dashboard/assets/page-u_-ZXBKb.js +1 -0
- zenml/zen_server/dashboard/assets/page-zaMqB_ao.js +1 -0
- zenml/zen_server/dashboard/assets/{persist-GjC8PZoC.js → persist-AppN1B0J.js} +1 -1
- zenml/zen_server/dashboard/assets/{persist-Coz7ZWvz.js → persist-DAUi_3za.js} +1 -1
- zenml/zen_server/dashboard/assets/service-BqqeXLEe.js +2 -0
- zenml/zen_server/dashboard/assets/{sharedSchema-CQb14VSr.js → sharedSchema-uXN9FLLk.js} +1 -1
- zenml/zen_server/dashboard/assets/{stack-detail-query-OPEW-cDJ.js → stack-detail-query-XfZBiBP2.js} +1 -1
- zenml/zen_server/dashboard/assets/{update-server-settings-mutation-LwuQfHYn.js → update-server-settings-mutation-BWmgVJwA.js} +1 -1
- zenml/zen_server/dashboard/assets/{url-CkvKAnwF.js → url-BLwMbzES.js} +1 -1
- zenml/zen_server/dashboard/index.html +4 -4
- zenml/zen_server/deploy/helm/Chart.yaml +1 -1
- zenml/zen_server/deploy/helm/README.md +2 -2
- zenml/zen_server/rbac/endpoint_utils.py +6 -4
- zenml/zen_server/rbac/models.py +3 -2
- zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
- zenml/zen_server/rbac/utils.py +4 -7
- zenml/zen_server/routers/auth_endpoints.py +22 -11
- zenml/zen_server/routers/steps_endpoints.py +7 -1
- zenml/zen_server/routers/users_endpoints.py +35 -37
- zenml/zen_server/routers/workspaces_endpoints.py +44 -55
- zenml/zen_server/template_execution/utils.py +4 -1
- zenml/zen_server/utils.py +4 -3
- zenml/zen_stores/base_zen_store.py +10 -2
- zenml/zen_stores/migrations/versions/0.71.0_release.py +23 -0
- zenml/zen_stores/migrations/versions/26351d482b9e_add_step_run_unique_constraint.py +37 -0
- zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
- zenml/zen_stores/migrations/versions/b73bc71f1106_remove_component_spec_path.py +36 -0
- zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py +135 -0
- zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_links.py +7 -6
- zenml/zen_stores/rest_zen_store.py +76 -43
- zenml/zen_stores/schemas/__init__.py +5 -1
- zenml/zen_stores/schemas/artifact_schemas.py +12 -11
- zenml/zen_stores/schemas/component_schemas.py +0 -3
- zenml/zen_stores/schemas/model_schemas.py +55 -17
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_run_schemas.py +52 -18
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
- zenml/zen_stores/schemas/run_metadata_schemas.py +66 -31
- zenml/zen_stores/schemas/step_run_schemas.py +40 -13
- zenml/zen_stores/schemas/utils.py +47 -3
- zenml/zen_stores/sql_zen_store.py +462 -134
- {zenml_nightly-0.70.0.dev20241125.dist-info → zenml_nightly-0.71.0.dev20241220.dist-info}/METADATA +5 -5
- {zenml_nightly-0.70.0.dev20241125.dist-info → zenml_nightly-0.71.0.dev20241220.dist-info}/RECORD +239 -217
- zenml/utils/cloud_utils.py +0 -40
- zenml/zen_server/dashboard/assets/RunSelector-DkPiIiNr.js +0 -1
- zenml/zen_server/dashboard/assets/StackList-WvuKQusZ.js +0 -1
- zenml/zen_server/dashboard/assets/delete-run-CJdh1P_h.js +0 -1
- zenml/zen_server/dashboard/assets/index-DlGvJQPn.css +0 -1
- zenml/zen_server/dashboard/assets/page-0JE_-Ec1.js +0 -1
- zenml/zen_server/dashboard/assets/page-BRLpxOt0.js +0 -1
- zenml/zen_server/dashboard/assets/page-BU7huvKw.js +0 -6
- zenml/zen_server/dashboard/assets/page-BvqLv2Ky.js +0 -1
- zenml/zen_server/dashboard/assets/page-CwxrFarU.js +0 -1
- zenml/zen_server/dashboard/assets/page-DfbXf_8s.js +0 -1
- zenml/zen_server/dashboard/assets/page-Dnovpa0i.js +0 -3
- zenml/zen_server/dashboard/assets/page-Dot3LPmL.js +0 -1
- zenml/zen_server/dashboard/assets/page-Xynx4btY.js +0 -14
- zenml/zen_server/dashboard/assets/page-YpKAqVSa.js +0 -1
- {zenml_nightly-0.70.0.dev20241125.dist-info → zenml_nightly-0.71.0.dev20241220.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.70.0.dev20241125.dist-info → zenml_nightly-0.71.0.dev20241220.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.70.0.dev20241125.dist-info → zenml_nightly-0.71.0.dev20241220.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,242 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Modal step operator implementation."""
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast
|
18
|
+
|
19
|
+
import modal
|
20
|
+
from modal_proto import api_pb2
|
21
|
+
|
22
|
+
from zenml.client import Client
|
23
|
+
from zenml.config.build_configuration import BuildConfiguration
|
24
|
+
from zenml.config.resource_settings import ByteUnit, ResourceSettings
|
25
|
+
from zenml.enums import StackComponentType
|
26
|
+
from zenml.integrations.modal.flavors import (
|
27
|
+
ModalStepOperatorConfig,
|
28
|
+
ModalStepOperatorSettings,
|
29
|
+
)
|
30
|
+
from zenml.logger import get_logger
|
31
|
+
from zenml.stack import Stack, StackValidator
|
32
|
+
from zenml.step_operators import BaseStepOperator
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from zenml.config.base_settings import BaseSettings
|
36
|
+
from zenml.config.step_run_info import StepRunInfo
|
37
|
+
from zenml.models import PipelineDeploymentBase
|
38
|
+
|
39
|
+
logger = get_logger(__name__)
|
40
|
+
|
41
|
+
MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY = "modal_step_operator"
|
42
|
+
|
43
|
+
|
44
|
+
def get_gpu_values(
|
45
|
+
settings: ModalStepOperatorSettings, resource_settings: ResourceSettings
|
46
|
+
) -> Optional[str]:
|
47
|
+
"""Get the GPU values for the Modal step operator.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
settings: The Modal step operator settings.
|
51
|
+
resource_settings: The resource settings.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
The GPU string if a count is specified, otherwise the GPU type.
|
55
|
+
"""
|
56
|
+
if not settings.gpu:
|
57
|
+
return None
|
58
|
+
gpu_count = resource_settings.gpu_count
|
59
|
+
return f"{settings.gpu}:{gpu_count}" if gpu_count else settings.gpu
|
60
|
+
|
61
|
+
|
62
|
+
class ModalStepOperator(BaseStepOperator):
|
63
|
+
"""Step operator to run a step on Modal.
|
64
|
+
|
65
|
+
This class defines code that can set up a Modal environment and run
|
66
|
+
functions in it.
|
67
|
+
"""
|
68
|
+
|
69
|
+
@property
|
70
|
+
def config(self) -> ModalStepOperatorConfig:
|
71
|
+
"""Get the Modal step operator configuration.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
The Modal step operator configuration.
|
75
|
+
"""
|
76
|
+
return cast(ModalStepOperatorConfig, self._config)
|
77
|
+
|
78
|
+
@property
|
79
|
+
def settings_class(self) -> Optional[Type["BaseSettings"]]:
|
80
|
+
"""Get the settings class for the Modal step operator.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The Modal step operator settings class.
|
84
|
+
"""
|
85
|
+
return ModalStepOperatorSettings
|
86
|
+
|
87
|
+
@property
|
88
|
+
def validator(self) -> Optional[StackValidator]:
|
89
|
+
"""Get the stack validator for the Modal step operator.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
The stack validator.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
|
96
|
+
if stack.artifact_store.config.is_local:
|
97
|
+
return False, (
|
98
|
+
"The Modal step operator runs code remotely and "
|
99
|
+
"needs to write files into the artifact store, but the "
|
100
|
+
f"artifact store `{stack.artifact_store.name}` of the "
|
101
|
+
"active stack is local. Please ensure that your stack "
|
102
|
+
"contains a remote artifact store when using the Modal "
|
103
|
+
"step operator."
|
104
|
+
)
|
105
|
+
|
106
|
+
container_registry = stack.container_registry
|
107
|
+
assert container_registry is not None
|
108
|
+
|
109
|
+
if container_registry.config.is_local:
|
110
|
+
return False, (
|
111
|
+
"The Modal step operator runs code remotely and "
|
112
|
+
"needs to push/pull Docker images, but the "
|
113
|
+
f"container registry `{container_registry.name}` of the "
|
114
|
+
"active stack is local. Please ensure that your stack "
|
115
|
+
"contains a remote container registry when using the "
|
116
|
+
"Modal step operator."
|
117
|
+
)
|
118
|
+
|
119
|
+
return True, ""
|
120
|
+
|
121
|
+
return StackValidator(
|
122
|
+
required_components={
|
123
|
+
StackComponentType.CONTAINER_REGISTRY,
|
124
|
+
StackComponentType.IMAGE_BUILDER,
|
125
|
+
},
|
126
|
+
custom_validation_function=_validate_remote_components,
|
127
|
+
)
|
128
|
+
|
129
|
+
def get_docker_builds(
|
130
|
+
self, deployment: "PipelineDeploymentBase"
|
131
|
+
) -> List["BuildConfiguration"]:
|
132
|
+
"""Get the Docker build configurations for the Modal step operator.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
deployment: The pipeline deployment.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
A list of Docker build configurations.
|
139
|
+
"""
|
140
|
+
builds = []
|
141
|
+
for step_name, step in deployment.step_configurations.items():
|
142
|
+
if step.config.step_operator == self.name:
|
143
|
+
build = BuildConfiguration(
|
144
|
+
key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY,
|
145
|
+
settings=step.config.docker_settings,
|
146
|
+
step_name=step_name,
|
147
|
+
)
|
148
|
+
builds.append(build)
|
149
|
+
|
150
|
+
return builds
|
151
|
+
|
152
|
+
def launch(
|
153
|
+
self,
|
154
|
+
info: "StepRunInfo",
|
155
|
+
entrypoint_command: List[str],
|
156
|
+
environment: Dict[str, str],
|
157
|
+
) -> None:
|
158
|
+
"""Launch a step run on Modal.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
info: The step run information.
|
162
|
+
entrypoint_command: The entrypoint command for the step.
|
163
|
+
environment: The environment variables for the step.
|
164
|
+
|
165
|
+
Raises:
|
166
|
+
RuntimeError: If no Docker credentials are found for the container registry.
|
167
|
+
ValueError: If no container registry is found in the stack.
|
168
|
+
"""
|
169
|
+
settings = cast(ModalStepOperatorSettings, self.get_settings(info))
|
170
|
+
image_name = info.get_image(key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY)
|
171
|
+
zc = Client()
|
172
|
+
stack = zc.active_stack
|
173
|
+
|
174
|
+
if not stack.container_registry:
|
175
|
+
raise ValueError(
|
176
|
+
"No Container registry found in the stack. "
|
177
|
+
"Please add a container registry and ensure "
|
178
|
+
"it is correctly configured."
|
179
|
+
)
|
180
|
+
|
181
|
+
if docker_creds := stack.container_registry.credentials:
|
182
|
+
docker_username, docker_password = docker_creds
|
183
|
+
else:
|
184
|
+
raise RuntimeError(
|
185
|
+
"No Docker credentials found for the container registry."
|
186
|
+
)
|
187
|
+
|
188
|
+
my_secret = modal.secret._Secret.from_dict(
|
189
|
+
{
|
190
|
+
"REGISTRY_USERNAME": docker_username,
|
191
|
+
"REGISTRY_PASSWORD": docker_password,
|
192
|
+
}
|
193
|
+
)
|
194
|
+
|
195
|
+
spec = modal.image.DockerfileSpec(
|
196
|
+
commands=[f"FROM {image_name}"], context_files={}
|
197
|
+
)
|
198
|
+
|
199
|
+
zenml_image = modal.Image._from_args(
|
200
|
+
dockerfile_function=lambda *_, **__: spec,
|
201
|
+
force_build=False,
|
202
|
+
image_registry_config=modal.image._ImageRegistryConfig(
|
203
|
+
api_pb2.REGISTRY_AUTH_TYPE_STATIC_CREDS, my_secret
|
204
|
+
),
|
205
|
+
).env(environment)
|
206
|
+
|
207
|
+
resource_settings = info.config.resource_settings
|
208
|
+
gpu_values = get_gpu_values(settings, resource_settings)
|
209
|
+
|
210
|
+
app = modal.App(
|
211
|
+
f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}"
|
212
|
+
)
|
213
|
+
|
214
|
+
async def run_sandbox() -> asyncio.Future[None]:
|
215
|
+
loop = asyncio.get_event_loop()
|
216
|
+
future = loop.create_future()
|
217
|
+
with modal.enable_output():
|
218
|
+
async with app.run():
|
219
|
+
memory_mb = resource_settings.get_memory(ByteUnit.MB)
|
220
|
+
memory_int = (
|
221
|
+
int(memory_mb) if memory_mb is not None else None
|
222
|
+
)
|
223
|
+
sb = await modal.Sandbox.create.aio(
|
224
|
+
"bash",
|
225
|
+
"-c",
|
226
|
+
" ".join(entrypoint_command),
|
227
|
+
image=zenml_image,
|
228
|
+
gpu=gpu_values,
|
229
|
+
cpu=resource_settings.cpu_count,
|
230
|
+
memory=memory_int,
|
231
|
+
cloud=settings.cloud,
|
232
|
+
region=settings.region,
|
233
|
+
app=app,
|
234
|
+
timeout=86400, # 24h, the max Modal allows
|
235
|
+
)
|
236
|
+
|
237
|
+
await sb.wait.aio()
|
238
|
+
|
239
|
+
future.set_result(None)
|
240
|
+
return future
|
241
|
+
|
242
|
+
asyncio.run(run_sandbox())
|
@@ -77,10 +77,12 @@ class NeptuneExperimentTracker(BaseExperimentTracker):
|
|
77
77
|
NeptuneExperimentTrackerSettings, self.get_settings(info)
|
78
78
|
)
|
79
79
|
|
80
|
-
self.run_state.
|
81
|
-
|
82
|
-
|
83
|
-
|
80
|
+
self.run_state.initialize(
|
81
|
+
project=self.config.project,
|
82
|
+
token=self.config.api_token,
|
83
|
+
run_name=info.run_name,
|
84
|
+
tags=list(settings.tags),
|
85
|
+
)
|
84
86
|
|
85
87
|
def get_step_run_metadata(
|
86
88
|
self, info: "StepRunInfo"
|
@@ -107,4 +109,4 @@ class NeptuneExperimentTracker(BaseExperimentTracker):
|
|
107
109
|
"""
|
108
110
|
self.run_state.active_run.sync()
|
109
111
|
self.run_state.active_run.stop()
|
110
|
-
self.run_state.
|
112
|
+
self.run_state.reset()
|
@@ -20,7 +20,6 @@ import neptune
|
|
20
20
|
|
21
21
|
import zenml
|
22
22
|
from zenml.client import Client
|
23
|
-
from zenml.integrations.constants import NEPTUNE
|
24
23
|
from zenml.utils.singleton import SingletonMetaClass
|
25
24
|
|
26
25
|
if TYPE_CHECKING:
|
@@ -29,20 +28,38 @@ if TYPE_CHECKING:
|
|
29
28
|
_INTEGRATION_VERSION_KEY = "source_code/integrations/zenml"
|
30
29
|
|
31
30
|
|
32
|
-
class InvalidExperimentTrackerSelected(Exception):
|
33
|
-
"""Raised if a Neptune run is fetched while using a different experiment tracker."""
|
34
|
-
|
35
|
-
|
36
31
|
class RunProvider(metaclass=SingletonMetaClass):
|
37
32
|
"""Singleton object used to store and persist a Neptune run state across the pipeline."""
|
38
33
|
|
39
34
|
def __init__(self) -> None:
|
40
35
|
"""Initialize RunProvider. Called with no arguments."""
|
41
36
|
self._active_run: Optional["Run"] = None
|
42
|
-
self._project: Optional[str]
|
43
|
-
self._run_name: Optional[str]
|
44
|
-
self._token: Optional[str]
|
45
|
-
self._tags: Optional[List[str]]
|
37
|
+
self._project: Optional[str] = None
|
38
|
+
self._run_name: Optional[str] = None
|
39
|
+
self._token: Optional[str] = None
|
40
|
+
self._tags: Optional[List[str]] = None
|
41
|
+
self._initialized = False
|
42
|
+
|
43
|
+
def initialize(
|
44
|
+
self,
|
45
|
+
project: Optional[str] = None,
|
46
|
+
token: Optional[str] = None,
|
47
|
+
run_name: Optional[str] = None,
|
48
|
+
tags: Optional[List[str]] = None,
|
49
|
+
) -> None:
|
50
|
+
"""Initialize the run state.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
project: The neptune project.
|
54
|
+
token: The neptune token.
|
55
|
+
run_name: The neptune run name.
|
56
|
+
tags: Tags for the neptune run.
|
57
|
+
"""
|
58
|
+
self._project = project
|
59
|
+
self._token = token
|
60
|
+
self._run_name = run_name
|
61
|
+
self._tags = tags
|
62
|
+
self._initialized = True
|
46
63
|
|
47
64
|
@property
|
48
65
|
def project(self) -> Optional[Any]:
|
@@ -53,15 +70,6 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
53
70
|
"""
|
54
71
|
return self._project
|
55
72
|
|
56
|
-
@project.setter
|
57
|
-
def project(self, project: str) -> None:
|
58
|
-
"""Setter for project name.
|
59
|
-
|
60
|
-
Args:
|
61
|
-
project: Neptune project name
|
62
|
-
"""
|
63
|
-
self._project = project
|
64
|
-
|
65
73
|
@property
|
66
74
|
def token(self) -> Optional[Any]:
|
67
75
|
"""Getter for API token.
|
@@ -71,15 +79,6 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
71
79
|
"""
|
72
80
|
return self._token
|
73
81
|
|
74
|
-
@token.setter
|
75
|
-
def token(self, token: str) -> None:
|
76
|
-
"""Setter for API token.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
token: Neptune API token
|
80
|
-
"""
|
81
|
-
self._token = token
|
82
|
-
|
83
82
|
@property
|
84
83
|
def run_name(self) -> Optional[Any]:
|
85
84
|
"""Getter for run name.
|
@@ -89,15 +88,6 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
89
88
|
"""
|
90
89
|
return self._run_name
|
91
90
|
|
92
|
-
@run_name.setter
|
93
|
-
def run_name(self, run_name: str) -> None:
|
94
|
-
"""Setter for run name.
|
95
|
-
|
96
|
-
Args:
|
97
|
-
run_name: name of the pipeline run
|
98
|
-
"""
|
99
|
-
self._run_name = run_name
|
100
|
-
|
101
91
|
@property
|
102
92
|
def tags(self) -> Optional[Any]:
|
103
93
|
"""Getter for run tags.
|
@@ -107,14 +97,14 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
107
97
|
"""
|
108
98
|
return self._tags
|
109
99
|
|
110
|
-
@
|
111
|
-
def
|
112
|
-
"""
|
100
|
+
@property
|
101
|
+
def initialized(self) -> bool:
|
102
|
+
"""If the run state is initialized.
|
113
103
|
|
114
|
-
|
115
|
-
|
104
|
+
Returns:
|
105
|
+
If the run state is initialized.
|
116
106
|
"""
|
117
|
-
self.
|
107
|
+
return self._initialized
|
118
108
|
|
119
109
|
@property
|
120
110
|
def active_run(self) -> "Run":
|
@@ -137,9 +127,14 @@ class RunProvider(metaclass=SingletonMetaClass):
|
|
137
127
|
self._active_run = run
|
138
128
|
return self._active_run
|
139
129
|
|
140
|
-
def
|
141
|
-
"""
|
130
|
+
def reset(self) -> None:
|
131
|
+
"""Reset the run state."""
|
142
132
|
self._active_run = None
|
133
|
+
self._project = None
|
134
|
+
self._run_name = None
|
135
|
+
self._token = None
|
136
|
+
self._tags = None
|
137
|
+
self._initialized = False
|
143
138
|
|
144
139
|
|
145
140
|
def get_neptune_run() -> "Run":
|
@@ -149,14 +144,35 @@ def get_neptune_run() -> "Run":
|
|
149
144
|
Neptune run object
|
150
145
|
|
151
146
|
Raises:
|
152
|
-
|
147
|
+
RuntimeError: When unable to fetch the active neptune run.
|
153
148
|
"""
|
154
|
-
|
155
|
-
|
156
|
-
if experiment_tracker.flavor == NEPTUNE: # type: ignore
|
157
|
-
return experiment_tracker.run_state.active_run # type: ignore
|
158
|
-
raise InvalidExperimentTrackerSelected(
|
159
|
-
"Fetching a Neptune run works only with the 'neptune' flavor of "
|
160
|
-
"the experiment tracker. The flavor currently selected is %s"
|
161
|
-
% experiment_tracker.flavor # type: ignore
|
149
|
+
from zenml.integrations.neptune.experiment_trackers import (
|
150
|
+
NeptuneExperimentTracker,
|
162
151
|
)
|
152
|
+
|
153
|
+
experiment_tracker = Client().active_stack.experiment_tracker
|
154
|
+
|
155
|
+
if not experiment_tracker:
|
156
|
+
raise RuntimeError(
|
157
|
+
"Unable to get neptune run: Missing experiment tracker in the "
|
158
|
+
"active stack."
|
159
|
+
)
|
160
|
+
|
161
|
+
if not isinstance(experiment_tracker, NeptuneExperimentTracker):
|
162
|
+
raise RuntimeError(
|
163
|
+
"Unable to get neptune run: Experiment tracker in the active "
|
164
|
+
f"stack ({experiment_tracker.flavor}) is not a neptune experiment "
|
165
|
+
"tracker."
|
166
|
+
)
|
167
|
+
|
168
|
+
run_state = experiment_tracker.run_state
|
169
|
+
if not run_state.initialized:
|
170
|
+
raise RuntimeError(
|
171
|
+
"Unable to get neptune run: The experiment tracker has not been "
|
172
|
+
"initialized. To solve this, make sure you use the experiment "
|
173
|
+
"tracker in your step. See "
|
174
|
+
"https://docs.zenml.io/stack-components/experiment-trackers/neptune#how-do-you-use-it "
|
175
|
+
"for more information."
|
176
|
+
)
|
177
|
+
|
178
|
+
return experiment_tracker.run_state.active_run
|
zenml/integrations/registry.py
CHANGED
@@ -111,7 +111,7 @@ class IntegrationRegistry(object):
|
|
111
111
|
)
|
112
112
|
else:
|
113
113
|
raise KeyError(
|
114
|
-
f"
|
114
|
+
f"Integration {integration_name} does not exist. "
|
115
115
|
f"Currently the following integrations are implemented. "
|
116
116
|
f"{self.list_integration_names}"
|
117
117
|
)
|
@@ -148,7 +148,7 @@ class IntegrationRegistry(object):
|
|
148
148
|
].get_uninstall_requirements(target_os=target_os)
|
149
149
|
else:
|
150
150
|
raise KeyError(
|
151
|
-
f"
|
151
|
+
f"Integration {integration_name} does not exist. "
|
152
152
|
f"Currently the following integrations are implemented. "
|
153
153
|
f"{self.list_integration_names}"
|
154
154
|
)
|
@@ -144,3 +144,15 @@ class SkypilotBaseOrchestratorConfig(
|
|
144
144
|
True if this config is for a local component, False otherwise.
|
145
145
|
"""
|
146
146
|
return False
|
147
|
+
|
148
|
+
@property
|
149
|
+
def supports_client_side_caching(self) -> bool:
|
150
|
+
"""Whether the orchestrator supports client side caching.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Whether the orchestrator supports client side caching.
|
154
|
+
"""
|
155
|
+
# The Skypilot orchestrator runs the entire pipeline in a single VM, or
|
156
|
+
# starts additional VMs from the root VM. Both of those cases are
|
157
|
+
# currently not supported when using client-side caching.
|
158
|
+
return False
|
@@ -23,7 +23,7 @@ from typing import (
|
|
23
23
|
cast,
|
24
24
|
)
|
25
25
|
|
26
|
-
from pydantic import field_validator
|
26
|
+
from pydantic import field_validator, BaseModel
|
27
27
|
|
28
28
|
from zenml.config.base_settings import BaseSettings
|
29
29
|
from zenml.experiment_trackers.base_experiment_tracker import (
|
@@ -60,18 +60,26 @@ class WandbExperimentTrackerSettings(BaseSettings):
|
|
60
60
|
Args:
|
61
61
|
value: The settings.
|
62
62
|
|
63
|
+
Raises:
|
64
|
+
ValueError: If converting the settings failed.
|
65
|
+
|
63
66
|
Returns:
|
64
67
|
Dict representation of the settings.
|
65
68
|
"""
|
66
69
|
import wandb
|
67
70
|
|
68
71
|
if isinstance(value, wandb.Settings):
|
69
|
-
# Depending on the wandb version, either `
|
70
|
-
# is available to convert the settings
|
71
|
-
|
72
|
+
# Depending on the wandb version, either `model_dump`,
|
73
|
+
# `make_static` or `to_dict` is available to convert the settings
|
74
|
+
# to a dictionary
|
75
|
+
if isinstance(value, BaseModel):
|
76
|
+
return value.model_dump()
|
77
|
+
elif hasattr(value, "make_static"):
|
72
78
|
return cast(Dict[str, Any], value.make_static())
|
73
|
-
|
79
|
+
elif hasattr(value, "to_dict"):
|
74
80
|
return value.to_dict()
|
81
|
+
else:
|
82
|
+
raise ValueError("Unable to convert wandb settings to dict.")
|
75
83
|
else:
|
76
84
|
return value
|
77
85
|
|
zenml/io/filesystem.py
CHANGED
@@ -54,11 +54,11 @@ class BaseFilesystem(ABC):
|
|
54
54
|
|
55
55
|
@staticmethod
|
56
56
|
@abstractmethod
|
57
|
-
def open(
|
57
|
+
def open(path: PathType, mode: str = "r") -> Any:
|
58
58
|
"""Opens a file.
|
59
59
|
|
60
60
|
Args:
|
61
|
-
|
61
|
+
path: The path to the file.
|
62
62
|
mode: The mode to open the file in.
|
63
63
|
|
64
64
|
Returns:
|
zenml/io/local_filesystem.py
CHANGED
@@ -55,18 +55,18 @@ class LocalFilesystem(BaseFilesystem):
|
|
55
55
|
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {""}
|
56
56
|
|
57
57
|
@staticmethod
|
58
|
-
def open(
|
58
|
+
def open(path: PathType, mode: str = "r") -> Any:
|
59
59
|
"""Open a file at the given path.
|
60
60
|
|
61
61
|
Args:
|
62
|
-
|
62
|
+
path: The path to the file.
|
63
63
|
mode: The mode to open the file.
|
64
64
|
|
65
65
|
Returns:
|
66
66
|
Any: The file object.
|
67
67
|
"""
|
68
68
|
encoding = "utf-8" if "b" not in mode else None
|
69
|
-
return open(
|
69
|
+
return open(path, mode=mode, encoding=encoding)
|
70
70
|
|
71
71
|
@staticmethod
|
72
72
|
def copyfile(
|
@@ -28,7 +28,7 @@ from typing import (
|
|
28
28
|
)
|
29
29
|
|
30
30
|
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
|
31
|
-
from zenml.enums import ArtifactType
|
31
|
+
from zenml.enums import ArtifactType, VisualizationType
|
32
32
|
from zenml.logger import get_logger
|
33
33
|
from zenml.materializers.base_materializer import BaseMaterializer
|
34
34
|
from zenml.materializers.materializer_registry import materializer_registry
|
@@ -415,6 +415,23 @@ class BuiltInContainerMaterializer(BaseMaterializer):
|
|
415
415
|
self.artifact_store.rmtree(entry["path"])
|
416
416
|
raise e
|
417
417
|
|
418
|
+
# save dict type objects to JSON file with JSON visualization type
|
419
|
+
def save_visualizations(self, data: Any) -> Dict[str, "VisualizationType"]:
|
420
|
+
"""Save visualizations for the given data.
|
421
|
+
|
422
|
+
Args:
|
423
|
+
data: The data to save visualizations for.
|
424
|
+
|
425
|
+
Returns:
|
426
|
+
A dictionary of visualization URIs and their types.
|
427
|
+
"""
|
428
|
+
# dict/list type objects are always saved as JSON files
|
429
|
+
# doesn't work for non-serializable types as they
|
430
|
+
# are saved as list of lists in different files
|
431
|
+
if _is_serializable(data):
|
432
|
+
return {self.data_path.replace("\\", "/"): VisualizationType.JSON}
|
433
|
+
return {}
|
434
|
+
|
418
435
|
def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]:
|
419
436
|
"""Extract metadata from the given built-in container object.
|
420
437
|
|
@@ -19,22 +19,23 @@ from typing import Dict, Type, Union
|
|
19
19
|
from zenml.enums import ArtifactType, VisualizationType
|
20
20
|
from zenml.logger import get_logger
|
21
21
|
from zenml.materializers.base_materializer import BaseMaterializer
|
22
|
-
from zenml.types import CSVString, HTMLString, MarkdownString
|
22
|
+
from zenml.types import CSVString, HTMLString, JSONString, MarkdownString
|
23
23
|
|
24
24
|
logger = get_logger(__name__)
|
25
25
|
|
26
26
|
|
27
|
-
STRUCTURED_STRINGS = Union[CSVString, HTMLString, MarkdownString]
|
27
|
+
STRUCTURED_STRINGS = Union[CSVString, HTMLString, MarkdownString, JSONString]
|
28
28
|
|
29
29
|
HTML_FILENAME = "output.html"
|
30
30
|
MARKDOWN_FILENAME = "output.md"
|
31
31
|
CSV_FILENAME = "output.csv"
|
32
|
+
JSON_FILENAME = "output.json"
|
32
33
|
|
33
34
|
|
34
35
|
class StructuredStringMaterializer(BaseMaterializer):
|
35
36
|
"""Materializer for HTML or Markdown strings."""
|
36
37
|
|
37
|
-
ASSOCIATED_TYPES = (CSVString, HTMLString, MarkdownString)
|
38
|
+
ASSOCIATED_TYPES = (CSVString, HTMLString, MarkdownString, JSONString)
|
38
39
|
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA_ANALYSIS
|
39
40
|
|
40
41
|
def load(self, data_type: Type[STRUCTURED_STRINGS]) -> STRUCTURED_STRINGS:
|
@@ -94,6 +95,8 @@ class StructuredStringMaterializer(BaseMaterializer):
|
|
94
95
|
filename = HTML_FILENAME
|
95
96
|
elif issubclass(data_type, MarkdownString):
|
96
97
|
filename = MARKDOWN_FILENAME
|
98
|
+
elif issubclass(data_type, JSONString):
|
99
|
+
filename = JSON_FILENAME
|
97
100
|
else:
|
98
101
|
raise ValueError(
|
99
102
|
f"Data type {data_type} is not supported by this materializer."
|
@@ -120,6 +123,8 @@ class StructuredStringMaterializer(BaseMaterializer):
|
|
120
123
|
return VisualizationType.HTML
|
121
124
|
elif issubclass(data_type, MarkdownString):
|
122
125
|
return VisualizationType.MARKDOWN
|
126
|
+
elif issubclass(data_type, JSONString):
|
127
|
+
return VisualizationType.JSON
|
123
128
|
else:
|
124
129
|
raise ValueError(
|
125
130
|
f"Data type {data_type} is not supported by this materializer."
|