zenml-nightly 0.58.2.dev20240614__py3-none-any.whl → 0.58.2.dev20240622__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/_hub/client.py +8 -5
- zenml/actions/base_action.py +8 -10
- zenml/artifact_stores/base_artifact_store.py +20 -15
- zenml/artifact_stores/local_artifact_store.py +3 -2
- zenml/artifacts/artifact_config.py +34 -19
- zenml/artifacts/external_artifact.py +18 -8
- zenml/artifacts/external_artifact_config.py +14 -6
- zenml/artifacts/unmaterialized_artifact.py +2 -11
- zenml/cli/__init__.py +6 -0
- zenml/cli/artifact.py +20 -2
- zenml/cli/base.py +2 -2
- zenml/cli/served_model.py +0 -1
- zenml/cli/server.py +3 -3
- zenml/cli/utils.py +36 -40
- zenml/cli/web_login.py +2 -2
- zenml/client.py +198 -24
- zenml/client_lazy_loader.py +20 -14
- zenml/config/base_settings.py +5 -6
- zenml/config/build_configuration.py +1 -1
- zenml/config/compiler.py +3 -3
- zenml/config/docker_settings.py +27 -28
- zenml/config/global_config.py +33 -37
- zenml/config/pipeline_configurations.py +8 -11
- zenml/config/pipeline_run_configuration.py +6 -2
- zenml/config/pipeline_spec.py +3 -4
- zenml/config/resource_settings.py +8 -9
- zenml/config/schedule.py +16 -20
- zenml/config/secret_reference_mixin.py +6 -3
- zenml/config/secrets_store_config.py +16 -23
- zenml/config/server_config.py +50 -46
- zenml/config/settings_resolver.py +1 -1
- zenml/config/source.py +45 -35
- zenml/config/step_configurations.py +53 -31
- zenml/config/step_run_info.py +3 -0
- zenml/config/store_config.py +20 -19
- zenml/config/strict_base_model.py +2 -6
- zenml/constants.py +26 -2
- zenml/container_registries/base_container_registry.py +3 -2
- zenml/container_registries/default_container_registry.py +3 -3
- zenml/event_hub/base_event_hub.py +1 -1
- zenml/event_sources/base_event_source.py +11 -16
- zenml/exceptions.py +4 -0
- zenml/integrations/airflow/__init__.py +2 -6
- zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +6 -7
- zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +13 -249
- zenml/integrations/airflow/orchestrators/dag_generator.py +5 -3
- zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +5 -4
- zenml/integrations/aws/__init__.py +1 -1
- zenml/integrations/aws/flavors/aws_container_registry_flavor.py +3 -2
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +11 -5
- zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +6 -2
- zenml/integrations/aws/service_connectors/aws_service_connector.py +5 -4
- zenml/integrations/aws/step_operators/sagemaker_step_operator.py +1 -1
- zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +4 -4
- zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -3
- zenml/integrations/azure/step_operators/azureml_step_operator.py +2 -1
- zenml/integrations/bentoml/steps/bentoml_deployer.py +1 -1
- zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +8 -12
- zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +1 -1
- zenml/integrations/constants.py +0 -1
- zenml/integrations/deepchecks/__init__.py +1 -0
- zenml/integrations/evidently/__init__.py +5 -3
- zenml/integrations/evidently/column_mapping.py +11 -3
- zenml/integrations/evidently/data_validators/evidently_data_validator.py +21 -3
- zenml/integrations/evidently/metrics.py +5 -6
- zenml/integrations/evidently/tests.py +5 -6
- zenml/integrations/facets/models.py +2 -6
- zenml/integrations/feast/__init__.py +3 -1
- zenml/integrations/feast/feature_stores/feast_feature_store.py +0 -23
- zenml/integrations/gcp/__init__.py +1 -1
- zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +1 -1
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +1 -1
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +234 -103
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +57 -42
- zenml/integrations/gcp/step_operators/vertex_step_operator.py +1 -0
- zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
- zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +9 -13
- zenml/integrations/great_expectations/__init__.py +1 -1
- zenml/integrations/great_expectations/data_validators/ge_data_validator.py +44 -44
- zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +35 -2
- zenml/integrations/great_expectations/ge_store_backend.py +24 -11
- zenml/integrations/great_expectations/materializers/ge_materializer.py +3 -3
- zenml/integrations/great_expectations/utils.py +5 -5
- zenml/integrations/huggingface/__init__.py +3 -0
- zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +1 -1
- zenml/integrations/huggingface/steps/__init__.py +3 -0
- zenml/integrations/huggingface/steps/accelerate_runner.py +149 -0
- zenml/integrations/huggingface/steps/huggingface_deployer.py +2 -2
- zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +1 -1
- zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +4 -3
- zenml/integrations/kubeflow/__init__.py +1 -1
- zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +48 -81
- zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +295 -245
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +1 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -2
- zenml/integrations/kubernetes/pod_settings.py +17 -31
- zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +8 -7
- zenml/integrations/label_studio/__init__.py +1 -3
- zenml/integrations/label_studio/annotators/label_studio_annotator.py +3 -4
- zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +2 -2
- zenml/integrations/langchain/__init__.py +5 -1
- zenml/integrations/langchain/materializers/document_materializer.py +44 -8
- zenml/integrations/mlflow/__init__.py +9 -3
- zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +1 -1
- zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +29 -37
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +4 -4
- zenml/integrations/mlflow/steps/mlflow_deployer.py +1 -1
- zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +1 -1
- zenml/integrations/neural_prophet/__init__.py +5 -1
- zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +1 -1
- zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +9 -8
- zenml/integrations/seldon/seldon_client.py +52 -67
- zenml/integrations/seldon/services/seldon_deployment.py +3 -3
- zenml/integrations/seldon/steps/seldon_deployer.py +4 -4
- zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +15 -5
- zenml/integrations/skypilot_aws/__init__.py +1 -1
- zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +1 -1
- zenml/integrations/skypilot_azure/__init__.py +1 -1
- zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +1 -1
- zenml/integrations/skypilot_gcp/__init__.py +2 -1
- zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +1 -1
- zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +2 -2
- zenml/integrations/spark/flavors/spark_step_operator_flavor.py +1 -1
- zenml/integrations/spark/step_operators/spark_step_operator.py +2 -0
- zenml/integrations/tekton/__init__.py +1 -1
- zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +66 -23
- zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +547 -233
- zenml/integrations/tensorboard/__init__.py +1 -12
- zenml/integrations/tensorboard/services/tensorboard_service.py +3 -5
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +6 -6
- zenml/integrations/tensorflow/__init__.py +2 -10
- zenml/integrations/tensorflow/materializers/keras_materializer.py +17 -9
- zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +9 -14
- zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +1 -1
- zenml/lineage_graph/lineage_graph.py +1 -1
- zenml/logging/step_logging.py +15 -7
- zenml/materializers/built_in_materializer.py +3 -3
- zenml/materializers/pydantic_materializer.py +2 -2
- zenml/metadata/lazy_load.py +4 -4
- zenml/metadata/metadata_types.py +64 -4
- zenml/model/model.py +79 -54
- zenml/model_deployers/base_model_deployer.py +14 -12
- zenml/model_registries/base_model_registry.py +17 -15
- zenml/models/__init__.py +79 -206
- zenml/models/v2/base/base.py +54 -41
- zenml/models/v2/base/base_plugin_flavor.py +2 -6
- zenml/models/v2/base/filter.py +91 -76
- zenml/models/v2/base/page.py +2 -12
- zenml/models/v2/base/scoped.py +4 -7
- zenml/models/v2/core/api_key.py +22 -8
- zenml/models/v2/core/artifact.py +2 -2
- zenml/models/v2/core/artifact_version.py +74 -40
- zenml/models/v2/core/code_repository.py +37 -10
- zenml/models/v2/core/component.py +65 -16
- zenml/models/v2/core/device.py +14 -4
- zenml/models/v2/core/event_source.py +1 -2
- zenml/models/v2/core/flavor.py +74 -8
- zenml/models/v2/core/logs.py +68 -8
- zenml/models/v2/core/model.py +8 -4
- zenml/models/v2/core/model_version.py +25 -6
- zenml/models/v2/core/model_version_artifact.py +51 -21
- zenml/models/v2/core/model_version_pipeline_run.py +45 -13
- zenml/models/v2/core/pipeline.py +37 -72
- zenml/models/v2/core/pipeline_build.py +29 -17
- zenml/models/v2/core/pipeline_deployment.py +18 -6
- zenml/models/v2/core/pipeline_namespace.py +113 -0
- zenml/models/v2/core/pipeline_run.py +50 -22
- zenml/models/v2/core/run_metadata.py +59 -36
- zenml/models/v2/core/schedule.py +37 -24
- zenml/models/v2/core/secret.py +31 -12
- zenml/models/v2/core/service.py +64 -36
- zenml/models/v2/core/service_account.py +24 -11
- zenml/models/v2/core/service_connector.py +219 -44
- zenml/models/v2/core/stack.py +45 -17
- zenml/models/v2/core/step_run.py +28 -8
- zenml/models/v2/core/tag.py +8 -4
- zenml/models/v2/core/trigger.py +2 -2
- zenml/models/v2/core/trigger_execution.py +1 -0
- zenml/models/v2/core/user.py +18 -21
- zenml/models/v2/core/workspace.py +13 -3
- zenml/models/v2/misc/build_item.py +3 -3
- zenml/models/v2/misc/external_user.py +2 -6
- zenml/models/v2/misc/hub_plugin_models.py +9 -9
- zenml/models/v2/misc/loaded_visualization.py +2 -2
- zenml/models/v2/misc/service_connector_type.py +8 -17
- zenml/models/v2/misc/user_auth.py +7 -2
- zenml/new/pipelines/build_utils.py +3 -3
- zenml/new/pipelines/pipeline.py +17 -13
- zenml/new/pipelines/run_utils.py +103 -1
- zenml/orchestrators/base_orchestrator.py +10 -7
- zenml/orchestrators/local_docker/local_docker_orchestrator.py +1 -1
- zenml/orchestrators/step_launcher.py +28 -4
- zenml/orchestrators/step_runner.py +3 -6
- zenml/orchestrators/utils.py +1 -1
- zenml/plugins/base_plugin_flavor.py +6 -10
- zenml/plugins/plugin_flavor_registry.py +3 -7
- zenml/secret/base_secret.py +7 -8
- zenml/service_connectors/docker_service_connector.py +4 -3
- zenml/service_connectors/service_connector.py +5 -12
- zenml/service_connectors/service_connector_registry.py +2 -4
- zenml/services/container/container_service.py +1 -1
- zenml/services/container/container_service_endpoint.py +1 -1
- zenml/services/local/local_service.py +1 -1
- zenml/services/local/local_service_endpoint.py +1 -1
- zenml/services/service.py +16 -10
- zenml/services/service_type.py +4 -5
- zenml/services/terraform/terraform_service.py +1 -1
- zenml/stack/flavor.py +2 -6
- zenml/stack/flavor_registry.py +4 -4
- zenml/stack/stack.py +4 -1
- zenml/stack/stack_component.py +55 -31
- zenml/step_operators/step_operator_entrypoint_configuration.py +1 -0
- zenml/steps/base_step.py +34 -28
- zenml/steps/entrypoint_function_utils.py +3 -5
- zenml/steps/utils.py +12 -14
- zenml/utils/cuda_utils.py +50 -0
- zenml/utils/deprecation_utils.py +18 -20
- zenml/utils/dict_utils.py +1 -1
- zenml/utils/filesync_model.py +65 -28
- zenml/utils/function_utils.py +260 -0
- zenml/utils/json_utils.py +131 -0
- zenml/utils/mlstacks_utils.py +2 -2
- zenml/utils/package_utils.py +1 -1
- zenml/utils/pipeline_docker_image_builder.py +9 -0
- zenml/utils/pydantic_utils.py +270 -62
- zenml/utils/secret_utils.py +65 -12
- zenml/utils/source_utils.py +2 -2
- zenml/utils/typed_model.py +5 -3
- zenml/utils/typing_utils.py +243 -0
- zenml/utils/yaml_utils.py +1 -1
- zenml/zen_server/auth.py +2 -2
- zenml/zen_server/cloud_utils.py +6 -6
- zenml/zen_server/deploy/base_provider.py +1 -1
- zenml/zen_server/deploy/deployment.py +6 -8
- zenml/zen_server/deploy/docker/docker_zen_server.py +3 -4
- zenml/zen_server/deploy/local/local_provider.py +0 -1
- zenml/zen_server/deploy/local/local_zen_server.py +6 -6
- zenml/zen_server/deploy/terraform/terraform_zen_server.py +4 -6
- zenml/zen_server/exceptions.py +4 -1
- zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
- zenml/zen_server/pipeline_deployment/utils.py +48 -68
- zenml/zen_server/rbac/models.py +2 -5
- zenml/zen_server/rbac/utils.py +11 -14
- zenml/zen_server/routers/auth_endpoints.py +2 -2
- zenml/zen_server/routers/pipeline_builds_endpoints.py +1 -1
- zenml/zen_server/routers/runs_endpoints.py +1 -1
- zenml/zen_server/routers/secrets_endpoints.py +3 -2
- zenml/zen_server/routers/server_endpoints.py +1 -1
- zenml/zen_server/routers/steps_endpoints.py +1 -1
- zenml/zen_server/routers/workspaces_endpoints.py +1 -1
- zenml/zen_stores/base_zen_store.py +46 -9
- zenml/zen_stores/migrations/utils.py +42 -46
- zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py +1 -1
- zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +5 -3
- zenml/zen_stores/migrations/versions/10a907dad202_delete_mlmd_tables.py +1 -1
- zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +8 -10
- zenml/zen_stores/migrations/versions/37835ce041d2_optimizing_database.py +3 -3
- zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +10 -12
- zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +3 -2
- zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py +4 -4
- zenml/zen_stores/migrations/versions/728c6369cfaa_add_name_column_to_input_artifact_pk.py +3 -2
- zenml/zen_stores/migrations/versions/743ec82b1b3c_update_size_of_build_images.py +2 -2
- zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
- zenml/zen_stores/migrations/versions/7834208cc3f6_artifact_project_scoping.py +8 -7
- zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +6 -4
- zenml/zen_stores/migrations/versions/7e4a481d17f7_add_identity_table.py +2 -2
- zenml/zen_stores/migrations/versions/7f603e583dd7_fixed_migration.py +1 -1
- zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +2 -2
- zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +4 -4
- zenml/zen_stores/migrations/versions/alembic_start.py +1 -1
- zenml/zen_stores/migrations/versions/fbd7f18ced1e_increase_step_run_field_lengths.py +4 -4
- zenml/zen_stores/rest_zen_store.py +109 -49
- zenml/zen_stores/schemas/api_key_schemas.py +1 -1
- zenml/zen_stores/schemas/artifact_schemas.py +8 -8
- zenml/zen_stores/schemas/artifact_visualization_schemas.py +3 -3
- zenml/zen_stores/schemas/code_repository_schemas.py +1 -1
- zenml/zen_stores/schemas/component_schemas.py +8 -3
- zenml/zen_stores/schemas/device_schemas.py +8 -6
- zenml/zen_stores/schemas/event_source_schemas.py +3 -4
- zenml/zen_stores/schemas/flavor_schemas.py +5 -3
- zenml/zen_stores/schemas/model_schemas.py +26 -1
- zenml/zen_stores/schemas/pipeline_build_schemas.py +1 -1
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +4 -4
- zenml/zen_stores/schemas/pipeline_run_schemas.py +6 -6
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -2
- zenml/zen_stores/schemas/run_metadata_schemas.py +2 -2
- zenml/zen_stores/schemas/secret_schemas.py +8 -5
- zenml/zen_stores/schemas/server_settings_schemas.py +3 -1
- zenml/zen_stores/schemas/service_connector_schemas.py +1 -1
- zenml/zen_stores/schemas/service_schemas.py +11 -2
- zenml/zen_stores/schemas/stack_schemas.py +1 -1
- zenml/zen_stores/schemas/step_run_schemas.py +11 -11
- zenml/zen_stores/schemas/tag_schemas.py +6 -2
- zenml/zen_stores/schemas/trigger_schemas.py +2 -2
- zenml/zen_stores/schemas/user_schemas.py +2 -2
- zenml/zen_stores/schemas/workspace_schemas.py +3 -1
- zenml/zen_stores/secrets_stores/aws_secrets_store.py +19 -20
- zenml/zen_stores/secrets_stores/azure_secrets_store.py +17 -20
- zenml/zen_stores/secrets_stores/base_secrets_store.py +79 -12
- zenml/zen_stores/secrets_stores/gcp_secrets_store.py +17 -20
- zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +4 -8
- zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +10 -7
- zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -6
- zenml/zen_stores/sql_zen_store.py +196 -120
- zenml/zen_stores/zen_store_interface.py +33 -0
- {zenml_nightly-0.58.2.dev20240614.dist-info → zenml_nightly-0.58.2.dev20240622.dist-info}/METADATA +9 -7
- {zenml_nightly-0.58.2.dev20240614.dist-info → zenml_nightly-0.58.2.dev20240622.dist-info}/RECORD +311 -308
- zenml/integrations/kubeflow/utils.py +0 -95
- zenml/models/v2/base/internal.py +0 -37
- zenml/models/v2/base/update.py +0 -44
- {zenml_nightly-0.58.2.dev20240614.dist-info → zenml_nightly-0.58.2.dev20240622.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.58.2.dev20240614.dist-info → zenml_nightly-0.58.2.dev20240622.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.58.2.dev20240614.dist-info → zenml_nightly-0.58.2.dev20240622.dist-info}/entry_points.txt +0 -0
@@ -17,7 +17,7 @@ import urllib
|
|
17
17
|
from typing import Any, Dict, List, Optional, Type, Union
|
18
18
|
from uuid import UUID
|
19
19
|
|
20
|
-
from pydantic import BaseModel,
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field
|
21
21
|
|
22
22
|
from zenml.enums import SecretScope
|
23
23
|
from zenml.event_sources.base_event import (
|
@@ -111,14 +111,10 @@ class GithubEvent(BaseEvent):
|
|
111
111
|
after: str
|
112
112
|
repository: Repository
|
113
113
|
commits: List[Commit]
|
114
|
-
head_commit: Optional[Commit]
|
115
|
-
tags: Optional[List[Tag]]
|
116
|
-
pull_requests: Optional[List[PullRequest]]
|
117
|
-
|
118
|
-
class Config:
|
119
|
-
"""Pydantic configuration class."""
|
120
|
-
|
121
|
-
extra = Extra.allow
|
114
|
+
head_commit: Optional[Commit] = None
|
115
|
+
tags: Optional[List[Tag]] = None
|
116
|
+
pull_requests: Optional[List[PullRequest]] = None
|
117
|
+
model_config = ConfigDict(extra="allow")
|
122
118
|
|
123
119
|
@property
|
124
120
|
def branch(self) -> Optional[str]:
|
@@ -157,9 +153,9 @@ class GithubEvent(BaseEvent):
|
|
157
153
|
class GithubWebhookEventFilterConfiguration(WebhookEventFilterConfig):
|
158
154
|
"""Configuration for github event filters."""
|
159
155
|
|
160
|
-
repo: Optional[str]
|
161
|
-
branch: Optional[str]
|
162
|
-
event_type: Optional[GithubEventType]
|
156
|
+
repo: Optional[str] = None
|
157
|
+
branch: Optional[str] = None
|
158
|
+
event_type: Optional[GithubEventType] = None
|
163
159
|
|
164
160
|
def event_matches_filter(self, event: BaseEvent) -> bool:
|
165
161
|
"""Checks the filter against the inbound event.
|
@@ -442,7 +438,7 @@ class GithubWebhookEventSourceHandler(BaseWebhookEventSourceHandler):
|
|
442
438
|
if config.rotate_secret:
|
443
439
|
# In case the secret is being rotated
|
444
440
|
secret_key_value = random_str(12)
|
445
|
-
webhook_secret = SecretUpdate(
|
441
|
+
webhook_secret = SecretUpdate(
|
446
442
|
values={"webhook_secret": secret_key_value}
|
447
443
|
)
|
448
444
|
self.zen_store.update_secret(
|
@@ -17,21 +17,25 @@ import os
|
|
17
17
|
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Type, cast
|
18
18
|
|
19
19
|
import pandas as pd
|
20
|
-
import yaml
|
21
20
|
from great_expectations.checkpoint.types.checkpoint_result import ( # type: ignore[import-untyped]
|
22
21
|
CheckpointResult,
|
23
22
|
)
|
24
23
|
from great_expectations.core import ( # type: ignore[import-untyped]
|
25
24
|
ExpectationSuite,
|
26
25
|
)
|
27
|
-
from great_expectations.data_context.data_context import (
|
28
|
-
|
29
|
-
DataContext,
|
26
|
+
from great_expectations.data_context.data_context.abstract_data_context import (
|
27
|
+
AbstractDataContext,
|
30
28
|
)
|
31
|
-
from great_expectations.data_context.
|
29
|
+
from great_expectations.data_context.data_context.context_factory import (
|
30
|
+
get_context,
|
31
|
+
)
|
32
|
+
from great_expectations.data_context.data_context.ephemeral_data_context import (
|
33
|
+
EphemeralDataContext,
|
34
|
+
)
|
35
|
+
from great_expectations.data_context.types.base import (
|
32
36
|
DataContextConfig,
|
33
37
|
)
|
34
|
-
from great_expectations.data_context.types.resource_identifiers import (
|
38
|
+
from great_expectations.data_context.types.resource_identifiers import (
|
35
39
|
ExpectationSuiteIdentifier,
|
36
40
|
)
|
37
41
|
from great_expectations.profile.user_configurable_profiler import ( # type: ignore[import-untyped]
|
@@ -65,8 +69,8 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
65
69
|
GreatExpectationsDataValidatorFlavor
|
66
70
|
)
|
67
71
|
|
68
|
-
_context:
|
69
|
-
_context_config: Optional[
|
72
|
+
_context: Optional[AbstractDataContext] = None
|
73
|
+
_context_config: Optional[DataContextConfig] = None
|
70
74
|
|
71
75
|
@property
|
72
76
|
def config(self) -> GreatExpectationsDataValidatorConfig:
|
@@ -78,7 +82,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
78
82
|
return cast(GreatExpectationsDataValidatorConfig, self._config)
|
79
83
|
|
80
84
|
@classmethod
|
81
|
-
def get_data_context(cls) ->
|
85
|
+
def get_data_context(cls) -> AbstractDataContext:
|
82
86
|
"""Get the Great Expectations data context managed by ZenML.
|
83
87
|
|
84
88
|
Call this method to retrieve the data context managed by ZenML
|
@@ -94,15 +98,11 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
94
98
|
return data_validator.data_context
|
95
99
|
|
96
100
|
@property
|
97
|
-
def context_config(self) -> Optional[
|
101
|
+
def context_config(self) -> Optional[DataContextConfig]:
|
98
102
|
"""Get the Great Expectations data context configuration.
|
99
103
|
|
100
|
-
The first time the context config is loaded from the stack component
|
101
|
-
config, it is converted from JSON/YAML string format to a dict.
|
102
|
-
|
103
104
|
Raises:
|
104
|
-
ValueError:
|
105
|
-
if the GE configuration extracted from it fails GE validation.
|
105
|
+
ValueError: In case there is an invalid context_config value
|
106
106
|
|
107
107
|
Returns:
|
108
108
|
A dictionary with the GE data context configuration.
|
@@ -111,31 +111,18 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
111
111
|
if self._context_config is not None:
|
112
112
|
return self._context_config
|
113
113
|
|
114
|
-
# Otherwise,
|
115
|
-
|
116
|
-
|
114
|
+
# Otherwise, use the configuration from the stack component config, if
|
115
|
+
# set
|
116
|
+
context_config_dict = self.config.context_config
|
117
|
+
if context_config_dict is None:
|
117
118
|
return None
|
118
|
-
if isinstance(context_config, dict):
|
119
|
-
self._context_config = context_config
|
120
|
-
return self._context_config
|
121
|
-
|
122
|
-
# If the context config is a string, try to parse it as JSON/YAML
|
123
|
-
try:
|
124
|
-
context_config_dict = yaml.safe_load(context_config)
|
125
|
-
except yaml.parser.ParserError as e:
|
126
|
-
raise ValueError(
|
127
|
-
f"Malformed `context_config` value. Only JSON and YAML "
|
128
|
-
f"formats are supported: {str(e)}"
|
129
|
-
)
|
130
119
|
|
131
120
|
# Validate that the context config is a valid GE config
|
132
121
|
try:
|
133
|
-
|
134
|
-
BaseDataContext(project_config=context_config)
|
122
|
+
self._context_config = DataContextConfig(**context_config_dict)
|
135
123
|
except Exception as e:
|
136
124
|
raise ValueError(f"Invalid `context_config` value: {str(e)}")
|
137
125
|
|
138
|
-
self._context_config = cast(Dict[str, Any], context_config_dict)
|
139
126
|
return self._context_config
|
140
127
|
|
141
128
|
@property
|
@@ -203,7 +190,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
203
190
|
}
|
204
191
|
|
205
192
|
@property
|
206
|
-
def data_context(self) ->
|
193
|
+
def data_context(self) -> AbstractDataContext:
|
207
194
|
"""Returns the Great Expectations data context configured for this component.
|
208
195
|
|
209
196
|
Returns:
|
@@ -216,7 +203,9 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
216
203
|
profiler_store_name = "zenml_profiler_store"
|
217
204
|
evaluation_parameter_store_name = "evaluation_parameter_store"
|
218
205
|
|
219
|
-
|
206
|
+
# Define default configuration options that plug the GX stores
|
207
|
+
# in the active ZenML artifact store
|
208
|
+
zenml_context_config: Dict[str, Any] = dict(
|
220
209
|
stores={
|
221
210
|
expectations_store_name: self.get_store_config(
|
222
211
|
"ExpectationsStore", "expectations"
|
@@ -250,18 +239,29 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
250
239
|
if self.config.context_root_dir:
|
251
240
|
# initialize the local data context, if a local path was
|
252
241
|
# configured
|
253
|
-
self._context =
|
242
|
+
self._context = get_context(
|
243
|
+
context_root_dir=self.config.context_root_dir
|
244
|
+
)
|
245
|
+
|
254
246
|
else:
|
255
|
-
# create an in-memory data context
|
256
|
-
# backed by a local YAML file (see https://docs.greatexpectations.io/docs/guides/setup/configuring_data_contexts/
|
247
|
+
# create an ephemeral in-memory data context that is not
|
248
|
+
# backed by a local YAML file (see https://docs.greatexpectations.io/docs/oss/guides/setup/configuring_data_contexts/instantiating_data_contexts/instantiate_data_context/).
|
257
249
|
if self.context_config:
|
258
|
-
|
250
|
+
# Use the data context configuration provided in the stack
|
251
|
+
# component configuration
|
252
|
+
context_config = self.context_config
|
259
253
|
else:
|
254
|
+
# Initialize the data context with the default ZenML
|
255
|
+
# configuration options effectively plugging the GX stores
|
256
|
+
# into the ZenML artifact store
|
260
257
|
context_config = DataContextConfig(**zenml_context_config)
|
261
258
|
# skip adding the stores after initialization, as they are
|
262
259
|
# already baked in the initial configuration
|
263
260
|
configure_zenml_stores = False
|
264
|
-
|
261
|
+
|
262
|
+
self._context = EphemeralDataContext(
|
263
|
+
project_config=context_config
|
264
|
+
)
|
265
265
|
|
266
266
|
if configure_zenml_stores:
|
267
267
|
self._context.config.expectations_store_name = (
|
@@ -277,14 +277,14 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
277
277
|
self._context.config.evaluation_parameter_store_name = (
|
278
278
|
evaluation_parameter_store_name
|
279
279
|
)
|
280
|
-
for store_name, store_config in zenml_context_config[
|
280
|
+
for store_name, store_config in zenml_context_config[
|
281
281
|
"stores"
|
282
282
|
].items():
|
283
283
|
self._context.add_store(
|
284
284
|
store_name=store_name,
|
285
285
|
store_config=store_config,
|
286
286
|
)
|
287
|
-
for site_name, site_config in zenml_context_config[
|
287
|
+
for site_name, site_config in zenml_context_config[
|
288
288
|
"data_docs_sites"
|
289
289
|
].items():
|
290
290
|
self._context.config.data_docs_sites[site_name] = (
|
@@ -509,7 +509,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
509
509
|
},
|
510
510
|
]
|
511
511
|
|
512
|
-
checkpoint_config = {
|
512
|
+
checkpoint_config: Dict[str, Any] = {
|
513
513
|
"name": checkpoint_name,
|
514
514
|
"run_name_template": run_name,
|
515
515
|
"config_version": 1,
|
@@ -517,7 +517,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
|
|
517
517
|
"expectation_suite_name": expectation_suite_name,
|
518
518
|
"action_list": action_list,
|
519
519
|
}
|
520
|
-
context.add_checkpoint(**checkpoint_config)
|
520
|
+
context.add_checkpoint(**checkpoint_config) # type: ignore[has-type]
|
521
521
|
|
522
522
|
try:
|
523
523
|
results = context.run_checkpoint(
|
@@ -16,7 +16,9 @@
|
|
16
16
|
import os
|
17
17
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
18
18
|
|
19
|
-
|
19
|
+
import yaml
|
20
|
+
from pydantic import field_validator, model_validator
|
21
|
+
from yaml.parser import ParserError
|
20
22
|
|
21
23
|
from zenml.data_validators.base_data_validator import (
|
22
24
|
BaseDataValidatorConfig,
|
@@ -26,6 +28,7 @@ from zenml.integrations.great_expectations import (
|
|
26
28
|
GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
|
27
29
|
)
|
28
30
|
from zenml.io import fileio
|
31
|
+
from zenml.utils.pydantic_utils import before_validator_handler
|
29
32
|
|
30
33
|
if TYPE_CHECKING:
|
31
34
|
from zenml.integrations.great_expectations.data_validators import (
|
@@ -41,6 +44,8 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
|
|
41
44
|
data context. If configured, the data validator will only be usable
|
42
45
|
with local orchestrators.
|
43
46
|
context_config: in-line Great Expectations data context configuration.
|
47
|
+
If the `context_root_dir` attribute is also set, this configuration
|
48
|
+
will be ignored.
|
44
49
|
configure_zenml_stores: if set, ZenML will automatically configure
|
45
50
|
stores that use the Artifact Store as a backend. If neither
|
46
51
|
`context_root_dir` nor `context_config` are set, this is the default
|
@@ -54,7 +59,8 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
|
|
54
59
|
configure_zenml_stores: bool = False
|
55
60
|
configure_local_docs: bool = True
|
56
61
|
|
57
|
-
@
|
62
|
+
@field_validator("context_root_dir")
|
63
|
+
@classmethod
|
58
64
|
def _ensure_valid_context_root_dir(
|
59
65
|
cls, context_root_dir: Optional[str] = None
|
60
66
|
) -> Optional[str]:
|
@@ -78,6 +84,33 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
|
|
78
84
|
)
|
79
85
|
return context_root_dir
|
80
86
|
|
87
|
+
@model_validator(mode="before")
|
88
|
+
@classmethod
|
89
|
+
@before_validator_handler
|
90
|
+
def validate_context_config(cls, data: Dict[str, Any]) -> Dict[str, Any]:
|
91
|
+
"""Convert the context configuration if given in JSON/YAML format.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
data: The configuration values.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The validated configuration values.
|
98
|
+
|
99
|
+
Raises:
|
100
|
+
ValueError: If the context configuration is not a valid
|
101
|
+
JSON/YAML object.
|
102
|
+
"""
|
103
|
+
if isinstance(data.get("context_config"), str):
|
104
|
+
try:
|
105
|
+
data["context_config"] = yaml.safe_load(data["context_config"])
|
106
|
+
except ParserError as e:
|
107
|
+
raise ValueError(
|
108
|
+
f"Malformed `context_config` value. Only JSON and YAML "
|
109
|
+
f"formats are supported: {str(e)}"
|
110
|
+
)
|
111
|
+
|
112
|
+
return data
|
113
|
+
|
81
114
|
@property
|
82
115
|
def is_local(self) -> bool:
|
83
116
|
"""Checks if this stack component is running locally.
|
@@ -17,14 +17,16 @@ import os
|
|
17
17
|
from pathlib import Path
|
18
18
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
19
19
|
|
20
|
-
from great_expectations.data_context.store.tuple_store_backend import (
|
20
|
+
from great_expectations.data_context.store.tuple_store_backend import (
|
21
21
|
TupleStoreBackend,
|
22
|
-
filter_properties_dict,
|
23
22
|
)
|
24
23
|
from great_expectations.exceptions import ( # type: ignore[import-untyped]
|
25
24
|
InvalidKeyError,
|
26
25
|
StoreBackendError,
|
27
26
|
)
|
27
|
+
from great_expectations.util import ( # type: ignore[import-untyped]
|
28
|
+
filter_properties_dict,
|
29
|
+
)
|
28
30
|
|
29
31
|
from zenml.client import Client
|
30
32
|
from zenml.io import fileio
|
@@ -34,7 +36,7 @@ from zenml.utils import io_utils
|
|
34
36
|
logger = get_logger(__name__)
|
35
37
|
|
36
38
|
|
37
|
-
class ZenMLArtifactStoreBackend(TupleStoreBackend):
|
39
|
+
class ZenMLArtifactStoreBackend(TupleStoreBackend):
|
38
40
|
"""Great Expectations store backend that uses the active ZenML Artifact Store as a store."""
|
39
41
|
|
40
42
|
def __init__(
|
@@ -105,7 +107,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
|
|
105
107
|
if not isinstance(key, tuple):
|
106
108
|
key = key.to_tuple()
|
107
109
|
if not is_prefix:
|
108
|
-
object_relative_path = self._convert_key_to_filepath(key)
|
110
|
+
object_relative_path = self._convert_key_to_filepath(key) # type: ignore[no-untyped-call]
|
109
111
|
elif key:
|
110
112
|
object_relative_path = os.path.join(*key)
|
111
113
|
else:
|
@@ -116,7 +118,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
|
|
116
118
|
object_key = object_relative_path
|
117
119
|
return os.path.join(self.root_path, object_key)
|
118
120
|
|
119
|
-
def _get(self, key: Tuple[str, ...]) -> str:
|
121
|
+
def _get(self, key: Tuple[str, ...]) -> str: # type: ignore[override]
|
120
122
|
"""Get the value of an object from the store.
|
121
123
|
|
122
124
|
Args:
|
@@ -140,7 +142,18 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
|
|
140
142
|
)
|
141
143
|
return contents
|
142
144
|
|
143
|
-
def
|
145
|
+
def _get_all(self) -> List[Any]:
|
146
|
+
"""Get all objects in the store.
|
147
|
+
|
148
|
+
Raises:
|
149
|
+
NotImplementedError: if the method is not implemented for this store
|
150
|
+
backend.
|
151
|
+
"""
|
152
|
+
raise NotImplementedError(
|
153
|
+
"Method `_get_all` is not implemented for this store backend."
|
154
|
+
)
|
155
|
+
|
156
|
+
def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str: # type: ignore[override]
|
144
157
|
"""Set the value of an object in the store.
|
145
158
|
|
146
159
|
Args:
|
@@ -212,12 +225,12 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
|
|
212
225
|
self.filepath_suffix
|
213
226
|
):
|
214
227
|
continue
|
215
|
-
key = self._convert_filepath_to_key(filepath)
|
216
|
-
if key and not self.is_ignored_key(key):
|
228
|
+
key = self._convert_filepath_to_key(filepath) # type: ignore[no-untyped-call]
|
229
|
+
if key and not self.is_ignored_key(key): # type: ignore[no-untyped-call]
|
217
230
|
key_list.append(key)
|
218
231
|
return key_list
|
219
232
|
|
220
|
-
def remove_key(self, key: Tuple[str, ...]) -> bool:
|
233
|
+
def remove_key(self, key: Tuple[str, ...]) -> bool: # type: ignore[override]
|
221
234
|
"""Delete an object from the store.
|
222
235
|
|
223
236
|
Args:
|
@@ -250,7 +263,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
|
|
250
263
|
result = fileio.exists(filepath)
|
251
264
|
return result
|
252
265
|
|
253
|
-
def get_url_for_key(
|
266
|
+
def get_url_for_key( # type: ignore[override]
|
254
267
|
self, key: Tuple[str, ...], protocol: Optional[str] = None
|
255
268
|
) -> str:
|
256
269
|
"""Get the URL of an object in the store.
|
@@ -292,7 +305,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
|
|
292
305
|
f"requested but `base_public_path` was not configured for the "
|
293
306
|
f"{self.__class__.__name__}"
|
294
307
|
)
|
295
|
-
filepath = self._convert_key_to_filepath(key)
|
308
|
+
filepath = self._convert_key_to_filepath(key) # type: ignore[no-untyped-call]
|
296
309
|
public_url = self.base_public_path + filepath.replace(self.proto, "")
|
297
310
|
return cast(str, public_url)
|
298
311
|
|
@@ -25,10 +25,10 @@ from great_expectations.core import ( # type: ignore[import-untyped]
|
|
25
25
|
from great_expectations.core.expectation_validation_result import ( # type: ignore[import-untyped]
|
26
26
|
ExpectationSuiteValidationResult,
|
27
27
|
)
|
28
|
-
from great_expectations.data_context.types.base import (
|
28
|
+
from great_expectations.data_context.types.base import (
|
29
29
|
CheckpointConfig,
|
30
30
|
)
|
31
|
-
from great_expectations.data_context.types.resource_identifiers import (
|
31
|
+
from great_expectations.data_context.types.resource_identifiers import (
|
32
32
|
ExpectationSuiteIdentifier,
|
33
33
|
ValidationResultIdentifier,
|
34
34
|
)
|
@@ -86,7 +86,7 @@ class GreatExpectationsMaterializer(BaseMaterializer):
|
|
86
86
|
validation_dict = {}
|
87
87
|
for result_ident, results in artifact_dict["run_results"].items():
|
88
88
|
validation_ident = (
|
89
|
-
ValidationResultIdentifier.from_fixed_length_tuple(
|
89
|
+
ValidationResultIdentifier.from_fixed_length_tuple( # type: ignore[no-untyped-call]
|
90
90
|
result_ident.split("::")[1].split("/")
|
91
91
|
)
|
92
92
|
)
|
@@ -13,14 +13,14 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Great Expectations data profiling standard step."""
|
15
15
|
|
16
|
-
from typing import Optional
|
16
|
+
from typing import Any, Dict, Optional
|
17
17
|
|
18
18
|
import pandas as pd
|
19
19
|
from great_expectations.core.batch import ( # type: ignore[import-untyped]
|
20
20
|
RuntimeBatchRequest,
|
21
21
|
)
|
22
|
-
from great_expectations.data_context.data_context import (
|
23
|
-
|
22
|
+
from great_expectations.data_context.data_context.abstract_data_context import (
|
23
|
+
AbstractDataContext,
|
24
24
|
)
|
25
25
|
|
26
26
|
from zenml import get_step_context
|
@@ -31,7 +31,7 @@ logger = get_logger(__name__)
|
|
31
31
|
|
32
32
|
|
33
33
|
def create_batch_request(
|
34
|
-
context:
|
34
|
+
context: AbstractDataContext,
|
35
35
|
dataset: pd.DataFrame,
|
36
36
|
data_asset_name: Optional[str],
|
37
37
|
) -> RuntimeBatchRequest:
|
@@ -62,7 +62,7 @@ def create_batch_request(
|
|
62
62
|
data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}"
|
63
63
|
batch_identifier = "default"
|
64
64
|
|
65
|
-
datasource_config = {
|
65
|
+
datasource_config: Dict[str, Any] = {
|
66
66
|
"name": datasource_name,
|
67
67
|
"class_name": "Datasource",
|
68
68
|
"module_name": "great_expectations.datasource",
|
@@ -30,6 +30,9 @@ class HuggingfaceIntegration(Integration):
|
|
30
30
|
"transformers<=4.31",
|
31
31
|
"datasets",
|
32
32
|
"huggingface_hub>0.19.0",
|
33
|
+
"accelerate",
|
34
|
+
"bitsandbytes>=0.41.3",
|
35
|
+
"peft",
|
33
36
|
# temporary fix for CI issue similar to:
|
34
37
|
# - https://github.com/huggingface/datasets/issues/6737
|
35
38
|
# - https://github.com/huggingface/datasets/issues/6697
|
@@ -61,7 +61,7 @@ class HuggingFaceModelDeployerConfig(
|
|
61
61
|
namespace: Hugging Face namespace used to list endpoints
|
62
62
|
"""
|
63
63
|
|
64
|
-
token: Optional[str] = SecretField()
|
64
|
+
token: Optional[str] = SecretField(default=None)
|
65
65
|
|
66
66
|
# The namespace to list endpoints for. Set to `"*"` to list all endpoints
|
67
67
|
# from all namespaces (i.e. personal namespace and all orgs the user belongs to).
|
@@ -0,0 +1,149 @@
|
|
1
|
+
# Apache Software License 2.0
|
2
|
+
#
|
3
|
+
# Copyright (c) ZenML GmbH 2024. All rights reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
#
|
17
|
+
"""Step function to run any ZenML step using Accelerate."""
|
18
|
+
|
19
|
+
import functools
|
20
|
+
from typing import Any, Callable, Optional, TypeVar, cast
|
21
|
+
|
22
|
+
import cloudpickle as pickle
|
23
|
+
from accelerate.commands.launch import ( # type: ignore[import-untyped]
|
24
|
+
launch_command,
|
25
|
+
launch_command_parser,
|
26
|
+
)
|
27
|
+
|
28
|
+
from zenml.logger import get_logger
|
29
|
+
from zenml.steps import BaseStep
|
30
|
+
from zenml.utils.function_utils import _cli_arg_name, create_cli_wrapped_script
|
31
|
+
|
32
|
+
logger = get_logger(__name__)
|
33
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
34
|
+
|
35
|
+
|
36
|
+
def run_with_accelerate(
|
37
|
+
step_function: BaseStep,
|
38
|
+
num_processes: Optional[int] = None,
|
39
|
+
use_cpu: bool = False,
|
40
|
+
) -> BaseStep:
|
41
|
+
"""Run a function with accelerate.
|
42
|
+
|
43
|
+
Accelerate package: https://huggingface.co/docs/accelerate/en/index
|
44
|
+
Example:
|
45
|
+
```python
|
46
|
+
from zenml import step, pipeline
|
47
|
+
from zenml.integrations.hugginface.steps import run_with_accelerate
|
48
|
+
@step
|
49
|
+
def training_step(some_param: int, ...):
|
50
|
+
# your training code is below
|
51
|
+
...
|
52
|
+
|
53
|
+
@pipeline
|
54
|
+
def training_pipeline(some_param: int, ...):
|
55
|
+
run_with_accelerate(training_step, num_processes=4)(some_param, ...)
|
56
|
+
```
|
57
|
+
|
58
|
+
Args:
|
59
|
+
step_function: The step function to run.
|
60
|
+
num_processes: The number of processes to use.
|
61
|
+
use_cpu: Whether to use the CPU.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
The accelerate-enabled version of the step.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def _decorator(entrypoint: F) -> F:
|
68
|
+
@functools.wraps(entrypoint)
|
69
|
+
def inner(*args: Any, **kwargs: Any) -> Any:
|
70
|
+
if args:
|
71
|
+
raise ValueError(
|
72
|
+
"Accelerated steps do not support positional arguments."
|
73
|
+
)
|
74
|
+
|
75
|
+
if not use_cpu:
|
76
|
+
import torch
|
77
|
+
|
78
|
+
logger.info("Starting accelerate job...")
|
79
|
+
|
80
|
+
device_count = torch.cuda.device_count()
|
81
|
+
if num_processes is None:
|
82
|
+
_num_processes = device_count
|
83
|
+
else:
|
84
|
+
if num_processes > device_count:
|
85
|
+
logger.warning(
|
86
|
+
f"Number of processes ({num_processes}) is greater than "
|
87
|
+
f"the number of available GPUs ({device_count}). Using all GPUs."
|
88
|
+
)
|
89
|
+
_num_processes = device_count
|
90
|
+
else:
|
91
|
+
_num_processes = num_processes
|
92
|
+
else:
|
93
|
+
_num_processes = num_processes or 1
|
94
|
+
|
95
|
+
with create_cli_wrapped_script(
|
96
|
+
entrypoint, flavour="accelerate"
|
97
|
+
) as (
|
98
|
+
script_path,
|
99
|
+
output_path,
|
100
|
+
):
|
101
|
+
commands = ["--num_processes", str(_num_processes)]
|
102
|
+
if use_cpu:
|
103
|
+
commands += [
|
104
|
+
"--cpu",
|
105
|
+
"--num_cpu_threads_per_process",
|
106
|
+
"10",
|
107
|
+
]
|
108
|
+
commands.append(str(script_path.absolute()))
|
109
|
+
for k, v in kwargs.items():
|
110
|
+
k = _cli_arg_name(k)
|
111
|
+
if isinstance(v, bool):
|
112
|
+
if v:
|
113
|
+
commands.append(f"--{k}")
|
114
|
+
elif isinstance(v, str):
|
115
|
+
commands += [f"--{k}", '"{v}"']
|
116
|
+
elif type(v) in (list, tuple, set):
|
117
|
+
for each in v:
|
118
|
+
commands.append(f"--{k}")
|
119
|
+
if isinstance(each, str):
|
120
|
+
commands.append(f'"{each}"')
|
121
|
+
else:
|
122
|
+
commands.append(f"{each}")
|
123
|
+
else:
|
124
|
+
commands += [f"--{k}", f"{v}"]
|
125
|
+
|
126
|
+
logger.debug(commands)
|
127
|
+
|
128
|
+
parser = launch_command_parser()
|
129
|
+
args = parser.parse_args(commands)
|
130
|
+
try:
|
131
|
+
launch_command(args)
|
132
|
+
except Exception as e:
|
133
|
+
logger.error(
|
134
|
+
"Accelerate training job failed... See error message for details."
|
135
|
+
)
|
136
|
+
raise RuntimeError(
|
137
|
+
"Accelerate training job failed."
|
138
|
+
) from e
|
139
|
+
else:
|
140
|
+
logger.info(
|
141
|
+
"Accelerate training job finished successfully."
|
142
|
+
)
|
143
|
+
return pickle.load(open(output_path, "rb"))
|
144
|
+
|
145
|
+
return cast(F, inner)
|
146
|
+
|
147
|
+
setattr(step_function, "entrypoint", _decorator(step_function.entrypoint))
|
148
|
+
|
149
|
+
return step_function
|