zenml-nightly 0.58.2.dev20240618__py3-none-any.whl → 0.58.2.dev20240620__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/_hub/client.py +8 -5
- zenml/actions/base_action.py +8 -10
- zenml/artifact_stores/base_artifact_store.py +20 -15
- zenml/artifact_stores/local_artifact_store.py +3 -2
- zenml/artifacts/artifact_config.py +34 -19
- zenml/artifacts/external_artifact.py +18 -8
- zenml/artifacts/external_artifact_config.py +14 -6
- zenml/artifacts/unmaterialized_artifact.py +2 -11
- zenml/cli/__init__.py +6 -0
- zenml/cli/artifact.py +20 -2
- zenml/cli/served_model.py +0 -1
- zenml/cli/server.py +3 -3
- zenml/cli/utils.py +36 -40
- zenml/cli/web_login.py +2 -2
- zenml/client.py +198 -24
- zenml/client_lazy_loader.py +20 -14
- zenml/config/base_settings.py +5 -6
- zenml/config/build_configuration.py +1 -1
- zenml/config/compiler.py +3 -3
- zenml/config/docker_settings.py +27 -28
- zenml/config/global_config.py +33 -37
- zenml/config/pipeline_configurations.py +8 -11
- zenml/config/pipeline_run_configuration.py +6 -2
- zenml/config/pipeline_spec.py +3 -4
- zenml/config/resource_settings.py +8 -9
- zenml/config/schedule.py +16 -20
- zenml/config/secret_reference_mixin.py +6 -3
- zenml/config/secrets_store_config.py +16 -23
- zenml/config/server_config.py +50 -46
- zenml/config/settings_resolver.py +1 -1
- zenml/config/source.py +45 -35
- zenml/config/step_configurations.py +53 -31
- zenml/config/store_config.py +20 -19
- zenml/config/strict_base_model.py +2 -6
- zenml/constants.py +26 -2
- zenml/container_registries/base_container_registry.py +3 -2
- zenml/container_registries/default_container_registry.py +3 -3
- zenml/event_hub/base_event_hub.py +1 -1
- zenml/event_sources/base_event_source.py +11 -16
- zenml/exceptions.py +4 -0
- zenml/integrations/airflow/__init__.py +2 -10
- zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +6 -7
- zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +13 -249
- zenml/integrations/airflow/orchestrators/dag_generator.py +5 -3
- zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +5 -4
- zenml/integrations/aws/__init__.py +1 -1
- zenml/integrations/aws/flavors/aws_container_registry_flavor.py +3 -2
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +11 -5
- zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +6 -2
- zenml/integrations/aws/service_connectors/aws_service_connector.py +5 -4
- zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +4 -4
- zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -3
- zenml/integrations/azure/step_operators/azureml_step_operator.py +1 -1
- zenml/integrations/bentoml/steps/bentoml_deployer.py +1 -1
- zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +8 -12
- zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +1 -1
- zenml/integrations/evidently/__init__.py +3 -4
- zenml/integrations/evidently/column_mapping.py +11 -3
- zenml/integrations/evidently/data_validators/evidently_data_validator.py +21 -3
- zenml/integrations/evidently/metrics.py +5 -6
- zenml/integrations/evidently/tests.py +5 -6
- zenml/integrations/facets/models.py +2 -6
- zenml/integrations/feast/__init__.py +3 -1
- zenml/integrations/feast/feature_stores/feast_feature_store.py +0 -23
- zenml/integrations/gcp/__init__.py +1 -1
- zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +1 -1
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +1 -1
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +234 -103
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +57 -42
- zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
- zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +9 -13
- zenml/integrations/great_expectations/__init__.py +1 -1
- zenml/integrations/great_expectations/data_validators/ge_data_validator.py +44 -44
- zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +35 -2
- zenml/integrations/great_expectations/ge_store_backend.py +24 -11
- zenml/integrations/great_expectations/materializers/ge_materializer.py +3 -3
- zenml/integrations/great_expectations/utils.py +5 -5
- zenml/integrations/huggingface/__init__.py +3 -0
- zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +1 -1
- zenml/integrations/huggingface/steps/__init__.py +3 -0
- zenml/integrations/huggingface/steps/accelerate_runner.py +149 -0
- zenml/integrations/huggingface/steps/huggingface_deployer.py +2 -2
- zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +1 -1
- zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +4 -3
- zenml/integrations/kubeflow/__init__.py +1 -1
- zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +48 -81
- zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +295 -245
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +1 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -2
- zenml/integrations/kubernetes/pod_settings.py +17 -31
- zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +8 -7
- zenml/integrations/label_studio/__init__.py +1 -3
- zenml/integrations/label_studio/annotators/label_studio_annotator.py +3 -4
- zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +2 -2
- zenml/integrations/langchain/materializers/document_materializer.py +44 -8
- zenml/integrations/mlflow/__init__.py +9 -3
- zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +1 -1
- zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +29 -37
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +4 -4
- zenml/integrations/mlflow/steps/mlflow_deployer.py +1 -1
- zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +1 -1
- zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +1 -1
- zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +9 -8
- zenml/integrations/seldon/seldon_client.py +52 -67
- zenml/integrations/seldon/services/seldon_deployment.py +3 -3
- zenml/integrations/seldon/steps/seldon_deployer.py +4 -4
- zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +15 -5
- zenml/integrations/skypilot_aws/__init__.py +1 -1
- zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +1 -1
- zenml/integrations/skypilot_azure/__init__.py +1 -1
- zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +1 -1
- zenml/integrations/skypilot_gcp/__init__.py +2 -1
- zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +1 -1
- zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +2 -2
- zenml/integrations/spark/flavors/spark_step_operator_flavor.py +1 -1
- zenml/integrations/tekton/__init__.py +1 -1
- zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +66 -23
- zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +547 -233
- zenml/integrations/tensorboard/__init__.py +1 -12
- zenml/integrations/tensorboard/services/tensorboard_service.py +3 -5
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +6 -6
- zenml/integrations/tensorflow/__init__.py +2 -10
- zenml/integrations/tensorflow/materializers/keras_materializer.py +17 -9
- zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +9 -14
- zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +1 -1
- zenml/lineage_graph/lineage_graph.py +1 -1
- zenml/materializers/built_in_materializer.py +3 -3
- zenml/materializers/pydantic_materializer.py +2 -2
- zenml/metadata/lazy_load.py +4 -4
- zenml/metadata/metadata_types.py +64 -4
- zenml/model/model.py +79 -54
- zenml/model_deployers/base_model_deployer.py +14 -12
- zenml/model_registries/base_model_registry.py +17 -15
- zenml/models/__init__.py +79 -206
- zenml/models/v2/base/base.py +54 -41
- zenml/models/v2/base/base_plugin_flavor.py +2 -6
- zenml/models/v2/base/filter.py +91 -76
- zenml/models/v2/base/page.py +2 -12
- zenml/models/v2/base/scoped.py +4 -7
- zenml/models/v2/core/api_key.py +22 -8
- zenml/models/v2/core/artifact.py +2 -2
- zenml/models/v2/core/artifact_version.py +74 -40
- zenml/models/v2/core/code_repository.py +37 -10
- zenml/models/v2/core/component.py +65 -16
- zenml/models/v2/core/device.py +14 -4
- zenml/models/v2/core/event_source.py +1 -2
- zenml/models/v2/core/flavor.py +74 -8
- zenml/models/v2/core/logs.py +68 -8
- zenml/models/v2/core/model.py +8 -4
- zenml/models/v2/core/model_version.py +25 -6
- zenml/models/v2/core/model_version_artifact.py +51 -21
- zenml/models/v2/core/model_version_pipeline_run.py +45 -13
- zenml/models/v2/core/pipeline.py +37 -72
- zenml/models/v2/core/pipeline_build.py +29 -17
- zenml/models/v2/core/pipeline_deployment.py +18 -6
- zenml/models/v2/core/pipeline_namespace.py +113 -0
- zenml/models/v2/core/pipeline_run.py +50 -22
- zenml/models/v2/core/run_metadata.py +59 -36
- zenml/models/v2/core/schedule.py +37 -24
- zenml/models/v2/core/secret.py +31 -12
- zenml/models/v2/core/service.py +64 -36
- zenml/models/v2/core/service_account.py +24 -11
- zenml/models/v2/core/service_connector.py +219 -44
- zenml/models/v2/core/stack.py +45 -17
- zenml/models/v2/core/step_run.py +28 -8
- zenml/models/v2/core/tag.py +8 -4
- zenml/models/v2/core/trigger.py +2 -2
- zenml/models/v2/core/trigger_execution.py +1 -0
- zenml/models/v2/core/user.py +18 -21
- zenml/models/v2/core/workspace.py +13 -3
- zenml/models/v2/misc/build_item.py +3 -3
- zenml/models/v2/misc/external_user.py +2 -6
- zenml/models/v2/misc/hub_plugin_models.py +9 -9
- zenml/models/v2/misc/loaded_visualization.py +2 -2
- zenml/models/v2/misc/service_connector_type.py +8 -17
- zenml/models/v2/misc/user_auth.py +7 -2
- zenml/new/pipelines/build_utils.py +3 -3
- zenml/new/pipelines/pipeline.py +17 -13
- zenml/new/pipelines/run_utils.py +103 -1
- zenml/orchestrators/base_orchestrator.py +10 -7
- zenml/orchestrators/local_docker/local_docker_orchestrator.py +1 -1
- zenml/orchestrators/step_runner.py +3 -6
- zenml/orchestrators/utils.py +1 -1
- zenml/plugins/base_plugin_flavor.py +6 -10
- zenml/plugins/plugin_flavor_registry.py +3 -7
- zenml/secret/base_secret.py +7 -8
- zenml/service_connectors/docker_service_connector.py +4 -3
- zenml/service_connectors/service_connector.py +5 -12
- zenml/service_connectors/service_connector_registry.py +2 -4
- zenml/services/container/container_service.py +1 -1
- zenml/services/container/container_service_endpoint.py +1 -1
- zenml/services/local/local_service.py +1 -1
- zenml/services/local/local_service_endpoint.py +1 -1
- zenml/services/service.py +16 -10
- zenml/services/service_type.py +4 -5
- zenml/services/terraform/terraform_service.py +1 -1
- zenml/stack/flavor.py +1 -5
- zenml/stack/flavor_registry.py +4 -4
- zenml/stack/stack.py +4 -1
- zenml/stack/stack_component.py +55 -31
- zenml/steps/base_step.py +34 -28
- zenml/steps/entrypoint_function_utils.py +3 -5
- zenml/steps/utils.py +12 -14
- zenml/utils/cuda_utils.py +50 -0
- zenml/utils/deprecation_utils.py +18 -20
- zenml/utils/dict_utils.py +1 -1
- zenml/utils/filesync_model.py +65 -28
- zenml/utils/function_utils.py +260 -0
- zenml/utils/json_utils.py +131 -0
- zenml/utils/mlstacks_utils.py +2 -2
- zenml/utils/pydantic_utils.py +270 -62
- zenml/utils/secret_utils.py +65 -12
- zenml/utils/source_utils.py +2 -2
- zenml/utils/typed_model.py +5 -3
- zenml/utils/typing_utils.py +243 -0
- zenml/utils/yaml_utils.py +1 -1
- zenml/zen_server/auth.py +2 -2
- zenml/zen_server/cloud_utils.py +6 -6
- zenml/zen_server/deploy/base_provider.py +1 -1
- zenml/zen_server/deploy/deployment.py +6 -8
- zenml/zen_server/deploy/docker/docker_zen_server.py +3 -4
- zenml/zen_server/deploy/local/local_provider.py +0 -1
- zenml/zen_server/deploy/local/local_zen_server.py +6 -6
- zenml/zen_server/deploy/terraform/terraform_zen_server.py +4 -6
- zenml/zen_server/exceptions.py +4 -1
- zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
- zenml/zen_server/pipeline_deployment/utils.py +48 -68
- zenml/zen_server/rbac/models.py +2 -5
- zenml/zen_server/rbac/utils.py +11 -14
- zenml/zen_server/routers/auth_endpoints.py +2 -2
- zenml/zen_server/routers/pipeline_builds_endpoints.py +1 -1
- zenml/zen_server/routers/runs_endpoints.py +1 -1
- zenml/zen_server/routers/secrets_endpoints.py +3 -2
- zenml/zen_server/routers/server_endpoints.py +1 -1
- zenml/zen_server/routers/steps_endpoints.py +1 -1
- zenml/zen_server/routers/workspaces_endpoints.py +1 -1
- zenml/zen_stores/base_zen_store.py +46 -9
- zenml/zen_stores/migrations/utils.py +42 -46
- zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py +1 -1
- zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +5 -3
- zenml/zen_stores/migrations/versions/10a907dad202_delete_mlmd_tables.py +1 -1
- zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +8 -10
- zenml/zen_stores/migrations/versions/37835ce041d2_optimizing_database.py +3 -3
- zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +10 -12
- zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +3 -2
- zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py +4 -4
- zenml/zen_stores/migrations/versions/728c6369cfaa_add_name_column_to_input_artifact_pk.py +3 -2
- zenml/zen_stores/migrations/versions/743ec82b1b3c_update_size_of_build_images.py +2 -2
- zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
- zenml/zen_stores/migrations/versions/7834208cc3f6_artifact_project_scoping.py +8 -7
- zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +6 -4
- zenml/zen_stores/migrations/versions/7e4a481d17f7_add_identity_table.py +2 -2
- zenml/zen_stores/migrations/versions/7f603e583dd7_fixed_migration.py +1 -1
- zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +2 -2
- zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +4 -4
- zenml/zen_stores/migrations/versions/alembic_start.py +1 -1
- zenml/zen_stores/migrations/versions/fbd7f18ced1e_increase_step_run_field_lengths.py +4 -4
- zenml/zen_stores/rest_zen_store.py +109 -49
- zenml/zen_stores/schemas/api_key_schemas.py +1 -1
- zenml/zen_stores/schemas/artifact_schemas.py +8 -8
- zenml/zen_stores/schemas/artifact_visualization_schemas.py +3 -3
- zenml/zen_stores/schemas/code_repository_schemas.py +1 -1
- zenml/zen_stores/schemas/component_schemas.py +8 -3
- zenml/zen_stores/schemas/device_schemas.py +8 -6
- zenml/zen_stores/schemas/event_source_schemas.py +3 -4
- zenml/zen_stores/schemas/flavor_schemas.py +5 -3
- zenml/zen_stores/schemas/model_schemas.py +26 -1
- zenml/zen_stores/schemas/pipeline_build_schemas.py +1 -1
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +4 -4
- zenml/zen_stores/schemas/pipeline_run_schemas.py +6 -6
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -2
- zenml/zen_stores/schemas/run_metadata_schemas.py +2 -2
- zenml/zen_stores/schemas/secret_schemas.py +8 -5
- zenml/zen_stores/schemas/server_settings_schemas.py +3 -1
- zenml/zen_stores/schemas/service_connector_schemas.py +1 -1
- zenml/zen_stores/schemas/service_schemas.py +11 -2
- zenml/zen_stores/schemas/stack_schemas.py +1 -1
- zenml/zen_stores/schemas/step_run_schemas.py +11 -11
- zenml/zen_stores/schemas/tag_schemas.py +6 -2
- zenml/zen_stores/schemas/trigger_schemas.py +2 -2
- zenml/zen_stores/schemas/user_schemas.py +2 -2
- zenml/zen_stores/schemas/workspace_schemas.py +3 -1
- zenml/zen_stores/secrets_stores/aws_secrets_store.py +19 -20
- zenml/zen_stores/secrets_stores/azure_secrets_store.py +17 -20
- zenml/zen_stores/secrets_stores/base_secrets_store.py +79 -12
- zenml/zen_stores/secrets_stores/gcp_secrets_store.py +17 -20
- zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +4 -8
- zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +10 -7
- zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -6
- zenml/zen_stores/sql_zen_store.py +196 -120
- zenml/zen_stores/zen_store_interface.py +33 -0
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/METADATA +8 -7
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/RECORD +297 -294
- zenml/integrations/kubeflow/utils.py +0 -95
- zenml/models/v2/base/internal.py +0 -37
- zenml/models/v2/base/update.py +0 -44
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/entry_points.txt +0 -0
zenml/steps/base_step.py
CHANGED
@@ -34,7 +34,7 @@ from typing import (
|
|
34
34
|
cast,
|
35
35
|
)
|
36
36
|
|
37
|
-
from pydantic import BaseModel,
|
37
|
+
from pydantic import BaseModel, ConfigDict, ValidationError
|
38
38
|
|
39
39
|
from zenml.client_lazy_loader import ClientLazyLoader
|
40
40
|
from zenml.config.retry_config import StepRetryConfig
|
@@ -59,6 +59,7 @@ from zenml.utils import (
|
|
59
59
|
settings_utils,
|
60
60
|
source_code_utils,
|
61
61
|
source_utils,
|
62
|
+
typing_utils,
|
62
63
|
)
|
63
64
|
|
64
65
|
if TYPE_CHECKING:
|
@@ -513,17 +514,17 @@ class BaseStep(metaclass=BaseStepMeta):
|
|
513
514
|
)
|
514
515
|
elif isinstance(value, LazyArtifactVersionResponse):
|
515
516
|
model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
|
516
|
-
model=value.
|
517
|
-
artifact_name=value.
|
518
|
-
artifact_version=value.
|
517
|
+
model=value.lazy_load_model,
|
518
|
+
artifact_name=value.lazy_load_name,
|
519
|
+
artifact_version=value.lazy_load_version,
|
519
520
|
metadata_name=None,
|
520
521
|
)
|
521
522
|
elif isinstance(value, LazyRunMetadataResponse):
|
522
523
|
model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
|
523
|
-
model=value.
|
524
|
-
artifact_name=value.
|
525
|
-
artifact_version=value.
|
526
|
-
metadata_name=value.
|
524
|
+
model=value.lazy_load_model,
|
525
|
+
artifact_name=value.lazy_load_artifact_name,
|
526
|
+
artifact_version=value.lazy_load_artifact_version,
|
527
|
+
metadata_name=value.lazy_load_metadata_name,
|
527
528
|
)
|
528
529
|
elif isinstance(value, ClientLazyLoader):
|
529
530
|
client_lazy_loaders[key] = value
|
@@ -583,7 +584,7 @@ class BaseStep(metaclass=BaseStepMeta):
|
|
583
584
|
from zenml.new.pipelines.pipeline import Pipeline
|
584
585
|
|
585
586
|
if not Pipeline.ACTIVE_PIPELINE:
|
586
|
-
# The step is being called outside
|
587
|
+
# The step is being called outside the context of a pipeline,
|
587
588
|
# we simply call the entrypoint
|
588
589
|
return self.call_entrypoint(*args, **kwargs)
|
589
590
|
|
@@ -645,12 +646,15 @@ class BaseStep(metaclass=BaseStepMeta):
|
|
645
646
|
try:
|
646
647
|
validated_args = pydantic_utils.validate_function_args(
|
647
648
|
self.entrypoint,
|
648
|
-
|
649
|
+
ConfigDict(arbitrary_types_allowed=True),
|
649
650
|
*args,
|
650
651
|
**kwargs,
|
651
652
|
)
|
652
653
|
except ValidationError as e:
|
653
|
-
raise StepInterfaceError(
|
654
|
+
raise StepInterfaceError(
|
655
|
+
"Invalid step function entrypoint arguments. Check out the "
|
656
|
+
"pydantic error above for more details."
|
657
|
+
) from e
|
654
658
|
|
655
659
|
return self.entrypoint(**validated_args)
|
656
660
|
|
@@ -796,7 +800,7 @@ class BaseStep(metaclass=BaseStepMeta):
|
|
796
800
|
success_hook_source = resolve_and_validate_hook(on_success)
|
797
801
|
|
798
802
|
if isinstance(parameters, BaseParameters):
|
799
|
-
parameters = parameters.
|
803
|
+
parameters = parameters.model_dump()
|
800
804
|
|
801
805
|
values = dict_utils.remove_none_values(
|
802
806
|
{
|
@@ -1111,12 +1115,6 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1111
1115
|
output_name, PartialArtifactConfiguration()
|
1112
1116
|
)
|
1113
1117
|
|
1114
|
-
from pydantic.typing import (
|
1115
|
-
get_origin,
|
1116
|
-
is_none_type,
|
1117
|
-
is_union,
|
1118
|
-
)
|
1119
|
-
|
1120
1118
|
from zenml.steps.utils import get_args
|
1121
1119
|
|
1122
1120
|
if not output.materializer_source:
|
@@ -1129,13 +1127,15 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1129
1127
|
)
|
1130
1128
|
continue
|
1131
1129
|
|
1132
|
-
if is_union(
|
1133
|
-
get_origin(
|
1130
|
+
if typing_utils.is_union(
|
1131
|
+
typing_utils.get_origin(
|
1132
|
+
output_annotation.resolved_annotation
|
1133
|
+
)
|
1134
1134
|
or output_annotation.resolved_annotation
|
1135
1135
|
):
|
1136
1136
|
output_types = tuple(
|
1137
1137
|
type(None)
|
1138
|
-
if is_none_type(output_type)
|
1138
|
+
if typing_utils.is_none_type(output_type)
|
1139
1139
|
else output_type
|
1140
1140
|
for output_type in get_args(
|
1141
1141
|
output_annotation.resolved_annotation
|
@@ -1169,7 +1169,7 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1169
1169
|
config = StepConfigurationUpdate(**values)
|
1170
1170
|
self._apply_configuration(config)
|
1171
1171
|
|
1172
|
-
self._configuration = self._configuration.
|
1172
|
+
self._configuration = self._configuration.model_copy(
|
1173
1173
|
update={
|
1174
1174
|
"caching_parameters": self.caching_parameters,
|
1175
1175
|
"external_input_artifacts": external_artifacts,
|
@@ -1178,7 +1178,9 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1178
1178
|
}
|
1179
1179
|
)
|
1180
1180
|
|
1181
|
-
return StepConfiguration.
|
1181
|
+
return StepConfiguration.model_validate(
|
1182
|
+
self._configuration.model_dump()
|
1183
|
+
)
|
1182
1184
|
|
1183
1185
|
def _finalize_parameters(self) -> Dict[str, Any]:
|
1184
1186
|
"""Finalizes the config parameters for running this step.
|
@@ -1199,7 +1201,7 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1199
1201
|
# Make sure we have all necessary values to instantiate the
|
1200
1202
|
# pydantic model later
|
1201
1203
|
model = annotation(**value)
|
1202
|
-
params[key] = model.
|
1204
|
+
params[key] = model.model_dump()
|
1203
1205
|
else:
|
1204
1206
|
params[key] = value
|
1205
1207
|
|
@@ -1254,7 +1256,7 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1254
1256
|
for (
|
1255
1257
|
name,
|
1256
1258
|
field,
|
1257
|
-
) in self.entrypoint_definition.legacy_params.annotation.
|
1259
|
+
) in self.entrypoint_definition.legacy_params.annotation.model_fields.items():
|
1258
1260
|
if name in self.configuration.parameters:
|
1259
1261
|
# a value for this parameter has been set already
|
1260
1262
|
values[name] = self.configuration.parameters[name]
|
@@ -1262,7 +1264,7 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1262
1264
|
# a value for this parameter has been set in the "new" way
|
1263
1265
|
# already
|
1264
1266
|
values[name] = params_defined_in_new_way[name]
|
1265
|
-
elif field.
|
1267
|
+
elif field.is_required():
|
1266
1268
|
# this field has no default value set and therefore needs
|
1267
1269
|
# to be passed via an initialized config object
|
1268
1270
|
missing_keys.append(name)
|
@@ -1278,8 +1280,12 @@ To avoid this consider setting step parameters only in one place (config or code
|
|
1278
1280
|
)
|
1279
1281
|
|
1280
1282
|
if (
|
1281
|
-
|
1282
|
-
|
1283
|
+
getattr(
|
1284
|
+
self.entrypoint_definition.legacy_params.annotation.model_config,
|
1285
|
+
"extra",
|
1286
|
+
None,
|
1287
|
+
)
|
1288
|
+
== "allow"
|
1283
1289
|
):
|
1284
1290
|
# Add all parameters for the config class for backwards
|
1285
1291
|
# compatibility if the config class allows extra attributes
|
@@ -27,7 +27,7 @@ from typing import (
|
|
27
27
|
Union,
|
28
28
|
)
|
29
29
|
|
30
|
-
from pydantic import
|
30
|
+
from pydantic import ConfigDict, ValidationError, create_model
|
31
31
|
|
32
32
|
from zenml.constants import ENFORCE_TYPE_ANNOTATIONS
|
33
33
|
from zenml.exceptions import StepInterfaceError
|
@@ -235,16 +235,14 @@ class EntrypointFunctionDefinition(NamedTuple):
|
|
235
235
|
parameter: The function parameter for which the value was provided.
|
236
236
|
value: The input value.
|
237
237
|
"""
|
238
|
-
|
239
|
-
class ModelConfig(BaseConfig):
|
240
|
-
arbitrary_types_allowed = False
|
238
|
+
config_dict = ConfigDict(arbitrary_types_allowed=False)
|
241
239
|
|
242
240
|
# Create a pydantic model with just a single required field with the
|
243
241
|
# type annotation of the parameter to verify the input type including
|
244
242
|
# pydantics type coercion
|
245
243
|
validation_model_class = create_model(
|
246
244
|
"input_validation_model",
|
247
|
-
__config__=
|
245
|
+
__config__=config_dict,
|
248
246
|
value=(parameter.annotation, ...),
|
249
247
|
)
|
250
248
|
validation_model_class(value=value)
|
zenml/steps/utils.py
CHANGED
@@ -21,7 +21,6 @@ import textwrap
|
|
21
21
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
22
22
|
from uuid import UUID
|
23
23
|
|
24
|
-
import pydantic.typing as pydantic_typing
|
25
24
|
from pydantic import BaseModel
|
26
25
|
from typing_extensions import Annotated
|
27
26
|
|
@@ -32,7 +31,7 @@ from zenml.logger import get_logger
|
|
32
31
|
from zenml.metadata.metadata_types import MetadataType
|
33
32
|
from zenml.new.steps.step_context import get_step_context
|
34
33
|
from zenml.steps.step_output import Output
|
35
|
-
from zenml.utils import source_code_utils
|
34
|
+
from zenml.utils import source_code_utils, typing_utils
|
36
35
|
|
37
36
|
logger = get_logger(__name__)
|
38
37
|
|
@@ -42,8 +41,8 @@ SINGLE_RETURN_OUT_NAME = "output"
|
|
42
41
|
class OutputSignature(BaseModel):
|
43
42
|
"""The signature of an output artifact."""
|
44
43
|
|
45
|
-
resolved_annotation: Any
|
46
|
-
artifact_config: Optional[ArtifactConfig]
|
44
|
+
resolved_annotation: Any = None
|
45
|
+
artifact_config: Optional[ArtifactConfig] = None
|
47
46
|
has_custom_name: bool = False
|
48
47
|
|
49
48
|
|
@@ -60,8 +59,7 @@ def get_args(obj: Any) -> Tuple[Any, ...]:
|
|
60
59
|
The args of the annotation.
|
61
60
|
"""
|
62
61
|
return tuple(
|
63
|
-
|
64
|
-
for v in pydantic_typing.get_args(obj)
|
62
|
+
typing_utils.get_origin(v) or v for v in typing_utils.get_args(obj)
|
65
63
|
)
|
66
64
|
|
67
65
|
|
@@ -123,11 +121,11 @@ def parse_return_type_annotations(
|
|
123
121
|
for output_name, output_type in return_annotation.items()
|
124
122
|
}
|
125
123
|
|
126
|
-
elif
|
124
|
+
elif typing_utils.get_origin(return_annotation) is tuple:
|
127
125
|
requires_multiple_artifacts = has_tuple_return(func)
|
128
126
|
if requires_multiple_artifacts:
|
129
127
|
output_signature: Dict[str, Any] = {}
|
130
|
-
args =
|
128
|
+
args = typing_utils.get_args(return_annotation)
|
131
129
|
if args[-1] is Ellipsis:
|
132
130
|
raise RuntimeError(
|
133
131
|
"Variable length output annotations are not allowed."
|
@@ -179,12 +177,12 @@ def resolve_type_annotation(obj: Any) -> Any:
|
|
179
177
|
Returns:
|
180
178
|
The non-generic class for generic aliases of the typing module.
|
181
179
|
"""
|
182
|
-
origin =
|
180
|
+
origin = typing_utils.get_origin(obj) or obj
|
183
181
|
|
184
182
|
if origin is Annotated:
|
185
|
-
annotation, *_ =
|
183
|
+
annotation, *_ = typing_utils.get_args(obj)
|
186
184
|
return resolve_type_annotation(annotation)
|
187
|
-
elif
|
185
|
+
elif typing_utils.is_union(origin):
|
188
186
|
return obj
|
189
187
|
|
190
188
|
return origin
|
@@ -212,10 +210,10 @@ def get_artifact_config_from_annotation_metadata(
|
|
212
210
|
Returns:
|
213
211
|
The artifact config.
|
214
212
|
"""
|
215
|
-
if (
|
213
|
+
if (typing_utils.get_origin(annotation) or annotation) is not Annotated:
|
216
214
|
return None
|
217
215
|
|
218
|
-
annotation, *metadata =
|
216
|
+
annotation, *metadata = typing_utils.get_args(annotation)
|
219
217
|
|
220
218
|
error_message = (
|
221
219
|
"Artifact annotation should only contain two elements: the artifact "
|
@@ -251,7 +249,7 @@ def get_artifact_config_from_annotation_metadata(
|
|
251
249
|
if not artifact_config:
|
252
250
|
artifact_config = ArtifactConfig(name=output_name)
|
253
251
|
elif not artifact_config.name:
|
254
|
-
artifact_config = artifact_config.
|
252
|
+
artifact_config = artifact_config.model_copy()
|
255
253
|
artifact_config.name = output_name
|
256
254
|
|
257
255
|
if artifact_config and artifact_config.name == "":
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Apache Software License 2.0
|
2
|
+
#
|
3
|
+
# Copyright (c) ZenML GmbH 2024. All rights reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
#
|
17
|
+
"""Utilities for managing GPU memory."""
|
18
|
+
|
19
|
+
import gc
|
20
|
+
|
21
|
+
from zenml.logger import get_logger
|
22
|
+
|
23
|
+
logger = get_logger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
def cleanup_gpu_memory(force: bool = False) -> None:
|
27
|
+
"""Clean up GPU memory.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
force: whether to force the cleanup of GPU memory (must be passed explicitly)
|
31
|
+
"""
|
32
|
+
if not force:
|
33
|
+
logger.warning(
|
34
|
+
"This will clean up all GPU memory on current physical machine. "
|
35
|
+
"This action is considered to be dangerous by default, since "
|
36
|
+
"it might affect other processes running in the same environment. "
|
37
|
+
"If this is intended, please explicitly pass `force=True`."
|
38
|
+
)
|
39
|
+
else:
|
40
|
+
try:
|
41
|
+
import torch
|
42
|
+
except ModuleNotFoundError:
|
43
|
+
logger.warning(
|
44
|
+
"No PyTorch installed. Skipping GPU memory cleanup."
|
45
|
+
)
|
46
|
+
return
|
47
|
+
|
48
|
+
logger.info("Cleaning up GPU memory...")
|
49
|
+
while gc.collect():
|
50
|
+
torch.cuda.empty_cache()
|
zenml/utils/deprecation_utils.py
CHANGED
@@ -14,14 +14,12 @@
|
|
14
14
|
"""Deprecation utilities."""
|
15
15
|
|
16
16
|
import warnings
|
17
|
-
from typing import
|
17
|
+
from typing import Any, Dict, Set, Tuple, Type, Union
|
18
18
|
|
19
|
-
from pydantic import BaseModel,
|
19
|
+
from pydantic import BaseModel, model_validator
|
20
20
|
|
21
21
|
from zenml.logger import get_logger
|
22
|
-
|
23
|
-
if TYPE_CHECKING:
|
24
|
-
AnyClassMethod = classmethod[Any] # type: ignore[type-arg]
|
22
|
+
from zenml.utils.pydantic_utils import before_validator_handler
|
25
23
|
|
26
24
|
logger = get_logger(__name__)
|
27
25
|
|
@@ -30,7 +28,7 @@ PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE = "__previous_deprecation_warnings"
|
|
30
28
|
|
31
29
|
def deprecate_pydantic_attributes(
|
32
30
|
*attributes: Union[str, Tuple[str, str]],
|
33
|
-
) ->
|
31
|
+
) -> Any:
|
34
32
|
"""Utility function for deprecating and migrating pydantic attributes.
|
35
33
|
|
36
34
|
**Usage**:
|
@@ -55,22 +53,24 @@ def deprecate_pydantic_attributes(
|
|
55
53
|
Args:
|
56
54
|
*attributes: List of attributes to deprecate. This is either the name
|
57
55
|
of the attribute to deprecate, or a tuple containing the name of
|
58
|
-
the deprecated attribute and it's replacement.
|
56
|
+
the deprecated attribute, and it's replacement.
|
59
57
|
|
60
58
|
Returns:
|
61
59
|
Pydantic validator class method to be used on BaseModel subclasses
|
62
60
|
to deprecate or migrate attributes.
|
63
61
|
"""
|
64
62
|
|
65
|
-
@
|
63
|
+
@model_validator(mode="before") # type: ignore[misc]
|
64
|
+
@classmethod
|
65
|
+
@before_validator_handler
|
66
66
|
def _deprecation_validator(
|
67
|
-
cls: Type[BaseModel],
|
67
|
+
cls: Type[BaseModel], data: Dict[str, Any]
|
68
68
|
) -> Dict[str, Any]:
|
69
69
|
"""Pydantic validator function for deprecating pydantic attributes.
|
70
70
|
|
71
71
|
Args:
|
72
72
|
cls: The class on which the attributes are defined.
|
73
|
-
|
73
|
+
data: All values passed at model initialization.
|
74
74
|
|
75
75
|
Raises:
|
76
76
|
AssertionError: If either the deprecated or replacement attribute
|
@@ -110,14 +110,14 @@ def deprecate_pydantic_attributes(
|
|
110
110
|
deprecated_attribute, replacement_attribute = attribute
|
111
111
|
|
112
112
|
assert (
|
113
|
-
replacement_attribute in cls.
|
113
|
+
replacement_attribute in cls.model_fields
|
114
114
|
), f"Unable to find attribute {replacement_attribute}."
|
115
115
|
|
116
116
|
assert (
|
117
|
-
deprecated_attribute in cls.
|
117
|
+
deprecated_attribute in cls.model_fields
|
118
118
|
), f"Unable to find attribute {deprecated_attribute}."
|
119
119
|
|
120
|
-
if cls.
|
120
|
+
if cls.model_fields[deprecated_attribute].is_required():
|
121
121
|
raise TypeError(
|
122
122
|
f"Unable to deprecate attribute '{deprecated_attribute}' "
|
123
123
|
f"of class {cls.__name__}. In order to deprecate an "
|
@@ -126,7 +126,7 @@ def deprecate_pydantic_attributes(
|
|
126
126
|
"annotation."
|
127
127
|
)
|
128
128
|
|
129
|
-
if
|
129
|
+
if data.get(deprecated_attribute, None) is None:
|
130
130
|
continue
|
131
131
|
|
132
132
|
if replacement_attribute is None:
|
@@ -144,17 +144,15 @@ def deprecate_pydantic_attributes(
|
|
144
144
|
attribute=deprecated_attribute,
|
145
145
|
)
|
146
146
|
|
147
|
-
if
|
147
|
+
if data.get(replacement_attribute, None) is None:
|
148
148
|
logger.debug(
|
149
149
|
"Migrating value of deprecated attribute %s to "
|
150
150
|
"replacement attribute %s.",
|
151
151
|
deprecated_attribute,
|
152
152
|
replacement_attribute,
|
153
153
|
)
|
154
|
-
|
155
|
-
|
156
|
-
)
|
157
|
-
elif values[deprecated_attribute] != values[replacement_attribute]:
|
154
|
+
data[replacement_attribute] = data.pop(deprecated_attribute)
|
155
|
+
elif data[deprecated_attribute] != data[replacement_attribute]:
|
158
156
|
raise ValueError(
|
159
157
|
"Got different values for deprecated attribute "
|
160
158
|
f"{deprecated_attribute} and replacement "
|
@@ -170,6 +168,6 @@ def deprecate_pydantic_attributes(
|
|
170
168
|
previous_deprecation_warnings,
|
171
169
|
)
|
172
170
|
|
173
|
-
return
|
171
|
+
return data
|
174
172
|
|
175
173
|
return _deprecation_validator
|
zenml/utils/dict_utils.py
CHANGED
zenml/utils/filesync_model.py
CHANGED
@@ -13,11 +13,16 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Filesync utils for ZenML."""
|
15
15
|
|
16
|
-
import json
|
17
16
|
import os
|
18
17
|
from typing import Any, Optional
|
19
18
|
|
20
|
-
from pydantic import
|
19
|
+
from pydantic import (
|
20
|
+
BaseModel,
|
21
|
+
ValidationError,
|
22
|
+
ValidationInfo,
|
23
|
+
ValidatorFunctionWrapHandler,
|
24
|
+
model_validator,
|
25
|
+
)
|
21
26
|
|
22
27
|
from zenml.io import fileio
|
23
28
|
from zenml.logger import get_logger
|
@@ -40,29 +45,69 @@ class FileSyncModel(BaseModel):
|
|
40
45
|
_config_file: str
|
41
46
|
_config_file_timestamp: Optional[float] = None
|
42
47
|
|
43
|
-
|
44
|
-
|
48
|
+
@model_validator(mode="wrap")
|
49
|
+
@classmethod
|
50
|
+
def config_validator(
|
51
|
+
cls,
|
52
|
+
data: Any,
|
53
|
+
handler: ValidatorFunctionWrapHandler,
|
54
|
+
info: ValidationInfo,
|
55
|
+
) -> "FileSyncModel":
|
56
|
+
"""Wrap model validator to infer the config_file during initialization.
|
45
57
|
|
46
58
|
Args:
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
59
|
+
data: The raw data that is provided before the validation.
|
60
|
+
handler: The actual validation function pydantic would use for the
|
61
|
+
built-in validation function.
|
62
|
+
info: The context information during the execution of this
|
63
|
+
validation function.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
the actual instance after the validation
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValidationError: if you try to validate through a JSON string. You
|
70
|
+
need to provide a config_file path when you create a
|
71
|
+
FileSyncModel.
|
72
|
+
AssertionError: if the raw input does not include a config_file
|
73
|
+
path for the configuration file.
|
52
74
|
"""
|
53
|
-
|
54
|
-
if
|
55
|
-
|
75
|
+
# Disable json validation
|
76
|
+
if info.mode == "json":
|
77
|
+
raise ValidationError(
|
78
|
+
"You can not instantiate filesync models using the JSON mode."
|
79
|
+
)
|
56
80
|
|
57
|
-
|
58
|
-
|
81
|
+
if isinstance(data, dict):
|
82
|
+
# Assert that the config file is defined
|
83
|
+
assert (
|
84
|
+
"config_file" in data
|
85
|
+
), "You have to provide a path for the configuration file."
|
59
86
|
|
60
|
-
|
61
|
-
super(FileSyncModel, self).__init__(**config_dict)
|
87
|
+
config_file = data.pop("config_file")
|
62
88
|
|
63
|
-
|
64
|
-
|
65
|
-
|
89
|
+
# Load the current values and update with new values
|
90
|
+
config_dict = {}
|
91
|
+
if fileio.exists(config_file):
|
92
|
+
config_dict = yaml_utils.read_yaml(config_file)
|
93
|
+
config_dict.update(data)
|
94
|
+
|
95
|
+
# Execute the regular validation
|
96
|
+
model = handler(config_dict)
|
97
|
+
|
98
|
+
assert isinstance(model, cls)
|
99
|
+
|
100
|
+
# Assign the private attribute and save the config
|
101
|
+
model._config_file = config_file
|
102
|
+
model.write_config()
|
103
|
+
|
104
|
+
else:
|
105
|
+
# If the raw value is not a dict, apply proper validation.
|
106
|
+
model = handler(data)
|
107
|
+
|
108
|
+
assert isinstance(model, cls)
|
109
|
+
|
110
|
+
return model
|
66
111
|
|
67
112
|
def __setattr__(self, key: str, value: Any) -> None:
|
68
113
|
"""Sets an attribute on the model and persists it in the configuration file.
|
@@ -91,8 +136,7 @@ class FileSyncModel(BaseModel):
|
|
91
136
|
|
92
137
|
def write_config(self) -> None:
|
93
138
|
"""Writes the model to the configuration file."""
|
94
|
-
|
95
|
-
yaml_utils.write_yaml(self._config_file, config_dict)
|
139
|
+
yaml_utils.write_yaml(self._config_file, self.model_dump(mode="json"))
|
96
140
|
self._config_file_timestamp = os.path.getmtime(self._config_file)
|
97
141
|
|
98
142
|
def load_config(self) -> None:
|
@@ -115,10 +159,3 @@ class FileSyncModel(BaseModel):
|
|
115
159
|
super(FileSyncModel, self).__setattr__(key, value)
|
116
160
|
|
117
161
|
self._config_file_timestamp = file_timestamp
|
118
|
-
|
119
|
-
class Config:
|
120
|
-
"""Pydantic configuration class."""
|
121
|
-
|
122
|
-
# all attributes with leading underscore are private and therefore
|
123
|
-
# are mutable and not included in serialization
|
124
|
-
underscore_attrs_are_private = True
|