zenml-nightly 0.58.2.dev20240618__py3-none-any.whl → 0.58.2.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. zenml/VERSION +1 -1
  2. zenml/_hub/client.py +8 -5
  3. zenml/actions/base_action.py +8 -10
  4. zenml/artifact_stores/base_artifact_store.py +20 -15
  5. zenml/artifact_stores/local_artifact_store.py +3 -2
  6. zenml/artifacts/artifact_config.py +34 -19
  7. zenml/artifacts/external_artifact.py +18 -8
  8. zenml/artifacts/external_artifact_config.py +14 -6
  9. zenml/artifacts/unmaterialized_artifact.py +2 -11
  10. zenml/cli/__init__.py +6 -0
  11. zenml/cli/artifact.py +20 -2
  12. zenml/cli/served_model.py +0 -1
  13. zenml/cli/server.py +3 -3
  14. zenml/cli/utils.py +36 -40
  15. zenml/cli/web_login.py +2 -2
  16. zenml/client.py +198 -24
  17. zenml/client_lazy_loader.py +20 -14
  18. zenml/config/base_settings.py +5 -6
  19. zenml/config/build_configuration.py +1 -1
  20. zenml/config/compiler.py +3 -3
  21. zenml/config/docker_settings.py +27 -28
  22. zenml/config/global_config.py +33 -37
  23. zenml/config/pipeline_configurations.py +8 -11
  24. zenml/config/pipeline_run_configuration.py +6 -2
  25. zenml/config/pipeline_spec.py +3 -4
  26. zenml/config/resource_settings.py +8 -9
  27. zenml/config/schedule.py +16 -20
  28. zenml/config/secret_reference_mixin.py +6 -3
  29. zenml/config/secrets_store_config.py +16 -23
  30. zenml/config/server_config.py +50 -46
  31. zenml/config/settings_resolver.py +1 -1
  32. zenml/config/source.py +45 -35
  33. zenml/config/step_configurations.py +53 -31
  34. zenml/config/store_config.py +20 -19
  35. zenml/config/strict_base_model.py +2 -6
  36. zenml/constants.py +26 -2
  37. zenml/container_registries/base_container_registry.py +3 -2
  38. zenml/container_registries/default_container_registry.py +3 -3
  39. zenml/event_hub/base_event_hub.py +1 -1
  40. zenml/event_sources/base_event_source.py +11 -16
  41. zenml/exceptions.py +4 -0
  42. zenml/integrations/airflow/__init__.py +2 -10
  43. zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +6 -7
  44. zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +13 -249
  45. zenml/integrations/airflow/orchestrators/dag_generator.py +5 -3
  46. zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +5 -4
  47. zenml/integrations/aws/__init__.py +1 -1
  48. zenml/integrations/aws/flavors/aws_container_registry_flavor.py +3 -2
  49. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +11 -5
  50. zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +6 -2
  51. zenml/integrations/aws/service_connectors/aws_service_connector.py +5 -4
  52. zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +4 -4
  53. zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -3
  54. zenml/integrations/azure/step_operators/azureml_step_operator.py +1 -1
  55. zenml/integrations/bentoml/steps/bentoml_deployer.py +1 -1
  56. zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +8 -12
  57. zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +1 -1
  58. zenml/integrations/evidently/__init__.py +3 -4
  59. zenml/integrations/evidently/column_mapping.py +11 -3
  60. zenml/integrations/evidently/data_validators/evidently_data_validator.py +21 -3
  61. zenml/integrations/evidently/metrics.py +5 -6
  62. zenml/integrations/evidently/tests.py +5 -6
  63. zenml/integrations/facets/models.py +2 -6
  64. zenml/integrations/feast/__init__.py +3 -1
  65. zenml/integrations/feast/feature_stores/feast_feature_store.py +0 -23
  66. zenml/integrations/gcp/__init__.py +1 -1
  67. zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +1 -1
  68. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +1 -1
  69. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +234 -103
  70. zenml/integrations/gcp/service_connectors/gcp_service_connector.py +57 -42
  71. zenml/integrations/github/code_repositories/github_code_repository.py +1 -1
  72. zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +9 -13
  73. zenml/integrations/great_expectations/__init__.py +1 -1
  74. zenml/integrations/great_expectations/data_validators/ge_data_validator.py +44 -44
  75. zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +35 -2
  76. zenml/integrations/great_expectations/ge_store_backend.py +24 -11
  77. zenml/integrations/great_expectations/materializers/ge_materializer.py +3 -3
  78. zenml/integrations/great_expectations/utils.py +5 -5
  79. zenml/integrations/huggingface/__init__.py +3 -0
  80. zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +1 -1
  81. zenml/integrations/huggingface/steps/__init__.py +3 -0
  82. zenml/integrations/huggingface/steps/accelerate_runner.py +149 -0
  83. zenml/integrations/huggingface/steps/huggingface_deployer.py +2 -2
  84. zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +1 -1
  85. zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +4 -3
  86. zenml/integrations/kubeflow/__init__.py +1 -1
  87. zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +48 -81
  88. zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +295 -245
  89. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +1 -1
  90. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -2
  91. zenml/integrations/kubernetes/pod_settings.py +17 -31
  92. zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +8 -7
  93. zenml/integrations/label_studio/__init__.py +1 -3
  94. zenml/integrations/label_studio/annotators/label_studio_annotator.py +3 -4
  95. zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +2 -2
  96. zenml/integrations/langchain/materializers/document_materializer.py +44 -8
  97. zenml/integrations/mlflow/__init__.py +9 -3
  98. zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +1 -1
  99. zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +29 -37
  100. zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +4 -4
  101. zenml/integrations/mlflow/steps/mlflow_deployer.py +1 -1
  102. zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +1 -1
  103. zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +1 -1
  104. zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +9 -8
  105. zenml/integrations/seldon/seldon_client.py +52 -67
  106. zenml/integrations/seldon/services/seldon_deployment.py +3 -3
  107. zenml/integrations/seldon/steps/seldon_deployer.py +4 -4
  108. zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +15 -5
  109. zenml/integrations/skypilot_aws/__init__.py +1 -1
  110. zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +1 -1
  111. zenml/integrations/skypilot_azure/__init__.py +1 -1
  112. zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +1 -1
  113. zenml/integrations/skypilot_gcp/__init__.py +2 -1
  114. zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +1 -1
  115. zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +2 -2
  116. zenml/integrations/spark/flavors/spark_step_operator_flavor.py +1 -1
  117. zenml/integrations/tekton/__init__.py +1 -1
  118. zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +66 -23
  119. zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +547 -233
  120. zenml/integrations/tensorboard/__init__.py +1 -12
  121. zenml/integrations/tensorboard/services/tensorboard_service.py +3 -5
  122. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +6 -6
  123. zenml/integrations/tensorflow/__init__.py +2 -10
  124. zenml/integrations/tensorflow/materializers/keras_materializer.py +17 -9
  125. zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +9 -14
  126. zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +1 -1
  127. zenml/lineage_graph/lineage_graph.py +1 -1
  128. zenml/materializers/built_in_materializer.py +3 -3
  129. zenml/materializers/pydantic_materializer.py +2 -2
  130. zenml/metadata/lazy_load.py +4 -4
  131. zenml/metadata/metadata_types.py +64 -4
  132. zenml/model/model.py +79 -54
  133. zenml/model_deployers/base_model_deployer.py +14 -12
  134. zenml/model_registries/base_model_registry.py +17 -15
  135. zenml/models/__init__.py +79 -206
  136. zenml/models/v2/base/base.py +54 -41
  137. zenml/models/v2/base/base_plugin_flavor.py +2 -6
  138. zenml/models/v2/base/filter.py +91 -76
  139. zenml/models/v2/base/page.py +2 -12
  140. zenml/models/v2/base/scoped.py +4 -7
  141. zenml/models/v2/core/api_key.py +22 -8
  142. zenml/models/v2/core/artifact.py +2 -2
  143. zenml/models/v2/core/artifact_version.py +74 -40
  144. zenml/models/v2/core/code_repository.py +37 -10
  145. zenml/models/v2/core/component.py +65 -16
  146. zenml/models/v2/core/device.py +14 -4
  147. zenml/models/v2/core/event_source.py +1 -2
  148. zenml/models/v2/core/flavor.py +74 -8
  149. zenml/models/v2/core/logs.py +68 -8
  150. zenml/models/v2/core/model.py +8 -4
  151. zenml/models/v2/core/model_version.py +25 -6
  152. zenml/models/v2/core/model_version_artifact.py +51 -21
  153. zenml/models/v2/core/model_version_pipeline_run.py +45 -13
  154. zenml/models/v2/core/pipeline.py +37 -72
  155. zenml/models/v2/core/pipeline_build.py +29 -17
  156. zenml/models/v2/core/pipeline_deployment.py +18 -6
  157. zenml/models/v2/core/pipeline_namespace.py +113 -0
  158. zenml/models/v2/core/pipeline_run.py +50 -22
  159. zenml/models/v2/core/run_metadata.py +59 -36
  160. zenml/models/v2/core/schedule.py +37 -24
  161. zenml/models/v2/core/secret.py +31 -12
  162. zenml/models/v2/core/service.py +64 -36
  163. zenml/models/v2/core/service_account.py +24 -11
  164. zenml/models/v2/core/service_connector.py +219 -44
  165. zenml/models/v2/core/stack.py +45 -17
  166. zenml/models/v2/core/step_run.py +28 -8
  167. zenml/models/v2/core/tag.py +8 -4
  168. zenml/models/v2/core/trigger.py +2 -2
  169. zenml/models/v2/core/trigger_execution.py +1 -0
  170. zenml/models/v2/core/user.py +18 -21
  171. zenml/models/v2/core/workspace.py +13 -3
  172. zenml/models/v2/misc/build_item.py +3 -3
  173. zenml/models/v2/misc/external_user.py +2 -6
  174. zenml/models/v2/misc/hub_plugin_models.py +9 -9
  175. zenml/models/v2/misc/loaded_visualization.py +2 -2
  176. zenml/models/v2/misc/service_connector_type.py +8 -17
  177. zenml/models/v2/misc/user_auth.py +7 -2
  178. zenml/new/pipelines/build_utils.py +3 -3
  179. zenml/new/pipelines/pipeline.py +17 -13
  180. zenml/new/pipelines/run_utils.py +103 -1
  181. zenml/orchestrators/base_orchestrator.py +10 -7
  182. zenml/orchestrators/local_docker/local_docker_orchestrator.py +1 -1
  183. zenml/orchestrators/step_runner.py +3 -6
  184. zenml/orchestrators/utils.py +1 -1
  185. zenml/plugins/base_plugin_flavor.py +6 -10
  186. zenml/plugins/plugin_flavor_registry.py +3 -7
  187. zenml/secret/base_secret.py +7 -8
  188. zenml/service_connectors/docker_service_connector.py +4 -3
  189. zenml/service_connectors/service_connector.py +5 -12
  190. zenml/service_connectors/service_connector_registry.py +2 -4
  191. zenml/services/container/container_service.py +1 -1
  192. zenml/services/container/container_service_endpoint.py +1 -1
  193. zenml/services/local/local_service.py +1 -1
  194. zenml/services/local/local_service_endpoint.py +1 -1
  195. zenml/services/service.py +16 -10
  196. zenml/services/service_type.py +4 -5
  197. zenml/services/terraform/terraform_service.py +1 -1
  198. zenml/stack/flavor.py +1 -5
  199. zenml/stack/flavor_registry.py +4 -4
  200. zenml/stack/stack.py +4 -1
  201. zenml/stack/stack_component.py +55 -31
  202. zenml/steps/base_step.py +34 -28
  203. zenml/steps/entrypoint_function_utils.py +3 -5
  204. zenml/steps/utils.py +12 -14
  205. zenml/utils/cuda_utils.py +50 -0
  206. zenml/utils/deprecation_utils.py +18 -20
  207. zenml/utils/dict_utils.py +1 -1
  208. zenml/utils/filesync_model.py +65 -28
  209. zenml/utils/function_utils.py +260 -0
  210. zenml/utils/json_utils.py +131 -0
  211. zenml/utils/mlstacks_utils.py +2 -2
  212. zenml/utils/pydantic_utils.py +270 -62
  213. zenml/utils/secret_utils.py +65 -12
  214. zenml/utils/source_utils.py +2 -2
  215. zenml/utils/typed_model.py +5 -3
  216. zenml/utils/typing_utils.py +243 -0
  217. zenml/utils/yaml_utils.py +1 -1
  218. zenml/zen_server/auth.py +2 -2
  219. zenml/zen_server/cloud_utils.py +6 -6
  220. zenml/zen_server/deploy/base_provider.py +1 -1
  221. zenml/zen_server/deploy/deployment.py +6 -8
  222. zenml/zen_server/deploy/docker/docker_zen_server.py +3 -4
  223. zenml/zen_server/deploy/local/local_provider.py +0 -1
  224. zenml/zen_server/deploy/local/local_zen_server.py +6 -6
  225. zenml/zen_server/deploy/terraform/terraform_zen_server.py +4 -6
  226. zenml/zen_server/exceptions.py +4 -1
  227. zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +1 -1
  228. zenml/zen_server/pipeline_deployment/utils.py +48 -68
  229. zenml/zen_server/rbac/models.py +2 -5
  230. zenml/zen_server/rbac/utils.py +11 -14
  231. zenml/zen_server/routers/auth_endpoints.py +2 -2
  232. zenml/zen_server/routers/pipeline_builds_endpoints.py +1 -1
  233. zenml/zen_server/routers/runs_endpoints.py +1 -1
  234. zenml/zen_server/routers/secrets_endpoints.py +3 -2
  235. zenml/zen_server/routers/server_endpoints.py +1 -1
  236. zenml/zen_server/routers/steps_endpoints.py +1 -1
  237. zenml/zen_server/routers/workspaces_endpoints.py +1 -1
  238. zenml/zen_stores/base_zen_store.py +46 -9
  239. zenml/zen_stores/migrations/utils.py +42 -46
  240. zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py +1 -1
  241. zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +5 -3
  242. zenml/zen_stores/migrations/versions/10a907dad202_delete_mlmd_tables.py +1 -1
  243. zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +8 -10
  244. zenml/zen_stores/migrations/versions/37835ce041d2_optimizing_database.py +3 -3
  245. zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +10 -12
  246. zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +3 -2
  247. zenml/zen_stores/migrations/versions/6917bce75069_add_pipeline_run_unique_constraint.py +4 -4
  248. zenml/zen_stores/migrations/versions/728c6369cfaa_add_name_column_to_input_artifact_pk.py +3 -2
  249. zenml/zen_stores/migrations/versions/743ec82b1b3c_update_size_of_build_images.py +2 -2
  250. zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
  251. zenml/zen_stores/migrations/versions/7834208cc3f6_artifact_project_scoping.py +8 -7
  252. zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +6 -4
  253. zenml/zen_stores/migrations/versions/7e4a481d17f7_add_identity_table.py +2 -2
  254. zenml/zen_stores/migrations/versions/7f603e583dd7_fixed_migration.py +1 -1
  255. zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +2 -2
  256. zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +4 -4
  257. zenml/zen_stores/migrations/versions/alembic_start.py +1 -1
  258. zenml/zen_stores/migrations/versions/fbd7f18ced1e_increase_step_run_field_lengths.py +4 -4
  259. zenml/zen_stores/rest_zen_store.py +109 -49
  260. zenml/zen_stores/schemas/api_key_schemas.py +1 -1
  261. zenml/zen_stores/schemas/artifact_schemas.py +8 -8
  262. zenml/zen_stores/schemas/artifact_visualization_schemas.py +3 -3
  263. zenml/zen_stores/schemas/code_repository_schemas.py +1 -1
  264. zenml/zen_stores/schemas/component_schemas.py +8 -3
  265. zenml/zen_stores/schemas/device_schemas.py +8 -6
  266. zenml/zen_stores/schemas/event_source_schemas.py +3 -4
  267. zenml/zen_stores/schemas/flavor_schemas.py +5 -3
  268. zenml/zen_stores/schemas/model_schemas.py +26 -1
  269. zenml/zen_stores/schemas/pipeline_build_schemas.py +1 -1
  270. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +4 -4
  271. zenml/zen_stores/schemas/pipeline_run_schemas.py +6 -6
  272. zenml/zen_stores/schemas/pipeline_schemas.py +5 -2
  273. zenml/zen_stores/schemas/run_metadata_schemas.py +2 -2
  274. zenml/zen_stores/schemas/secret_schemas.py +8 -5
  275. zenml/zen_stores/schemas/server_settings_schemas.py +3 -1
  276. zenml/zen_stores/schemas/service_connector_schemas.py +1 -1
  277. zenml/zen_stores/schemas/service_schemas.py +11 -2
  278. zenml/zen_stores/schemas/stack_schemas.py +1 -1
  279. zenml/zen_stores/schemas/step_run_schemas.py +11 -11
  280. zenml/zen_stores/schemas/tag_schemas.py +6 -2
  281. zenml/zen_stores/schemas/trigger_schemas.py +2 -2
  282. zenml/zen_stores/schemas/user_schemas.py +2 -2
  283. zenml/zen_stores/schemas/workspace_schemas.py +3 -1
  284. zenml/zen_stores/secrets_stores/aws_secrets_store.py +19 -20
  285. zenml/zen_stores/secrets_stores/azure_secrets_store.py +17 -20
  286. zenml/zen_stores/secrets_stores/base_secrets_store.py +79 -12
  287. zenml/zen_stores/secrets_stores/gcp_secrets_store.py +17 -20
  288. zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +4 -8
  289. zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +10 -7
  290. zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -6
  291. zenml/zen_stores/sql_zen_store.py +196 -120
  292. zenml/zen_stores/zen_store_interface.py +33 -0
  293. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/METADATA +8 -7
  294. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/RECORD +297 -294
  295. zenml/integrations/kubeflow/utils.py +0 -95
  296. zenml/models/v2/base/internal.py +0 -37
  297. zenml/models/v2/base/update.py +0 -44
  298. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/LICENSE +0 -0
  299. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/WHEEL +0 -0
  300. {zenml_nightly-0.58.2.dev20240618.dist-info → zenml_nightly-0.58.2.dev20240619.dist-info}/entry_points.txt +0 -0
zenml/steps/base_step.py CHANGED
@@ -34,7 +34,7 @@ from typing import (
34
34
  cast,
35
35
  )
36
36
 
37
- from pydantic import BaseModel, Extra, ValidationError
37
+ from pydantic import BaseModel, ConfigDict, ValidationError
38
38
 
39
39
  from zenml.client_lazy_loader import ClientLazyLoader
40
40
  from zenml.config.retry_config import StepRetryConfig
@@ -59,6 +59,7 @@ from zenml.utils import (
59
59
  settings_utils,
60
60
  source_code_utils,
61
61
  source_utils,
62
+ typing_utils,
62
63
  )
63
64
 
64
65
  if TYPE_CHECKING:
@@ -513,17 +514,17 @@ class BaseStep(metaclass=BaseStepMeta):
513
514
  )
514
515
  elif isinstance(value, LazyArtifactVersionResponse):
515
516
  model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
516
- model=value._lazy_load_model,
517
- artifact_name=value._lazy_load_name,
518
- artifact_version=value._lazy_load_version,
517
+ model=value.lazy_load_model,
518
+ artifact_name=value.lazy_load_name,
519
+ artifact_version=value.lazy_load_version,
519
520
  metadata_name=None,
520
521
  )
521
522
  elif isinstance(value, LazyRunMetadataResponse):
522
523
  model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader(
523
- model=value._lazy_load_model,
524
- artifact_name=value._lazy_load_artifact_name,
525
- artifact_version=value._lazy_load_artifact_version,
526
- metadata_name=value._lazy_load_metadata_name,
524
+ model=value.lazy_load_model,
525
+ artifact_name=value.lazy_load_artifact_name,
526
+ artifact_version=value.lazy_load_artifact_version,
527
+ metadata_name=value.lazy_load_metadata_name,
527
528
  )
528
529
  elif isinstance(value, ClientLazyLoader):
529
530
  client_lazy_loaders[key] = value
@@ -583,7 +584,7 @@ class BaseStep(metaclass=BaseStepMeta):
583
584
  from zenml.new.pipelines.pipeline import Pipeline
584
585
 
585
586
  if not Pipeline.ACTIVE_PIPELINE:
586
- # The step is being called outside of the context of a pipeline,
587
+ # The step is being called outside the context of a pipeline,
587
588
  # we simply call the entrypoint
588
589
  return self.call_entrypoint(*args, **kwargs)
589
590
 
@@ -645,12 +646,15 @@ class BaseStep(metaclass=BaseStepMeta):
645
646
  try:
646
647
  validated_args = pydantic_utils.validate_function_args(
647
648
  self.entrypoint,
648
- {"arbitrary_types_allowed": True, "smart_union": True},
649
+ ConfigDict(arbitrary_types_allowed=True),
649
650
  *args,
650
651
  **kwargs,
651
652
  )
652
653
  except ValidationError as e:
653
- raise StepInterfaceError("Invalid entrypoint arguments.") from e
654
+ raise StepInterfaceError(
655
+ "Invalid step function entrypoint arguments. Check out the "
656
+ "pydantic error above for more details."
657
+ ) from e
654
658
 
655
659
  return self.entrypoint(**validated_args)
656
660
 
@@ -796,7 +800,7 @@ class BaseStep(metaclass=BaseStepMeta):
796
800
  success_hook_source = resolve_and_validate_hook(on_success)
797
801
 
798
802
  if isinstance(parameters, BaseParameters):
799
- parameters = parameters.dict()
803
+ parameters = parameters.model_dump()
800
804
 
801
805
  values = dict_utils.remove_none_values(
802
806
  {
@@ -1111,12 +1115,6 @@ To avoid this consider setting step parameters only in one place (config or code
1111
1115
  output_name, PartialArtifactConfiguration()
1112
1116
  )
1113
1117
 
1114
- from pydantic.typing import (
1115
- get_origin,
1116
- is_none_type,
1117
- is_union,
1118
- )
1119
-
1120
1118
  from zenml.steps.utils import get_args
1121
1119
 
1122
1120
  if not output.materializer_source:
@@ -1129,13 +1127,15 @@ To avoid this consider setting step parameters only in one place (config or code
1129
1127
  )
1130
1128
  continue
1131
1129
 
1132
- if is_union(
1133
- get_origin(output_annotation.resolved_annotation)
1130
+ if typing_utils.is_union(
1131
+ typing_utils.get_origin(
1132
+ output_annotation.resolved_annotation
1133
+ )
1134
1134
  or output_annotation.resolved_annotation
1135
1135
  ):
1136
1136
  output_types = tuple(
1137
1137
  type(None)
1138
- if is_none_type(output_type)
1138
+ if typing_utils.is_none_type(output_type)
1139
1139
  else output_type
1140
1140
  for output_type in get_args(
1141
1141
  output_annotation.resolved_annotation
@@ -1169,7 +1169,7 @@ To avoid this consider setting step parameters only in one place (config or code
1169
1169
  config = StepConfigurationUpdate(**values)
1170
1170
  self._apply_configuration(config)
1171
1171
 
1172
- self._configuration = self._configuration.copy(
1172
+ self._configuration = self._configuration.model_copy(
1173
1173
  update={
1174
1174
  "caching_parameters": self.caching_parameters,
1175
1175
  "external_input_artifacts": external_artifacts,
@@ -1178,7 +1178,9 @@ To avoid this consider setting step parameters only in one place (config or code
1178
1178
  }
1179
1179
  )
1180
1180
 
1181
- return StepConfiguration.parse_obj(self._configuration)
1181
+ return StepConfiguration.model_validate(
1182
+ self._configuration.model_dump()
1183
+ )
1182
1184
 
1183
1185
  def _finalize_parameters(self) -> Dict[str, Any]:
1184
1186
  """Finalizes the config parameters for running this step.
@@ -1199,7 +1201,7 @@ To avoid this consider setting step parameters only in one place (config or code
1199
1201
  # Make sure we have all necessary values to instantiate the
1200
1202
  # pydantic model later
1201
1203
  model = annotation(**value)
1202
- params[key] = model.dict()
1204
+ params[key] = model.model_dump()
1203
1205
  else:
1204
1206
  params[key] = value
1205
1207
 
@@ -1254,7 +1256,7 @@ To avoid this consider setting step parameters only in one place (config or code
1254
1256
  for (
1255
1257
  name,
1256
1258
  field,
1257
- ) in self.entrypoint_definition.legacy_params.annotation.__fields__.items():
1259
+ ) in self.entrypoint_definition.legacy_params.annotation.model_fields.items():
1258
1260
  if name in self.configuration.parameters:
1259
1261
  # a value for this parameter has been set already
1260
1262
  values[name] = self.configuration.parameters[name]
@@ -1262,7 +1264,7 @@ To avoid this consider setting step parameters only in one place (config or code
1262
1264
  # a value for this parameter has been set in the "new" way
1263
1265
  # already
1264
1266
  values[name] = params_defined_in_new_way[name]
1265
- elif field.required:
1267
+ elif field.is_required():
1266
1268
  # this field has no default value set and therefore needs
1267
1269
  # to be passed via an initialized config object
1268
1270
  missing_keys.append(name)
@@ -1278,8 +1280,12 @@ To avoid this consider setting step parameters only in one place (config or code
1278
1280
  )
1279
1281
 
1280
1282
  if (
1281
- self.entrypoint_definition.legacy_params.annotation.__config__.extra
1282
- == Extra.allow
1283
+ getattr(
1284
+ self.entrypoint_definition.legacy_params.annotation.model_config,
1285
+ "extra",
1286
+ None,
1287
+ )
1288
+ == "allow"
1283
1289
  ):
1284
1290
  # Add all parameters for the config class for backwards
1285
1291
  # compatibility if the config class allows extra attributes
@@ -27,7 +27,7 @@ from typing import (
27
27
  Union,
28
28
  )
29
29
 
30
- from pydantic import BaseConfig, ValidationError, create_model
30
+ from pydantic import ConfigDict, ValidationError, create_model
31
31
 
32
32
  from zenml.constants import ENFORCE_TYPE_ANNOTATIONS
33
33
  from zenml.exceptions import StepInterfaceError
@@ -235,16 +235,14 @@ class EntrypointFunctionDefinition(NamedTuple):
235
235
  parameter: The function parameter for which the value was provided.
236
236
  value: The input value.
237
237
  """
238
-
239
- class ModelConfig(BaseConfig):
240
- arbitrary_types_allowed = False
238
+ config_dict = ConfigDict(arbitrary_types_allowed=False)
241
239
 
242
240
  # Create a pydantic model with just a single required field with the
243
241
  # type annotation of the parameter to verify the input type including
244
242
  # pydantics type coercion
245
243
  validation_model_class = create_model(
246
244
  "input_validation_model",
247
- __config__=ModelConfig,
245
+ __config__=config_dict,
248
246
  value=(parameter.annotation, ...),
249
247
  )
250
248
  validation_model_class(value=value)
zenml/steps/utils.py CHANGED
@@ -21,7 +21,6 @@ import textwrap
21
21
  from typing import Any, Callable, Dict, Optional, Tuple, Union
22
22
  from uuid import UUID
23
23
 
24
- import pydantic.typing as pydantic_typing
25
24
  from pydantic import BaseModel
26
25
  from typing_extensions import Annotated
27
26
 
@@ -32,7 +31,7 @@ from zenml.logger import get_logger
32
31
  from zenml.metadata.metadata_types import MetadataType
33
32
  from zenml.new.steps.step_context import get_step_context
34
33
  from zenml.steps.step_output import Output
35
- from zenml.utils import source_code_utils
34
+ from zenml.utils import source_code_utils, typing_utils
36
35
 
37
36
  logger = get_logger(__name__)
38
37
 
@@ -42,8 +41,8 @@ SINGLE_RETURN_OUT_NAME = "output"
42
41
  class OutputSignature(BaseModel):
43
42
  """The signature of an output artifact."""
44
43
 
45
- resolved_annotation: Any
46
- artifact_config: Optional[ArtifactConfig]
44
+ resolved_annotation: Any = None
45
+ artifact_config: Optional[ArtifactConfig] = None
47
46
  has_custom_name: bool = False
48
47
 
49
48
 
@@ -60,8 +59,7 @@ def get_args(obj: Any) -> Tuple[Any, ...]:
60
59
  The args of the annotation.
61
60
  """
62
61
  return tuple(
63
- pydantic_typing.get_origin(v) or v
64
- for v in pydantic_typing.get_args(obj)
62
+ typing_utils.get_origin(v) or v for v in typing_utils.get_args(obj)
65
63
  )
66
64
 
67
65
 
@@ -123,11 +121,11 @@ def parse_return_type_annotations(
123
121
  for output_name, output_type in return_annotation.items()
124
122
  }
125
123
 
126
- elif pydantic_typing.get_origin(return_annotation) is tuple:
124
+ elif typing_utils.get_origin(return_annotation) is tuple:
127
125
  requires_multiple_artifacts = has_tuple_return(func)
128
126
  if requires_multiple_artifacts:
129
127
  output_signature: Dict[str, Any] = {}
130
- args = pydantic_typing.get_args(return_annotation)
128
+ args = typing_utils.get_args(return_annotation)
131
129
  if args[-1] is Ellipsis:
132
130
  raise RuntimeError(
133
131
  "Variable length output annotations are not allowed."
@@ -179,12 +177,12 @@ def resolve_type_annotation(obj: Any) -> Any:
179
177
  Returns:
180
178
  The non-generic class for generic aliases of the typing module.
181
179
  """
182
- origin = pydantic_typing.get_origin(obj) or obj
180
+ origin = typing_utils.get_origin(obj) or obj
183
181
 
184
182
  if origin is Annotated:
185
- annotation, *_ = pydantic_typing.get_args(obj)
183
+ annotation, *_ = typing_utils.get_args(obj)
186
184
  return resolve_type_annotation(annotation)
187
- elif pydantic_typing.is_union(origin):
185
+ elif typing_utils.is_union(origin):
188
186
  return obj
189
187
 
190
188
  return origin
@@ -212,10 +210,10 @@ def get_artifact_config_from_annotation_metadata(
212
210
  Returns:
213
211
  The artifact config.
214
212
  """
215
- if (pydantic_typing.get_origin(annotation) or annotation) is not Annotated:
213
+ if (typing_utils.get_origin(annotation) or annotation) is not Annotated:
216
214
  return None
217
215
 
218
- annotation, *metadata = pydantic_typing.get_args(annotation)
216
+ annotation, *metadata = typing_utils.get_args(annotation)
219
217
 
220
218
  error_message = (
221
219
  "Artifact annotation should only contain two elements: the artifact "
@@ -251,7 +249,7 @@ def get_artifact_config_from_annotation_metadata(
251
249
  if not artifact_config:
252
250
  artifact_config = ArtifactConfig(name=output_name)
253
251
  elif not artifact_config.name:
254
- artifact_config = artifact_config.copy()
252
+ artifact_config = artifact_config.model_copy()
255
253
  artifact_config.name = output_name
256
254
 
257
255
  if artifact_config and artifact_config.name == "":
@@ -0,0 +1,50 @@
1
+ # Apache Software License 2.0
2
+ #
3
+ # Copyright (c) ZenML GmbH 2024. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ """Utilities for managing GPU memory."""
18
+
19
+ import gc
20
+
21
+ from zenml.logger import get_logger
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ def cleanup_gpu_memory(force: bool = False) -> None:
27
+ """Clean up GPU memory.
28
+
29
+ Args:
30
+ force: whether to force the cleanup of GPU memory (must be passed explicitly)
31
+ """
32
+ if not force:
33
+ logger.warning(
34
+ "This will clean up all GPU memory on current physical machine. "
35
+ "This action is considered to be dangerous by default, since "
36
+ "it might affect other processes running in the same environment. "
37
+ "If this is intended, please explicitly pass `force=True`."
38
+ )
39
+ else:
40
+ try:
41
+ import torch
42
+ except ModuleNotFoundError:
43
+ logger.warning(
44
+ "No PyTorch installed. Skipping GPU memory cleanup."
45
+ )
46
+ return
47
+
48
+ logger.info("Cleaning up GPU memory...")
49
+ while gc.collect():
50
+ torch.cuda.empty_cache()
@@ -14,14 +14,12 @@
14
14
  """Deprecation utilities."""
15
15
 
16
16
  import warnings
17
- from typing import TYPE_CHECKING, Any, Dict, Set, Tuple, Type, Union
17
+ from typing import Any, Dict, Set, Tuple, Type, Union
18
18
 
19
- from pydantic import BaseModel, root_validator
19
+ from pydantic import BaseModel, model_validator
20
20
 
21
21
  from zenml.logger import get_logger
22
-
23
- if TYPE_CHECKING:
24
- AnyClassMethod = classmethod[Any] # type: ignore[type-arg]
22
+ from zenml.utils.pydantic_utils import before_validator_handler
25
23
 
26
24
  logger = get_logger(__name__)
27
25
 
@@ -30,7 +28,7 @@ PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE = "__previous_deprecation_warnings"
30
28
 
31
29
  def deprecate_pydantic_attributes(
32
30
  *attributes: Union[str, Tuple[str, str]],
33
- ) -> "AnyClassMethod":
31
+ ) -> Any:
34
32
  """Utility function for deprecating and migrating pydantic attributes.
35
33
 
36
34
  **Usage**:
@@ -55,22 +53,24 @@ def deprecate_pydantic_attributes(
55
53
  Args:
56
54
  *attributes: List of attributes to deprecate. This is either the name
57
55
  of the attribute to deprecate, or a tuple containing the name of
58
- the deprecated attribute and it's replacement.
56
+ the deprecated attribute, and it's replacement.
59
57
 
60
58
  Returns:
61
59
  Pydantic validator class method to be used on BaseModel subclasses
62
60
  to deprecate or migrate attributes.
63
61
  """
64
62
 
65
- @root_validator(pre=True, allow_reuse=True)
63
+ @model_validator(mode="before") # type: ignore[misc]
64
+ @classmethod
65
+ @before_validator_handler
66
66
  def _deprecation_validator(
67
- cls: Type[BaseModel], values: Dict[str, Any]
67
+ cls: Type[BaseModel], data: Dict[str, Any]
68
68
  ) -> Dict[str, Any]:
69
69
  """Pydantic validator function for deprecating pydantic attributes.
70
70
 
71
71
  Args:
72
72
  cls: The class on which the attributes are defined.
73
- values: All values passed at model initialization.
73
+ data: All values passed at model initialization.
74
74
 
75
75
  Raises:
76
76
  AssertionError: If either the deprecated or replacement attribute
@@ -110,14 +110,14 @@ def deprecate_pydantic_attributes(
110
110
  deprecated_attribute, replacement_attribute = attribute
111
111
 
112
112
  assert (
113
- replacement_attribute in cls.__fields__
113
+ replacement_attribute in cls.model_fields
114
114
  ), f"Unable to find attribute {replacement_attribute}."
115
115
 
116
116
  assert (
117
- deprecated_attribute in cls.__fields__
117
+ deprecated_attribute in cls.model_fields
118
118
  ), f"Unable to find attribute {deprecated_attribute}."
119
119
 
120
- if cls.__fields__[deprecated_attribute].required:
120
+ if cls.model_fields[deprecated_attribute].is_required():
121
121
  raise TypeError(
122
122
  f"Unable to deprecate attribute '{deprecated_attribute}' "
123
123
  f"of class {cls.__name__}. In order to deprecate an "
@@ -126,7 +126,7 @@ def deprecate_pydantic_attributes(
126
126
  "annotation."
127
127
  )
128
128
 
129
- if values.get(deprecated_attribute, None) is None:
129
+ if data.get(deprecated_attribute, None) is None:
130
130
  continue
131
131
 
132
132
  if replacement_attribute is None:
@@ -144,17 +144,15 @@ def deprecate_pydantic_attributes(
144
144
  attribute=deprecated_attribute,
145
145
  )
146
146
 
147
- if values.get(replacement_attribute, None) is None:
147
+ if data.get(replacement_attribute, None) is None:
148
148
  logger.debug(
149
149
  "Migrating value of deprecated attribute %s to "
150
150
  "replacement attribute %s.",
151
151
  deprecated_attribute,
152
152
  replacement_attribute,
153
153
  )
154
- values[replacement_attribute] = values.pop(
155
- deprecated_attribute
156
- )
157
- elif values[deprecated_attribute] != values[replacement_attribute]:
154
+ data[replacement_attribute] = data.pop(deprecated_attribute)
155
+ elif data[deprecated_attribute] != data[replacement_attribute]:
158
156
  raise ValueError(
159
157
  "Got different values for deprecated attribute "
160
158
  f"{deprecated_attribute} and replacement "
@@ -170,6 +168,6 @@ def deprecate_pydantic_attributes(
170
168
  previous_deprecation_warnings,
171
169
  )
172
170
 
173
- return values
171
+ return data
174
172
 
175
173
  return _deprecation_validator
zenml/utils/dict_utils.py CHANGED
@@ -17,7 +17,7 @@ import base64
17
17
  import json
18
18
  from typing import Any, Dict
19
19
 
20
- from pydantic.json import pydantic_encoder
20
+ from zenml.utils.json_utils import pydantic_encoder
21
21
 
22
22
 
23
23
  def recursive_update(
@@ -13,11 +13,16 @@
13
13
  # permissions and limitations under the License.
14
14
  """Filesync utils for ZenML."""
15
15
 
16
- import json
17
16
  import os
18
17
  from typing import Any, Optional
19
18
 
20
- from pydantic import BaseModel
19
+ from pydantic import (
20
+ BaseModel,
21
+ ValidationError,
22
+ ValidationInfo,
23
+ ValidatorFunctionWrapHandler,
24
+ model_validator,
25
+ )
21
26
 
22
27
  from zenml.io import fileio
23
28
  from zenml.logger import get_logger
@@ -40,29 +45,69 @@ class FileSyncModel(BaseModel):
40
45
  _config_file: str
41
46
  _config_file_timestamp: Optional[float] = None
42
47
 
43
- def __init__(self, config_file: str, **kwargs: Any) -> None:
44
- """Create a FileSyncModel instance synchronized with a configuration file on disk.
48
+ @model_validator(mode="wrap")
49
+ @classmethod
50
+ def config_validator(
51
+ cls,
52
+ data: Any,
53
+ handler: ValidatorFunctionWrapHandler,
54
+ info: ValidationInfo,
55
+ ) -> "FileSyncModel":
56
+ """Wrap model validator to infer the config_file during initialization.
45
57
 
46
58
  Args:
47
- config_file: configuration file path. If the file exists, the model
48
- will be initialized with the values from the file.
49
- **kwargs: additional keyword arguments to pass to the Pydantic model
50
- constructor. If supplied, these values will override those
51
- loaded from the configuration file.
59
+ data: The raw data that is provided before the validation.
60
+ handler: The actual validation function pydantic would use for the
61
+ built-in validation function.
62
+ info: The context information during the execution of this
63
+ validation function.
64
+
65
+ Returns:
66
+ the actual instance after the validation
67
+
68
+ Raises:
69
+ ValidationError: if you try to validate through a JSON string. You
70
+ need to provide a config_file path when you create a
71
+ FileSyncModel.
72
+ AssertionError: if the raw input does not include a config_file
73
+ path for the configuration file.
52
74
  """
53
- config_dict = {}
54
- if fileio.exists(config_file):
55
- config_dict = yaml_utils.read_yaml(config_file)
75
+ # Disable json validation
76
+ if info.mode == "json":
77
+ raise ValidationError(
78
+ "You can not instantiate filesync models using the JSON mode."
79
+ )
56
80
 
57
- self._config_file = config_file
58
- self._config_file_timestamp = None
81
+ if isinstance(data, dict):
82
+ # Assert that the config file is defined
83
+ assert (
84
+ "config_file" in data
85
+ ), "You have to provide a path for the configuration file."
59
86
 
60
- config_dict.update(kwargs)
61
- super(FileSyncModel, self).__init__(**config_dict)
87
+ config_file = data.pop("config_file")
62
88
 
63
- # write the configuration file to disk, to reflect new attributes
64
- # and schema changes
65
- self.write_config()
89
+ # Load the current values and update with new values
90
+ config_dict = {}
91
+ if fileio.exists(config_file):
92
+ config_dict = yaml_utils.read_yaml(config_file)
93
+ config_dict.update(data)
94
+
95
+ # Execute the regular validation
96
+ model = handler(config_dict)
97
+
98
+ assert isinstance(model, cls)
99
+
100
+ # Assign the private attribute and save the config
101
+ model._config_file = config_file
102
+ model.write_config()
103
+
104
+ else:
105
+ # If the raw value is not a dict, apply proper validation.
106
+ model = handler(data)
107
+
108
+ assert isinstance(model, cls)
109
+
110
+ return model
66
111
 
67
112
  def __setattr__(self, key: str, value: Any) -> None:
68
113
  """Sets an attribute on the model and persists it in the configuration file.
@@ -91,8 +136,7 @@ class FileSyncModel(BaseModel):
91
136
 
92
137
  def write_config(self) -> None:
93
138
  """Writes the model to the configuration file."""
94
- config_dict = json.loads(self.json())
95
- yaml_utils.write_yaml(self._config_file, config_dict)
139
+ yaml_utils.write_yaml(self._config_file, self.model_dump(mode="json"))
96
140
  self._config_file_timestamp = os.path.getmtime(self._config_file)
97
141
 
98
142
  def load_config(self) -> None:
@@ -115,10 +159,3 @@ class FileSyncModel(BaseModel):
115
159
  super(FileSyncModel, self).__setattr__(key, value)
116
160
 
117
161
  self._config_file_timestamp = file_timestamp
118
-
119
- class Config:
120
- """Pydantic configuration class."""
121
-
122
- # all attributes with leading underscore are private and therefore
123
- # are mutable and not included in serialization
124
- underscore_attrs_are_private = True