zenml-nightly 0.66.0.dev20240919__py3-none-any.whl → 0.66.0.dev20240927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/__init__.py +7 -0
  3. zenml/cli/base.py +2 -2
  4. zenml/cli/pipeline.py +21 -0
  5. zenml/cli/utils.py +14 -11
  6. zenml/client.py +68 -3
  7. zenml/config/step_configurations.py +0 -5
  8. zenml/constants.py +3 -0
  9. zenml/enums.py +2 -0
  10. zenml/integrations/__init__.py +1 -0
  11. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
  12. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
  13. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
  14. zenml/integrations/constants.py +1 -0
  15. zenml/integrations/deepchecks/__init__.py +1 -1
  16. zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
  17. zenml/integrations/deepchecks/validation_checks.py +62 -5
  18. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
  19. zenml/integrations/lightning/__init__.py +1 -1
  20. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
  21. zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
  22. zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
  23. zenml/integrations/mlflow/__init__.py +14 -15
  24. zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
  25. zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
  26. zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
  27. zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
  28. zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
  29. zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
  30. zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
  31. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
  32. zenml/models/v2/base/filter.py +315 -149
  33. zenml/models/v2/base/scoped.py +5 -2
  34. zenml/models/v2/core/artifact_version.py +69 -8
  35. zenml/models/v2/core/model.py +43 -6
  36. zenml/models/v2/core/model_version.py +49 -1
  37. zenml/models/v2/core/model_version_artifact.py +18 -3
  38. zenml/models/v2/core/model_version_pipeline_run.py +18 -4
  39. zenml/models/v2/core/pipeline.py +108 -1
  40. zenml/models/v2/core/pipeline_run.py +172 -21
  41. zenml/models/v2/core/run_template.py +53 -1
  42. zenml/models/v2/core/stack.py +33 -5
  43. zenml/models/v2/core/step_run.py +7 -0
  44. zenml/new/pipelines/pipeline.py +4 -0
  45. zenml/new/pipelines/run_utils.py +4 -1
  46. zenml/orchestrators/base_orchestrator.py +41 -12
  47. zenml/stack/stack.py +11 -2
  48. zenml/utils/env_utils.py +54 -1
  49. zenml/utils/string_utils.py +50 -0
  50. zenml/zen_server/cloud_utils.py +33 -8
  51. zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
  52. zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
  53. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
  54. zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
  55. zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
  56. zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
  57. zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
  58. zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
  59. zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
  60. zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
  61. zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
  62. zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
  63. zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
  64. zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
  65. zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
  66. zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
  67. zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
  68. zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
  69. zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
  70. zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
  71. zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
  72. zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
  73. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
  74. zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
  75. zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
  76. zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
  77. zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
  78. zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
  79. zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
  80. zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
  81. zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
  82. zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
  83. zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
  84. zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
  85. zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
  86. zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
  87. zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
  88. zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
  89. zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
  90. zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
  91. zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
  92. zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
  93. zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
  94. zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
  95. zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
  96. zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
  97. zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
  98. zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
  99. zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
  100. zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
  101. zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
  102. zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
  103. zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
  104. zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
  105. zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
  106. zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
  107. zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
  108. zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
  109. zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
  110. zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
  111. zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
  112. zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
  113. zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
  114. zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
  115. zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
  116. zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
  117. zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
  118. zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
  119. zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
  120. zenml/zen_server/dashboard/index.html +4 -4
  121. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  122. zenml/zen_server/dashboard_legacy/index.html +1 -1
  123. zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
  124. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  125. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
  126. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
  127. zenml/zen_server/rbac/utils.py +6 -2
  128. zenml/zen_server/routers/runs_endpoints.py +89 -3
  129. zenml/zen_stores/sql_zen_store.py +1 -0
  130. {zenml_nightly-0.66.0.dev20240919.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/METADATA +8 -1
  131. {zenml_nightly-0.66.0.dev20240919.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/RECORD +134 -126
  132. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
  133. zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
  134. zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
  135. zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
  136. zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
  137. zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
  138. zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
  139. zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
  140. zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
  141. zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
  142. zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
  143. zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
  144. zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
  145. zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
  146. zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
  147. zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
  148. zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
  149. zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
  150. zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
  151. zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
  152. zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
  153. {zenml_nightly-0.66.0.dev20240919.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/LICENSE +0 -0
  154. {zenml_nightly-0.66.0.dev20240919.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/WHEEL +0 -0
  155. {zenml_nightly-0.66.0.dev20240919.dist-info → zenml_nightly-0.66.0.dev20240927.dist-info}/entry_points.txt +0 -0
@@ -15,7 +15,16 @@
15
15
 
16
16
  import os
17
17
  import re
18
- from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type, cast
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ Any,
21
+ Dict,
22
+ Iterator,
23
+ Optional,
24
+ Tuple,
25
+ Type,
26
+ cast,
27
+ )
19
28
  from uuid import UUID
20
29
 
21
30
  import boto3
@@ -25,13 +34,15 @@ from sagemaker.network import NetworkConfig
25
34
  from sagemaker.processing import ProcessingInput, ProcessingOutput
26
35
  from sagemaker.workflow.execution_variables import ExecutionVariables
27
36
  from sagemaker.workflow.pipeline import Pipeline
28
- from sagemaker.workflow.steps import ProcessingStep
37
+ from sagemaker.workflow.steps import ProcessingStep, TrainingStep
29
38
 
30
39
  from zenml.config.base_settings import BaseSettings
31
40
  from zenml.constants import (
41
+ METADATA_ORCHESTRATOR_LOGS_URL,
42
+ METADATA_ORCHESTRATOR_RUN_ID,
32
43
  METADATA_ORCHESTRATOR_URL,
33
44
  )
34
- from zenml.enums import StackComponentType
45
+ from zenml.enums import ExecutionStatus, StackComponentType
35
46
  from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
36
47
  SagemakerOrchestratorConfig,
37
48
  SagemakerOrchestratorSettings,
@@ -48,7 +59,7 @@ from zenml.stack import StackValidator
48
59
  from zenml.utils.env_utils import split_environment_variables
49
60
 
50
61
  if TYPE_CHECKING:
51
- from zenml.models import PipelineDeploymentResponse
62
+ from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
52
63
  from zenml.stack import Stack
53
64
 
54
65
  ENV_ZENML_SAGEMAKER_RUN_ID = "ZENML_SAGEMAKER_RUN_ID"
@@ -58,6 +69,34 @@ POLLING_DELAY = 30
58
69
  logger = get_logger(__name__)
59
70
 
60
71
 
72
+ def dissect_pipeline_execution_arn(
73
+ pipeline_execution_arn: str,
74
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
75
+ """Extract region name, pipeline name, and execution id from the ARN.
76
+
77
+ Args:
78
+ pipeline_execution_arn: the pipeline execution ARN
79
+
80
+ Returns:
81
+ Region Name, Pipeline Name, Execution ID in order
82
+ """
83
+ # Extract region_name
84
+ region_match = re.search(r"sagemaker:(.*?):", pipeline_execution_arn)
85
+ region_name = region_match.group(1) if region_match else None
86
+
87
+ # Extract pipeline_name
88
+ pipeline_match = re.search(
89
+ r"pipeline/(.*?)/execution", pipeline_execution_arn
90
+ )
91
+ pipeline_name = pipeline_match.group(1) if pipeline_match else None
92
+
93
+ # Extract execution_id
94
+ execution_match = re.search(r"execution/(.*)", pipeline_execution_arn)
95
+ execution_id = execution_match.group(1) if execution_match else None
96
+
97
+ return region_name, pipeline_name, execution_id
98
+
99
+
61
100
  class SagemakerOrchestrator(ContainerizedOrchestrator):
62
101
  """Orchestrator responsible for running pipelines on Sagemaker."""
63
102
 
@@ -136,42 +175,16 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
136
175
  """
137
176
  return SagemakerOrchestratorSettings
138
177
 
139
- def prepare_or_run_pipeline(
140
- self,
141
- deployment: "PipelineDeploymentResponse",
142
- stack: "Stack",
143
- environment: Dict[str, str],
144
- ) -> None:
145
- """Prepares or runs a pipeline on Sagemaker.
178
+ def _get_sagemaker_session(self) -> sagemaker.Session:
179
+ """Method to create the sagemaker session with proper authentication.
146
180
 
147
- Args:
148
- deployment: The deployment to prepare or run.
149
- stack: The stack to run on.
150
- environment: Environment variables to set in the orchestration
151
- environment.
181
+ Returns:
182
+ The Sagemaker Session.
152
183
 
153
184
  Raises:
154
- RuntimeError: If a connector is used that does not return a
155
- `boto3.Session` object.
156
- TypeError: If the network_config passed is not compatible with the
157
- AWS SageMaker NetworkConfig class.
185
+ RuntimeError: If the connector returns the wrong type for the
186
+ session.
158
187
  """
159
- if deployment.schedule:
160
- logger.warning(
161
- "The Sagemaker Orchestrator currently does not support the "
162
- "use of schedules. The `schedule` will be ignored "
163
- "and the pipeline will be run immediately."
164
- )
165
-
166
- # sagemaker requires pipelineName to use alphanum and hyphens only
167
- unsanitized_orchestrator_run_name = get_orchestrator_run_name(
168
- pipeline_name=deployment.pipeline_configuration.name
169
- )
170
- # replace all non-alphanum and non-hyphens with hyphens
171
- orchestrator_run_name = re.sub(
172
- r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
173
- )
174
-
175
188
  # Get authenticated session
176
189
  # Option 1: Service connector
177
190
  boto_session: boto3.Session
@@ -205,10 +218,51 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
205
218
  aws_session_token=credentials["SessionToken"],
206
219
  region_name=self.config.region,
207
220
  )
208
- session = sagemaker.Session(
221
+ return sagemaker.Session(
209
222
  boto_session=boto_session, default_bucket=self.config.bucket
210
223
  )
211
224
 
225
+ def prepare_or_run_pipeline(
226
+ self,
227
+ deployment: "PipelineDeploymentResponse",
228
+ stack: "Stack",
229
+ environment: Dict[str, str],
230
+ ) -> Iterator[Dict[str, MetadataType]]:
231
+ """Prepares or runs a pipeline on Sagemaker.
232
+
233
+ Args:
234
+ deployment: The deployment to prepare or run.
235
+ stack: The stack to run on.
236
+ environment: Environment variables to set in the orchestration
237
+ environment.
238
+
239
+ Raises:
240
+ RuntimeError: If a connector is used that does not return a
241
+ `boto3.Session` object.
242
+ TypeError: If the network_config passed is not compatible with the
243
+ AWS SageMaker NetworkConfig class.
244
+
245
+ Yields:
246
+ A dictionary of metadata related to the pipeline run.
247
+ """
248
+ if deployment.schedule:
249
+ logger.warning(
250
+ "The Sagemaker Orchestrator currently does not support the "
251
+ "use of schedules. The `schedule` will be ignored "
252
+ "and the pipeline will be run immediately."
253
+ )
254
+
255
+ # sagemaker requires pipelineName to use alphanum and hyphens only
256
+ unsanitized_orchestrator_run_name = get_orchestrator_run_name(
257
+ pipeline_name=deployment.pipeline_configuration.name
258
+ )
259
+ # replace all non-alphanum and non-hyphens with hyphens
260
+ orchestrator_run_name = re.sub(
261
+ r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
262
+ )
263
+
264
+ session = self._get_sagemaker_session()
265
+
212
266
  # Sagemaker does not allow environment variables longer than 256
213
267
  # characters to be passed to Processor steps. If an environment variable
214
268
  # is longer than 256 characters, we split it into multiple environment
@@ -238,54 +292,71 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
238
292
  ExecutionVariables.PIPELINE_EXECUTION_ARN
239
293
  )
240
294
 
241
- # Retrieve Processor arguments provided in the Step settings.
242
- processor_args_for_step = step_settings.processor_args or {}
243
-
244
- # Set default values from configured orchestrator Component to arguments
245
- # to be used when they are not present in processor_args.
246
- processor_args_for_step.setdefault(
247
- "instance_type", step_settings.instance_type
295
+ use_training_step = (
296
+ step_settings.use_training_step
297
+ if step_settings.use_training_step is not None
298
+ else (
299
+ self.config.use_training_step
300
+ if self.config.use_training_step is not None
301
+ else True
302
+ )
248
303
  )
249
- processor_args_for_step.setdefault(
304
+
305
+ # Retrieve Executor arguments provided in the Step settings.
306
+ if use_training_step:
307
+ args_for_step_executor = step_settings.estimator_args or {}
308
+ else:
309
+ args_for_step_executor = step_settings.processor_args or {}
310
+
311
+ # Set default values from configured orchestrator Component to
312
+ # arguments to be used when they are not present in processor_args.
313
+ args_for_step_executor.setdefault(
250
314
  "role",
251
- step_settings.processor_role or self.config.execution_role,
315
+ step_settings.execution_role or self.config.execution_role,
252
316
  )
253
- processor_args_for_step.setdefault(
317
+ args_for_step_executor.setdefault(
254
318
  "volume_size_in_gb", step_settings.volume_size_in_gb
255
319
  )
256
- processor_args_for_step.setdefault(
320
+ args_for_step_executor.setdefault(
257
321
  "max_runtime_in_seconds", step_settings.max_runtime_in_seconds
258
322
  )
259
- processor_args_for_step.setdefault(
323
+ tags = step_settings.tags
324
+ args_for_step_executor.setdefault(
260
325
  "tags",
261
- [
262
- {"Key": key, "Value": value}
263
- for key, value in step_settings.processor_tags.items()
264
- ]
265
- if step_settings.processor_tags
266
- else None,
326
+ (
327
+ [
328
+ {"Key": key, "Value": value}
329
+ for key, value in tags.items()
330
+ ]
331
+ if tags
332
+ else None
333
+ ),
334
+ )
335
+ args_for_step_executor.setdefault(
336
+ "instance_type", step_settings.instance_type
267
337
  )
268
338
 
269
339
  # Set values that cannot be overwritten
270
- processor_args_for_step["image_uri"] = image
271
- processor_args_for_step["instance_count"] = 1
272
- processor_args_for_step["sagemaker_session"] = session
273
- processor_args_for_step["entrypoint"] = entrypoint
274
- processor_args_for_step["base_job_name"] = orchestrator_run_name
275
- processor_args_for_step["env"] = environment
276
-
277
- # Convert network_config to sagemaker.network.NetworkConfig if present
278
- network_config = processor_args_for_step.get("network_config")
340
+ args_for_step_executor["image_uri"] = image
341
+ args_for_step_executor["instance_count"] = 1
342
+ args_for_step_executor["sagemaker_session"] = session
343
+ args_for_step_executor["base_job_name"] = orchestrator_run_name
344
+
345
+ # Convert network_config to sagemaker.network.NetworkConfig if
346
+ # present
347
+ network_config = args_for_step_executor.get("network_config")
348
+
279
349
  if network_config and isinstance(network_config, dict):
280
350
  try:
281
- processor_args_for_step["network_config"] = NetworkConfig(
351
+ args_for_step_executor["network_config"] = NetworkConfig(
282
352
  **network_config
283
353
  )
284
354
  except TypeError:
285
- # If the network_config passed is not compatible with the NetworkConfig class,
286
- # raise a more informative error.
355
+ # If the network_config passed is not compatible with the
356
+ # NetworkConfig class, raise a more informative error.
287
357
  raise TypeError(
288
- "Expected a sagemaker.network.NetworkConfig compatible object for the network_config argument, "
358
+ "Expected a sagemaker.network.NetworkConfig "
359
+ "compatible object for the network_config argument, "
289
360
  "but the network_config processor argument is invalid."
290
361
  "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
291
362
  "for more information about the NetworkConfig class."
@@ -317,17 +388,21 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
317
388
 
318
389
  # Construct S3 outputs from container for step
319
390
  outputs = None
391
+ output_path = None
320
392
 
321
393
  if step_settings.output_data_s3_uri is None:
322
394
  pass
323
395
  elif isinstance(step_settings.output_data_s3_uri, str):
324
- outputs = [
325
- ProcessingOutput(
326
- source="/opt/ml/processing/output/data",
327
- destination=step_settings.output_data_s3_uri,
328
- s3_upload_mode=step_settings.output_data_s3_mode,
329
- )
330
- ]
396
+ if use_training_step:
397
+ output_path = step_settings.output_data_s3_uri
398
+ else:
399
+ outputs = [
400
+ ProcessingOutput(
401
+ source="/opt/ml/processing/output/data",
402
+ destination=step_settings.output_data_s3_uri,
403
+ s3_upload_mode=step_settings.output_data_s3_mode,
404
+ )
405
+ ]
331
406
  elif isinstance(step_settings.output_data_s3_uri, dict):
332
407
  outputs = []
333
408
  for (
@@ -342,17 +417,37 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
342
417
  )
343
418
  )
344
419
 
345
- # Create Processor and ProcessingStep
346
- processor = sagemaker.processing.Processor(
347
- **processor_args_for_step
348
- )
349
- sagemaker_step = ProcessingStep(
350
- name=step_name,
351
- processor=processor,
352
- depends_on=step.spec.upstream_steps,
353
- inputs=inputs,
354
- outputs=outputs,
355
- )
420
+ if use_training_step:
421
+ # Create Estimator and TrainingStep
422
+ estimator = sagemaker.estimator.Estimator(
423
+ keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
424
+ output_path=output_path,
425
+ environment=environment,
426
+ container_entry_point=entrypoint,
427
+ **args_for_step_executor,
428
+ )
429
+ sagemaker_step = TrainingStep(
430
+ name=step_name,
431
+ depends_on=step.spec.upstream_steps,
432
+ inputs=inputs,
433
+ estimator=estimator,
434
+ )
435
+ else:
436
+ # Create Processor and ProcessingStep
437
+ processor = sagemaker.processing.Processor(
438
+ entrypoint=entrypoint,
439
+ env=environment,
440
+ **args_for_step_executor,
441
+ )
442
+
443
+ sagemaker_step = ProcessingStep(
444
+ name=step_name,
445
+ processor=processor,
446
+ depends_on=step.spec.upstream_steps,
447
+ inputs=inputs,
448
+ outputs=outputs,
449
+ )
450
+
356
451
  sagemaker_steps.append(sagemaker_step)
357
452
 
358
453
  # construct the pipeline from the sagemaker_steps
@@ -363,48 +458,37 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
363
458
  )
364
459
 
365
460
  pipeline.create(role_arn=self.config.execution_role)
366
- pipeline_execution = pipeline.start()
461
+ execution = pipeline.start()
367
462
  logger.warning(
368
463
  "Steps can take 5-15 minutes to start running "
369
464
  "when using the Sagemaker Orchestrator."
370
465
  )
371
466
 
467
+ # Yield metadata based on the generated execution object
468
+ yield from self.compute_metadata(execution=execution)
469
+
372
470
  # mainly for testing purposes, we wait for the pipeline to finish
373
471
  if self.config.synchronous:
374
472
  logger.info(
375
473
  "Executing synchronously. Waiting for pipeline to finish... \n"
376
- "At this point you can `Ctrl-C` out without cancelling the execution."
474
+ "At this point you can `Ctrl-C` out without cancelling the "
475
+ "execution."
377
476
  )
378
477
  try:
379
- pipeline_execution.wait(
478
+ execution.wait(
380
479
  delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
381
480
  )
382
481
  logger.info("Pipeline completed successfully.")
383
482
  except WaiterError:
384
483
  raise RuntimeError(
385
- "Timed out while waiting for pipeline execution to finish. For long-running "
386
- "pipelines we recommend configuring your orchestrator for asynchronous execution. "
484
+ "Timed out while waiting for pipeline execution to "
485
+ "finish. For long-running pipelines we recommend "
486
+ "configuring your orchestrator for asynchronous execution. "
387
487
  "The following command does this for you: \n"
388
- f"`zenml orchestrator update {self.name} --synchronous=False`"
488
+ f"`zenml orchestrator update {self.name} "
489
+ f"--synchronous=False`"
389
490
  )
390
491
 
391
- def _get_region_name(self) -> str:
392
- """Returns the AWS region name.
393
-
394
- Returns:
395
- The region name.
396
-
397
- Raises:
398
- RuntimeError: If the region name cannot be retrieved.
399
- """
400
- try:
401
- return cast(str, sagemaker.Session().boto_region_name)
402
- except Exception as e:
403
- raise RuntimeError(
404
- "Unable to get region name. Please ensure that you have "
405
- "configured your AWS credentials correctly."
406
- ) from e
407
-
408
492
  def get_pipeline_run_metadata(
409
493
  self, run_id: UUID
410
494
  ) -> Dict[str, "MetadataType"]:
@@ -416,16 +500,17 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
416
500
  Returns:
417
501
  A dictionary of metadata.
418
502
  """
503
+ pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
419
504
  run_metadata: Dict[str, "MetadataType"] = {
420
- "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
505
+ "pipeline_execution_arn": pipeline_execution_arn,
421
506
  }
422
- try:
423
- region_name = self._get_region_name()
424
- except RuntimeError:
425
- logger.warning("Unable to get region name from AWS Sagemaker.")
426
- return run_metadata
427
507
 
428
508
  aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
509
+
510
+ region_name, _, _ = dissect_pipeline_execution_arn(
511
+ pipeline_execution_arn=pipeline_execution_arn
512
+ )
513
+
429
514
  orchestrator_logs_url = (
430
515
  f"https://{region_name}.console.aws.amazon.com/"
431
516
  f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
@@ -434,3 +519,173 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
434
519
  )
435
520
  run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
436
521
  return run_metadata
522
+
523
+ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
524
+ """Refreshes the status of a specific pipeline run.
525
+
526
+ Args:
527
+ run: The run that was executed by this orchestrator.
528
+
529
+ Returns:
530
+ the actual status of the pipeline job.
531
+
532
+ Raises:
533
+ AssertionError: If the run was not executed by to this orchestrator.
534
+ ValueError: If it fetches an unknown state or if we can not fetch
535
+ the orchestrator run ID.
536
+ """
537
+ # Make sure that the stack exists and is accessible
538
+ if run.stack is None:
539
+ raise ValueError(
540
+ "The stack that the run was executed on is not available "
541
+ "anymore."
542
+ )
543
+
544
+ # Make sure that the run belongs to this orchestrator
545
+ assert (
546
+ self.id
547
+ == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
548
+ )
549
+
550
+ # Initialize the Sagemaker client
551
+ session = self._get_sagemaker_session()
552
+ sagemaker_client = session.sagemaker_client
553
+
554
+ # Fetch the status of the _PipelineExecution
555
+ if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
556
+ run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
557
+ elif run.orchestrator_run_id is not None:
558
+ run_id = run.orchestrator_run_id
559
+ else:
560
+ raise ValueError(
561
+ "Can not find the orchestrator run ID, thus can not fetch "
562
+ "the status."
563
+ )
564
+ status = sagemaker_client.describe_pipeline_execution(
565
+ PipelineExecutionArn=run_id
566
+ )["PipelineExecutionStatus"]
567
+
568
+ # Map the potential outputs to ZenML ExecutionStatus. Potential values:
569
+ # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
570
+ if status in ["Executing", "Stopping"]:
571
+ return ExecutionStatus.RUNNING
572
+ elif status in ["Stopped", "Failed"]:
573
+ return ExecutionStatus.FAILED
574
+ elif status in ["Succeeded"]:
575
+ return ExecutionStatus.COMPLETED
576
+ else:
577
+ raise ValueError("Unknown status for the pipeline execution.")
578
+
579
+ def compute_metadata(
580
+ self, execution: Any
581
+ ) -> Iterator[Dict[str, MetadataType]]:
582
+ """Generate run metadata based on the generated Sagemaker Execution.
583
+
584
+ Args:
585
+ execution: The corresponding _PipelineExecution object.
586
+
587
+ Yields:
588
+ A dictionary of metadata related to the pipeline run.
589
+ """
590
+ # Metadata
591
+ metadata: Dict[str, MetadataType] = {}
592
+
593
+ # Orchestrator Run ID
594
+ if run_id := self._compute_orchestrator_run_id(execution):
595
+ metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
596
+
597
+ # URL to the Sagemaker's pipeline view
598
+ if orchestrator_url := self._compute_orchestrator_url(execution):
599
+ metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
600
+
601
+ # URL to the corresponding CloudWatch page
602
+ if logs_url := self._compute_orchestrator_logs_url(execution):
603
+ metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
604
+
605
+ yield metadata
606
+
607
+ @staticmethod
608
+ def _compute_orchestrator_url(
609
+ pipeline_execution: Any,
610
+ ) -> Optional[str]:
611
+ """Generate the Orchestrator Dashboard URL upon pipeline execution.
612
+
613
+ Args:
614
+ pipeline_execution: The corresponding _PipelineExecution object.
615
+
616
+ Returns:
617
+ the URL to the dashboard view in SageMaker.
618
+ """
619
+ try:
620
+ region_name, pipeline_name, execution_id = (
621
+ dissect_pipeline_execution_arn(pipeline_execution.arn)
622
+ )
623
+
624
+ # Get the Sagemaker session
625
+ session = pipeline_execution.sagemaker_session
626
+
627
+ # List the Studio domains and get the Studio Domain ID
628
+ domains_response = session.sagemaker_client.list_domains()
629
+ studio_domain_id = domains_response["Domains"][0]["DomainId"]
630
+
631
+ return (
632
+ f"https://studio-{studio_domain_id}.studio.{region_name}."
633
+ f"sagemaker.aws/pipelines/view/{pipeline_name}/executions"
634
+ f"/{execution_id}/graph"
635
+ )
636
+
637
+ except Exception as e:
638
+ logger.warning(
639
+ f"There was an issue while extracting the pipeline url: {e}"
640
+ )
641
+ return None
642
+
643
+ @staticmethod
644
+ def _compute_orchestrator_logs_url(
645
+ pipeline_execution: Any,
646
+ ) -> Optional[str]:
647
+ """Generate the CloudWatch URL upon pipeline execution.
648
+
649
+ Args:
650
+ pipeline_execution: The corresponding _PipelineExecution object.
651
+
652
+ Returns:
653
+ the URL querying the pipeline logs in CloudWatch on AWS.
654
+ """
655
+ try:
656
+ region_name, _, execution_id = dissect_pipeline_execution_arn(
657
+ pipeline_execution.arn
658
+ )
659
+
660
+ return (
661
+ f"https://{region_name}.console.aws.amazon.com/"
662
+ f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
663
+ f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
664
+ f"$3Dpipelines-{execution_id}-"
665
+ )
666
+ except Exception as e:
667
+ logger.warning(
668
+ f"There was an issue while extracting the logs url: {e}"
669
+ )
670
+ return None
671
+
672
+ @staticmethod
673
+ def _compute_orchestrator_run_id(
674
+ pipeline_execution: Any,
675
+ ) -> Optional[str]:
676
+ """Fetch the Orchestrator Run ID upon pipeline execution.
677
+
678
+ Args:
679
+ pipeline_execution: The corresponding _PipelineExecution object.
680
+
681
+ Returns:
682
+ the Execution ID of the run in SageMaker.
683
+ """
684
+ try:
685
+ return str(pipeline_execution.arn)
686
+
687
+ except Exception as e:
688
+ logger.warning(
689
+ f"There was an issue while extracting the pipeline run ID: {e}"
690
+ )
691
+ return None