zenml-nightly 0.58.2.dev20240618__py3-none-any.whl → 0.58.2.dev20240620__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/_hub/client.py +8 -5
- zenml/actions/base_action.py +8 -10
- zenml/artifact_stores/base_artifact_store.py +20 -15
- zenml/artifact_stores/local_artifact_store.py +3 -2
- zenml/artifacts/artifact_config.py +34 -19
- zenml/artifacts/external_artifact.py +18 -8
- zenml/artifacts/external_artifact_config.py +14 -6
- zenml/artifacts/unmaterialized_artifact.py +2 -11
- zenml/cli/__init__.py +6 -0
- zenml/cli/artifact.py +20 -2
- zenml/cli/served_model.py +0 -1
- zenml/cli/server.py +3 -3
- zenml/cli/utils.py +36 -40
- zenml/cli/web_login.py +2 -2
- zenml/client.py +198 -24
- zenml/client_lazy_loader.py +20 -14
- zenml/config/base_settings.py +5 -6
- zenml/config/build_configuration.py +1 -1
- zenml/config/compiler.py +3 -3
- zenml/config/docker_settings.py +27 -28
- zenml/config/global_config.py +33 -37
- zenml/config/pipeline_configurations.py +8 -11
- zenml/config/pipeline_run_configuration.py +6 -2
- zenml/config/pipeline_spec.py +3 -4
- zenml/config/resource_settings.py +8 -9
- zenml/config/schedule.py +16 -20
- zenml/config/secret_reference_mixin.py +6 -3
- zenml/config/secrets_store_config.py +16 -23
- zenml/config/server_config.py +50 -46
- zenml/config/settings_resolver.py +1 -1
- zenml/config/source.py +45 -35
- zenml/config/step_configurations.py +53 -31
- zenml/config/store_config.py +20 -19
- zenml/config/strict_base_model.py +2 -6
- zenml/constants.py +26 -2
- zenml/container_registries/base_container_registry.py +3 -2
- zenml/container_registries/default_container_registry.py +3 -3
- zenml/event_hub/base_event_hub.py +1 -1
- zenml/event_sources/base_event_source.py +11 -16
- zenml/exceptions.py +4 -0
- zenml/integrations/airflow/__init__.py +2 -10
- zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +6 -7
- zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +13 -249
- zenml/integrations/airflow/orchestrators/dag_generator.py +5 -3
- zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +5 -4
- zenml/integrations/aws/__init__.py +1 -1
- zenml/integrations/aws/flavors/aws_container_registry_flavor.py +3 -2
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +11 -5
- zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +6 -2
- zenml/integrations/aws/service_connectors/aws_service_connector.py +5 -4
- zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +4 -4
- zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -3
- zenml/integrations/azure/step_operators/azureml_step_operator.py +1 -1
- zenml/integrations/bentoml/steps/bentoml_deployer.py +1 -1
- zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +8 -12
- zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +1 -1
- zenml/integrations/evidently/__init__.py +3 -4
- zenml/integrations/evidently/column_mapping.py +11 -3
- zenml/integrations/evidently/data_validators/evidently_data_validator.py +21 -3
- zenml/integrations/evidently/metrics.py +5 -6
- zenml/integrations/evidently/tests.py +5 -6
- zenml/integrations/facets/models.py +2 -6
- zenml/integrations/feast/__init__.py +3 -1
- zenml/integrations/feast/feature_stores/feast_feature_store.py +0 -23
- zenml/integrations/gcp/__init__.py +1 -1
- zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +1 -1
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +1 -1
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +234 -103
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +57 -42
- zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
- zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +9 -13
- zenml/integrations/great_expectations/__init__.py +1 -1
- zenml/integrations/great_expectations/data_validators/ge_data_validator.py +44 -44
- zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +35 -2
- zenml/integrations/great_expectations/ge_store_backend.py +24 -11
- zenml/integrations/great_expectations/materializers/ge_materializer.py +3 -3
- zenml/integrations/great_expectations/utils.py +5 -5
- zenml/integrations/huggingface/__init__.py +3 -0
- zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +1 -1
- zenml/integrations/huggingface/steps/__init__.py +3 -0
- zenml/integrations/huggingface/steps/accelerate_runner.py +149 -0
- zenml/integrations/huggingface/steps/huggingface_deployer.py +2 -2
- zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +1 -1
- zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +4 -3
- zenml/integrations/kubeflow/__init__.py +1 -1
- zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +48 -81
- zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +295 -245
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +1 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -2
- zenml/integrations/kubernetes/pod_settings.py +17 -31
- zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +8 -7
- zenml/integrations/label_studio/__init__.py +1 -3
- zenml/integrations/label_studio/annotators/label_studio_annotator.py +3 -4
- zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +2 -2
- zenml/integrations/langchain/materializers/document_materializer.py +44 -8
- zenml/integrations/mlflow/__init__.py +9 -3
- zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +1 -1
- zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +29 -37
- zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +4 -4
- zenml/integrations/mlflow/steps/mlflow_deployer.py +1 -1
- zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +1 -1
- zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +1 -1
- zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +9 -8
- zenml/integrations/seldon/seldon_client.py +52 -67
- zenml/integrations/seldon/services/seldon_deployment.py +3 -3
- zenml/integrations/seldon/steps/seldon_deployer.py +4 -4
- zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +15 -5
- zenml/integrations/skypilot_aws/__init__.py +1 -1
- zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +1 -1
- zenml/integrations/skypilot_azure/__init__.py +1 -1
- zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +1 -1
- zenml/integrations/skypilot_gcp/__init__.py +2 -1
- zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +1 -1
- zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +2 -2
- zenml/integrations/spark/flavors/spark_step_operator_flavor.py +1 -1
- zenml/integrations/tekton/__init__.py +1 -1
- zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +66 -23
- zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +547 -233
- zenml/integrations/tensorboard/__init__.py +1 -12
- zenml/integrations/tensorboard/services/tensorboard_service.py +3 -5
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +6 -6
- zenml/integrations/tensorflow/__init__.py +2 -10
- zenml/integrations/tensorflow/materializers/keras_materializer.py +17 -9
- zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +9 -14
- zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +1 -1
- zenml/lineage_graph/lineage_graph.py +1 -1
- zenml/materializers/built_in_materializer.py +3 -3
- zenml/materializers/pydantic_materializer.py +2 -2
- zenml/metadata/lazy_load.py +4 -4
- zenml/metadata/metadata_types.py +64 -4
- zenml/model/model.py +79 -54
- zenml/model_deployers/base_model_deployer.py +14 -12
- zenml/model_registries/base_model_registry.py +17 -15
- zenml/models/__init__.py +79 -206
- zenml/models/v2/base/base.py +54 -41
- zenml/models/v2/base/base_plugin_flavor.py +2 -6
- zenml/models/v2/base/filter.py +91 -76
- zenml/models/v2/base/page.py +2 -12
- zenml/models/v2/base/scoped.py +4 -7
- zenml/models/v2/core/api_key.py +22 -8
- zenml/models/v2/core/artifact.py +2 -2
- zenml/models/v2/core/artifact_version.py +74 -40
- zenml/models/v2/core/code_repository.py +37 -10
- zenml/models/v2/core/component.py +65 -16
- zenml/models/v2/core/device.py +14 -4
- zenml/models/v2/core/event_source.py +1 -2
- zenml/models/v2/core/flavor.py +74 -8
- zenml/models/v2/core/logs.py +68 -8
- zenml/models/v2/core/model.py +8 -4
- zenml/models/v2/core/model_version.py +25 -6
- zenml/models/v2/core/model_version_artifact.py +51 -21
- zenml/models/v2/core/model_version_pipeline_run.py +45 -13
- zenml/models/v2/core/pipeline.py +37 -72
- zenml/models/v2/core/pipeline_build.py +29 -17
- zenml/models/v2/core/pipeline_deployment.py +18 -6
- zenml/models/v2/core/pipeline_namespace.py +113 -0
- zenml/models/v2/core/pipeline_run.py +50 -22
- zenml/models/v2/core/run_metadata.py +59 -36
- zenml/models/v2/core/schedule.py +37 -24
- zenml/models/v2/core/secret.py +31 -12
- zenml/models/v2/core/service.py +64 -36
- zenml/models/v2/core/service_account.py +24 -11
- zenml/models/v2/core/service_connector.py +219 -44
- zenml/models/v2/core/stack.py +45 -17
- zenml/models/v2/core/step_run.py +28 -8
- zenml/models/v2/core/tag.py +8 -4
- zenml/models/v2/core/trigger.py +2 -2
- zenml/models/v2/core/trigger_execution.py +1 -0
- zenml/models/v2/core/user.py +18 -21
- zenml/models/v2/core/workspace.py +13 -3
- zenml/models/v2/misc/build_item.py +3 -3
- zenml/models/v2/misc/external_user.py +2 -6
- zenml/models/v2/misc/hub_plugin_models.py +9 -9
- zenml/models/v2/misc/loaded_visualization.py +2 -2
- zenml/models/v2/misc/service_connector_type.py +8 -17
- zenml/models/v2/misc/user_auth.py +7 -2
- zenml/new/pipelines/build_utils.py +3 -3
- zenml/new/pipelines/pipeline.py +17 -13
- zenml/new/pipelines/run_utils.py +103 -1
- zenml/orchestrators/base_orchestrator.py +10 -7
- zenml/orchestrators/local_docker/local_docker_orchestrator.py +1 -1
- zenml/orchestrators/step_runner.py +3 -6
- zenml/orchestrators/utils.py +1 -1
- zenml/plugins/base_plugin_flavor.py +6 -10
- zenml/plugins/plugin_flavor_registry.py +3 -7
- zenml/secret/base_secret.py +7 -8
- zenml/service_connectors/docker_service_connector.py +4 -3
- zenml/service_connectors/service_connector.py +5 -12
- zenml/service_connectors/service_connector_registry.py +2 -4
- zenml/services/container/container_service.py +1 -1
- zenml/services/container/container_service_endpoint.py +1 -1
- zenml/services/local/local_service.py +1 -1
- zenml/services/local/local_service_endpoint.py +1 -1
- zenml/services/service.py +16 -10
- zenml/services/service_type.py +4 -5
- zenml/services/terraform/terraform_service.py +1 -1
- zenml/stack/flavor.py +1 -5
- zenml/stack/flavor_registry.py +4 -4
- zenml/stack/stack.py +4 -1
- zenml/stack/stack_component.py +55 -31
- zenml/steps/base_step.py +34 -28
- zenml/steps/entrypoint_function_utils.py +3 -5
- zenml/steps/utils.py +12 -14
- zenml/utils/cuda_utils.py +50 -0
- zenml/utils/deprecation_utils.py +18 -20
- zenml/utils/dict_utils.py +1 -1
- zenml/utils/filesync_model.py +65 -28
- zenml/utils/function_utils.py +260 -0
- zenml/utils/json_utils.py +131 -0
- zenml/utils/mlstacks_utils.py +2 -2
- zenml/utils/pydantic_utils.py +270 -62
- zenml/utils/secret_utils.py +65 -12
- zenml/utils/source_utils.py +2 -2
- zenml/utils/typed_model.py +5 -3
- zenml/utils/typing_utils.py +243 -0
- zenml/utils/yaml_utils.py +1 -1
- zenml/zen_server/auth.py +2 -2
- zenml/zen_server/cloud_utils.py +6 -6
- zenml/zen_server/deploy/base_provider.py +1 -1
- zenml/zen_server/deploy/deployment.py +6 -8
- zenml/zen_server/deploy/docker/docker_zen_server.py +3 -4
- zenml/zen_server/deploy/local/local_provider.py +0 -1
- zenml/zen_server/deploy/local/local_zen_server.py +6 -6
- zenml/zen_server/deploy/terraform/terraform_zen_server.py +4 -6
- zenml/zen_server/exceptions.py +4 -1
- zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
- zenml/zen_server/pipeline_deployment/utils.py +48 -68
- zenml/zen_server/rbac/models.py +2 -5
- zenml/zen_server/rbac/utils.py +11 -14
- zenml/zen_server/routers/auth_endpoints.py +2 -2
- zenml/zen_server/routers/pipeline_builds_endpoints.py +1 -1
- zenml/zen_server/routers/runs_endpoints.py +1 -1
- zenml/zen_server/routers/secrets_endpoints.py +3 -2
- zenml/zen_server/routers/server_endpoints.py +1 -1
- zenml/zen_server/routers/steps_endpoints.py +1 -1
- zenml/zen_server/routers/workspaces_endpoints.py +1 -1
- zenml/zen_stores/base_zen_store.py +46 -9
- zenml/zen_stores/migrations/utils.py +42 -46
- zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py +1 -1
- zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +5 -3
- zenml/zen_stores/migrations/versions/10a907dad202_delete_mlmd_tables.py +1 -1
- zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +8 -10
- zenml/zen_stores/migrations/versions/37835ce041d2_optimizing_database.py +3 -3
- zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +10 -12
- zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +3 -2
- zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py +4 -4
- zenml/zen_stores/migrations/versions/728c6369cfaa_add_name_column_to_input_artifact_pk.py +3 -2
- zenml/zen_stores/migrations/versions/743ec82b1b3c_update_size_of_build_images.py +2 -2
- zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
- zenml/zen_stores/migrations/versions/7834208cc3f6_artifact_project_scoping.py +8 -7
- zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +6 -4
- zenml/zen_stores/migrations/versions/7e4a481d17f7_add_identity_table.py +2 -2
- zenml/zen_stores/migrations/versions/7f603e583dd7_fixed_migration.py +1 -1
- zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +2 -2
- zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +4 -4
- zenml/zen_stores/migrations/versions/alembic_start.py +1 -1
- zenml/zen_stores/migrations/versions/fbd7f18ced1e_increase_step_run_field_lengths.py +4 -4
- zenml/zen_stores/rest_zen_store.py +109 -49
- zenml/zen_stores/schemas/api_key_schemas.py +1 -1
- zenml/zen_stores/schemas/artifact_schemas.py +8 -8
- zenml/zen_stores/schemas/artifact_visualization_schemas.py +3 -3
- zenml/zen_stores/schemas/code_repository_schemas.py +1 -1
- zenml/zen_stores/schemas/component_schemas.py +8 -3
- zenml/zen_stores/schemas/device_schemas.py +8 -6
- zenml/zen_stores/schemas/event_source_schemas.py +3 -4
- zenml/zen_stores/schemas/flavor_schemas.py +5 -3
- zenml/zen_stores/schemas/model_schemas.py +26 -1
- zenml/zen_stores/schemas/pipeline_build_schemas.py +1 -1
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +4 -4
- zenml/zen_stores/schemas/pipeline_run_schemas.py +6 -6
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -2
- zenml/zen_stores/schemas/run_metadata_schemas.py +2 -2
- zenml/zen_stores/schemas/secret_schemas.py +8 -5
- zenml/zen_stores/schemas/server_settings_schemas.py +3 -1
- zenml/zen_stores/schemas/service_connector_schemas.py +1 -1
- zenml/zen_stores/schemas/service_schemas.py +11 -2
- zenml/zen_stores/schemas/stack_schemas.py +1 -1
- zenml/zen_stores/schemas/step_run_schemas.py +11 -11
- zenml/zen_stores/schemas/tag_schemas.py +6 -2
- zenml/zen_stores/schemas/trigger_schemas.py +2 -2
- zenml/zen_stores/schemas/user_schemas.py +2 -2
- zenml/zen_stores/schemas/workspace_schemas.py +3 -1
- zenml/zen_stores/secrets_stores/aws_secrets_store.py +19 -20
- zenml/zen_stores/secrets_stores/azure_secrets_store.py +17 -20
- zenml/zen_stores/secrets_stores/base_secrets_store.py +79 -12
- zenml/zen_stores/secrets_stores/gcp_secrets_store.py +17 -20
- zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +4 -8
- zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +10 -7
- zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -6
- zenml/zen_stores/sql_zen_store.py +196 -120
- zenml/zen_stores/zen_store_interface.py +33 -0
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/METADATA +8 -7
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/RECORD +297 -294
- zenml/integrations/kubeflow/utils.py +0 -95
- zenml/models/v2/base/internal.py +0 -37
- zenml/models/v2/base/update.py +0 -44
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Utility functions for python functions."""
|
15
|
+
|
16
|
+
import inspect
|
17
|
+
import os
|
18
|
+
from contextlib import contextmanager
|
19
|
+
from pathlib import Path
|
20
|
+
from typing import Any, Callable, Iterator, List, Tuple, TypeVar, Union
|
21
|
+
|
22
|
+
import click
|
23
|
+
|
24
|
+
from zenml.logger import get_logger
|
25
|
+
from zenml.utils.string_utils import random_str
|
26
|
+
|
27
|
+
F = TypeVar("F", bound=Callable[..., None])
|
28
|
+
|
29
|
+
logger = get_logger(__name__)
|
30
|
+
|
31
|
+
_CLI_WRAPPED_SCRIPT_TEMPLATE_HEADER = """
|
32
|
+
from zenml.utils.function_utils import _cli_wrapped_function
|
33
|
+
|
34
|
+
import sys
|
35
|
+
sys.path.append(r"{func_path}")
|
36
|
+
|
37
|
+
from {func_module} import {func_name} as func_to_wrap
|
38
|
+
|
39
|
+
if entrypoint:=getattr(func_to_wrap, "entrypoint", None):
|
40
|
+
func = _cli_wrapped_function(entrypoint)
|
41
|
+
else:
|
42
|
+
func = _cli_wrapped_function(func_to_wrap)
|
43
|
+
"""
|
44
|
+
_CLI_WRAPPED_MAINS = {
|
45
|
+
"accelerate": """
|
46
|
+
if __name__=="__main__":
|
47
|
+
from accelerate import Accelerator
|
48
|
+
import cloudpickle as pickle
|
49
|
+
accelerator = Accelerator()
|
50
|
+
ret = func(standalone_mode=False)
|
51
|
+
if accelerator.is_main_process:
|
52
|
+
pickle.dump(ret, open(r"{output_file}", "wb"))
|
53
|
+
"""
|
54
|
+
}
|
55
|
+
_ALLOWED_TYPES = (str, int, float, bool, Path)
|
56
|
+
_ALLOWED_COLLECTIONS = (tuple,)
|
57
|
+
_CLICK_TYPES_MAPPER = {
|
58
|
+
str: click.STRING,
|
59
|
+
int: click.INT,
|
60
|
+
float: click.FLOAT,
|
61
|
+
bool: click.BOOL,
|
62
|
+
Path: click.STRING,
|
63
|
+
None: click.STRING,
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
def _cli_arg_name(arg_name: str) -> str:
|
68
|
+
return arg_name.replace("_", "-")
|
69
|
+
|
70
|
+
|
71
|
+
def _is_valid_collection_arg(arg_type: Any) -> bool:
|
72
|
+
"""Check if the given argument type is a valid collection type.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
arg_type: The type to check.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
True if the argument type is a valid collection type, False otherwise.
|
79
|
+
"""
|
80
|
+
if getattr(arg_type, "__origin__", None) in _ALLOWED_COLLECTIONS:
|
81
|
+
if arg_type.__args__[0] not in _ALLOWED_TYPES:
|
82
|
+
return False
|
83
|
+
return True
|
84
|
+
return False
|
85
|
+
|
86
|
+
|
87
|
+
def _is_valid_optional_arg(arg_type: Any) -> bool:
|
88
|
+
"""Check if the given argument type is a valid Optional type.
|
89
|
+
|
90
|
+
A valid Optional type is defined as a Union with two arguments, where:
|
91
|
+
- The first argument is either an allowed type or a valid collection type.
|
92
|
+
- The second argument is the NoneType.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
arg_type: The type to check.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
True if the argument type is a valid Optional type, False otherwise.
|
99
|
+
"""
|
100
|
+
if (
|
101
|
+
getattr(arg_type, "_name", None) == "Optional"
|
102
|
+
and getattr(arg_type, "__origin__", None) == Union
|
103
|
+
):
|
104
|
+
if args := getattr(arg_type, "__args__", None):
|
105
|
+
if len(args) != 2:
|
106
|
+
return False
|
107
|
+
if (
|
108
|
+
args[0] not in _ALLOWED_TYPES
|
109
|
+
and not _is_valid_collection_arg(args[0])
|
110
|
+
) or args[1] != type(None):
|
111
|
+
return False
|
112
|
+
return True
|
113
|
+
return False
|
114
|
+
|
115
|
+
|
116
|
+
def _cli_wrapped_function(func: F) -> F:
|
117
|
+
"""Create a decorator to generate the CLI-wrapped function.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
func: The function to decorate.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
The inner decorator.
|
124
|
+
|
125
|
+
Raises:
|
126
|
+
ValueError: If the function arguments are not valid.
|
127
|
+
"""
|
128
|
+
options: List[Any] = []
|
129
|
+
fullargspec = inspect.getfullargspec(func)
|
130
|
+
if fullargspec.defaults is not None:
|
131
|
+
defaults = [None] * (
|
132
|
+
len(fullargspec.args) - len(fullargspec.defaults)
|
133
|
+
) + list(fullargspec.defaults)
|
134
|
+
else:
|
135
|
+
defaults = [None] * len(fullargspec.args)
|
136
|
+
input_args_dict = (
|
137
|
+
(
|
138
|
+
arg_name,
|
139
|
+
fullargspec.annotations.get(arg_name, None),
|
140
|
+
defaults[i],
|
141
|
+
)
|
142
|
+
for i, arg_name in enumerate(fullargspec.args)
|
143
|
+
)
|
144
|
+
invalid_types = {}
|
145
|
+
for arg_name, arg_type, arg_default in input_args_dict:
|
146
|
+
if _is_valid_optional_arg(arg_type):
|
147
|
+
arg_type = arg_type.__args__[0]
|
148
|
+
arg_name = _cli_arg_name(arg_name)
|
149
|
+
if arg_type == bool:
|
150
|
+
options.append(
|
151
|
+
click.option(
|
152
|
+
f"--{arg_name}",
|
153
|
+
type=click.BOOL,
|
154
|
+
is_flag=True,
|
155
|
+
default=False,
|
156
|
+
required=False,
|
157
|
+
)
|
158
|
+
)
|
159
|
+
elif _is_valid_collection_arg(arg_type):
|
160
|
+
member_type = arg_type.__args__[0]
|
161
|
+
options.append(
|
162
|
+
click.option(
|
163
|
+
f"--{arg_name}",
|
164
|
+
type=member_type,
|
165
|
+
default=arg_default,
|
166
|
+
required=False,
|
167
|
+
multiple=True,
|
168
|
+
)
|
169
|
+
)
|
170
|
+
elif arg_type in _ALLOWED_TYPES:
|
171
|
+
options.append(
|
172
|
+
click.option(
|
173
|
+
f"--{arg_name}",
|
174
|
+
type=_CLICK_TYPES_MAPPER[arg_type],
|
175
|
+
default=arg_default,
|
176
|
+
required=False if arg_default is not None else True,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
invalid_types[arg_name] = arg_type
|
181
|
+
if invalid_types:
|
182
|
+
raise ValueError(
|
183
|
+
f"Invalid argument types: {invalid_types}. CLI functions only "
|
184
|
+
f"supports: {_ALLOWED_TYPES} types (including Optional) and "
|
185
|
+
f"{_ALLOWED_COLLECTIONS} collections."
|
186
|
+
)
|
187
|
+
options.append(
|
188
|
+
click.command(
|
189
|
+
help="Technical wrapper to pass into the `accelerate launch` command."
|
190
|
+
)
|
191
|
+
)
|
192
|
+
|
193
|
+
def wrapper(function: F) -> F:
|
194
|
+
for option in reversed(options):
|
195
|
+
function = option(function)
|
196
|
+
return function
|
197
|
+
|
198
|
+
func.__doc__ = (
|
199
|
+
f"{func.__doc__}\n\nThis is ZenML-generated " "CLI wrapper function."
|
200
|
+
)
|
201
|
+
|
202
|
+
return wrapper(func)
|
203
|
+
|
204
|
+
|
205
|
+
@contextmanager
|
206
|
+
def create_cli_wrapped_script(
|
207
|
+
func: F, flavour: str = "accelerate"
|
208
|
+
) -> Iterator[Tuple[Path, Path]]:
|
209
|
+
"""Create a script with the CLI-wrapped function.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
func: The function to use.
|
213
|
+
flavour: The flavour to use.
|
214
|
+
|
215
|
+
Yields:
|
216
|
+
The paths of the script and the output.
|
217
|
+
|
218
|
+
Raises:
|
219
|
+
ValueError: If the function is not defined in a module.
|
220
|
+
"""
|
221
|
+
try:
|
222
|
+
random_name = random_str(20)
|
223
|
+
script_path = Path(random_name + ".py")
|
224
|
+
output_path = Path(random_name + ".out")
|
225
|
+
|
226
|
+
module = inspect.getmodule(func)
|
227
|
+
if module is None:
|
228
|
+
raise ValueError(
|
229
|
+
f"Function `{func.__name__}` must be defined in a "
|
230
|
+
"module to be used with Accelerate."
|
231
|
+
)
|
232
|
+
|
233
|
+
with open(script_path, "w") as f:
|
234
|
+
if path := module.__file__:
|
235
|
+
func_path = str(Path(path).parent)
|
236
|
+
relative_path = path.replace(func_path, "").lstrip(os.sep)
|
237
|
+
relative_path = os.path.splitext(relative_path)[0]
|
238
|
+
clean_module_name = ".".join(relative_path.split(os.sep))
|
239
|
+
script = _CLI_WRAPPED_SCRIPT_TEMPLATE_HEADER.format(
|
240
|
+
func_path=func_path,
|
241
|
+
func_module=clean_module_name,
|
242
|
+
func_name=func.__name__,
|
243
|
+
)
|
244
|
+
script += _CLI_WRAPPED_MAINS[flavour].format(
|
245
|
+
output_file=str(output_path.absolute())
|
246
|
+
)
|
247
|
+
f.write(script)
|
248
|
+
else:
|
249
|
+
raise ValueError(
|
250
|
+
f"Cannot find module file path for function `{func.__name__}`."
|
251
|
+
)
|
252
|
+
|
253
|
+
logger.debug(f"Created script:\n\n{script}")
|
254
|
+
|
255
|
+
yield script_path, output_path
|
256
|
+
finally:
|
257
|
+
if script_path.exists():
|
258
|
+
script_path.unlink()
|
259
|
+
if output_path.exists():
|
260
|
+
output_path.unlink()
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Carried over version of some functions from the pydantic v1 json module.
|
15
|
+
|
16
|
+
Check out the latest version here:
|
17
|
+
https://github.com/pydantic/pydantic/blob/v1.10.15/pydantic/json.py
|
18
|
+
"""
|
19
|
+
|
20
|
+
import datetime
|
21
|
+
from collections import deque
|
22
|
+
from decimal import Decimal
|
23
|
+
from enum import Enum
|
24
|
+
from ipaddress import (
|
25
|
+
IPv4Address,
|
26
|
+
IPv4Interface,
|
27
|
+
IPv4Network,
|
28
|
+
IPv6Address,
|
29
|
+
IPv6Interface,
|
30
|
+
IPv6Network,
|
31
|
+
)
|
32
|
+
from pathlib import Path
|
33
|
+
from re import Pattern
|
34
|
+
from types import GeneratorType
|
35
|
+
from typing import Any, Callable, Dict, Type, Union
|
36
|
+
from uuid import UUID
|
37
|
+
|
38
|
+
from pydantic import NameEmail, SecretBytes, SecretStr
|
39
|
+
from pydantic.color import Color
|
40
|
+
|
41
|
+
__all__ = "pydantic_encoder"
|
42
|
+
|
43
|
+
|
44
|
+
def isoformat(obj: Union[datetime.date, datetime.time]) -> str:
|
45
|
+
"""Function to convert a datetime into iso format.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
obj: input datetime
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
the corresponding time in iso format.
|
52
|
+
"""
|
53
|
+
return obj.isoformat()
|
54
|
+
|
55
|
+
|
56
|
+
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
57
|
+
"""Encodes a Decimal as int of there's no exponent, otherwise float.
|
58
|
+
|
59
|
+
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
60
|
+
where an integer (but not int typed) is used. Encoding this as a float
|
61
|
+
results in failed round-tripping between encode and parse.
|
62
|
+
Our ID type is a prime example of this.
|
63
|
+
|
64
|
+
>>> decimal_encoder(Decimal("1.0"))
|
65
|
+
1.0
|
66
|
+
|
67
|
+
>>> decimal_encoder(Decimal("1"))
|
68
|
+
1
|
69
|
+
|
70
|
+
Args:
|
71
|
+
dec_value: The input Decimal value
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
the encoded result
|
75
|
+
"""
|
76
|
+
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
|
77
|
+
return int(dec_value)
|
78
|
+
else:
|
79
|
+
return float(dec_value)
|
80
|
+
|
81
|
+
|
82
|
+
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
83
|
+
bytes: lambda obj: obj.decode(),
|
84
|
+
Color: str,
|
85
|
+
datetime.date: isoformat,
|
86
|
+
datetime.datetime: isoformat,
|
87
|
+
datetime.time: isoformat,
|
88
|
+
datetime.timedelta: lambda td: td.total_seconds(),
|
89
|
+
Decimal: decimal_encoder,
|
90
|
+
Enum: lambda obj: obj.value,
|
91
|
+
frozenset: list,
|
92
|
+
deque: list,
|
93
|
+
GeneratorType: list,
|
94
|
+
IPv4Address: str,
|
95
|
+
IPv4Interface: str,
|
96
|
+
IPv4Network: str,
|
97
|
+
IPv6Address: str,
|
98
|
+
IPv6Interface: str,
|
99
|
+
IPv6Network: str,
|
100
|
+
NameEmail: str,
|
101
|
+
Path: str,
|
102
|
+
Pattern: lambda obj: obj.pattern,
|
103
|
+
SecretBytes: str,
|
104
|
+
SecretStr: str,
|
105
|
+
set: list,
|
106
|
+
UUID: str,
|
107
|
+
}
|
108
|
+
|
109
|
+
|
110
|
+
def pydantic_encoder(obj: Any) -> Any:
|
111
|
+
from dataclasses import asdict, is_dataclass
|
112
|
+
|
113
|
+
from pydantic import BaseModel
|
114
|
+
|
115
|
+
if isinstance(obj, BaseModel):
|
116
|
+
return obj.model_dump()
|
117
|
+
elif is_dataclass(obj):
|
118
|
+
return asdict(obj)
|
119
|
+
|
120
|
+
# Check the class type and its superclasses for a matching encoder
|
121
|
+
for base in obj.__class__.__mro__[:-1]:
|
122
|
+
try:
|
123
|
+
encoder = ENCODERS_BY_TYPE[base]
|
124
|
+
except KeyError:
|
125
|
+
continue
|
126
|
+
return encoder(obj)
|
127
|
+
else: # We have exited the for loop without finding a suitable encoder
|
128
|
+
raise TypeError(
|
129
|
+
f"Object of type '{obj.__class__.__name__}' is not JSON "
|
130
|
+
f"serializable."
|
131
|
+
)
|
zenml/utils/mlstacks_utils.py
CHANGED
@@ -422,9 +422,9 @@ def convert_mlstacks_primitives_to_dicts(
|
|
422
422
|
verify_mlstacks_prerequisites_installation()
|
423
423
|
|
424
424
|
# convert to json first to strip out Enums objects
|
425
|
-
stack_dict = json.loads(stack.
|
425
|
+
stack_dict = json.loads(stack.model_dump_json())
|
426
426
|
components_dicts = [
|
427
|
-
json.loads(component.
|
427
|
+
json.loads(component.model_dump_json()) for component in components
|
428
428
|
]
|
429
429
|
|
430
430
|
return stack_dict, components_dicts
|