zenml-nightly 0.66.0.dev20240924__py3-none-any.whl → 0.66.0.dev20240926__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/__init__.py +7 -0
  3. zenml/cli/pipeline.py +21 -0
  4. zenml/constants.py +3 -0
  5. zenml/integrations/__init__.py +1 -0
  6. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +288 -71
  7. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
  8. zenml/integrations/constants.py +1 -0
  9. zenml/integrations/deepchecks/__init__.py +1 -1
  10. zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
  11. zenml/integrations/deepchecks/validation_checks.py +62 -5
  12. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
  13. zenml/integrations/lightning/__init__.py +1 -1
  14. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
  15. zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
  16. zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
  17. zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
  18. zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
  19. zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
  20. zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
  21. zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
  22. zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
  23. zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
  24. zenml/models/v2/core/pipeline_run.py +62 -1
  25. zenml/new/pipelines/run_utils.py +4 -1
  26. zenml/orchestrators/base_orchestrator.py +41 -12
  27. zenml/stack/stack.py +11 -2
  28. zenml/zen_server/cloud_utils.py +33 -8
  29. zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-CMnKjD-L.js} +1 -1
  30. zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-CEC2f0cl.js} +1 -1
  31. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BqM1UpCD.js +1 -0
  32. zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-DRy_0J4D.js} +2 -2
  33. zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-lE-75Zob.js} +1 -1
  34. zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-CVx2RAoT.js} +1 -1
  35. zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-C_yRGWuP.js} +1 -1
  36. zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-Dd0P02Iz.js} +1 -1
  37. zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-BCrc2wIk.js} +1 -1
  38. zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BuMJbG-M.js} +1 -1
  39. zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-fIulMG4w.js} +1 -1
  40. zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-CJAp4kbv.js} +1 -1
  41. zenml/zen_server/dashboard/assets/Infobox-CC70zvGO.js +1 -0
  42. zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-C3QXdFW1.js} +1 -1
  43. zenml/zen_server/dashboard/assets/{Partials-DX-8iEa1.js → Partials-Cb8lrNsi.js} +1 -1
  44. zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C9BuYVSN.js} +1 -1
  45. zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-GYc9PJtG.js} +1 -1
  46. zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BeF1yR7M.js} +1 -1
  47. zenml/zen_server/dashboard/assets/SecretTooltip-DgVWrPxX.js +1 -0
  48. zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-nAhHddXW.js} +1 -1
  49. zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-C5ZVvNRQ.js} +1 -1
  50. zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-7KFsDbKb.js} +1 -1
  51. zenml/zen_server/dashboard/assets/UsageReason-DL5NL_ZD.js +1 -0
  52. zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-CgvFSppz.js} +1 -1
  53. zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-DAPSF_74.js} +1 -1
  54. zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-CxoNxh86.js} +1 -1
  55. zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-BfgeXFuV.js} +1 -1
  56. zenml/zen_server/dashboard/assets/delete-run-OkGmZQ5G.js +1 -0
  57. zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-C09PrQUJ.js} +1 -1
  58. zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-CLT4K7oC.js} +1 -1
  59. zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-D0bJjaey.js} +3 -3
  60. zenml/zen_server/dashboard/assets/index-PcI3Xw77.css +1 -0
  61. zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-CB45FHbP.js} +1 -1
  62. zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-NtCUfXiV.js} +1 -1
  63. zenml/zen_server/dashboard/assets/page-AvcQe_oR.js +1 -0
  64. zenml/zen_server/dashboard/assets/page-B6DccgPa.js +1 -0
  65. zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-B7DTiwhv.js} +1 -1
  66. zenml/zen_server/dashboard/assets/{page-CIbehp7V.js → page-B7LduaiG.js} +1 -1
  67. zenml/zen_server/dashboard/assets/{page-CEJWu1YO.js → page-B8WlhDq6.js} +1 -1
  68. zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-BIhP9udn.js} +1 -1
  69. zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-BLS9bXB8.js} +1 -1
  70. zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-BYXn4SXu.js} +1 -1
  71. zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-Bfvwt3AB.js} +1 -1
  72. zenml/zen_server/dashboard/assets/{page-D5F3DJjm.js → page-BipKr1Pt.js} +1 -1
  73. zenml/zen_server/dashboard/assets/page-BwG4f5qc.js +1 -0
  74. zenml/zen_server/dashboard/assets/page-C1c_unjg.js +9 -0
  75. zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-C25tiRdj.js} +1 -1
  76. zenml/zen_server/dashboard/assets/page-CIATsAA7.js +1 -0
  77. zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-CKUVhcYr.js} +1 -1
  78. zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-CXLwze-m.js} +1 -1
  79. zenml/zen_server/dashboard/assets/page-D7TD0k_A.js +1 -0
  80. zenml/zen_server/dashboard/assets/{page-CDOQLrPC.js → page-DIlOQjGU.js} +1 -1
  81. zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-DJ31Huvj.js} +1 -1
  82. zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-DOqsdVzG.js} +1 -1
  83. zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DUapawuM.js} +1 -1
  84. zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-Dd3jZyrf.js} +1 -1
  85. zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DyZzYHWA.js} +2 -2
  86. zenml/zen_server/dashboard/assets/page-L_xNBh_5.js +3 -0
  87. zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-VsrKiIdF.js} +1 -1
  88. zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-ioO58ULo.js} +1 -1
  89. zenml/zen_server/dashboard/assets/page-kalpiPZz.js +6 -0
  90. zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-ChKZVcn3.js} +1 -1
  91. zenml/zen_server/dashboard/assets/{persist-mEZN_fgH.js → persist-DodaLO0k.js} +1 -1
  92. zenml/zen_server/dashboard/assets/{sharedSchema-BfZcy7aP.js → sharedSchema-BvRWAv-c.js} +1 -1
  93. zenml/zen_server/dashboard/assets/{stack-detail-query-CU4egfhp.js → stack-detail-query-C9XwNP1E.js} +1 -1
  94. zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
  95. zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-DJDefwqW.js} +1 -1
  96. zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-DdWrpIhi.js} +1 -1
  97. zenml/zen_server/dashboard/index.html +4 -4
  98. zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
  99. zenml/zen_server/dashboard_legacy/index.html +1 -1
  100. zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.4f9db97de1b48fd5944e8a766c1300fe.js} +4 -4
  101. zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
  102. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.0fdd4aad.chunk.js} +2 -2
  103. zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.0fdd4aad.chunk.js.map} +1 -1
  104. zenml/zen_server/routers/runs_endpoints.py +89 -3
  105. {zenml_nightly-0.66.0.dev20240924.dist-info → zenml_nightly-0.66.0.dev20240926.dist-info}/METADATA +8 -1
  106. {zenml_nightly-0.66.0.dev20240924.dist-info → zenml_nightly-0.66.0.dev20240926.dist-info}/RECORD +109 -102
  107. zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
  108. zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
  109. zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
  110. zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
  111. zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
  112. zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
  113. zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
  114. zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
  115. zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
  116. zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
  117. zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
  118. zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
  119. zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
  120. {zenml_nightly-0.66.0.dev20240924.dist-info → zenml_nightly-0.66.0.dev20240926.dist-info}/LICENSE +0 -0
  121. {zenml_nightly-0.66.0.dev20240924.dist-info → zenml_nightly-0.66.0.dev20240926.dist-info}/WHEEL +0 -0
  122. {zenml_nightly-0.66.0.dev20240924.dist-info → zenml_nightly-0.66.0.dev20240926.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.66.0.dev20240924
1
+ 0.66.0.dev20240926
zenml/cli/__init__.py CHANGED
@@ -1715,6 +1715,13 @@ To delete a pipeline run, use:
1715
1715
  zenml pipeline runs delete <PIPELINE_RUN_NAME_OR_ID>
1716
1716
  ```
1717
1717
 
1718
+ To refresh the status of a pipeline run, you can use the `refresh` command (
1719
+ only supported for pipelines executed on Vertex, Sagemaker or AzureML).
1720
+
1721
+ ```bash
1722
+ zenml pipeline runs refresh <PIPELINE_RUN_NAME_OR_ID>
1723
+ ```
1724
+
1718
1725
  If you run any of your pipelines with `pipeline.run(schedule=...)`, ZenML keeps
1719
1726
  track of the schedule and you can list all schedules via:
1720
1727
 
zenml/cli/pipeline.py CHANGED
@@ -500,6 +500,27 @@ def delete_pipeline_run(
500
500
  cli_utils.declare(f"Deleted pipeline run '{run_name_or_id}'.")
501
501
 
502
502
 
503
+ @runs.command("refresh")
504
+ @click.argument("run_name_or_id", type=str, required=True)
505
+ def refresh_pipeline_run(run_name_or_id: str) -> None:
506
+ """Refresh the status of a pipeline run.
507
+
508
+ Args:
509
+ run_name_or_id: The name or ID of the pipeline run to refresh.
510
+ """
511
+ try:
512
+ # Fetch and update the run
513
+ run = Client().get_pipeline_run(name_id_or_prefix=run_name_or_id)
514
+ run.refresh_run_status()
515
+
516
+ except KeyError as e:
517
+ cli_utils.error(str(e))
518
+ else:
519
+ cli_utils.declare(
520
+ f"Refreshed the status of pipeline run '{run.name}'."
521
+ )
522
+
523
+
503
524
  @pipeline.group()
504
525
  def builds() -> None:
505
526
  """Commands for pipeline builds."""
zenml/constants.py CHANGED
@@ -364,6 +364,7 @@ PIPELINE_DEPLOYMENTS = "/pipeline_deployments"
364
364
  PIPELINES = "/pipelines"
365
365
  PIPELINE_SPEC = "/pipeline-spec"
366
366
  PLUGIN_FLAVORS = "/plugin-flavors"
367
+ REFRESH = "/refresh"
367
368
  RUNS = "/runs"
368
369
  RUN_TEMPLATES = "/run_templates"
369
370
  RUN_METADATA = "/run-metadata"
@@ -430,6 +431,8 @@ SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run"
430
431
 
431
432
  # Metadata constants
432
433
  METADATA_ORCHESTRATOR_URL = "orchestrator_url"
434
+ METADATA_ORCHESTRATOR_LOGS_URL = "orchestrator_logs_url"
435
+ METADATA_ORCHESTRATOR_RUN_ID = "orchestrator_run_id"
433
436
  METADATA_EXPERIMENT_TRACKER_URL = "experiment_tracker_url"
434
437
  METADATA_DEPLOYED_MODEL_URL = "deployed_model_url"
435
438
 
@@ -69,6 +69,7 @@ from zenml.integrations.skypilot_aws import SkypilotAWSIntegration # noqa
69
69
  from zenml.integrations.skypilot_gcp import SkypilotGCPIntegration # noqa
70
70
  from zenml.integrations.skypilot_azure import SkypilotAzureIntegration # noqa
71
71
  from zenml.integrations.skypilot_lambda import SkypilotLambdaIntegration # noqa
72
+ from zenml.integrations.skypilot_kubernetes import SkypilotKubernetesIntegration # noqa
72
73
  from zenml.integrations.slack import SlackIntegration # noqa
73
74
  from zenml.integrations.spark import SparkIntegration # noqa
74
75
  from zenml.integrations.tekton import TektonIntegration # noqa
@@ -15,7 +15,16 @@
15
15
 
16
16
  import os
17
17
  import re
18
- from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type, cast
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ Any,
21
+ Dict,
22
+ Iterator,
23
+ Optional,
24
+ Tuple,
25
+ Type,
26
+ cast,
27
+ )
19
28
  from uuid import UUID
20
29
 
21
30
  import boto3
@@ -29,9 +38,11 @@ from sagemaker.workflow.steps import ProcessingStep, TrainingStep
29
38
 
30
39
  from zenml.config.base_settings import BaseSettings
31
40
  from zenml.constants import (
41
+ METADATA_ORCHESTRATOR_LOGS_URL,
42
+ METADATA_ORCHESTRATOR_RUN_ID,
32
43
  METADATA_ORCHESTRATOR_URL,
33
44
  )
34
- from zenml.enums import StackComponentType
45
+ from zenml.enums import ExecutionStatus, StackComponentType
35
46
  from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
36
47
  SagemakerOrchestratorConfig,
37
48
  SagemakerOrchestratorSettings,
@@ -48,7 +59,7 @@ from zenml.stack import StackValidator
48
59
  from zenml.utils.env_utils import split_environment_variables
49
60
 
50
61
  if TYPE_CHECKING:
51
- from zenml.models import PipelineDeploymentResponse
62
+ from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
52
63
  from zenml.stack import Stack
53
64
 
54
65
  ENV_ZENML_SAGEMAKER_RUN_ID = "ZENML_SAGEMAKER_RUN_ID"
@@ -58,6 +69,34 @@ POLLING_DELAY = 30
58
69
  logger = get_logger(__name__)
59
70
 
60
71
 
72
+ def dissect_pipeline_execution_arn(
73
+ pipeline_execution_arn: str,
74
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
75
+ """Extract region name, pipeline name, and execution id from the ARN.
76
+
77
+ Args:
78
+ pipeline_execution_arn: the pipeline execution ARN
79
+
80
+ Returns:
81
+ Region Name, Pipeline Name, Execution ID in order
82
+ """
83
+ # Extract region_name
84
+ region_match = re.search(r"sagemaker:(.*?):", pipeline_execution_arn)
85
+ region_name = region_match.group(1) if region_match else None
86
+
87
+ # Extract pipeline_name
88
+ pipeline_match = re.search(
89
+ r"pipeline/(.*?)/execution", pipeline_execution_arn
90
+ )
91
+ pipeline_name = pipeline_match.group(1) if pipeline_match else None
92
+
93
+ # Extract execution_id
94
+ execution_match = re.search(r"execution/(.*)", pipeline_execution_arn)
95
+ execution_id = execution_match.group(1) if execution_match else None
96
+
97
+ return region_name, pipeline_name, execution_id
98
+
99
+
61
100
  class SagemakerOrchestrator(ContainerizedOrchestrator):
62
101
  """Orchestrator responsible for running pipelines on Sagemaker."""
63
102
 
@@ -136,42 +175,16 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
136
175
  """
137
176
  return SagemakerOrchestratorSettings
138
177
 
139
- def prepare_or_run_pipeline(
140
- self,
141
- deployment: "PipelineDeploymentResponse",
142
- stack: "Stack",
143
- environment: Dict[str, str],
144
- ) -> None:
145
- """Prepares or runs a pipeline on Sagemaker.
178
+ def _get_sagemaker_session(self) -> sagemaker.Session:
179
+ """Method to create the sagemaker session with proper authentication.
146
180
 
147
- Args:
148
- deployment: The deployment to prepare or run.
149
- stack: The stack to run on.
150
- environment: Environment variables to set in the orchestration
151
- environment.
181
+ Returns:
182
+ The Sagemaker Session.
152
183
 
153
184
  Raises:
154
- RuntimeError: If a connector is used that does not return a
155
- `boto3.Session` object.
156
- TypeError: If the network_config passed is not compatible with the
157
- AWS SageMaker NetworkConfig class.
185
+ RuntimeError: If the connector returns the wrong type for the
186
+ session.
158
187
  """
159
- if deployment.schedule:
160
- logger.warning(
161
- "The Sagemaker Orchestrator currently does not support the "
162
- "use of schedules. The `schedule` will be ignored "
163
- "and the pipeline will be run immediately."
164
- )
165
-
166
- # sagemaker requires pipelineName to use alphanum and hyphens only
167
- unsanitized_orchestrator_run_name = get_orchestrator_run_name(
168
- pipeline_name=deployment.pipeline_configuration.name
169
- )
170
- # replace all non-alphanum and non-hyphens with hyphens
171
- orchestrator_run_name = re.sub(
172
- r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
173
- )
174
-
175
188
  # Get authenticated session
176
189
  # Option 1: Service connector
177
190
  boto_session: boto3.Session
@@ -205,10 +218,51 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
205
218
  aws_session_token=credentials["SessionToken"],
206
219
  region_name=self.config.region,
207
220
  )
208
- session = sagemaker.Session(
221
+ return sagemaker.Session(
209
222
  boto_session=boto_session, default_bucket=self.config.bucket
210
223
  )
211
224
 
225
+ def prepare_or_run_pipeline(
226
+ self,
227
+ deployment: "PipelineDeploymentResponse",
228
+ stack: "Stack",
229
+ environment: Dict[str, str],
230
+ ) -> Iterator[Dict[str, MetadataType]]:
231
+ """Prepares or runs a pipeline on Sagemaker.
232
+
233
+ Args:
234
+ deployment: The deployment to prepare or run.
235
+ stack: The stack to run on.
236
+ environment: Environment variables to set in the orchestration
237
+ environment.
238
+
239
+ Raises:
240
+ RuntimeError: If a connector is used that does not return a
241
+ `boto3.Session` object.
242
+ TypeError: If the network_config passed is not compatible with the
243
+ AWS SageMaker NetworkConfig class.
244
+
245
+ Yields:
246
+ A dictionary of metadata related to the pipeline run.
247
+ """
248
+ if deployment.schedule:
249
+ logger.warning(
250
+ "The Sagemaker Orchestrator currently does not support the "
251
+ "use of schedules. The `schedule` will be ignored "
252
+ "and the pipeline will be run immediately."
253
+ )
254
+
255
+ # sagemaker requires pipelineName to use alphanum and hyphens only
256
+ unsanitized_orchestrator_run_name = get_orchestrator_run_name(
257
+ pipeline_name=deployment.pipeline_configuration.name
258
+ )
259
+ # replace all non-alphanum and non-hyphens with hyphens
260
+ orchestrator_run_name = re.sub(
261
+ r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
262
+ )
263
+
264
+ session = self._get_sagemaker_session()
265
+
212
266
  # Sagemaker does not allow environment variables longer than 256
213
267
  # characters to be passed to Processor steps. If an environment variable
214
268
  # is longer than 256 characters, we split it into multiple environment
@@ -254,8 +308,8 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
254
308
  else:
255
309
  args_for_step_executor = step_settings.processor_args or {}
256
310
 
257
- # Set default values from configured orchestrator Component to arguments
258
- # to be used when they are not present in processor_args.
311
+ # Set default values from configured orchestrator Component to
312
+ # arguments to be used when they are not present in processor_args.
259
313
  args_for_step_executor.setdefault(
260
314
  "role",
261
315
  step_settings.execution_role or self.config.execution_role,
@@ -288,18 +342,21 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
288
342
  args_for_step_executor["sagemaker_session"] = session
289
343
  args_for_step_executor["base_job_name"] = orchestrator_run_name
290
344
 
291
- # Convert network_config to sagemaker.network.NetworkConfig if present
345
+ # Convert network_config to sagemaker.network.NetworkConfig if
346
+ # present
292
347
  network_config = args_for_step_executor.get("network_config")
348
+
293
349
  if network_config and isinstance(network_config, dict):
294
350
  try:
295
351
  args_for_step_executor["network_config"] = NetworkConfig(
296
352
  **network_config
297
353
  )
298
354
  except TypeError:
299
- # If the network_config passed is not compatible with the NetworkConfig class,
300
- # raise a more informative error.
355
+ # If the network_config passed is not compatible with the
356
+ # NetworkConfig class, raise a more informative error.
301
357
  raise TypeError(
302
- "Expected a sagemaker.network.NetworkConfig compatible object for the network_config argument, "
358
+ "Expected a sagemaker.network.NetworkConfig "
359
+ "compatible object for the network_config argument, "
303
360
  "but the network_config processor argument is invalid."
304
361
  "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
305
362
  "for more information about the NetworkConfig class."
@@ -401,48 +458,37 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
401
458
  )
402
459
 
403
460
  pipeline.create(role_arn=self.config.execution_role)
404
- pipeline_execution = pipeline.start()
461
+ execution = pipeline.start()
405
462
  logger.warning(
406
463
  "Steps can take 5-15 minutes to start running "
407
464
  "when using the Sagemaker Orchestrator."
408
465
  )
409
466
 
467
+ # Yield metadata based on the generated execution object
468
+ yield from self.compute_metadata(execution=execution)
469
+
410
470
  # mainly for testing purposes, we wait for the pipeline to finish
411
471
  if self.config.synchronous:
412
472
  logger.info(
413
473
  "Executing synchronously. Waiting for pipeline to finish... \n"
414
- "At this point you can `Ctrl-C` out without cancelling the execution."
474
+ "At this point you can `Ctrl-C` out without cancelling the "
475
+ "execution."
415
476
  )
416
477
  try:
417
- pipeline_execution.wait(
478
+ execution.wait(
418
479
  delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
419
480
  )
420
481
  logger.info("Pipeline completed successfully.")
421
482
  except WaiterError:
422
483
  raise RuntimeError(
423
- "Timed out while waiting for pipeline execution to finish. For long-running "
424
- "pipelines we recommend configuring your orchestrator for asynchronous execution. "
484
+ "Timed out while waiting for pipeline execution to "
485
+ "finish. For long-running pipelines we recommend "
486
+ "configuring your orchestrator for asynchronous execution. "
425
487
  "The following command does this for you: \n"
426
- f"`zenml orchestrator update {self.name} --synchronous=False`"
488
+ f"`zenml orchestrator update {self.name} "
489
+ f"--synchronous=False`"
427
490
  )
428
491
 
429
- def _get_region_name(self) -> str:
430
- """Returns the AWS region name.
431
-
432
- Returns:
433
- The region name.
434
-
435
- Raises:
436
- RuntimeError: If the region name cannot be retrieved.
437
- """
438
- try:
439
- return cast(str, sagemaker.Session().boto_region_name)
440
- except Exception as e:
441
- raise RuntimeError(
442
- "Unable to get region name. Please ensure that you have "
443
- "configured your AWS credentials correctly."
444
- ) from e
445
-
446
492
  def get_pipeline_run_metadata(
447
493
  self, run_id: UUID
448
494
  ) -> Dict[str, "MetadataType"]:
@@ -454,16 +500,17 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
454
500
  Returns:
455
501
  A dictionary of metadata.
456
502
  """
503
+ pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
457
504
  run_metadata: Dict[str, "MetadataType"] = {
458
- "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
505
+ "pipeline_execution_arn": pipeline_execution_arn,
459
506
  }
460
- try:
461
- region_name = self._get_region_name()
462
- except RuntimeError:
463
- logger.warning("Unable to get region name from AWS Sagemaker.")
464
- return run_metadata
465
507
 
466
508
  aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
509
+
510
+ region_name, _, _ = dissect_pipeline_execution_arn(
511
+ pipeline_execution_arn=pipeline_execution_arn
512
+ )
513
+
467
514
  orchestrator_logs_url = (
468
515
  f"https://{region_name}.console.aws.amazon.com/"
469
516
  f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
@@ -472,3 +519,173 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
472
519
  )
473
520
  run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
474
521
  return run_metadata
522
+
523
+ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
524
+ """Refreshes the status of a specific pipeline run.
525
+
526
+ Args:
527
+ run: The run that was executed by this orchestrator.
528
+
529
+ Returns:
530
+ the actual status of the pipeline job.
531
+
532
+ Raises:
533
+ AssertionError: If the run was not executed by to this orchestrator.
534
+ ValueError: If it fetches an unknown state or if we can not fetch
535
+ the orchestrator run ID.
536
+ """
537
+ # Make sure that the stack exists and is accessible
538
+ if run.stack is None:
539
+ raise ValueError(
540
+ "The stack that the run was executed on is not available "
541
+ "anymore."
542
+ )
543
+
544
+ # Make sure that the run belongs to this orchestrator
545
+ assert (
546
+ self.id
547
+ == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
548
+ )
549
+
550
+ # Initialize the Sagemaker client
551
+ session = self._get_sagemaker_session()
552
+ sagemaker_client = session.sagemaker_client
553
+
554
+ # Fetch the status of the _PipelineExecution
555
+ if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
556
+ run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
557
+ elif run.orchestrator_run_id is not None:
558
+ run_id = run.orchestrator_run_id
559
+ else:
560
+ raise ValueError(
561
+ "Can not find the orchestrator run ID, thus can not fetch "
562
+ "the status."
563
+ )
564
+ status = sagemaker_client.describe_pipeline_execution(
565
+ PipelineExecutionArn=run_id
566
+ )["PipelineExecutionStatus"]
567
+
568
+ # Map the potential outputs to ZenML ExecutionStatus. Potential values:
569
+ # https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
570
+ if status in ["Executing", "Stopping"]:
571
+ return ExecutionStatus.RUNNING
572
+ elif status in ["Stopped", "Failed"]:
573
+ return ExecutionStatus.FAILED
574
+ elif status in ["Succeeded"]:
575
+ return ExecutionStatus.COMPLETED
576
+ else:
577
+ raise ValueError("Unknown status for the pipeline execution.")
578
+
579
+ def compute_metadata(
580
+ self, execution: Any
581
+ ) -> Iterator[Dict[str, MetadataType]]:
582
+ """Generate run metadata based on the generated Sagemaker Execution.
583
+
584
+ Args:
585
+ execution: The corresponding _PipelineExecution object.
586
+
587
+ Yields:
588
+ A dictionary of metadata related to the pipeline run.
589
+ """
590
+ # Metadata
591
+ metadata: Dict[str, MetadataType] = {}
592
+
593
+ # Orchestrator Run ID
594
+ if run_id := self._compute_orchestrator_run_id(execution):
595
+ metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
596
+
597
+ # URL to the Sagemaker's pipeline view
598
+ if orchestrator_url := self._compute_orchestrator_url(execution):
599
+ metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
600
+
601
+ # URL to the corresponding CloudWatch page
602
+ if logs_url := self._compute_orchestrator_logs_url(execution):
603
+ metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
604
+
605
+ yield metadata
606
+
607
+ @staticmethod
608
+ def _compute_orchestrator_url(
609
+ pipeline_execution: Any,
610
+ ) -> Optional[str]:
611
+ """Generate the Orchestrator Dashboard URL upon pipeline execution.
612
+
613
+ Args:
614
+ pipeline_execution: The corresponding _PipelineExecution object.
615
+
616
+ Returns:
617
+ the URL to the dashboard view in SageMaker.
618
+ """
619
+ try:
620
+ region_name, pipeline_name, execution_id = (
621
+ dissect_pipeline_execution_arn(pipeline_execution.arn)
622
+ )
623
+
624
+ # Get the Sagemaker session
625
+ session = pipeline_execution.sagemaker_session
626
+
627
+ # List the Studio domains and get the Studio Domain ID
628
+ domains_response = session.sagemaker_client.list_domains()
629
+ studio_domain_id = domains_response["Domains"][0]["DomainId"]
630
+
631
+ return (
632
+ f"https://studio-{studio_domain_id}.studio.{region_name}."
633
+ f"sagemaker.aws/pipelines/view/{pipeline_name}/executions"
634
+ f"/{execution_id}/graph"
635
+ )
636
+
637
+ except Exception as e:
638
+ logger.warning(
639
+ f"There was an issue while extracting the pipeline url: {e}"
640
+ )
641
+ return None
642
+
643
+ @staticmethod
644
+ def _compute_orchestrator_logs_url(
645
+ pipeline_execution: Any,
646
+ ) -> Optional[str]:
647
+ """Generate the CloudWatch URL upon pipeline execution.
648
+
649
+ Args:
650
+ pipeline_execution: The corresponding _PipelineExecution object.
651
+
652
+ Returns:
653
+ the URL querying the pipeline logs in CloudWatch on AWS.
654
+ """
655
+ try:
656
+ region_name, _, execution_id = dissect_pipeline_execution_arn(
657
+ pipeline_execution.arn
658
+ )
659
+
660
+ return (
661
+ f"https://{region_name}.console.aws.amazon.com/"
662
+ f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
663
+ f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
664
+ f"$3Dpipelines-{execution_id}-"
665
+ )
666
+ except Exception as e:
667
+ logger.warning(
668
+ f"There was an issue while extracting the logs url: {e}"
669
+ )
670
+ return None
671
+
672
+ @staticmethod
673
+ def _compute_orchestrator_run_id(
674
+ pipeline_execution: Any,
675
+ ) -> Optional[str]:
676
+ """Fetch the Orchestrator Run ID upon pipeline execution.
677
+
678
+ Args:
679
+ pipeline_execution: The corresponding _PipelineExecution object.
680
+
681
+ Returns:
682
+ the Execution ID of the run in SageMaker.
683
+ """
684
+ try:
685
+ return str(pipeline_execution.arn)
686
+
687
+ except Exception as e:
688
+ logger.warning(
689
+ f"There was an issue while extracting the pipeline run ID: {e}"
690
+ )
691
+ return None