zenml-nightly 0.58.2.dev20240618__py3-none-any.whl → 0.58.2.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. zenml/VERSION +1 -1
  2. zenml/_hub/client.py +8 -5
  3. zenml/actions/base_action.py +8 -10
  4. zenml/artifact_stores/base_artifact_store.py +20 -15
  5. zenml/artifact_stores/local_artifact_store.py +3 -2
  6. zenml/artifacts/artifact_config.py +34 -19
  7. zenml/artifacts/external_artifact.py +18 -8
  8. zenml/artifacts/external_artifact_config.py +14 -6
  9. zenml/artifacts/unmaterialized_artifact.py +2 -11
  10. zenml/cli/__init__.py +6 -0
  11. zenml/cli/artifact.py +20 -2
  12. zenml/cli/served_model.py +0 -1
  13. zenml/cli/server.py +3 -3
  14. zenml/cli/utils.py +36 -40
  15. zenml/cli/web_login.py +2 -2
  16. zenml/client.py +198 -24
  17. zenml/client_lazy_loader.py +20 -14
  18. zenml/config/base_settings.py +5 -6
  19. zenml/config/build_configuration.py +1 -1
  20. zenml/config/compiler.py +3 -3
  21. zenml/config/docker_settings.py +27 -28
  22. zenml/config/global_config.py +33 -37
  23. zenml/config/pipeline_configurations.py +8 -11
  24. zenml/config/pipeline_run_configuration.py +6 -2
  25. zenml/config/pipeline_spec.py +3 -4
  26. zenml/config/resource_settings.py +8 -9
  27. zenml/config/schedule.py +16 -20
  28. zenml/config/secret_reference_mixin.py +6 -3
  29. zenml/config/secrets_store_config.py +16 -23
  30. zenml/config/server_config.py +50 -46
  31. zenml/config/settings_resolver.py +1 -1
  32. zenml/config/source.py +45 -35
  33. zenml/config/step_configurations.py +53 -31
  34. zenml/config/store_config.py +20 -19
  35. zenml/config/strict_base_model.py +2 -6
  36. zenml/constants.py +26 -2
  37. zenml/container_registries/base_container_registry.py +3 -2
  38. zenml/container_registries/default_container_registry.py +3 -3
  39. zenml/event_hub/base_event_hub.py +1 -1
  40. zenml/event_sources/base_event_source.py +11 -16
  41. zenml/exceptions.py +4 -0
  42. zenml/integrations/airflow/__init__.py +2 -10
  43. zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +6 -7
  44. zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +13 -249
  45. zenml/integrations/airflow/orchestrators/dag_generator.py +5 -3
  46. zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +5 -4
  47. zenml/integrations/aws/__init__.py +1 -1
  48. zenml/integrations/aws/flavors/aws_container_registry_flavor.py +3 -2
  49. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +11 -5
  50. zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +6 -2
  51. zenml/integrations/aws/service_connectors/aws_service_connector.py +5 -4
  52. zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +4 -4
  53. zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -3
  54. zenml/integrations/azure/step_operators/azureml_step_operator.py +1 -1
  55. zenml/integrations/bentoml/steps/bentoml_deployer.py +1 -1
  56. zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +8 -12
  57. zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +1 -1
  58. zenml/integrations/evidently/__init__.py +3 -4
  59. zenml/integrations/evidently/column_mapping.py +11 -3
  60. zenml/integrations/evidently/data_validators/evidently_data_validator.py +21 -3
  61. zenml/integrations/evidently/metrics.py +5 -6
  62. zenml/integrations/evidently/tests.py +5 -6
  63. zenml/integrations/facets/models.py +2 -6
  64. zenml/integrations/feast/__init__.py +3 -1
  65. zenml/integrations/feast/feature_stores/feast_feature_store.py +0 -23
  66. zenml/integrations/gcp/__init__.py +1 -1
  67. zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +1 -1
  68. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +1 -1
  69. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +234 -103
  70. zenml/integrations/gcp/service_connectors/gcp_service_connector.py +57 -42
  71. zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
  72. zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +9 -13
  73. zenml/integrations/great_expectations/__init__.py +1 -1
  74. zenml/integrations/great_expectations/data_validators/ge_data_validator.py +44 -44
  75. zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +35 -2
  76. zenml/integrations/great_expectations/ge_store_backend.py +24 -11
  77. zenml/integrations/great_expectations/materializers/ge_materializer.py +3 -3
  78. zenml/integrations/great_expectations/utils.py +5 -5
  79. zenml/integrations/huggingface/__init__.py +3 -0
  80. zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +1 -1
  81. zenml/integrations/huggingface/steps/__init__.py +3 -0
  82. zenml/integrations/huggingface/steps/accelerate_runner.py +149 -0
  83. zenml/integrations/huggingface/steps/huggingface_deployer.py +2 -2
  84. zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +1 -1
  85. zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +4 -3
  86. zenml/integrations/kubeflow/__init__.py +1 -1
  87. zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +48 -81
  88. zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +295 -245
  89. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +1 -1
  90. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -2
  91. zenml/integrations/kubernetes/pod_settings.py +17 -31
  92. zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +8 -7
  93. zenml/integrations/label_studio/__init__.py +1 -3
  94. zenml/integrations/label_studio/annotators/label_studio_annotator.py +3 -4
  95. zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +2 -2
  96. zenml/integrations/langchain/materializers/document_materializer.py +44 -8
  97. zenml/integrations/mlflow/__init__.py +9 -3
  98. zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +1 -1
  99. zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +29 -37
  100. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +4 -4
  101. zenml/integrations/mlflow/steps/mlflow_deployer.py +1 -1
  102. zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +1 -1
  103. zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +1 -1
  104. zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +9 -8
  105. zenml/integrations/seldon/seldon_client.py +52 -67
  106. zenml/integrations/seldon/services/seldon_deployment.py +3 -3
  107. zenml/integrations/seldon/steps/seldon_deployer.py +4 -4
  108. zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +15 -5
  109. zenml/integrations/skypilot_aws/__init__.py +1 -1
  110. zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +1 -1
  111. zenml/integrations/skypilot_azure/__init__.py +1 -1
  112. zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +1 -1
  113. zenml/integrations/skypilot_gcp/__init__.py +2 -1
  114. zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +1 -1
  115. zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +2 -2
  116. zenml/integrations/spark/flavors/spark_step_operator_flavor.py +1 -1
  117. zenml/integrations/tekton/__init__.py +1 -1
  118. zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +66 -23
  119. zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +547 -233
  120. zenml/integrations/tensorboard/__init__.py +1 -12
  121. zenml/integrations/tensorboard/services/tensorboard_service.py +3 -5
  122. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +6 -6
  123. zenml/integrations/tensorflow/__init__.py +2 -10
  124. zenml/integrations/tensorflow/materializers/keras_materializer.py +17 -9
  125. zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +9 -14
  126. zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +1 -1
  127. zenml/lineage_graph/lineage_graph.py +1 -1
  128. zenml/materializers/built_in_materializer.py +3 -3
  129. zenml/materializers/pydantic_materializer.py +2 -2
  130. zenml/metadata/lazy_load.py +4 -4
  131. zenml/metadata/metadata_types.py +64 -4
  132. zenml/model/model.py +79 -54
  133. zenml/model_deployers/base_model_deployer.py +14 -12
  134. zenml/model_registries/base_model_registry.py +17 -15
  135. zenml/models/__init__.py +79 -206
  136. zenml/models/v2/base/base.py +54 -41
  137. zenml/models/v2/base/base_plugin_flavor.py +2 -6
  138. zenml/models/v2/base/filter.py +91 -76
  139. zenml/models/v2/base/page.py +2 -12
  140. zenml/models/v2/base/scoped.py +4 -7
  141. zenml/models/v2/core/api_key.py +22 -8
  142. zenml/models/v2/core/artifact.py +2 -2
  143. zenml/models/v2/core/artifact_version.py +74 -40
  144. zenml/models/v2/core/code_repository.py +37 -10
  145. zenml/models/v2/core/component.py +65 -16
  146. zenml/models/v2/core/device.py +14 -4
  147. zenml/models/v2/core/event_source.py +1 -2
  148. zenml/models/v2/core/flavor.py +74 -8
  149. zenml/models/v2/core/logs.py +68 -8
  150. zenml/models/v2/core/model.py +8 -4
  151. zenml/models/v2/core/model_version.py +25 -6
  152. zenml/models/v2/core/model_version_artifact.py +51 -21
  153. zenml/models/v2/core/model_version_pipeline_run.py +45 -13
  154. zenml/models/v2/core/pipeline.py +37 -72
  155. zenml/models/v2/core/pipeline_build.py +29 -17
  156. zenml/models/v2/core/pipeline_deployment.py +18 -6
  157. zenml/models/v2/core/pipeline_namespace.py +113 -0
  158. zenml/models/v2/core/pipeline_run.py +50 -22
  159. zenml/models/v2/core/run_metadata.py +59 -36
  160. zenml/models/v2/core/schedule.py +37 -24
  161. zenml/models/v2/core/secret.py +31 -12
  162. zenml/models/v2/core/service.py +64 -36
  163. zenml/models/v2/core/service_account.py +24 -11
  164. zenml/models/v2/core/service_connector.py +219 -44
  165. zenml/models/v2/core/stack.py +45 -17
  166. zenml/models/v2/core/step_run.py +28 -8
  167. zenml/models/v2/core/tag.py +8 -4
  168. zenml/models/v2/core/trigger.py +2 -2
  169. zenml/models/v2/core/trigger_execution.py +1 -0
  170. zenml/models/v2/core/user.py +18 -21
  171. zenml/models/v2/core/workspace.py +13 -3
  172. zenml/models/v2/misc/build_item.py +3 -3
  173. zenml/models/v2/misc/external_user.py +2 -6
  174. zenml/models/v2/misc/hub_plugin_models.py +9 -9
  175. zenml/models/v2/misc/loaded_visualization.py +2 -2
  176. zenml/models/v2/misc/service_connector_type.py +8 -17
  177. zenml/models/v2/misc/user_auth.py +7 -2
  178. zenml/new/pipelines/build_utils.py +3 -3
  179. zenml/new/pipelines/pipeline.py +17 -13
  180. zenml/new/pipelines/run_utils.py +103 -1
  181. zenml/orchestrators/base_orchestrator.py +10 -7
  182. zenml/orchestrators/local_docker/local_docker_orchestrator.py +1 -1
  183. zenml/orchestrators/step_runner.py +3 -6
  184. zenml/orchestrators/utils.py +1 -1
  185. zenml/plugins/base_plugin_flavor.py +6 -10
  186. zenml/plugins/plugin_flavor_registry.py +3 -7
  187. zenml/secret/base_secret.py +7 -8
  188. zenml/service_connectors/docker_service_connector.py +4 -3
  189. zenml/service_connectors/service_connector.py +5 -12
  190. zenml/service_connectors/service_connector_registry.py +2 -4
  191. zenml/services/container/container_service.py +1 -1
  192. zenml/services/container/container_service_endpoint.py +1 -1
  193. zenml/services/local/local_service.py +1 -1
  194. zenml/services/local/local_service_endpoint.py +1 -1
  195. zenml/services/service.py +16 -10
  196. zenml/services/service_type.py +4 -5
  197. zenml/services/terraform/terraform_service.py +1 -1
  198. zenml/stack/flavor.py +1 -5
  199. zenml/stack/flavor_registry.py +4 -4
  200. zenml/stack/stack.py +4 -1
  201. zenml/stack/stack_component.py +55 -31
  202. zenml/steps/base_step.py +34 -28
  203. zenml/steps/entrypoint_function_utils.py +3 -5
  204. zenml/steps/utils.py +12 -14
  205. zenml/utils/cuda_utils.py +50 -0
  206. zenml/utils/deprecation_utils.py +18 -20
  207. zenml/utils/dict_utils.py +1 -1
  208. zenml/utils/filesync_model.py +65 -28
  209. zenml/utils/function_utils.py +260 -0
  210. zenml/utils/json_utils.py +131 -0
  211. zenml/utils/mlstacks_utils.py +2 -2
  212. zenml/utils/pydantic_utils.py +270 -62
  213. zenml/utils/secret_utils.py +65 -12
  214. zenml/utils/source_utils.py +2 -2
  215. zenml/utils/typed_model.py +5 -3
  216. zenml/utils/typing_utils.py +243 -0
  217. zenml/utils/yaml_utils.py +1 -1
  218. zenml/zen_server/auth.py +2 -2
  219. zenml/zen_server/cloud_utils.py +6 -6
  220. zenml/zen_server/deploy/base_provider.py +1 -1
  221. zenml/zen_server/deploy/deployment.py +6 -8
  222. zenml/zen_server/deploy/docker/docker_zen_server.py +3 -4
  223. zenml/zen_server/deploy/local/local_provider.py +0 -1
  224. zenml/zen_server/deploy/local/local_zen_server.py +6 -6
  225. zenml/zen_server/deploy/terraform/terraform_zen_server.py +4 -6
  226. zenml/zen_server/exceptions.py +4 -1
  227. zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
  228. zenml/zen_server/pipeline_deployment/utils.py +48 -68
  229. zenml/zen_server/rbac/models.py +2 -5
  230. zenml/zen_server/rbac/utils.py +11 -14
  231. zenml/zen_server/routers/auth_endpoints.py +2 -2
  232. zenml/zen_server/routers/pipeline_builds_endpoints.py +1 -1
  233. zenml/zen_server/routers/runs_endpoints.py +1 -1
  234. zenml/zen_server/routers/secrets_endpoints.py +3 -2
  235. zenml/zen_server/routers/server_endpoints.py +1 -1
  236. zenml/zen_server/routers/steps_endpoints.py +1 -1
  237. zenml/zen_server/routers/workspaces_endpoints.py +1 -1
  238. zenml/zen_stores/base_zen_store.py +46 -9
  239. zenml/zen_stores/migrations/utils.py +42 -46
  240. zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py +1 -1
  241. zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +5 -3
  242. zenml/zen_stores/migrations/versions/10a907dad202_delete_mlmd_tables.py +1 -1
  243. zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +8 -10
  244. zenml/zen_stores/migrations/versions/37835ce041d2_optimizing_database.py +3 -3
  245. zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +10 -12
  246. zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +3 -2
  247. zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py +4 -4
  248. zenml/zen_stores/migrations/versions/728c6369cfaa_add_name_column_to_input_artifact_pk.py +3 -2
  249. zenml/zen_stores/migrations/versions/743ec82b1b3c_update_size_of_build_images.py +2 -2
  250. zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
  251. zenml/zen_stores/migrations/versions/7834208cc3f6_artifact_project_scoping.py +8 -7
  252. zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +6 -4
  253. zenml/zen_stores/migrations/versions/7e4a481d17f7_add_identity_table.py +2 -2
  254. zenml/zen_stores/migrations/versions/7f603e583dd7_fixed_migration.py +1 -1
  255. zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +2 -2
  256. zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +4 -4
  257. zenml/zen_stores/migrations/versions/alembic_start.py +1 -1
  258. zenml/zen_stores/migrations/versions/fbd7f18ced1e_increase_step_run_field_lengths.py +4 -4
  259. zenml/zen_stores/rest_zen_store.py +109 -49
  260. zenml/zen_stores/schemas/api_key_schemas.py +1 -1
  261. zenml/zen_stores/schemas/artifact_schemas.py +8 -8
  262. zenml/zen_stores/schemas/artifact_visualization_schemas.py +3 -3
  263. zenml/zen_stores/schemas/code_repository_schemas.py +1 -1
  264. zenml/zen_stores/schemas/component_schemas.py +8 -3
  265. zenml/zen_stores/schemas/device_schemas.py +8 -6
  266. zenml/zen_stores/schemas/event_source_schemas.py +3 -4
  267. zenml/zen_stores/schemas/flavor_schemas.py +5 -3
  268. zenml/zen_stores/schemas/model_schemas.py +26 -1
  269. zenml/zen_stores/schemas/pipeline_build_schemas.py +1 -1
  270. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +4 -4
  271. zenml/zen_stores/schemas/pipeline_run_schemas.py +6 -6
  272. zenml/zen_stores/schemas/pipeline_schemas.py +5 -2
  273. zenml/zen_stores/schemas/run_metadata_schemas.py +2 -2
  274. zenml/zen_stores/schemas/secret_schemas.py +8 -5
  275. zenml/zen_stores/schemas/server_settings_schemas.py +3 -1
  276. zenml/zen_stores/schemas/service_connector_schemas.py +1 -1
  277. zenml/zen_stores/schemas/service_schemas.py +11 -2
  278. zenml/zen_stores/schemas/stack_schemas.py +1 -1
  279. zenml/zen_stores/schemas/step_run_schemas.py +11 -11
  280. zenml/zen_stores/schemas/tag_schemas.py +6 -2
  281. zenml/zen_stores/schemas/trigger_schemas.py +2 -2
  282. zenml/zen_stores/schemas/user_schemas.py +2 -2
  283. zenml/zen_stores/schemas/workspace_schemas.py +3 -1
  284. zenml/zen_stores/secrets_stores/aws_secrets_store.py +19 -20
  285. zenml/zen_stores/secrets_stores/azure_secrets_store.py +17 -20
  286. zenml/zen_stores/secrets_stores/base_secrets_store.py +79 -12
  287. zenml/zen_stores/secrets_stores/gcp_secrets_store.py +17 -20
  288. zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +4 -8
  289. zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +10 -7
  290. zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -6
  291. zenml/zen_stores/sql_zen_store.py +196 -120
  292. zenml/zen_stores/zen_store_interface.py +33 -0
  293. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/METADATA +8 -7
  294. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/RECORD +297 -294
  295. zenml/integrations/kubeflow/utils.py +0 -95
  296. zenml/models/v2/base/internal.py +0 -37
  297. zenml/models/v2/base/update.py +0 -44
  298. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/LICENSE +0 -0
  299. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/WHEEL +0 -0
  300. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/entry_points.txt +0 -0
@@ -17,7 +17,7 @@ import urllib
17
17
  from typing import Any, Dict, List, Optional, Type, Union
18
18
  from uuid import UUID
19
19
 
20
- from pydantic import BaseModel, Extra, Field
20
+ from pydantic import BaseModel, ConfigDict, Field
21
21
 
22
22
  from zenml.enums import SecretScope
23
23
  from zenml.event_sources.base_event import (
@@ -111,14 +111,10 @@ class GithubEvent(BaseEvent):
111
111
  after: str
112
112
  repository: Repository
113
113
  commits: List[Commit]
114
- head_commit: Optional[Commit]
115
- tags: Optional[List[Tag]]
116
- pull_requests: Optional[List[PullRequest]]
117
-
118
- class Config:
119
- """Pydantic configuration class."""
120
-
121
- extra = Extra.allow
114
+ head_commit: Optional[Commit] = None
115
+ tags: Optional[List[Tag]] = None
116
+ pull_requests: Optional[List[PullRequest]] = None
117
+ model_config = ConfigDict(extra="allow")
122
118
 
123
119
  @property
124
120
  def branch(self) -> Optional[str]:
@@ -157,9 +153,9 @@ class GithubEvent(BaseEvent):
157
153
  class GithubWebhookEventFilterConfiguration(WebhookEventFilterConfig):
158
154
  """Configuration for github event filters."""
159
155
 
160
- repo: Optional[str]
161
- branch: Optional[str]
162
- event_type: Optional[GithubEventType]
156
+ repo: Optional[str] = None
157
+ branch: Optional[str] = None
158
+ event_type: Optional[GithubEventType] = None
163
159
 
164
160
  def event_matches_filter(self, event: BaseEvent) -> bool:
165
161
  """Checks the filter against the inbound event.
@@ -442,7 +438,7 @@ class GithubWebhookEventSourceHandler(BaseWebhookEventSourceHandler):
442
438
  if config.rotate_secret:
443
439
  # In case the secret is being rotated
444
440
  secret_key_value = random_str(12)
445
- webhook_secret = SecretUpdate( # type: ignore[call-arg]
441
+ webhook_secret = SecretUpdate(
446
442
  values={"webhook_secret": secret_key_value}
447
443
  )
448
444
  self.zen_store.update_secret(
@@ -32,7 +32,7 @@ class GreatExpectationsIntegration(Integration):
32
32
 
33
33
  NAME = GREAT_EXPECTATIONS
34
34
  REQUIREMENTS = [
35
- "great-expectations>=0.15.0,<=0.15.47",
35
+ "great-expectations>=0.17.15,<1.0",
36
36
  ]
37
37
 
38
38
  @staticmethod
@@ -17,21 +17,25 @@ import os
17
17
  from typing import Any, ClassVar, Dict, List, Optional, Sequence, Type, cast
18
18
 
19
19
  import pandas as pd
20
- import yaml
21
20
  from great_expectations.checkpoint.types.checkpoint_result import ( # type: ignore[import-untyped]
22
21
  CheckpointResult,
23
22
  )
24
23
  from great_expectations.core import ( # type: ignore[import-untyped]
25
24
  ExpectationSuite,
26
25
  )
27
- from great_expectations.data_context.data_context import ( # type: ignore[import-untyped]
28
- BaseDataContext,
29
- DataContext,
26
+ from great_expectations.data_context.data_context.abstract_data_context import (
27
+ AbstractDataContext,
30
28
  )
31
- from great_expectations.data_context.types.base import ( # type: ignore[import-untyped]
29
+ from great_expectations.data_context.data_context.context_factory import (
30
+ get_context,
31
+ )
32
+ from great_expectations.data_context.data_context.ephemeral_data_context import (
33
+ EphemeralDataContext,
34
+ )
35
+ from great_expectations.data_context.types.base import (
32
36
  DataContextConfig,
33
37
  )
34
- from great_expectations.data_context.types.resource_identifiers import ( # type: ignore[import-untyped]
38
+ from great_expectations.data_context.types.resource_identifiers import (
35
39
  ExpectationSuiteIdentifier,
36
40
  )
37
41
  from great_expectations.profile.user_configurable_profiler import ( # type: ignore[import-untyped]
@@ -65,8 +69,8 @@ class GreatExpectationsDataValidator(BaseDataValidator):
65
69
  GreatExpectationsDataValidatorFlavor
66
70
  )
67
71
 
68
- _context: BaseDataContext = None
69
- _context_config: Optional[Dict[str, Any]] = None
72
+ _context: Optional[AbstractDataContext] = None
73
+ _context_config: Optional[DataContextConfig] = None
70
74
 
71
75
  @property
72
76
  def config(self) -> GreatExpectationsDataValidatorConfig:
@@ -78,7 +82,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
78
82
  return cast(GreatExpectationsDataValidatorConfig, self._config)
79
83
 
80
84
  @classmethod
81
- def get_data_context(cls) -> BaseDataContext:
85
+ def get_data_context(cls) -> AbstractDataContext:
82
86
  """Get the Great Expectations data context managed by ZenML.
83
87
 
84
88
  Call this method to retrieve the data context managed by ZenML
@@ -94,15 +98,11 @@ class GreatExpectationsDataValidator(BaseDataValidator):
94
98
  return data_validator.data_context
95
99
 
96
100
  @property
97
- def context_config(self) -> Optional[Dict[str, Any]]:
101
+ def context_config(self) -> Optional[DataContextConfig]:
98
102
  """Get the Great Expectations data context configuration.
99
103
 
100
- The first time the context config is loaded from the stack component
101
- config, it is converted from JSON/YAML string format to a dict.
102
-
103
104
  Raises:
104
- ValueError: If the context_config value is not a valid JSON/YAML or
105
- if the GE configuration extracted from it fails GE validation.
105
+ ValueError: In case there is an invalid context_config value
106
106
 
107
107
  Returns:
108
108
  A dictionary with the GE data context configuration.
@@ -111,31 +111,18 @@ class GreatExpectationsDataValidator(BaseDataValidator):
111
111
  if self._context_config is not None:
112
112
  return self._context_config
113
113
 
114
- # Otherwise, load it from the stack component config
115
- context_config = self.config.context_config
116
- if context_config is None:
114
+ # Otherwise, use the configuration from the stack component config, if
115
+ # set
116
+ context_config_dict = self.config.context_config
117
+ if context_config_dict is None:
117
118
  return None
118
- if isinstance(context_config, dict):
119
- self._context_config = context_config
120
- return self._context_config
121
-
122
- # If the context config is a string, try to parse it as JSON/YAML
123
- try:
124
- context_config_dict = yaml.safe_load(context_config)
125
- except yaml.parser.ParserError as e:
126
- raise ValueError(
127
- f"Malformed `context_config` value. Only JSON and YAML "
128
- f"formats are supported: {str(e)}"
129
- )
130
119
 
131
120
  # Validate that the context config is a valid GE config
132
121
  try:
133
- context_config = DataContextConfig(**context_config_dict)
134
- BaseDataContext(project_config=context_config)
122
+ self._context_config = DataContextConfig(**context_config_dict)
135
123
  except Exception as e:
136
124
  raise ValueError(f"Invalid `context_config` value: {str(e)}")
137
125
 
138
- self._context_config = cast(Dict[str, Any], context_config_dict)
139
126
  return self._context_config
140
127
 
141
128
  @property
@@ -203,7 +190,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
203
190
  }
204
191
 
205
192
  @property
206
- def data_context(self) -> BaseDataContext:
193
+ def data_context(self) -> AbstractDataContext:
207
194
  """Returns the Great Expectations data context configured for this component.
208
195
 
209
196
  Returns:
@@ -216,7 +203,9 @@ class GreatExpectationsDataValidator(BaseDataValidator):
216
203
  profiler_store_name = "zenml_profiler_store"
217
204
  evaluation_parameter_store_name = "evaluation_parameter_store"
218
205
 
219
- zenml_context_config = dict(
206
+ # Define default configuration options that plug the GX stores
207
+ # in the active ZenML artifact store
208
+ zenml_context_config: Dict[str, Any] = dict(
220
209
  stores={
221
210
  expectations_store_name: self.get_store_config(
222
211
  "ExpectationsStore", "expectations"
@@ -250,18 +239,29 @@ class GreatExpectationsDataValidator(BaseDataValidator):
250
239
  if self.config.context_root_dir:
251
240
  # initialize the local data context, if a local path was
252
241
  # configured
253
- self._context = DataContext(self.config.context_root_dir)
242
+ self._context = get_context(
243
+ context_root_dir=self.config.context_root_dir
244
+ )
245
+
254
246
  else:
255
- # create an in-memory data context configuration that is not
256
- # backed by a local YAML file (see https://docs.greatexpectations.io/docs/guides/setup/configuring_data_contexts/how_to_instantiate_a_data_context_without_a_yml_file/).
247
+ # create an ephemeral in-memory data context that is not
248
+ # backed by a local YAML file (see https://docs.greatexpectations.io/docs/oss/guides/setup/configuring_data_contexts/instantiating_data_contexts/instantiate_data_context/).
257
249
  if self.context_config:
258
- context_config = DataContextConfig(**self.context_config)
250
+ # Use the data context configuration provided in the stack
251
+ # component configuration
252
+ context_config = self.context_config
259
253
  else:
254
+ # Initialize the data context with the default ZenML
255
+ # configuration options effectively plugging the GX stores
256
+ # into the ZenML artifact store
260
257
  context_config = DataContextConfig(**zenml_context_config)
261
258
  # skip adding the stores after initialization, as they are
262
259
  # already baked in the initial configuration
263
260
  configure_zenml_stores = False
264
- self._context = BaseDataContext(project_config=context_config)
261
+
262
+ self._context = EphemeralDataContext(
263
+ project_config=context_config
264
+ )
265
265
 
266
266
  if configure_zenml_stores:
267
267
  self._context.config.expectations_store_name = (
@@ -277,14 +277,14 @@ class GreatExpectationsDataValidator(BaseDataValidator):
277
277
  self._context.config.evaluation_parameter_store_name = (
278
278
  evaluation_parameter_store_name
279
279
  )
280
- for store_name, store_config in zenml_context_config[ # type: ignore[attr-defined]
280
+ for store_name, store_config in zenml_context_config[
281
281
  "stores"
282
282
  ].items():
283
283
  self._context.add_store(
284
284
  store_name=store_name,
285
285
  store_config=store_config,
286
286
  )
287
- for site_name, site_config in zenml_context_config[ # type: ignore[attr-defined]
287
+ for site_name, site_config in zenml_context_config[
288
288
  "data_docs_sites"
289
289
  ].items():
290
290
  self._context.config.data_docs_sites[site_name] = (
@@ -509,7 +509,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
509
509
  },
510
510
  ]
511
511
 
512
- checkpoint_config = {
512
+ checkpoint_config: Dict[str, Any] = {
513
513
  "name": checkpoint_name,
514
514
  "run_name_template": run_name,
515
515
  "config_version": 1,
@@ -517,7 +517,7 @@ class GreatExpectationsDataValidator(BaseDataValidator):
517
517
  "expectation_suite_name": expectation_suite_name,
518
518
  "action_list": action_list,
519
519
  }
520
- context.add_checkpoint(**checkpoint_config)
520
+ context.add_checkpoint(**checkpoint_config) # type: ignore[has-type]
521
521
 
522
522
  try:
523
523
  results = context.run_checkpoint(
@@ -16,7 +16,9 @@
16
16
  import os
17
17
  from typing import TYPE_CHECKING, Any, Dict, Optional, Type
18
18
 
19
- from pydantic import validator
19
+ import yaml
20
+ from pydantic import field_validator, model_validator
21
+ from yaml.parser import ParserError
20
22
 
21
23
  from zenml.data_validators.base_data_validator import (
22
24
  BaseDataValidatorConfig,
@@ -26,6 +28,7 @@ from zenml.integrations.great_expectations import (
26
28
  GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
27
29
  )
28
30
  from zenml.io import fileio
31
+ from zenml.utils.pydantic_utils import before_validator_handler
29
32
 
30
33
  if TYPE_CHECKING:
31
34
  from zenml.integrations.great_expectations.data_validators import (
@@ -41,6 +44,8 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
41
44
  data context. If configured, the data validator will only be usable
42
45
  with local orchestrators.
43
46
  context_config: in-line Great Expectations data context configuration.
47
+ If the `context_root_dir` attribute is also set, this configuration
48
+ will be ignored.
44
49
  configure_zenml_stores: if set, ZenML will automatically configure
45
50
  stores that use the Artifact Store as a backend. If neither
46
51
  `context_root_dir` nor `context_config` are set, this is the default
@@ -54,7 +59,8 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
54
59
  configure_zenml_stores: bool = False
55
60
  configure_local_docs: bool = True
56
61
 
57
- @validator("context_root_dir")
62
+ @field_validator("context_root_dir")
63
+ @classmethod
58
64
  def _ensure_valid_context_root_dir(
59
65
  cls, context_root_dir: Optional[str] = None
60
66
  ) -> Optional[str]:
@@ -78,6 +84,33 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig):
78
84
  )
79
85
  return context_root_dir
80
86
 
87
+ @model_validator(mode="before")
88
+ @classmethod
89
+ @before_validator_handler
90
+ def validate_context_config(cls, data: Dict[str, Any]) -> Dict[str, Any]:
91
+ """Convert the context configuration if given in JSON/YAML format.
92
+
93
+ Args:
94
+ data: The configuration values.
95
+
96
+ Returns:
97
+ The validated configuration values.
98
+
99
+ Raises:
100
+ ValueError: If the context configuration is not a valid
101
+ JSON/YAML object.
102
+ """
103
+ if isinstance(data.get("context_config"), str):
104
+ try:
105
+ data["context_config"] = yaml.safe_load(data["context_config"])
106
+ except ParserError as e:
107
+ raise ValueError(
108
+ f"Malformed `context_config` value. Only JSON and YAML "
109
+ f"formats are supported: {str(e)}"
110
+ )
111
+
112
+ return data
113
+
81
114
  @property
82
115
  def is_local(self) -> bool:
83
116
  """Checks if this stack component is running locally.
@@ -17,14 +17,16 @@ import os
17
17
  from pathlib import Path
18
18
  from typing import Any, Dict, List, Optional, Tuple, cast
19
19
 
20
- from great_expectations.data_context.store.tuple_store_backend import ( # type: ignore[import-untyped]
20
+ from great_expectations.data_context.store.tuple_store_backend import (
21
21
  TupleStoreBackend,
22
- filter_properties_dict,
23
22
  )
24
23
  from great_expectations.exceptions import ( # type: ignore[import-untyped]
25
24
  InvalidKeyError,
26
25
  StoreBackendError,
27
26
  )
27
+ from great_expectations.util import ( # type: ignore[import-untyped]
28
+ filter_properties_dict,
29
+ )
28
30
 
29
31
  from zenml.client import Client
30
32
  from zenml.io import fileio
@@ -34,7 +36,7 @@ from zenml.utils import io_utils
34
36
  logger = get_logger(__name__)
35
37
 
36
38
 
37
- class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
39
+ class ZenMLArtifactStoreBackend(TupleStoreBackend):
38
40
  """Great Expectations store backend that uses the active ZenML Artifact Store as a store."""
39
41
 
40
42
  def __init__(
@@ -105,7 +107,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
105
107
  if not isinstance(key, tuple):
106
108
  key = key.to_tuple()
107
109
  if not is_prefix:
108
- object_relative_path = self._convert_key_to_filepath(key)
110
+ object_relative_path = self._convert_key_to_filepath(key) # type: ignore[no-untyped-call]
109
111
  elif key:
110
112
  object_relative_path = os.path.join(*key)
111
113
  else:
@@ -116,7 +118,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
116
118
  object_key = object_relative_path
117
119
  return os.path.join(self.root_path, object_key)
118
120
 
119
- def _get(self, key: Tuple[str, ...]) -> str:
121
+ def _get(self, key: Tuple[str, ...]) -> str: # type: ignore[override]
120
122
  """Get the value of an object from the store.
121
123
 
122
124
  Args:
@@ -140,7 +142,18 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
140
142
  )
141
143
  return contents
142
144
 
143
- def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str:
145
+ def _get_all(self) -> List[Any]:
146
+ """Get all objects in the store.
147
+
148
+ Raises:
149
+ NotImplementedError: if the method is not implemented for this store
150
+ backend.
151
+ """
152
+ raise NotImplementedError(
153
+ "Method `_get_all` is not implemented for this store backend."
154
+ )
155
+
156
+ def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str: # type: ignore[override]
144
157
  """Set the value of an object in the store.
145
158
 
146
159
  Args:
@@ -212,12 +225,12 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
212
225
  self.filepath_suffix
213
226
  ):
214
227
  continue
215
- key = self._convert_filepath_to_key(filepath)
216
- if key and not self.is_ignored_key(key):
228
+ key = self._convert_filepath_to_key(filepath) # type: ignore[no-untyped-call]
229
+ if key and not self.is_ignored_key(key): # type: ignore[no-untyped-call]
217
230
  key_list.append(key)
218
231
  return key_list
219
232
 
220
- def remove_key(self, key: Tuple[str, ...]) -> bool:
233
+ def remove_key(self, key: Tuple[str, ...]) -> bool: # type: ignore[override]
221
234
  """Delete an object from the store.
222
235
 
223
236
  Args:
@@ -250,7 +263,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
250
263
  result = fileio.exists(filepath)
251
264
  return result
252
265
 
253
- def get_url_for_key(
266
+ def get_url_for_key( # type: ignore[override]
254
267
  self, key: Tuple[str, ...], protocol: Optional[str] = None
255
268
  ) -> str:
256
269
  """Get the URL of an object in the store.
@@ -292,7 +305,7 @@ class ZenMLArtifactStoreBackend(TupleStoreBackend): # type: ignore[misc]
292
305
  f"requested but `base_public_path` was not configured for the "
293
306
  f"{self.__class__.__name__}"
294
307
  )
295
- filepath = self._convert_key_to_filepath(key)
308
+ filepath = self._convert_key_to_filepath(key) # type: ignore[no-untyped-call]
296
309
  public_url = self.base_public_path + filepath.replace(self.proto, "")
297
310
  return cast(str, public_url)
298
311
 
@@ -25,10 +25,10 @@ from great_expectations.core import ( # type: ignore[import-untyped]
25
25
  from great_expectations.core.expectation_validation_result import ( # type: ignore[import-untyped]
26
26
  ExpectationSuiteValidationResult,
27
27
  )
28
- from great_expectations.data_context.types.base import ( # type: ignore[import-untyped]
28
+ from great_expectations.data_context.types.base import (
29
29
  CheckpointConfig,
30
30
  )
31
- from great_expectations.data_context.types.resource_identifiers import ( # type: ignore[import-untyped]
31
+ from great_expectations.data_context.types.resource_identifiers import (
32
32
  ExpectationSuiteIdentifier,
33
33
  ValidationResultIdentifier,
34
34
  )
@@ -86,7 +86,7 @@ class GreatExpectationsMaterializer(BaseMaterializer):
86
86
  validation_dict = {}
87
87
  for result_ident, results in artifact_dict["run_results"].items():
88
88
  validation_ident = (
89
- ValidationResultIdentifier.from_fixed_length_tuple(
89
+ ValidationResultIdentifier.from_fixed_length_tuple( # type: ignore[no-untyped-call]
90
90
  result_ident.split("::")[1].split("/")
91
91
  )
92
92
  )
@@ -13,14 +13,14 @@
13
13
  # permissions and limitations under the License.
14
14
  """Great Expectations data profiling standard step."""
15
15
 
16
- from typing import Optional
16
+ from typing import Any, Dict, Optional
17
17
 
18
18
  import pandas as pd
19
19
  from great_expectations.core.batch import ( # type: ignore[import-untyped]
20
20
  RuntimeBatchRequest,
21
21
  )
22
- from great_expectations.data_context.data_context import ( # type: ignore[import-untyped]
23
- BaseDataContext,
22
+ from great_expectations.data_context.data_context.abstract_data_context import (
23
+ AbstractDataContext,
24
24
  )
25
25
 
26
26
  from zenml import get_step_context
@@ -31,7 +31,7 @@ logger = get_logger(__name__)
31
31
 
32
32
 
33
33
  def create_batch_request(
34
- context: BaseDataContext,
34
+ context: AbstractDataContext,
35
35
  dataset: pd.DataFrame,
36
36
  data_asset_name: Optional[str],
37
37
  ) -> RuntimeBatchRequest:
@@ -62,7 +62,7 @@ def create_batch_request(
62
62
  data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}"
63
63
  batch_identifier = "default"
64
64
 
65
- datasource_config = {
65
+ datasource_config: Dict[str, Any] = {
66
66
  "name": datasource_name,
67
67
  "class_name": "Datasource",
68
68
  "module_name": "great_expectations.datasource",
@@ -30,6 +30,9 @@ class HuggingfaceIntegration(Integration):
30
30
  "transformers<=4.31",
31
31
  "datasets",
32
32
  "huggingface_hub>0.19.0",
33
+ "accelerate",
34
+ "bitsandbytes>=0.41.3",
35
+ "peft",
33
36
  # temporary fix for CI issue similar to:
34
37
  # - https://github.com/huggingface/datasets/issues/6737
35
38
  # - https://github.com/huggingface/datasets/issues/6697
@@ -61,7 +61,7 @@ class HuggingFaceModelDeployerConfig(
61
61
  namespace: Hugging Face namespace used to list endpoints
62
62
  """
63
63
 
64
- token: Optional[str] = SecretField()
64
+ token: Optional[str] = SecretField(default=None)
65
65
 
66
66
  # The namespace to list endpoints for. Set to `"*"` to list all endpoints
67
67
  # from all namespaces (i.e. personal namespace and all orgs the user belongs to).
@@ -16,3 +16,6 @@
16
16
  from zenml.integrations.huggingface.steps.huggingface_deployer import (
17
17
  huggingface_model_deployer_step,
18
18
  )
19
+ from zenml.integrations.huggingface.steps.accelerate_runner import (
20
+ run_with_accelerate,
21
+ )
@@ -0,0 +1,149 @@
1
+ # Apache Software License 2.0
2
+ #
3
+ # Copyright (c) ZenML GmbH 2024. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """Step function to run any ZenML step using Accelerate."""
18
+
19
+ import functools
20
+ from typing import Any, Callable, Optional, TypeVar, cast
21
+
22
+ import cloudpickle as pickle
23
+ from accelerate.commands.launch import ( # type: ignore[import-untyped]
24
+ launch_command,
25
+ launch_command_parser,
26
+ )
27
+
28
+ from zenml.logger import get_logger
29
+ from zenml.steps import BaseStep
30
+ from zenml.utils.function_utils import _cli_arg_name, create_cli_wrapped_script
31
+
32
+ logger = get_logger(__name__)
33
+ F = TypeVar("F", bound=Callable[..., Any])
34
+
35
+
36
+ def run_with_accelerate(
37
+ step_function: BaseStep,
38
+ num_processes: Optional[int] = None,
39
+ use_cpu: bool = False,
40
+ ) -> BaseStep:
41
+ """Run a function with accelerate.
42
+
43
+ Accelerate package: https://huggingface.co/docs/accelerate/en/index
44
+ Example:
45
+ ```python
46
+ from zenml import step, pipeline
47
+ from zenml.integrations.hugginface.steps import run_with_accelerate
48
+ @step
49
+ def training_step(some_param: int, ...):
50
+ # your training code is below
51
+ ...
52
+
53
+ @pipeline
54
+ def training_pipeline(some_param: int, ...):
55
+ run_with_accelerate(training_step, num_processes=4)(some_param, ...)
56
+ ```
57
+
58
+ Args:
59
+ step_function: The step function to run.
60
+ num_processes: The number of processes to use.
61
+ use_cpu: Whether to use the CPU.
62
+
63
+ Returns:
64
+ The accelerate-enabled version of the step.
65
+ """
66
+
67
+ def _decorator(entrypoint: F) -> F:
68
+ @functools.wraps(entrypoint)
69
+ def inner(*args: Any, **kwargs: Any) -> Any:
70
+ if args:
71
+ raise ValueError(
72
+ "Accelerated steps do not support positional arguments."
73
+ )
74
+
75
+ if not use_cpu:
76
+ import torch
77
+
78
+ logger.info("Starting accelerate job...")
79
+
80
+ device_count = torch.cuda.device_count()
81
+ if num_processes is None:
82
+ _num_processes = device_count
83
+ else:
84
+ if num_processes > device_count:
85
+ logger.warning(
86
+ f"Number of processes ({num_processes}) is greater than "
87
+ f"the number of available GPUs ({device_count}). Using all GPUs."
88
+ )
89
+ _num_processes = device_count
90
+ else:
91
+ _num_processes = num_processes
92
+ else:
93
+ _num_processes = num_processes or 1
94
+
95
+ with create_cli_wrapped_script(
96
+ entrypoint, flavour="accelerate"
97
+ ) as (
98
+ script_path,
99
+ output_path,
100
+ ):
101
+ commands = ["--num_processes", str(_num_processes)]
102
+ if use_cpu:
103
+ commands += [
104
+ "--cpu",
105
+ "--num_cpu_threads_per_process",
106
+ "10",
107
+ ]
108
+ commands.append(str(script_path.absolute()))
109
+ for k, v in kwargs.items():
110
+ k = _cli_arg_name(k)
111
+ if isinstance(v, bool):
112
+ if v:
113
+ commands.append(f"--{k}")
114
+ elif isinstance(v, str):
115
+ commands += [f"--{k}", '"{v}"']
116
+ elif type(v) in (list, tuple, set):
117
+ for each in v:
118
+ commands.append(f"--{k}")
119
+ if isinstance(each, str):
120
+ commands.append(f'"{each}"')
121
+ else:
122
+ commands.append(f"{each}")
123
+ else:
124
+ commands += [f"--{k}", f"{v}"]
125
+
126
+ logger.debug(commands)
127
+
128
+ parser = launch_command_parser()
129
+ args = parser.parse_args(commands)
130
+ try:
131
+ launch_command(args)
132
+ except Exception as e:
133
+ logger.error(
134
+ "Accelerate training job failed... See error message for details."
135
+ )
136
+ raise RuntimeError(
137
+ "Accelerate training job failed."
138
+ ) from e
139
+ else:
140
+ logger.info(
141
+ "Accelerate training job finished successfully."
142
+ )
143
+ return pickle.load(open(output_path, "rb"))
144
+
145
+ return cast(F, inner)
146
+
147
+ setattr(step_function, "entrypoint", _decorator(step_function.entrypoint))
148
+
149
+ return step_function