zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/__init__.py +7 -0
  3. zenml/cli/base.py +2 -2
  4. zenml/cli/pipeline.py +21 -0
  5. zenml/cli/utils.py +14 -11
  6. zenml/client.py +68 -3
  7. zenml/config/step_configurations.py +0 -5
  8. zenml/constants.py +3 -0
  9. zenml/enums.py +2 -0
  10. zenml/integrations/__init__.py +1 -0
  11. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
  12. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
  13. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
  14. zenml/integrations/constants.py +1 -0
  15. zenml/integrations/deepchecks/__init__.py +1 -1
  16. zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
  17. zenml/integrations/deepchecks/validation_checks.py +62 -5
  18. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
  19. zenml/integrations/lightning/__init__.py +1 -1
  20. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
  21. zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
  22. zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
  23. zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
  24. zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
  25. zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
  26. zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
  27. zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
  28. zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
  29. zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
  30. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
  31. zenml/models/v2/base/filter.py +315 -149
  32. zenml/models/v2/base/scoped.py +5 -2
  33. zenml/models/v2/core/artifact_version.py +69 -8
  34. zenml/models/v2/core/model.py +43 -6
  35. zenml/models/v2/core/model_version.py +49 -1
  36. zenml/models/v2/core/model_version_artifact.py +18 -3
  37. zenml/models/v2/core/model_version_pipeline_run.py +18 -4
  38. zenml/models/v2/core/pipeline.py +108 -1
  39. zenml/models/v2/core/pipeline_run.py +172 -21
  40. zenml/models/v2/core/run_template.py +53 -1
  41. zenml/models/v2/core/stack.py +33 -5
  42. zenml/models/v2/core/step_run.py +7 -0
  43. zenml/new/pipelines/pipeline.py +4 -0
  44. zenml/new/pipelines/run_utils.py +4 -1
  45. zenml/orchestrators/base_orchestrator.py +41 -12
  46. zenml/stack/stack.py +11 -2
  47. zenml/utils/env_utils.py +54 -1
  48. zenml/utils/string_utils.py +50 -0
  49. zenml/zen_server/cloud_utils.py +33 -8
  50. zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
  51. zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
  52. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
  53. zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
  54. zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
  55. zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
  56. zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
  57. zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
  58. zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
  59. zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
  60. zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
  61. zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
  62. zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
  63. zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
  64. zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
  65. zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
  66. zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
  67. zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
  68. zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
  69. zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
  70. zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
  71. zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
  72. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
  73. zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
  74. zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
  75. zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
  76. zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
  77. zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
  78. zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
  79. zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
  80. zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
  81. zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
  82. zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
  83. zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
  84. zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
  85. zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
  86. zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
  87. zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
  88. zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
  89. zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
  90. zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
  91. zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
  92. zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
  93. zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
  94. zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
  95. zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
  96. zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
  97. zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
  98. zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
  99. zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
  100. zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
  101. zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
  102. zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
  103. zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
  104. zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
  105. zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
  106. zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
  107. zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
  108. zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
  109. zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
  110. zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
  111. zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
  112. zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
  113. zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
  114. zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
  115. zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
  116. zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
  117. zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
  118. zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
  119. zenml/zen_server/dashboard/index.html +4 -4
  120. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  121. zenml/zen_server/dashboard_legacy/index.html +1 -1
  122. zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
  123. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  124. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
  125. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
  126. zenml/zen_server/routers/runs_endpoints.py +89 -3
  127. zenml/zen_stores/sql_zen_store.py +1 -0
  128. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/METADATA +8 -1
  129. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/RECORD +132 -124
  130. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
  131. zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
  132. zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
  133. zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
  134. zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
  135. zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
  136. zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
  137. zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
  138. zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
  139. zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
  140. zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
  141. zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
  142. zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
  143. zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
  144. zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
  145. zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
  146. zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
  147. zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
  148. zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
  149. zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
  150. zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
  151. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/LICENSE +0 -0
  152. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/WHEEL +0 -0
  153. {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.66.0.dev20240923
1
+ 0.66.0.dev20240927
zenml/cli/__init__.py CHANGED
@@ -1715,6 +1715,13 @@ To delete a pipeline run, use:
1715
1715
  zenml pipeline runs delete <PIPELINE_RUN_NAME_OR_ID>
1716
1716
  ```
1717
1717
 
1718
+ To refresh the status of a pipeline run, you can use the `refresh` command (
1719
+ only supported for pipelines executed on Vertex, Sagemaker or AzureML).
1720
+
1721
+ ```bash
1722
+ zenml pipeline runs refresh <PIPELINE_RUN_NAME_OR_ID>
1723
+ ```
1724
+
1718
1725
  If you run any of your pipelines with `pipeline.run(schedule=...)`, ZenML keeps
1719
1726
  track of the schedule and you can list all schedules via:
1720
1727
 
zenml/cli/base.py CHANGED
@@ -83,11 +83,11 @@ ZENML_PROJECT_TEMPLATES = dict(
83
83
  ),
84
84
  starter=ZenMLProjectTemplateLocation(
85
85
  github_url="zenml-io/template-starter",
86
- github_tag="2024.08.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
86
+ github_tag="2024.09.23", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
87
87
  ),
88
88
  nlp=ZenMLProjectTemplateLocation(
89
89
  github_url="zenml-io/template-nlp",
90
- github_tag="2024.08.29", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
90
+ github_tag="2024.09.23", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
91
91
  ),
92
92
  llm_finetuning=ZenMLProjectTemplateLocation(
93
93
  github_url="zenml-io/template-llm-finetuning",
zenml/cli/pipeline.py CHANGED
@@ -500,6 +500,27 @@ def delete_pipeline_run(
500
500
  cli_utils.declare(f"Deleted pipeline run '{run_name_or_id}'.")
501
501
 
502
502
 
503
+ @runs.command("refresh")
504
+ @click.argument("run_name_or_id", type=str, required=True)
505
+ def refresh_pipeline_run(run_name_or_id: str) -> None:
506
+ """Refresh the status of a pipeline run.
507
+
508
+ Args:
509
+ run_name_or_id: The name or ID of the pipeline run to refresh.
510
+ """
511
+ try:
512
+ # Fetch and update the run
513
+ run = Client().get_pipeline_run(name_id_or_prefix=run_name_or_id)
514
+ run.refresh_run_status()
515
+
516
+ except KeyError as e:
517
+ cli_utils.error(str(e))
518
+ else:
519
+ cli_utils.declare(
520
+ f"Refreshed the status of pipeline run '{run.name}'."
521
+ )
522
+
523
+
503
524
  @pipeline.group()
504
525
  def builds() -> None:
505
526
  """Commands for pipeline builds."""
zenml/cli/utils.py CHANGED
@@ -76,6 +76,7 @@ from zenml.models import (
76
76
  StrFilter,
77
77
  UUIDFilter,
78
78
  )
79
+ from zenml.models.v2.base.filter import FilterGenerator
79
80
  from zenml.services import BaseService, ServiceState
80
81
  from zenml.stack import StackComponent
81
82
  from zenml.stack.stack_component import StackComponentConfig
@@ -2477,12 +2478,13 @@ def create_filter_help_text(filter_model: Type[BaseFilter], field: str) -> str:
2477
2478
  Returns:
2478
2479
  The help text.
2479
2480
  """
2480
- if filter_model.is_sort_by_field(field):
2481
+ filter_generator = FilterGenerator(filter_model)
2482
+ if filter_generator.is_sort_by_field(field):
2481
2483
  return (
2482
2484
  "[STRING] Example: --sort_by='desc:name' to sort by name in "
2483
2485
  "descending order. "
2484
2486
  )
2485
- if filter_model.is_datetime_field(field):
2487
+ if filter_generator.is_datetime_field(field):
2486
2488
  return (
2487
2489
  f"[DATETIME] The following datetime format is supported: "
2488
2490
  f"'{FILTERING_DATETIME_FORMAT}'. Make sure to keep it in "
@@ -2491,23 +2493,23 @@ def create_filter_help_text(filter_model: Type[BaseFilter], field: str) -> str:
2491
2493
  f"'{GenericFilterOps.GTE}:{FILTERING_DATETIME_FORMAT}' to "
2492
2494
  f"filter for everything created on or after the given date."
2493
2495
  )
2494
- elif filter_model.is_uuid_field(field):
2496
+ elif filter_generator.is_uuid_field(field):
2495
2497
  return (
2496
2498
  f"[UUID] Example: --{field}='{GenericFilterOps.STARTSWITH}:ab53ca' "
2497
2499
  f"to filter for all UUIDs starting with that prefix."
2498
2500
  )
2499
- elif filter_model.is_int_field(field):
2501
+ elif filter_generator.is_int_field(field):
2500
2502
  return (
2501
2503
  f"[INTEGER] Example: --{field}='{GenericFilterOps.GTE}:25' to "
2502
2504
  f"filter for all entities where this field has a value greater than "
2503
2505
  f"or equal to the value."
2504
2506
  )
2505
- elif filter_model.is_bool_field(field):
2507
+ elif filter_generator.is_bool_field(field):
2506
2508
  return (
2507
2509
  f"[BOOL] Example: --{field}='True' to "
2508
2510
  f"filter for all instances where this field is true."
2509
2511
  )
2510
- elif filter_model.is_str_field(field):
2512
+ elif filter_generator.is_str_field(field):
2511
2513
  return (
2512
2514
  f"[STRING] Example: --{field}='{GenericFilterOps.CONTAINS}:example' "
2513
2515
  f"to filter everything that contains the query string somewhere in "
@@ -2529,27 +2531,28 @@ def create_data_type_help_text(
2529
2531
  Returns:
2530
2532
  The help text.
2531
2533
  """
2532
- if filter_model.is_datetime_field(field):
2534
+ filter_generator = FilterGenerator(filter_model)
2535
+ if filter_generator.is_datetime_field(field):
2533
2536
  return (
2534
2537
  f"[DATETIME] supported filter operators: "
2535
2538
  f"{[str(op) for op in NumericFilter.ALLOWED_OPS]}"
2536
2539
  )
2537
- elif filter_model.is_uuid_field(field):
2540
+ elif filter_generator.is_uuid_field(field):
2538
2541
  return (
2539
2542
  f"[UUID] supported filter operators: "
2540
2543
  f"{[str(op) for op in UUIDFilter.ALLOWED_OPS]}"
2541
2544
  )
2542
- elif filter_model.is_int_field(field):
2545
+ elif filter_generator.is_int_field(field):
2543
2546
  return (
2544
2547
  f"[INTEGER] supported filter operators: "
2545
2548
  f"{[str(op) for op in NumericFilter.ALLOWED_OPS]}"
2546
2549
  )
2547
- elif filter_model.is_bool_field(field):
2550
+ elif filter_generator.is_bool_field(field):
2548
2551
  return (
2549
2552
  f"[BOOL] supported filter operators: "
2550
2553
  f"{[str(op) for op in BoolFilter.ALLOWED_OPS]}"
2551
2554
  )
2552
- elif filter_model.is_str_field(field):
2555
+ elif filter_generator.is_str_field(field):
2553
2556
  return (
2554
2557
  f"[STRING] supported filter operators: "
2555
2558
  f"{[str(op) for op in StrFilter.ALLOWED_OPS]}"
zenml/client.py CHANGED
@@ -1236,6 +1236,8 @@ class Client(metaclass=ClientMetaClass):
1236
1236
  workspace_id: Optional[Union[str, UUID]] = None,
1237
1237
  user_id: Optional[Union[str, UUID]] = None,
1238
1238
  component_id: Optional[Union[str, UUID]] = None,
1239
+ user: Optional[Union[UUID, str]] = None,
1240
+ component: Optional[Union[UUID, str]] = None,
1239
1241
  hydrate: bool = False,
1240
1242
  ) -> Page[StackResponse]:
1241
1243
  """Lists all stacks.
@@ -1252,6 +1254,8 @@ class Client(metaclass=ClientMetaClass):
1252
1254
  workspace_id: The id of the workspace to filter by.
1253
1255
  user_id: The id of the user to filter by.
1254
1256
  component_id: The id of the component to filter by.
1257
+ user: The name/ID of the user to filter by.
1258
+ component: The name/ID of the component to filter by.
1255
1259
  name: The name of the stack to filter by.
1256
1260
  hydrate: Flag deciding whether to hydrate the output model(s)
1257
1261
  by including metadata fields in the response.
@@ -1267,6 +1271,8 @@ class Client(metaclass=ClientMetaClass):
1267
1271
  workspace_id=workspace_id,
1268
1272
  user_id=user_id,
1269
1273
  component_id=component_id,
1274
+ user=user,
1275
+ component=component,
1270
1276
  name=name,
1271
1277
  description=description,
1272
1278
  id=id,
@@ -2348,8 +2354,10 @@ class Client(metaclass=ClientMetaClass):
2348
2354
  created: Optional[Union[datetime, str]] = None,
2349
2355
  updated: Optional[Union[datetime, str]] = None,
2350
2356
  name: Optional[str] = None,
2357
+ latest_run_status: Optional[str] = None,
2351
2358
  workspace_id: Optional[Union[str, UUID]] = None,
2352
2359
  user_id: Optional[Union[str, UUID]] = None,
2360
+ user: Optional[Union[UUID, str]] = None,
2353
2361
  tag: Optional[str] = None,
2354
2362
  hydrate: bool = False,
2355
2363
  ) -> Page[PipelineResponse]:
@@ -2364,8 +2372,11 @@ class Client(metaclass=ClientMetaClass):
2364
2372
  created: Use to filter by time of creation
2365
2373
  updated: Use the last updated date for filtering
2366
2374
  name: The name of the pipeline to filter by.
2375
+ latest_run_status: Filter by the status of the latest run of a
2376
+ pipeline.
2367
2377
  workspace_id: The id of the workspace to filter by.
2368
2378
  user_id: The id of the user to filter by.
2379
+ user: The name/ID of the user to filter by.
2369
2380
  tag: Tag to filter by.
2370
2381
  hydrate: Flag deciding whether to hydrate the output model(s)
2371
2382
  by including metadata fields in the response.
@@ -2382,8 +2393,10 @@ class Client(metaclass=ClientMetaClass):
2382
2393
  created=created,
2383
2394
  updated=updated,
2384
2395
  name=name,
2396
+ latest_run_status=latest_run_status,
2385
2397
  workspace_id=workspace_id,
2386
2398
  user_id=user_id,
2399
+ user=user,
2387
2400
  tag=tag,
2388
2401
  )
2389
2402
  pipeline_filter_model.set_scope_workspace(self.active_workspace.id)
@@ -3467,6 +3480,9 @@ class Client(metaclass=ClientMetaClass):
3467
3480
  build_id: Optional[Union[str, UUID]] = None,
3468
3481
  stack_id: Optional[Union[str, UUID]] = None,
3469
3482
  code_repository_id: Optional[Union[str, UUID]] = None,
3483
+ user: Optional[Union[UUID, str]] = None,
3484
+ pipeline: Optional[Union[UUID, str]] = None,
3485
+ stack: Optional[Union[UUID, str]] = None,
3470
3486
  hydrate: bool = False,
3471
3487
  ) -> Page[RunTemplateResponse]:
3472
3488
  """Get a page of run templates.
@@ -3486,6 +3502,9 @@ class Client(metaclass=ClientMetaClass):
3486
3502
  build_id: Filter by build ID.
3487
3503
  stack_id: Filter by stack ID.
3488
3504
  code_repository_id: Filter by code repository ID.
3505
+ user: Filter by user name/ID.
3506
+ pipeline: Filter by pipeline name/ID.
3507
+ stack: Filter by stack name/ID.
3489
3508
  hydrate: Flag deciding whether to hydrate the output model(s)
3490
3509
  by including metadata fields in the response.
3491
3510
 
@@ -3507,6 +3526,9 @@ class Client(metaclass=ClientMetaClass):
3507
3526
  build_id=build_id,
3508
3527
  stack_id=stack_id,
3509
3528
  code_repository_id=code_repository_id,
3529
+ user=user,
3530
+ pipeline=pipeline,
3531
+ stack=stack,
3510
3532
  )
3511
3533
 
3512
3534
  return self.zen_store.list_run_templates(
@@ -3742,6 +3764,7 @@ class Client(metaclass=ClientMetaClass):
3742
3764
  deployment_id: Optional[Union[str, UUID]] = None,
3743
3765
  code_repository_id: Optional[Union[str, UUID]] = None,
3744
3766
  template_id: Optional[Union[str, UUID]] = None,
3767
+ model_version_id: Optional[Union[str, UUID]] = None,
3745
3768
  orchestrator_run_id: Optional[str] = None,
3746
3769
  status: Optional[str] = None,
3747
3770
  start_time: Optional[Union[datetime, str]] = None,
@@ -3750,6 +3773,11 @@ class Client(metaclass=ClientMetaClass):
3750
3773
  unlisted: Optional[bool] = None,
3751
3774
  templatable: Optional[bool] = None,
3752
3775
  tag: Optional[str] = None,
3776
+ user: Optional[Union[UUID, str]] = None,
3777
+ pipeline: Optional[Union[UUID, str]] = None,
3778
+ code_repository: Optional[Union[UUID, str]] = None,
3779
+ model: Optional[Union[UUID, str]] = None,
3780
+ stack: Optional[Union[UUID, str]] = None,
3753
3781
  hydrate: bool = False,
3754
3782
  ) -> Page[PipelineRunResponse]:
3755
3783
  """List all pipeline runs.
@@ -3764,7 +3792,8 @@ class Client(metaclass=ClientMetaClass):
3764
3792
  updated: Use the last updated date for filtering
3765
3793
  workspace_id: The id of the workspace to filter by.
3766
3794
  pipeline_id: The id of the pipeline to filter by.
3767
- pipeline_name: The name of the pipeline to filter by.
3795
+ pipeline_name: DEPRECATED. Use `pipeline` instead to filter by
3796
+ pipeline name.
3768
3797
  user_id: The id of the user to filter by.
3769
3798
  stack_id: The id of the stack to filter by.
3770
3799
  schedule_id: The id of the schedule to filter by.
@@ -3772,6 +3801,7 @@ class Client(metaclass=ClientMetaClass):
3772
3801
  deployment_id: The id of the deployment to filter by.
3773
3802
  code_repository_id: The id of the code repository to filter by.
3774
3803
  template_id: The ID of the template to filter by.
3804
+ model_version_id: The ID of the model version to filter by.
3775
3805
  orchestrator_run_id: The run id of the orchestrator to filter by.
3776
3806
  name: The name of the run to filter by.
3777
3807
  status: The status of the pipeline run
@@ -3781,6 +3811,11 @@ class Client(metaclass=ClientMetaClass):
3781
3811
  unlisted: If the runs should be unlisted or not.
3782
3812
  templatable: If the runs should be templatable or not.
3783
3813
  tag: Tag to filter by.
3814
+ user: The name/ID of the user to filter by.
3815
+ pipeline: The name/ID of the pipeline to filter by.
3816
+ code_repository: Filter by code repository name/ID.
3817
+ model: Filter by model name/ID.
3818
+ stack: Filter by stack name/ID.
3784
3819
  hydrate: Flag deciding whether to hydrate the output model(s)
3785
3820
  by including metadata fields in the response.
3786
3821
 
@@ -3804,6 +3839,7 @@ class Client(metaclass=ClientMetaClass):
3804
3839
  deployment_id=deployment_id,
3805
3840
  code_repository_id=code_repository_id,
3806
3841
  template_id=template_id,
3842
+ model_version_id=model_version_id,
3807
3843
  orchestrator_run_id=orchestrator_run_id,
3808
3844
  user_id=user_id,
3809
3845
  stack_id=stack_id,
@@ -3813,6 +3849,11 @@ class Client(metaclass=ClientMetaClass):
3813
3849
  num_steps=num_steps,
3814
3850
  tag=tag,
3815
3851
  unlisted=unlisted,
3852
+ user=user,
3853
+ pipeline=pipeline,
3854
+ code_repository=code_repository,
3855
+ stack=stack,
3856
+ model=model,
3816
3857
  templatable=templatable,
3817
3858
  )
3818
3859
  runs_filter_model.set_scope_workspace(self.active_workspace.id)
@@ -3892,6 +3933,7 @@ class Client(metaclass=ClientMetaClass):
3892
3933
  original_step_run_id: Optional[Union[str, UUID]] = None,
3893
3934
  workspace_id: Optional[Union[str, UUID]] = None,
3894
3935
  user_id: Optional[Union[str, UUID]] = None,
3936
+ model_version_id: Optional[Union[str, UUID]] = None,
3895
3937
  num_outputs: Optional[Union[int, str]] = None,
3896
3938
  hydrate: bool = False,
3897
3939
  ) -> Page[StepRunResponse]:
@@ -3911,6 +3953,7 @@ class Client(metaclass=ClientMetaClass):
3911
3953
  user_id: The id of the user to filter by.
3912
3954
  pipeline_run_id: The id of the pipeline run to filter by.
3913
3955
  original_step_run_id: The id of the pipeline run to filter by.
3956
+ model_version_id: The ID of the model version to filter by.
3914
3957
  name: The name of the run to filter by.
3915
3958
  entrypoint_name: The entrypoint_name of the run to filter by.
3916
3959
  code_hash: The code_hash of the run to filter by.
@@ -3942,6 +3985,7 @@ class Client(metaclass=ClientMetaClass):
3942
3985
  name=name,
3943
3986
  workspace_id=workspace_id,
3944
3987
  user_id=user_id,
3988
+ model_version_id=model_version_id,
3945
3989
  num_outputs=num_outputs,
3946
3990
  )
3947
3991
  step_run_filter_model.set_scope_workspace(self.active_workspace.id)
@@ -4150,8 +4194,11 @@ class Client(metaclass=ClientMetaClass):
4150
4194
  user_id: Optional[Union[str, UUID]] = None,
4151
4195
  only_unused: Optional[bool] = False,
4152
4196
  has_custom_name: Optional[bool] = None,
4153
- hydrate: bool = False,
4197
+ user: Optional[Union[UUID, str]] = None,
4198
+ model: Optional[Union[UUID, str]] = None,
4199
+ pipeline_run: Optional[Union[UUID, str]] = None,
4154
4200
  tag: Optional[str] = None,
4201
+ hydrate: bool = False,
4155
4202
  ) -> Page[ArtifactVersionResponse]:
4156
4203
  """Get a list of artifact versions.
4157
4204
 
@@ -4177,9 +4224,12 @@ class Client(metaclass=ClientMetaClass):
4177
4224
  only_unused: Only return artifact versions that are not used in
4178
4225
  any pipeline runs.
4179
4226
  has_custom_name: Filter artifacts with/without custom names.
4227
+ tag: A tag to filter by.
4228
+ user: Filter by user name or ID.
4229
+ model: Filter by model name or ID.
4230
+ pipeline_run: Filter by pipeline run name or ID.
4180
4231
  hydrate: Flag deciding whether to hydrate the output model(s)
4181
4232
  by including metadata fields in the response.
4182
- tag: A tag to filter by.
4183
4233
 
4184
4234
  Returns:
4185
4235
  A list of artifact versions.
@@ -4206,6 +4256,9 @@ class Client(metaclass=ClientMetaClass):
4206
4256
  only_unused=only_unused,
4207
4257
  has_custom_name=has_custom_name,
4208
4258
  tag=tag,
4259
+ user=user,
4260
+ model=model,
4261
+ pipeline_run=pipeline_run,
4209
4262
  )
4210
4263
  artifact_version_filter_model.set_scope_workspace(
4211
4264
  self.active_workspace.id
@@ -6102,6 +6155,7 @@ class Client(metaclass=ClientMetaClass):
6102
6155
  created: Optional[Union[datetime, str]] = None,
6103
6156
  updated: Optional[Union[datetime, str]] = None,
6104
6157
  name: Optional[str] = None,
6158
+ user: Optional[Union[UUID, str]] = None,
6105
6159
  hydrate: bool = False,
6106
6160
  tag: Optional[str] = None,
6107
6161
  ) -> Page[ModelResponse]:
@@ -6115,6 +6169,7 @@ class Client(metaclass=ClientMetaClass):
6115
6169
  created: Use to filter by time of creation
6116
6170
  updated: Use the last updated date for filtering
6117
6171
  name: The name of the model to filter by.
6172
+ user: Filter by user name/ID.
6118
6173
  hydrate: Flag deciding whether to hydrate the output model(s)
6119
6174
  by including metadata fields in the response.
6120
6175
  tag: The tag of the model to filter by.
@@ -6131,6 +6186,7 @@ class Client(metaclass=ClientMetaClass):
6131
6186
  created=created,
6132
6187
  updated=updated,
6133
6188
  tag=tag,
6189
+ user=user,
6134
6190
  )
6135
6191
 
6136
6192
  return self.zen_store.list_models(
@@ -6310,6 +6366,7 @@ class Client(metaclass=ClientMetaClass):
6310
6366
  name: Optional[str] = None,
6311
6367
  number: Optional[int] = None,
6312
6368
  stage: Optional[Union[str, ModelStages]] = None,
6369
+ user: Optional[Union[UUID, str]] = None,
6313
6370
  hydrate: bool = False,
6314
6371
  tag: Optional[str] = None,
6315
6372
  ) -> Page[ModelVersionResponse]:
@@ -6327,6 +6384,7 @@ class Client(metaclass=ClientMetaClass):
6327
6384
  name: name or id of the model version.
6328
6385
  number: number of the model version.
6329
6386
  stage: stage of the model version.
6387
+ user: Filter by user name/ID.
6330
6388
  hydrate: Flag deciding whether to hydrate the output model(s)
6331
6389
  by including metadata fields in the response.
6332
6390
  tag: The tag to filter by.
@@ -6345,6 +6403,7 @@ class Client(metaclass=ClientMetaClass):
6345
6403
  number=number,
6346
6404
  stage=stage,
6347
6405
  tag=tag,
6406
+ user=user,
6348
6407
  )
6349
6408
 
6350
6409
  return self.zen_store.list_model_versions(
@@ -6422,6 +6481,7 @@ class Client(metaclass=ClientMetaClass):
6422
6481
  only_model_artifacts: Optional[bool] = None,
6423
6482
  only_deployment_artifacts: Optional[bool] = None,
6424
6483
  has_custom_name: Optional[bool] = None,
6484
+ user: Optional[Union[UUID, str]] = None,
6425
6485
  hydrate: bool = False,
6426
6486
  ) -> Page[ModelVersionArtifactResponse]:
6427
6487
  """Get model version to artifact links by filter in Model Control Plane.
@@ -6443,6 +6503,7 @@ class Client(metaclass=ClientMetaClass):
6443
6503
  only_model_artifacts: Use to filter by model artifacts
6444
6504
  only_deployment_artifacts: Use to filter by deployment artifacts
6445
6505
  has_custom_name: Filter artifacts with/without custom names.
6506
+ user: Filter by user name/ID.
6446
6507
  hydrate: Flag deciding whether to hydrate the output model(s)
6447
6508
  by including metadata fields in the response.
6448
6509
 
@@ -6467,6 +6528,7 @@ class Client(metaclass=ClientMetaClass):
6467
6528
  only_model_artifacts=only_model_artifacts,
6468
6529
  only_deployment_artifacts=only_deployment_artifacts,
6469
6530
  has_custom_name=has_custom_name,
6531
+ user=user,
6470
6532
  ),
6471
6533
  hydrate=hydrate,
6472
6534
  )
@@ -6536,6 +6598,7 @@ class Client(metaclass=ClientMetaClass):
6536
6598
  model_version_id: Optional[Union[UUID, str]] = None,
6537
6599
  pipeline_run_id: Optional[Union[UUID, str]] = None,
6538
6600
  pipeline_run_name: Optional[str] = None,
6601
+ user: Optional[Union[UUID, str]] = None,
6539
6602
  hydrate: bool = False,
6540
6603
  ) -> Page[ModelVersionPipelineRunResponse]:
6541
6604
  """Get all model version to pipeline run links by filter.
@@ -6553,6 +6616,7 @@ class Client(metaclass=ClientMetaClass):
6553
6616
  model_version_id: Use the model version id for filtering
6554
6617
  pipeline_run_id: Use the pipeline run id for filtering
6555
6618
  pipeline_run_name: Use the pipeline run name for filtering
6619
+ user: Filter by user name or ID.
6556
6620
  hydrate: Flag deciding whether to hydrate the output model(s)
6557
6621
  by including metadata fields in the response
6558
6622
 
@@ -6573,6 +6637,7 @@ class Client(metaclass=ClientMetaClass):
6573
6637
  model_version_id=model_version_id,
6574
6638
  pipeline_run_id=pipeline_run_id,
6575
6639
  pipeline_run_name=pipeline_run_name,
6640
+ user=user,
6576
6641
  ),
6577
6642
  hydrate=hydrate,
6578
6643
  )
@@ -137,7 +137,6 @@ class ArtifactConfiguration(PartialArtifactConfiguration):
137
137
  class StepConfigurationUpdate(StrictBaseModel):
138
138
  """Class for step configuration updates."""
139
139
 
140
- name: Optional[str] = None
141
140
  enable_cache: Optional[bool] = None
142
141
  enable_artifact_metadata: Optional[bool] = None
143
142
  enable_artifact_visualization: Optional[bool] = None
@@ -154,10 +153,6 @@ class StepConfigurationUpdate(StrictBaseModel):
154
153
 
155
154
  outputs: Mapping[str, PartialArtifactConfiguration] = {}
156
155
 
157
- _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
158
- "name"
159
- )
160
-
161
156
 
162
157
  class PartialStepConfiguration(StepConfigurationUpdate):
163
158
  """Class representing a partial step configuration."""
zenml/constants.py CHANGED
@@ -364,6 +364,7 @@ PIPELINE_DEPLOYMENTS = "/pipeline_deployments"
364
364
  PIPELINES = "/pipelines"
365
365
  PIPELINE_SPEC = "/pipeline-spec"
366
366
  PLUGIN_FLAVORS = "/plugin-flavors"
367
+ REFRESH = "/refresh"
367
368
  RUNS = "/runs"
368
369
  RUN_TEMPLATES = "/run_templates"
369
370
  RUN_METADATA = "/run-metadata"
@@ -430,6 +431,8 @@ SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run"
430
431
 
431
432
  # Metadata constants
432
433
  METADATA_ORCHESTRATOR_URL = "orchestrator_url"
434
+ METADATA_ORCHESTRATOR_LOGS_URL = "orchestrator_logs_url"
435
+ METADATA_ORCHESTRATOR_RUN_ID = "orchestrator_run_id"
433
436
  METADATA_EXPERIMENT_TRACKER_URL = "experiment_tracker_url"
434
437
  METADATA_DEPLOYED_MODEL_URL = "deployed_model_url"
435
438
 
zenml/enums.py CHANGED
@@ -244,6 +244,7 @@ class GenericFilterOps(StrEnum):
244
244
  """Ops for all filters for string values on list methods."""
245
245
 
246
246
  EQUALS = "equals"
247
+ NOT_EQUALS = "notequals"
247
248
  CONTAINS = "contains"
248
249
  STARTSWITH = "startswith"
249
250
  ENDSWITH = "endswith"
@@ -251,6 +252,7 @@ class GenericFilterOps(StrEnum):
251
252
  GT = "gt"
252
253
  LTE = "lte"
253
254
  LT = "lt"
255
+ IN = "in"
254
256
 
255
257
 
256
258
  class SorterOps(StrEnum):
@@ -69,6 +69,7 @@ from zenml.integrations.skypilot_aws import SkypilotAWSIntegration # noqa
69
69
  from zenml.integrations.skypilot_gcp import SkypilotGCPIntegration # noqa
70
70
  from zenml.integrations.skypilot_azure import SkypilotAzureIntegration # noqa
71
71
  from zenml.integrations.skypilot_lambda import SkypilotLambdaIntegration # noqa
72
+ from zenml.integrations.skypilot_kubernetes import SkypilotKubernetesIntegration # noqa
72
73
  from zenml.integrations.slack import SlackIntegration # noqa
73
74
  from zenml.integrations.spark import SparkIntegration # noqa
74
75
  from zenml.integrations.tekton import TektonIntegration # noqa
@@ -15,7 +15,7 @@
15
15
 
16
16
  from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
17
17
 
18
- from pydantic import Field
18
+ from pydantic import Field, model_validator
19
19
 
20
20
  from zenml.config.base_settings import BaseSettings
21
21
  from zenml.integrations.aws import (
@@ -25,23 +25,38 @@ from zenml.integrations.aws import (
25
25
  from zenml.models import ServiceConnectorRequirements
26
26
  from zenml.orchestrators import BaseOrchestratorConfig
27
27
  from zenml.orchestrators.base_orchestrator import BaseOrchestratorFlavor
28
+ from zenml.utils import deprecation_utils
28
29
  from zenml.utils.secret_utils import SecretField
29
30
 
30
31
  if TYPE_CHECKING:
31
32
  from zenml.integrations.aws.orchestrators import SagemakerOrchestrator
32
33
 
34
+ DEFAULT_TRAINING_INSTANCE_TYPE = "ml.m5.xlarge"
35
+ DEFAULT_PROCESSING_INSTANCE_TYPE = "ml.t3.medium"
36
+ DEFAULT_OUTPUT_DATA_S3_MODE = "EndOfJob"
37
+
33
38
 
34
39
  class SagemakerOrchestratorSettings(BaseSettings):
35
40
  """Settings for the Sagemaker orchestrator.
36
41
 
37
42
  Attributes:
38
43
  instance_type: The instance type to use for the processing job.
39
- processor_role: The IAM role to use for the step execution on a Processor.
44
+ execution_role: The IAM role to use for the step execution.
45
+ processor_role: DEPRECATED: use `execution_role` instead.
40
46
  volume_size_in_gb: The size of the EBS volume to use for the processing
41
47
  job.
42
48
  max_runtime_in_seconds: The maximum runtime in seconds for the
43
49
  processing job.
44
- processor_tags: Tags to apply to the Processor assigned to the step.
50
+ tags: Tags to apply to the Processor/Estimator assigned to the step.
51
+ processor_tags: DEPRECATED: use `tags` instead.
52
+ keep_alive_period_in_seconds: The time in seconds after which the
53
+ provisioned instance will be terminated if not used. This is only
54
+ applicable for TrainingStep type and it is not possible to use
55
+ TrainingStep type if the `output_data_s3_uri` is set to Dict[str, str].
56
+ use_training_step: Whether to use the TrainingStep type.
57
+ It is not possible to use TrainingStep type
58
+ if the `output_data_s3_uri` is set to Dict[str, str] or if the
59
+ `output_data_s3_mode` != "EndOfJob".
45
60
  processor_args: Arguments that are directly passed to the SageMaker
46
61
  Processor for a specific step, allowing for overriding the default
47
62
  settings provided when configuring the component. See
@@ -50,6 +65,13 @@ class SagemakerOrchestratorSettings(BaseSettings):
50
65
  For processor_args.instance_type, check
51
66
  https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
52
67
  for a list of available instance types.
68
+ estimator_args: Arguments that are directly passed to the SageMaker
69
+ Estimator for a specific step, allowing for overriding the default
70
+ settings provided when configuring the component. See
71
+ https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
72
+ for a full list of arguments.
73
+ For a list of available instance types, check
74
+ https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html.
53
75
  input_data_s3_mode: How data is made available to the container.
54
76
  Two possible input modes: File, Pipe.
55
77
  input_data_s3_uri: S3 URI where data is located if not locally,
@@ -74,23 +96,70 @@ class SagemakerOrchestratorSettings(BaseSettings):
74
96
  Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
75
97
  """
76
98
 
77
- instance_type: str = "ml.t3.medium"
78
- processor_role: Optional[str] = None
99
+ instance_type: Optional[str] = None
100
+ execution_role: Optional[str] = None
79
101
  volume_size_in_gb: int = 30
80
102
  max_runtime_in_seconds: int = 86400
81
- processor_tags: Dict[str, str] = {}
103
+ tags: Dict[str, str] = {}
104
+ keep_alive_period_in_seconds: Optional[int] = 300 # 5 minutes
105
+ use_training_step: Optional[bool] = None
82
106
 
83
107
  processor_args: Dict[str, Any] = {}
108
+ estimator_args: Dict[str, Any] = {}
109
+
84
110
  input_data_s3_mode: str = "File"
85
111
  input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
86
112
  default=None, union_mode="left_to_right"
87
113
  )
88
114
 
89
- output_data_s3_mode: str = "EndOfJob"
115
+ output_data_s3_mode: str = DEFAULT_OUTPUT_DATA_S3_MODE
90
116
  output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
91
117
  default=None, union_mode="left_to_right"
92
118
  )
93
119
 
120
+ processor_role: Optional[str] = None
121
+ processor_tags: Dict[str, str] = {}
122
+ _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
123
+ ("processor_role", "execution_role"), ("processor_tags", "tags")
124
+ )
125
+
126
+ @model_validator(mode="before")
127
+ def validate_model(cls, data: Dict[str, Any]) -> Dict[str, Any]:
128
+ """Check if model is configured correctly.
129
+
130
+ Args:
131
+ data: The model data.
132
+
133
+ Returns:
134
+ The validated model data.
135
+
136
+ Raises:
137
+ ValueError: If the model is configured incorrectly.
138
+ """
139
+ use_training_step = data.get("use_training_step", True)
140
+ output_data_s3_uri = data.get("output_data_s3_uri", None)
141
+ output_data_s3_mode = data.get(
142
+ "output_data_s3_mode", DEFAULT_OUTPUT_DATA_S3_MODE
143
+ )
144
+ if use_training_step and (
145
+ isinstance(output_data_s3_uri, dict)
146
+ or (
147
+ isinstance(output_data_s3_uri, str)
148
+ and (output_data_s3_mode != DEFAULT_OUTPUT_DATA_S3_MODE)
149
+ )
150
+ ):
151
+ raise ValueError(
152
+ "`use_training_step=True` is not supported when `output_data_s3_uri` is a dict or "
153
+ f"when `output_data_s3_mode` is not '{DEFAULT_OUTPUT_DATA_S3_MODE}'."
154
+ )
155
+ instance_type = data.get("instance_type", None)
156
+ if instance_type is None:
157
+ if use_training_step:
158
+ data["instance_type"] = DEFAULT_TRAINING_INSTANCE_TYPE
159
+ else:
160
+ data["instance_type"] = DEFAULT_PROCESSING_INSTANCE_TYPE
161
+ return data
162
+
94
163
 
95
164
  class SagemakerOrchestratorConfig(
96
165
  BaseOrchestratorConfig, SagemakerOrchestratorSettings