zenml-nightly 0.58.2.dev20240618__py3-none-any.whl → 0.58.2.dev20240620__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. zenml/VERSION +1 -1
  2. zenml/_hub/client.py +8 -5
  3. zenml/actions/base_action.py +8 -10
  4. zenml/artifact_stores/base_artifact_store.py +20 -15
  5. zenml/artifact_stores/local_artifact_store.py +3 -2
  6. zenml/artifacts/artifact_config.py +34 -19
  7. zenml/artifacts/external_artifact.py +18 -8
  8. zenml/artifacts/external_artifact_config.py +14 -6
  9. zenml/artifacts/unmaterialized_artifact.py +2 -11
  10. zenml/cli/__init__.py +6 -0
  11. zenml/cli/artifact.py +20 -2
  12. zenml/cli/served_model.py +0 -1
  13. zenml/cli/server.py +3 -3
  14. zenml/cli/utils.py +36 -40
  15. zenml/cli/web_login.py +2 -2
  16. zenml/client.py +198 -24
  17. zenml/client_lazy_loader.py +20 -14
  18. zenml/config/base_settings.py +5 -6
  19. zenml/config/build_configuration.py +1 -1
  20. zenml/config/compiler.py +3 -3
  21. zenml/config/docker_settings.py +27 -28
  22. zenml/config/global_config.py +33 -37
  23. zenml/config/pipeline_configurations.py +8 -11
  24. zenml/config/pipeline_run_configuration.py +6 -2
  25. zenml/config/pipeline_spec.py +3 -4
  26. zenml/config/resource_settings.py +8 -9
  27. zenml/config/schedule.py +16 -20
  28. zenml/config/secret_reference_mixin.py +6 -3
  29. zenml/config/secrets_store_config.py +16 -23
  30. zenml/config/server_config.py +50 -46
  31. zenml/config/settings_resolver.py +1 -1
  32. zenml/config/source.py +45 -35
  33. zenml/config/step_configurations.py +53 -31
  34. zenml/config/store_config.py +20 -19
  35. zenml/config/strict_base_model.py +2 -6
  36. zenml/constants.py +26 -2
  37. zenml/container_registries/base_container_registry.py +3 -2
  38. zenml/container_registries/default_container_registry.py +3 -3
  39. zenml/event_hub/base_event_hub.py +1 -1
  40. zenml/event_sources/base_event_source.py +11 -16
  41. zenml/exceptions.py +4 -0
  42. zenml/integrations/airflow/__init__.py +2 -10
  43. zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +6 -7
  44. zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +13 -249
  45. zenml/integrations/airflow/orchestrators/dag_generator.py +5 -3
  46. zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +5 -4
  47. zenml/integrations/aws/__init__.py +1 -1
  48. zenml/integrations/aws/flavors/aws_container_registry_flavor.py +3 -2
  49. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +11 -5
  50. zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +6 -2
  51. zenml/integrations/aws/service_connectors/aws_service_connector.py +5 -4
  52. zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +4 -4
  53. zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -3
  54. zenml/integrations/azure/step_operators/azureml_step_operator.py +1 -1
  55. zenml/integrations/bentoml/steps/bentoml_deployer.py +1 -1
  56. zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +8 -12
  57. zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +1 -1
  58. zenml/integrations/evidently/__init__.py +3 -4
  59. zenml/integrations/evidently/column_mapping.py +11 -3
  60. zenml/integrations/evidently/data_validators/evidently_data_validator.py +21 -3
  61. zenml/integrations/evidently/metrics.py +5 -6
  62. zenml/integrations/evidently/tests.py +5 -6
  63. zenml/integrations/facets/models.py +2 -6
  64. zenml/integrations/feast/__init__.py +3 -1
  65. zenml/integrations/feast/feature_stores/feast_feature_store.py +0 -23
  66. zenml/integrations/gcp/__init__.py +1 -1
  67. zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +1 -1
  68. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +1 -1
  69. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +234 -103
  70. zenml/integrations/gcp/service_connectors/gcp_service_connector.py +57 -42
  71. zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
  72. zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +9 -13
  73. zenml/integrations/great_expectations/__init__.py +1 -1
  74. zenml/integrations/great_expectations/data_validators/ge_data_validator.py +44 -44
  75. zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +35 -2
  76. zenml/integrations/great_expectations/ge_store_backend.py +24 -11
  77. zenml/integrations/great_expectations/materializers/ge_materializer.py +3 -3
  78. zenml/integrations/great_expectations/utils.py +5 -5
  79. zenml/integrations/huggingface/__init__.py +3 -0
  80. zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +1 -1
  81. zenml/integrations/huggingface/steps/__init__.py +3 -0
  82. zenml/integrations/huggingface/steps/accelerate_runner.py +149 -0
  83. zenml/integrations/huggingface/steps/huggingface_deployer.py +2 -2
  84. zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +1 -1
  85. zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +4 -3
  86. zenml/integrations/kubeflow/__init__.py +1 -1
  87. zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +48 -81
  88. zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +295 -245
  89. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +1 -1
  90. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -2
  91. zenml/integrations/kubernetes/pod_settings.py +17 -31
  92. zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +8 -7
  93. zenml/integrations/label_studio/__init__.py +1 -3
  94. zenml/integrations/label_studio/annotators/label_studio_annotator.py +3 -4
  95. zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +2 -2
  96. zenml/integrations/langchain/materializers/document_materializer.py +44 -8
  97. zenml/integrations/mlflow/__init__.py +9 -3
  98. zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +1 -1
  99. zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +29 -37
  100. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +4 -4
  101. zenml/integrations/mlflow/steps/mlflow_deployer.py +1 -1
  102. zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +1 -1
  103. zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +1 -1
  104. zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +9 -8
  105. zenml/integrations/seldon/seldon_client.py +52 -67
  106. zenml/integrations/seldon/services/seldon_deployment.py +3 -3
  107. zenml/integrations/seldon/steps/seldon_deployer.py +4 -4
  108. zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +15 -5
  109. zenml/integrations/skypilot_aws/__init__.py +1 -1
  110. zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +1 -1
  111. zenml/integrations/skypilot_azure/__init__.py +1 -1
  112. zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +1 -1
  113. zenml/integrations/skypilot_gcp/__init__.py +2 -1
  114. zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +1 -1
  115. zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +2 -2
  116. zenml/integrations/spark/flavors/spark_step_operator_flavor.py +1 -1
  117. zenml/integrations/tekton/__init__.py +1 -1
  118. zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +66 -23
  119. zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +547 -233
  120. zenml/integrations/tensorboard/__init__.py +1 -12
  121. zenml/integrations/tensorboard/services/tensorboard_service.py +3 -5
  122. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +6 -6
  123. zenml/integrations/tensorflow/__init__.py +2 -10
  124. zenml/integrations/tensorflow/materializers/keras_materializer.py +17 -9
  125. zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +9 -14
  126. zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +1 -1
  127. zenml/lineage_graph/lineage_graph.py +1 -1
  128. zenml/materializers/built_in_materializer.py +3 -3
  129. zenml/materializers/pydantic_materializer.py +2 -2
  130. zenml/metadata/lazy_load.py +4 -4
  131. zenml/metadata/metadata_types.py +64 -4
  132. zenml/model/model.py +79 -54
  133. zenml/model_deployers/base_model_deployer.py +14 -12
  134. zenml/model_registries/base_model_registry.py +17 -15
  135. zenml/models/__init__.py +79 -206
  136. zenml/models/v2/base/base.py +54 -41
  137. zenml/models/v2/base/base_plugin_flavor.py +2 -6
  138. zenml/models/v2/base/filter.py +91 -76
  139. zenml/models/v2/base/page.py +2 -12
  140. zenml/models/v2/base/scoped.py +4 -7
  141. zenml/models/v2/core/api_key.py +22 -8
  142. zenml/models/v2/core/artifact.py +2 -2
  143. zenml/models/v2/core/artifact_version.py +74 -40
  144. zenml/models/v2/core/code_repository.py +37 -10
  145. zenml/models/v2/core/component.py +65 -16
  146. zenml/models/v2/core/device.py +14 -4
  147. zenml/models/v2/core/event_source.py +1 -2
  148. zenml/models/v2/core/flavor.py +74 -8
  149. zenml/models/v2/core/logs.py +68 -8
  150. zenml/models/v2/core/model.py +8 -4
  151. zenml/models/v2/core/model_version.py +25 -6
  152. zenml/models/v2/core/model_version_artifact.py +51 -21
  153. zenml/models/v2/core/model_version_pipeline_run.py +45 -13
  154. zenml/models/v2/core/pipeline.py +37 -72
  155. zenml/models/v2/core/pipeline_build.py +29 -17
  156. zenml/models/v2/core/pipeline_deployment.py +18 -6
  157. zenml/models/v2/core/pipeline_namespace.py +113 -0
  158. zenml/models/v2/core/pipeline_run.py +50 -22
  159. zenml/models/v2/core/run_metadata.py +59 -36
  160. zenml/models/v2/core/schedule.py +37 -24
  161. zenml/models/v2/core/secret.py +31 -12
  162. zenml/models/v2/core/service.py +64 -36
  163. zenml/models/v2/core/service_account.py +24 -11
  164. zenml/models/v2/core/service_connector.py +219 -44
  165. zenml/models/v2/core/stack.py +45 -17
  166. zenml/models/v2/core/step_run.py +28 -8
  167. zenml/models/v2/core/tag.py +8 -4
  168. zenml/models/v2/core/trigger.py +2 -2
  169. zenml/models/v2/core/trigger_execution.py +1 -0
  170. zenml/models/v2/core/user.py +18 -21
  171. zenml/models/v2/core/workspace.py +13 -3
  172. zenml/models/v2/misc/build_item.py +3 -3
  173. zenml/models/v2/misc/external_user.py +2 -6
  174. zenml/models/v2/misc/hub_plugin_models.py +9 -9
  175. zenml/models/v2/misc/loaded_visualization.py +2 -2
  176. zenml/models/v2/misc/service_connector_type.py +8 -17
  177. zenml/models/v2/misc/user_auth.py +7 -2
  178. zenml/new/pipelines/build_utils.py +3 -3
  179. zenml/new/pipelines/pipeline.py +17 -13
  180. zenml/new/pipelines/run_utils.py +103 -1
  181. zenml/orchestrators/base_orchestrator.py +10 -7
  182. zenml/orchestrators/local_docker/local_docker_orchestrator.py +1 -1
  183. zenml/orchestrators/step_runner.py +3 -6
  184. zenml/orchestrators/utils.py +1 -1
  185. zenml/plugins/base_plugin_flavor.py +6 -10
  186. zenml/plugins/plugin_flavor_registry.py +3 -7
  187. zenml/secret/base_secret.py +7 -8
  188. zenml/service_connectors/docker_service_connector.py +4 -3
  189. zenml/service_connectors/service_connector.py +5 -12
  190. zenml/service_connectors/service_connector_registry.py +2 -4
  191. zenml/services/container/container_service.py +1 -1
  192. zenml/services/container/container_service_endpoint.py +1 -1
  193. zenml/services/local/local_service.py +1 -1
  194. zenml/services/local/local_service_endpoint.py +1 -1
  195. zenml/services/service.py +16 -10
  196. zenml/services/service_type.py +4 -5
  197. zenml/services/terraform/terraform_service.py +1 -1
  198. zenml/stack/flavor.py +1 -5
  199. zenml/stack/flavor_registry.py +4 -4
  200. zenml/stack/stack.py +4 -1
  201. zenml/stack/stack_component.py +55 -31
  202. zenml/steps/base_step.py +34 -28
  203. zenml/steps/entrypoint_function_utils.py +3 -5
  204. zenml/steps/utils.py +12 -14
  205. zenml/utils/cuda_utils.py +50 -0
  206. zenml/utils/deprecation_utils.py +18 -20
  207. zenml/utils/dict_utils.py +1 -1
  208. zenml/utils/filesync_model.py +65 -28
  209. zenml/utils/function_utils.py +260 -0
  210. zenml/utils/json_utils.py +131 -0
  211. zenml/utils/mlstacks_utils.py +2 -2
  212. zenml/utils/pydantic_utils.py +270 -62
  213. zenml/utils/secret_utils.py +65 -12
  214. zenml/utils/source_utils.py +2 -2
  215. zenml/utils/typed_model.py +5 -3
  216. zenml/utils/typing_utils.py +243 -0
  217. zenml/utils/yaml_utils.py +1 -1
  218. zenml/zen_server/auth.py +2 -2
  219. zenml/zen_server/cloud_utils.py +6 -6
  220. zenml/zen_server/deploy/base_provider.py +1 -1
  221. zenml/zen_server/deploy/deployment.py +6 -8
  222. zenml/zen_server/deploy/docker/docker_zen_server.py +3 -4
  223. zenml/zen_server/deploy/local/local_provider.py +0 -1
  224. zenml/zen_server/deploy/local/local_zen_server.py +6 -6
  225. zenml/zen_server/deploy/terraform/terraform_zen_server.py +4 -6
  226. zenml/zen_server/exceptions.py +4 -1
  227. zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
  228. zenml/zen_server/pipeline_deployment/utils.py +48 -68
  229. zenml/zen_server/rbac/models.py +2 -5
  230. zenml/zen_server/rbac/utils.py +11 -14
  231. zenml/zen_server/routers/auth_endpoints.py +2 -2
  232. zenml/zen_server/routers/pipeline_builds_endpoints.py +1 -1
  233. zenml/zen_server/routers/runs_endpoints.py +1 -1
  234. zenml/zen_server/routers/secrets_endpoints.py +3 -2
  235. zenml/zen_server/routers/server_endpoints.py +1 -1
  236. zenml/zen_server/routers/steps_endpoints.py +1 -1
  237. zenml/zen_server/routers/workspaces_endpoints.py +1 -1
  238. zenml/zen_stores/base_zen_store.py +46 -9
  239. zenml/zen_stores/migrations/utils.py +42 -46
  240. zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py +1 -1
  241. zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +5 -3
  242. zenml/zen_stores/migrations/versions/10a907dad202_delete_mlmd_tables.py +1 -1
  243. zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +8 -10
  244. zenml/zen_stores/migrations/versions/37835ce041d2_optimizing_database.py +3 -3
  245. zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +10 -12
  246. zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +3 -2
  247. zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py +4 -4
  248. zenml/zen_stores/migrations/versions/728c6369cfaa_add_name_column_to_input_artifact_pk.py +3 -2
  249. zenml/zen_stores/migrations/versions/743ec82b1b3c_update_size_of_build_images.py +2 -2
  250. zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
  251. zenml/zen_stores/migrations/versions/7834208cc3f6_artifact_project_scoping.py +8 -7
  252. zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +6 -4
  253. zenml/zen_stores/migrations/versions/7e4a481d17f7_add_identity_table.py +2 -2
  254. zenml/zen_stores/migrations/versions/7f603e583dd7_fixed_migration.py +1 -1
  255. zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +2 -2
  256. zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +4 -4
  257. zenml/zen_stores/migrations/versions/alembic_start.py +1 -1
  258. zenml/zen_stores/migrations/versions/fbd7f18ced1e_increase_step_run_field_lengths.py +4 -4
  259. zenml/zen_stores/rest_zen_store.py +109 -49
  260. zenml/zen_stores/schemas/api_key_schemas.py +1 -1
  261. zenml/zen_stores/schemas/artifact_schemas.py +8 -8
  262. zenml/zen_stores/schemas/artifact_visualization_schemas.py +3 -3
  263. zenml/zen_stores/schemas/code_repository_schemas.py +1 -1
  264. zenml/zen_stores/schemas/component_schemas.py +8 -3
  265. zenml/zen_stores/schemas/device_schemas.py +8 -6
  266. zenml/zen_stores/schemas/event_source_schemas.py +3 -4
  267. zenml/zen_stores/schemas/flavor_schemas.py +5 -3
  268. zenml/zen_stores/schemas/model_schemas.py +26 -1
  269. zenml/zen_stores/schemas/pipeline_build_schemas.py +1 -1
  270. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +4 -4
  271. zenml/zen_stores/schemas/pipeline_run_schemas.py +6 -6
  272. zenml/zen_stores/schemas/pipeline_schemas.py +5 -2
  273. zenml/zen_stores/schemas/run_metadata_schemas.py +2 -2
  274. zenml/zen_stores/schemas/secret_schemas.py +8 -5
  275. zenml/zen_stores/schemas/server_settings_schemas.py +3 -1
  276. zenml/zen_stores/schemas/service_connector_schemas.py +1 -1
  277. zenml/zen_stores/schemas/service_schemas.py +11 -2
  278. zenml/zen_stores/schemas/stack_schemas.py +1 -1
  279. zenml/zen_stores/schemas/step_run_schemas.py +11 -11
  280. zenml/zen_stores/schemas/tag_schemas.py +6 -2
  281. zenml/zen_stores/schemas/trigger_schemas.py +2 -2
  282. zenml/zen_stores/schemas/user_schemas.py +2 -2
  283. zenml/zen_stores/schemas/workspace_schemas.py +3 -1
  284. zenml/zen_stores/secrets_stores/aws_secrets_store.py +19 -20
  285. zenml/zen_stores/secrets_stores/azure_secrets_store.py +17 -20
  286. zenml/zen_stores/secrets_stores/base_secrets_store.py +79 -12
  287. zenml/zen_stores/secrets_stores/gcp_secrets_store.py +17 -20
  288. zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +4 -8
  289. zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +10 -7
  290. zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -6
  291. zenml/zen_stores/sql_zen_store.py +196 -120
  292. zenml/zen_stores/zen_store_interface.py +33 -0
  293. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/METADATA +8 -7
  294. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/RECORD +297 -294
  295. zenml/integrations/kubeflow/utils.py +0 -95
  296. zenml/models/v2/base/internal.py +0 -37
  297. zenml/models/v2/base/update.py +0 -44
  298. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/LICENSE +0 -0
  299. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/WHEEL +0 -0
  300. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240620.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,260 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Utility functions for python functions."""
15
+
16
+ import inspect
17
+ import os
18
+ from contextlib import contextmanager
19
+ from pathlib import Path
20
+ from typing import Any, Callable, Iterator, List, Tuple, TypeVar, Union
21
+
22
+ import click
23
+
24
+ from zenml.logger import get_logger
25
+ from zenml.utils.string_utils import random_str
26
+
27
+ F = TypeVar("F", bound=Callable[..., None])
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ _CLI_WRAPPED_SCRIPT_TEMPLATE_HEADER = """
32
+ from zenml.utils.function_utils import _cli_wrapped_function
33
+
34
+ import sys
35
+ sys.path.append(r"{func_path}")
36
+
37
+ from {func_module} import {func_name} as func_to_wrap
38
+
39
+ if entrypoint:=getattr(func_to_wrap, "entrypoint", None):
40
+ func = _cli_wrapped_function(entrypoint)
41
+ else:
42
+ func = _cli_wrapped_function(func_to_wrap)
43
+ """
44
+ _CLI_WRAPPED_MAINS = {
45
+ "accelerate": """
46
+ if __name__=="__main__":
47
+ from accelerate import Accelerator
48
+ import cloudpickle as pickle
49
+ accelerator = Accelerator()
50
+ ret = func(standalone_mode=False)
51
+ if accelerator.is_main_process:
52
+ pickle.dump(ret, open(r"{output_file}", "wb"))
53
+ """
54
+ }
55
+ _ALLOWED_TYPES = (str, int, float, bool, Path)
56
+ _ALLOWED_COLLECTIONS = (tuple,)
57
+ _CLICK_TYPES_MAPPER = {
58
+ str: click.STRING,
59
+ int: click.INT,
60
+ float: click.FLOAT,
61
+ bool: click.BOOL,
62
+ Path: click.STRING,
63
+ None: click.STRING,
64
+ }
65
+
66
+
67
+ def _cli_arg_name(arg_name: str) -> str:
68
+ return arg_name.replace("_", "-")
69
+
70
+
71
+ def _is_valid_collection_arg(arg_type: Any) -> bool:
72
+ """Check if the given argument type is a valid collection type.
73
+
74
+ Args:
75
+ arg_type: The type to check.
76
+
77
+ Returns:
78
+ True if the argument type is a valid collection type, False otherwise.
79
+ """
80
+ if getattr(arg_type, "__origin__", None) in _ALLOWED_COLLECTIONS:
81
+ if arg_type.__args__[0] not in _ALLOWED_TYPES:
82
+ return False
83
+ return True
84
+ return False
85
+
86
+
87
+ def _is_valid_optional_arg(arg_type: Any) -> bool:
88
+ """Check if the given argument type is a valid Optional type.
89
+
90
+ A valid Optional type is defined as a Union with two arguments, where:
91
+ - The first argument is either an allowed type or a valid collection type.
92
+ - The second argument is the NoneType.
93
+
94
+ Args:
95
+ arg_type: The type to check.
96
+
97
+ Returns:
98
+ True if the argument type is a valid Optional type, False otherwise.
99
+ """
100
+ if (
101
+ getattr(arg_type, "_name", None) == "Optional"
102
+ and getattr(arg_type, "__origin__", None) == Union
103
+ ):
104
+ if args := getattr(arg_type, "__args__", None):
105
+ if len(args) != 2:
106
+ return False
107
+ if (
108
+ args[0] not in _ALLOWED_TYPES
109
+ and not _is_valid_collection_arg(args[0])
110
+ ) or args[1] != type(None):
111
+ return False
112
+ return True
113
+ return False
114
+
115
+
116
+ def _cli_wrapped_function(func: F) -> F:
117
+ """Create a decorator to generate the CLI-wrapped function.
118
+
119
+ Args:
120
+ func: The function to decorate.
121
+
122
+ Returns:
123
+ The inner decorator.
124
+
125
+ Raises:
126
+ ValueError: If the function arguments are not valid.
127
+ """
128
+ options: List[Any] = []
129
+ fullargspec = inspect.getfullargspec(func)
130
+ if fullargspec.defaults is not None:
131
+ defaults = [None] * (
132
+ len(fullargspec.args) - len(fullargspec.defaults)
133
+ ) + list(fullargspec.defaults)
134
+ else:
135
+ defaults = [None] * len(fullargspec.args)
136
+ input_args_dict = (
137
+ (
138
+ arg_name,
139
+ fullargspec.annotations.get(arg_name, None),
140
+ defaults[i],
141
+ )
142
+ for i, arg_name in enumerate(fullargspec.args)
143
+ )
144
+ invalid_types = {}
145
+ for arg_name, arg_type, arg_default in input_args_dict:
146
+ if _is_valid_optional_arg(arg_type):
147
+ arg_type = arg_type.__args__[0]
148
+ arg_name = _cli_arg_name(arg_name)
149
+ if arg_type == bool:
150
+ options.append(
151
+ click.option(
152
+ f"--{arg_name}",
153
+ type=click.BOOL,
154
+ is_flag=True,
155
+ default=False,
156
+ required=False,
157
+ )
158
+ )
159
+ elif _is_valid_collection_arg(arg_type):
160
+ member_type = arg_type.__args__[0]
161
+ options.append(
162
+ click.option(
163
+ f"--{arg_name}",
164
+ type=member_type,
165
+ default=arg_default,
166
+ required=False,
167
+ multiple=True,
168
+ )
169
+ )
170
+ elif arg_type in _ALLOWED_TYPES:
171
+ options.append(
172
+ click.option(
173
+ f"--{arg_name}",
174
+ type=_CLICK_TYPES_MAPPER[arg_type],
175
+ default=arg_default,
176
+ required=False if arg_default is not None else True,
177
+ )
178
+ )
179
+ else:
180
+ invalid_types[arg_name] = arg_type
181
+ if invalid_types:
182
+ raise ValueError(
183
+ f"Invalid argument types: {invalid_types}. CLI functions only "
184
+ f"supports: {_ALLOWED_TYPES} types (including Optional) and "
185
+ f"{_ALLOWED_COLLECTIONS} collections."
186
+ )
187
+ options.append(
188
+ click.command(
189
+ help="Technical wrapper to pass into the `accelerate launch` command."
190
+ )
191
+ )
192
+
193
+ def wrapper(function: F) -> F:
194
+ for option in reversed(options):
195
+ function = option(function)
196
+ return function
197
+
198
+ func.__doc__ = (
199
+ f"{func.__doc__}\n\nThis is ZenML-generated " "CLI wrapper function."
200
+ )
201
+
202
+ return wrapper(func)
203
+
204
+
205
+ @contextmanager
206
+ def create_cli_wrapped_script(
207
+ func: F, flavour: str = "accelerate"
208
+ ) -> Iterator[Tuple[Path, Path]]:
209
+ """Create a script with the CLI-wrapped function.
210
+
211
+ Args:
212
+ func: The function to use.
213
+ flavour: The flavour to use.
214
+
215
+ Yields:
216
+ The paths of the script and the output.
217
+
218
+ Raises:
219
+ ValueError: If the function is not defined in a module.
220
+ """
221
+ try:
222
+ random_name = random_str(20)
223
+ script_path = Path(random_name + ".py")
224
+ output_path = Path(random_name + ".out")
225
+
226
+ module = inspect.getmodule(func)
227
+ if module is None:
228
+ raise ValueError(
229
+ f"Function `{func.__name__}` must be defined in a "
230
+ "module to be used with Accelerate."
231
+ )
232
+
233
+ with open(script_path, "w") as f:
234
+ if path := module.__file__:
235
+ func_path = str(Path(path).parent)
236
+ relative_path = path.replace(func_path, "").lstrip(os.sep)
237
+ relative_path = os.path.splitext(relative_path)[0]
238
+ clean_module_name = ".".join(relative_path.split(os.sep))
239
+ script = _CLI_WRAPPED_SCRIPT_TEMPLATE_HEADER.format(
240
+ func_path=func_path,
241
+ func_module=clean_module_name,
242
+ func_name=func.__name__,
243
+ )
244
+ script += _CLI_WRAPPED_MAINS[flavour].format(
245
+ output_file=str(output_path.absolute())
246
+ )
247
+ f.write(script)
248
+ else:
249
+ raise ValueError(
250
+ f"Cannot find module file path for function `{func.__name__}`."
251
+ )
252
+
253
+ logger.debug(f"Created script:\n\n{script}")
254
+
255
+ yield script_path, output_path
256
+ finally:
257
+ if script_path.exists():
258
+ script_path.unlink()
259
+ if output_path.exists():
260
+ output_path.unlink()
@@ -0,0 +1,131 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Carried over version of some functions from the pydantic v1 json module.
15
+
16
+ Check out the latest version here:
17
+ https://github.com/pydantic/pydantic/blob/v1.10.15/pydantic/json.py
18
+ """
19
+
20
+ import datetime
21
+ from collections import deque
22
+ from decimal import Decimal
23
+ from enum import Enum
24
+ from ipaddress import (
25
+ IPv4Address,
26
+ IPv4Interface,
27
+ IPv4Network,
28
+ IPv6Address,
29
+ IPv6Interface,
30
+ IPv6Network,
31
+ )
32
+ from pathlib import Path
33
+ from re import Pattern
34
+ from types import GeneratorType
35
+ from typing import Any, Callable, Dict, Type, Union
36
+ from uuid import UUID
37
+
38
+ from pydantic import NameEmail, SecretBytes, SecretStr
39
+ from pydantic.color import Color
40
+
41
+ __all__ = "pydantic_encoder"
42
+
43
+
44
+ def isoformat(obj: Union[datetime.date, datetime.time]) -> str:
45
+ """Function to convert a datetime into iso format.
46
+
47
+ Args:
48
+ obj: input datetime
49
+
50
+ Returns:
51
+ the corresponding time in iso format.
52
+ """
53
+ return obj.isoformat()
54
+
55
+
56
+ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
57
+ """Encodes a Decimal as int of there's no exponent, otherwise float.
58
+
59
+ This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
60
+ where an integer (but not int typed) is used. Encoding this as a float
61
+ results in failed round-tripping between encode and parse.
62
+ Our ID type is a prime example of this.
63
+
64
+ >>> decimal_encoder(Decimal("1.0"))
65
+ 1.0
66
+
67
+ >>> decimal_encoder(Decimal("1"))
68
+ 1
69
+
70
+ Args:
71
+ dec_value: The input Decimal value
72
+
73
+ Returns:
74
+ the encoded result
75
+ """
76
+ if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
77
+ return int(dec_value)
78
+ else:
79
+ return float(dec_value)
80
+
81
+
82
+ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
83
+ bytes: lambda obj: obj.decode(),
84
+ Color: str,
85
+ datetime.date: isoformat,
86
+ datetime.datetime: isoformat,
87
+ datetime.time: isoformat,
88
+ datetime.timedelta: lambda td: td.total_seconds(),
89
+ Decimal: decimal_encoder,
90
+ Enum: lambda obj: obj.value,
91
+ frozenset: list,
92
+ deque: list,
93
+ GeneratorType: list,
94
+ IPv4Address: str,
95
+ IPv4Interface: str,
96
+ IPv4Network: str,
97
+ IPv6Address: str,
98
+ IPv6Interface: str,
99
+ IPv6Network: str,
100
+ NameEmail: str,
101
+ Path: str,
102
+ Pattern: lambda obj: obj.pattern,
103
+ SecretBytes: str,
104
+ SecretStr: str,
105
+ set: list,
106
+ UUID: str,
107
+ }
108
+
109
+
110
+ def pydantic_encoder(obj: Any) -> Any:
111
+ from dataclasses import asdict, is_dataclass
112
+
113
+ from pydantic import BaseModel
114
+
115
+ if isinstance(obj, BaseModel):
116
+ return obj.model_dump()
117
+ elif is_dataclass(obj):
118
+ return asdict(obj)
119
+
120
+ # Check the class type and its superclasses for a matching encoder
121
+ for base in obj.__class__.__mro__[:-1]:
122
+ try:
123
+ encoder = ENCODERS_BY_TYPE[base]
124
+ except KeyError:
125
+ continue
126
+ return encoder(obj)
127
+ else: # We have exited the for loop without finding a suitable encoder
128
+ raise TypeError(
129
+ f"Object of type '{obj.__class__.__name__}' is not JSON "
130
+ f"serializable."
131
+ )
@@ -422,9 +422,9 @@ def convert_mlstacks_primitives_to_dicts(
422
422
  verify_mlstacks_prerequisites_installation()
423
423
 
424
424
  # convert to json first to strip out Enums objects
425
- stack_dict = json.loads(stack.json())
425
+ stack_dict = json.loads(stack.model_dump_json())
426
426
  components_dicts = [
427
- json.loads(component.json()) for component in components
427
+ json.loads(component.model_dump_json()) for component in components
428
428
  ]
429
429
 
430
430
  return stack_dict, components_dicts