zenml-nightly 0.66.0.dev20240923__py3-none-any.whl → 0.66.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/__init__.py +7 -0
- zenml/cli/base.py +2 -2
- zenml/cli/pipeline.py +21 -0
- zenml/cli/utils.py +14 -11
- zenml/client.py +68 -3
- zenml/config/step_configurations.py +0 -5
- zenml/constants.py +3 -0
- zenml/enums.py +2 -0
- zenml/integrations/__init__.py +1 -0
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +76 -7
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +370 -115
- zenml/integrations/azure/__init__.py +6 -2
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +157 -4
- zenml/integrations/constants.py +1 -0
- zenml/integrations/deepchecks/__init__.py +1 -1
- zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +55 -14
- zenml/integrations/deepchecks/validation_checks.py +62 -5
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +207 -18
- zenml/integrations/lightning/__init__.py +1 -1
- zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +9 -0
- zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +18 -17
- zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +2 -6
- zenml/integrations/mlflow/steps/mlflow_registry.py +2 -0
- zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +38 -26
- zenml/integrations/skypilot_kubernetes/__init__.py +52 -0
- zenml/integrations/skypilot_kubernetes/flavors/__init__.py +26 -0
- zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +125 -0
- zenml/integrations/skypilot_kubernetes/orchestrators/__init__.py +25 -0
- zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +74 -0
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +1 -1
- zenml/models/v2/base/filter.py +315 -149
- zenml/models/v2/base/scoped.py +5 -2
- zenml/models/v2/core/artifact_version.py +69 -8
- zenml/models/v2/core/model.py +43 -6
- zenml/models/v2/core/model_version.py +49 -1
- zenml/models/v2/core/model_version_artifact.py +18 -3
- zenml/models/v2/core/model_version_pipeline_run.py +18 -4
- zenml/models/v2/core/pipeline.py +108 -1
- zenml/models/v2/core/pipeline_run.py +172 -21
- zenml/models/v2/core/run_template.py +53 -1
- zenml/models/v2/core/stack.py +33 -5
- zenml/models/v2/core/step_run.py +7 -0
- zenml/new/pipelines/pipeline.py +4 -0
- zenml/new/pipelines/run_utils.py +4 -1
- zenml/orchestrators/base_orchestrator.py +41 -12
- zenml/stack/stack.py +11 -2
- zenml/utils/env_utils.py +54 -1
- zenml/utils/string_utils.py +50 -0
- zenml/zen_server/cloud_utils.py +33 -8
- zenml/zen_server/dashboard/assets/{404-iO8vpun1.js → 404-Y50hSt65.js} +1 -1
- zenml/zen_server/dashboard/assets/{@reactflow-B6kq9fJZ.js → @reactflow-ytavUpwh.js} +1 -1
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-xLR9a1iw.js +1 -0
- zenml/zen_server/dashboard/assets/{CodeSnippet-DNWdQmbo.js → CodeSnippet-IxXNxUDa.js} +2 -2
- zenml/zen_server/dashboard/assets/{CollapsibleCard-B2OVjWYE.js → CollapsibleCard-BhutZbBL.js} +1 -1
- zenml/zen_server/dashboard/assets/{Commands-DsoaVElZ.js → Commands-Bf-rd1z8.js} +1 -1
- zenml/zen_server/dashboard/assets/ComponentBadge-gKR1OIwG.js +1 -0
- zenml/zen_server/dashboard/assets/{CopyButton-BqE_-PHO.js → CopyButton-DcFHidFJ.js} +1 -1
- zenml/zen_server/dashboard/assets/{CsvVizualization-Dyasr2jU.js → CsvVizualization-QSbjrfxw.js} +1 -1
- zenml/zen_server/dashboard/assets/{DialogItem-Cz1VLRwa.js → DialogItem-Cd3HqST4.js} +1 -1
- zenml/zen_server/dashboard/assets/{Error-DorJD_va.js → Error-BhwdmqK7.js} +1 -1
- zenml/zen_server/dashboard/assets/{ExecutionStatus-CIfQTutR.js → ExecutionStatus-D6r6aK8J.js} +1 -1
- zenml/zen_server/dashboard/assets/{Helpbox-CmfvtNeq.js → Helpbox-0pBpTwTm.js} +1 -1
- zenml/zen_server/dashboard/assets/Infobox-BTK_EUKT.js +1 -0
- zenml/zen_server/dashboard/assets/{InlineAvatar-Ds2ZFHPc.js → InlineAvatar-CA3DFMcM.js} +1 -1
- zenml/zen_server/dashboard/assets/Partials-QLOZw624.js +1 -0
- zenml/zen_server/dashboard/assets/{ProviderIcon-BOQJgapd.js → ProviderIcon-C16CCIN4.js} +1 -1
- zenml/zen_server/dashboard/assets/{ProviderRadio-BsYBw9YA.js → ProviderRadio-D3FuCHf3.js} +1 -1
- zenml/zen_server/dashboard/assets/{SearchField-W3GXpLlI.js → SearchField-BzmfxS0L.js} +1 -1
- zenml/zen_server/dashboard/assets/SecretTooltip-BaMwHF-Q.js +1 -0
- zenml/zen_server/dashboard/assets/{SetPassword-B-0a8UCj.js → SetPassword-DuIC65H9.js} +1 -1
- zenml/zen_server/dashboard/assets/{Tick-i1DYsVcX.js → Tick-DJTCF0Re.js} +1 -1
- zenml/zen_server/dashboard/assets/{UpdatePasswordSchemas-C6Zb7ASL.js → UpdatePasswordSchemas-CUm-DMpw.js} +1 -1
- zenml/zen_server/dashboard/assets/UsageReason-CKw0juLF.js +1 -0
- zenml/zen_server/dashboard/assets/{WizardFooter-BHbO7zOa.js → WizardFooter-Cv9ApYWU.js} +1 -1
- zenml/zen_server/dashboard/assets/{all-pipeline-runs-query-BBEe6I9-.js → all-pipeline-runs-query-BA3R2Sey.js} +1 -1
- zenml/zen_server/dashboard/assets/{cloud-only-BuP4Kt_7.js → cloud-only-BB4BVa6E.js} +1 -1
- zenml/zen_server/dashboard/assets/{create-stack-B2x2d4r1.js → create-stack-F29xAUEx.js} +1 -1
- zenml/zen_server/dashboard/assets/delete-run-CP0pcJ3U.js +1 -0
- zenml/zen_server/dashboard/assets/{form-schemas-Bap0f854.js → form-schemas-BKXwSDK2.js} +1 -1
- zenml/zen_server/dashboard/assets/index-BhJ6ZJxv.css +1 -0
- zenml/zen_server/dashboard/assets/{index-B9wVwe7u.js → index-Ci0nJ8EZ.js} +5 -5
- zenml/zen_server/dashboard/assets/{index-DFi8BroH.js → index-D-mtoBj3.js} +1 -1
- zenml/zen_server/dashboard/assets/{login-mutation-DwxUz8VA.js → login-mutation-ax6iL2Mb.js} +1 -1
- zenml/zen_server/dashboard/assets/{not-found-D5i9DunU.js → not-found-DbjllLY_.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-oS4hqS8M.js → page-3qPX9WYH.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-iwoJnwPv.js → page-6mfzecin.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DGMa3ZQL.js → page-8kYmrh0B.js} +1 -1
- zenml/zen_server/dashboard/assets/page-B1n7_W7z.js +1 -0
- zenml/zen_server/dashboard/assets/page-BDg1F-Ug.js +6 -0
- zenml/zen_server/dashboard/assets/{page-xQG6GmFJ.js → page-BXarY9K2.js} +1 -1
- zenml/zen_server/dashboard/assets/page-BZZhLo2u.js +1 -0
- zenml/zen_server/dashboard/assets/page-Bbf_oBjn.js +1 -0
- zenml/zen_server/dashboard/assets/page-BjjuBvZG.js +9 -0
- zenml/zen_server/dashboard/assets/{page-J0s8Sq3N.js → page-BukXK1Aa.js} +1 -1
- zenml/zen_server/dashboard/assets/page-CHaQkFK5.js +1 -0
- zenml/zen_server/dashboard/assets/{page-BitfWsiW.js → page-CKHNAq7z.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-DE03uZZR.js → page-CS0SYFK8.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-WCQ659by.js → page-CvKnNK1S.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-CrSdkteO.js → page-DGM1CbYT.js} +2 -2
- zenml/zen_server/dashboard/assets/{page-DQGCHKrQ.js → page-DMSLXKGT.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DOmIZ2ra.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DgM-N9RL.js → page-DRfcRK1w.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DYVmJ9_w.js +3 -0
- zenml/zen_server/dashboard/assets/{page-BiF8hLbO.js → page-DcTjHmYZ.js} +1 -1
- zenml/zen_server/dashboard/assets/page-DuqYMYmH.js +1 -0
- zenml/zen_server/dashboard/assets/page-Dwow2doB.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DQdwZZ9x.js → page-HkVBdZl6.js} +1 -1
- zenml/zen_server/dashboard/assets/{page-bimkItOg.js → page-MAXyfXBq.js} +1 -1
- zenml/zen_server/dashboard/assets/page-miU2rhYG.js +1 -0
- zenml/zen_server/dashboard/assets/page-p0BhSAWx.js +1 -0
- zenml/zen_server/dashboard/assets/{page-DFCK65G9.js → page-uORspyRu.js} +1 -1
- zenml/zen_server/dashboard/assets/persist-BxIR2XZs.js +1 -0
- zenml/zen_server/dashboard/assets/{persist-xsYgVtR1.js → persist-CfJMar_k.js} +1 -1
- zenml/zen_server/dashboard/assets/sharedSchema-vub0rii3.js +14 -0
- zenml/zen_server/dashboard/assets/stack-detail-query-DQcyzG-2.js +1 -0
- zenml/zen_server/dashboard/assets/tick-circle-m-hJG8i9.js +1 -0
- zenml/zen_server/dashboard/assets/{update-server-settings-mutation-DNqmQXDM.js → update-server-settings-mutation-FGVP7X2U.js} +1 -1
- zenml/zen_server/dashboard/assets/{url-DwbuKk1b.js → url-CbAPzsmT.js} +1 -1
- zenml/zen_server/dashboard/index.html +4 -4
- zenml/zen_server/dashboard_legacy/asset-manifest.json +4 -4
- zenml/zen_server/dashboard_legacy/index.html +1 -1
- zenml/zen_server/dashboard_legacy/{precache-manifest.290b95d5b43efa3368b3dc63d20c4782.js → precache-manifest.6d320abb70db612019dda6c4948e7a90.js} +4 -4
- zenml/zen_server/dashboard_legacy/service-worker.js +1 -1
- zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js → main.fa9299d5.chunk.js} +2 -2
- zenml/zen_server/dashboard_legacy/static/js/{main.840d1bf0.chunk.js.map → main.fa9299d5.chunk.js.map} +1 -1
- zenml/zen_server/routers/runs_endpoints.py +89 -3
- zenml/zen_stores/sql_zen_store.py +1 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/METADATA +8 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/RECORD +133 -125
- zenml/zen_server/dashboard/assets/AlertDialogDropdownItem-BXeSvmMY.js +0 -1
- zenml/zen_server/dashboard/assets/EditSecretDialog-Du423_3U.js +0 -1
- zenml/zen_server/dashboard/assets/Infobox-BL9NOS37.js +0 -1
- zenml/zen_server/dashboard/assets/Partials-DX-8iEa1.js +0 -1
- zenml/zen_server/dashboard/assets/UsageReason-CCnzmwS8.js +0 -1
- zenml/zen_server/dashboard/assets/index-6DYjZgDn.css +0 -1
- zenml/zen_server/dashboard/assets/page-BFuJICXM.js +0 -9
- zenml/zen_server/dashboard/assets/page-CDOQLrPC.js +0 -1
- zenml/zen_server/dashboard/assets/page-CEJWu1YO.js +0 -1
- zenml/zen_server/dashboard/assets/page-CIbehp7V.js +0 -1
- zenml/zen_server/dashboard/assets/page-CLiRGfWo.js +0 -1
- zenml/zen_server/dashboard/assets/page-CV44mQn9.js +0 -1
- zenml/zen_server/dashboard/assets/page-D5F3DJjm.js +0 -1
- zenml/zen_server/dashboard/assets/page-DI-qTWrm.js +0 -1
- zenml/zen_server/dashboard/assets/page-Dt8VgzbE.js +0 -1
- zenml/zen_server/dashboard/assets/page-oSqx9dkH.js +0 -1
- zenml/zen_server/dashboard/assets/page-p3GqEAUW.js +0 -1
- zenml/zen_server/dashboard/assets/page-qvcUVPE-.js +0 -1
- zenml/zen_server/dashboard/assets/persist-mEZN_fgH.js +0 -1
- zenml/zen_server/dashboard/assets/sharedSchema-BfZcy7aP.js +0 -14
- zenml/zen_server/dashboard/assets/stack-detail-query-CU4egfhp.js +0 -1
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.66.0.dev20240923.dist-info → zenml_nightly-0.66.0.dev20240928.dist-info}/entry_points.txt +0 -0
@@ -15,7 +15,16 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import re
|
18
|
-
from typing import
|
18
|
+
from typing import (
|
19
|
+
TYPE_CHECKING,
|
20
|
+
Any,
|
21
|
+
Dict,
|
22
|
+
Iterator,
|
23
|
+
Optional,
|
24
|
+
Tuple,
|
25
|
+
Type,
|
26
|
+
cast,
|
27
|
+
)
|
19
28
|
from uuid import UUID
|
20
29
|
|
21
30
|
import boto3
|
@@ -25,13 +34,15 @@ from sagemaker.network import NetworkConfig
|
|
25
34
|
from sagemaker.processing import ProcessingInput, ProcessingOutput
|
26
35
|
from sagemaker.workflow.execution_variables import ExecutionVariables
|
27
36
|
from sagemaker.workflow.pipeline import Pipeline
|
28
|
-
from sagemaker.workflow.steps import ProcessingStep
|
37
|
+
from sagemaker.workflow.steps import ProcessingStep, TrainingStep
|
29
38
|
|
30
39
|
from zenml.config.base_settings import BaseSettings
|
31
40
|
from zenml.constants import (
|
41
|
+
METADATA_ORCHESTRATOR_LOGS_URL,
|
42
|
+
METADATA_ORCHESTRATOR_RUN_ID,
|
32
43
|
METADATA_ORCHESTRATOR_URL,
|
33
44
|
)
|
34
|
-
from zenml.enums import StackComponentType
|
45
|
+
from zenml.enums import ExecutionStatus, StackComponentType
|
35
46
|
from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
|
36
47
|
SagemakerOrchestratorConfig,
|
37
48
|
SagemakerOrchestratorSettings,
|
@@ -48,7 +59,7 @@ from zenml.stack import StackValidator
|
|
48
59
|
from zenml.utils.env_utils import split_environment_variables
|
49
60
|
|
50
61
|
if TYPE_CHECKING:
|
51
|
-
from zenml.models import PipelineDeploymentResponse
|
62
|
+
from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
|
52
63
|
from zenml.stack import Stack
|
53
64
|
|
54
65
|
ENV_ZENML_SAGEMAKER_RUN_ID = "ZENML_SAGEMAKER_RUN_ID"
|
@@ -58,6 +69,34 @@ POLLING_DELAY = 30
|
|
58
69
|
logger = get_logger(__name__)
|
59
70
|
|
60
71
|
|
72
|
+
def dissect_pipeline_execution_arn(
|
73
|
+
pipeline_execution_arn: str,
|
74
|
+
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
75
|
+
"""Extract region name, pipeline name, and execution id from the ARN.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
pipeline_execution_arn: the pipeline execution ARN
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
Region Name, Pipeline Name, Execution ID in order
|
82
|
+
"""
|
83
|
+
# Extract region_name
|
84
|
+
region_match = re.search(r"sagemaker:(.*?):", pipeline_execution_arn)
|
85
|
+
region_name = region_match.group(1) if region_match else None
|
86
|
+
|
87
|
+
# Extract pipeline_name
|
88
|
+
pipeline_match = re.search(
|
89
|
+
r"pipeline/(.*?)/execution", pipeline_execution_arn
|
90
|
+
)
|
91
|
+
pipeline_name = pipeline_match.group(1) if pipeline_match else None
|
92
|
+
|
93
|
+
# Extract execution_id
|
94
|
+
execution_match = re.search(r"execution/(.*)", pipeline_execution_arn)
|
95
|
+
execution_id = execution_match.group(1) if execution_match else None
|
96
|
+
|
97
|
+
return region_name, pipeline_name, execution_id
|
98
|
+
|
99
|
+
|
61
100
|
class SagemakerOrchestrator(ContainerizedOrchestrator):
|
62
101
|
"""Orchestrator responsible for running pipelines on Sagemaker."""
|
63
102
|
|
@@ -136,42 +175,16 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
136
175
|
"""
|
137
176
|
return SagemakerOrchestratorSettings
|
138
177
|
|
139
|
-
def
|
140
|
-
|
141
|
-
deployment: "PipelineDeploymentResponse",
|
142
|
-
stack: "Stack",
|
143
|
-
environment: Dict[str, str],
|
144
|
-
) -> None:
|
145
|
-
"""Prepares or runs a pipeline on Sagemaker.
|
178
|
+
def _get_sagemaker_session(self) -> sagemaker.Session:
|
179
|
+
"""Method to create the sagemaker session with proper authentication.
|
146
180
|
|
147
|
-
|
148
|
-
|
149
|
-
stack: The stack to run on.
|
150
|
-
environment: Environment variables to set in the orchestration
|
151
|
-
environment.
|
181
|
+
Returns:
|
182
|
+
The Sagemaker Session.
|
152
183
|
|
153
184
|
Raises:
|
154
|
-
RuntimeError: If
|
155
|
-
|
156
|
-
TypeError: If the network_config passed is not compatible with the
|
157
|
-
AWS SageMaker NetworkConfig class.
|
185
|
+
RuntimeError: If the connector returns the wrong type for the
|
186
|
+
session.
|
158
187
|
"""
|
159
|
-
if deployment.schedule:
|
160
|
-
logger.warning(
|
161
|
-
"The Sagemaker Orchestrator currently does not support the "
|
162
|
-
"use of schedules. The `schedule` will be ignored "
|
163
|
-
"and the pipeline will be run immediately."
|
164
|
-
)
|
165
|
-
|
166
|
-
# sagemaker requires pipelineName to use alphanum and hyphens only
|
167
|
-
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
|
168
|
-
pipeline_name=deployment.pipeline_configuration.name
|
169
|
-
)
|
170
|
-
# replace all non-alphanum and non-hyphens with hyphens
|
171
|
-
orchestrator_run_name = re.sub(
|
172
|
-
r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
|
173
|
-
)
|
174
|
-
|
175
188
|
# Get authenticated session
|
176
189
|
# Option 1: Service connector
|
177
190
|
boto_session: boto3.Session
|
@@ -205,10 +218,51 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
205
218
|
aws_session_token=credentials["SessionToken"],
|
206
219
|
region_name=self.config.region,
|
207
220
|
)
|
208
|
-
|
221
|
+
return sagemaker.Session(
|
209
222
|
boto_session=boto_session, default_bucket=self.config.bucket
|
210
223
|
)
|
211
224
|
|
225
|
+
def prepare_or_run_pipeline(
|
226
|
+
self,
|
227
|
+
deployment: "PipelineDeploymentResponse",
|
228
|
+
stack: "Stack",
|
229
|
+
environment: Dict[str, str],
|
230
|
+
) -> Iterator[Dict[str, MetadataType]]:
|
231
|
+
"""Prepares or runs a pipeline on Sagemaker.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
deployment: The deployment to prepare or run.
|
235
|
+
stack: The stack to run on.
|
236
|
+
environment: Environment variables to set in the orchestration
|
237
|
+
environment.
|
238
|
+
|
239
|
+
Raises:
|
240
|
+
RuntimeError: If a connector is used that does not return a
|
241
|
+
`boto3.Session` object.
|
242
|
+
TypeError: If the network_config passed is not compatible with the
|
243
|
+
AWS SageMaker NetworkConfig class.
|
244
|
+
|
245
|
+
Yields:
|
246
|
+
A dictionary of metadata related to the pipeline run.
|
247
|
+
"""
|
248
|
+
if deployment.schedule:
|
249
|
+
logger.warning(
|
250
|
+
"The Sagemaker Orchestrator currently does not support the "
|
251
|
+
"use of schedules. The `schedule` will be ignored "
|
252
|
+
"and the pipeline will be run immediately."
|
253
|
+
)
|
254
|
+
|
255
|
+
# sagemaker requires pipelineName to use alphanum and hyphens only
|
256
|
+
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
|
257
|
+
pipeline_name=deployment.pipeline_configuration.name
|
258
|
+
)
|
259
|
+
# replace all non-alphanum and non-hyphens with hyphens
|
260
|
+
orchestrator_run_name = re.sub(
|
261
|
+
r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
|
262
|
+
)
|
263
|
+
|
264
|
+
session = self._get_sagemaker_session()
|
265
|
+
|
212
266
|
# Sagemaker does not allow environment variables longer than 256
|
213
267
|
# characters to be passed to Processor steps. If an environment variable
|
214
268
|
# is longer than 256 characters, we split it into multiple environment
|
@@ -238,54 +292,71 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
238
292
|
ExecutionVariables.PIPELINE_EXECUTION_ARN
|
239
293
|
)
|
240
294
|
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
295
|
+
use_training_step = (
|
296
|
+
step_settings.use_training_step
|
297
|
+
if step_settings.use_training_step is not None
|
298
|
+
else (
|
299
|
+
self.config.use_training_step
|
300
|
+
if self.config.use_training_step is not None
|
301
|
+
else True
|
302
|
+
)
|
248
303
|
)
|
249
|
-
|
304
|
+
|
305
|
+
# Retrieve Executor arguments provided in the Step settings.
|
306
|
+
if use_training_step:
|
307
|
+
args_for_step_executor = step_settings.estimator_args or {}
|
308
|
+
else:
|
309
|
+
args_for_step_executor = step_settings.processor_args or {}
|
310
|
+
|
311
|
+
# Set default values from configured orchestrator Component to
|
312
|
+
# arguments to be used when they are not present in processor_args.
|
313
|
+
args_for_step_executor.setdefault(
|
250
314
|
"role",
|
251
|
-
step_settings.
|
315
|
+
step_settings.execution_role or self.config.execution_role,
|
252
316
|
)
|
253
|
-
|
317
|
+
args_for_step_executor.setdefault(
|
254
318
|
"volume_size_in_gb", step_settings.volume_size_in_gb
|
255
319
|
)
|
256
|
-
|
320
|
+
args_for_step_executor.setdefault(
|
257
321
|
"max_runtime_in_seconds", step_settings.max_runtime_in_seconds
|
258
322
|
)
|
259
|
-
|
323
|
+
tags = step_settings.tags
|
324
|
+
args_for_step_executor.setdefault(
|
260
325
|
"tags",
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
326
|
+
(
|
327
|
+
[
|
328
|
+
{"Key": key, "Value": value}
|
329
|
+
for key, value in tags.items()
|
330
|
+
]
|
331
|
+
if tags
|
332
|
+
else None
|
333
|
+
),
|
334
|
+
)
|
335
|
+
args_for_step_executor.setdefault(
|
336
|
+
"instance_type", step_settings.instance_type
|
267
337
|
)
|
268
338
|
|
269
339
|
# Set values that cannot be overwritten
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
340
|
+
args_for_step_executor["image_uri"] = image
|
341
|
+
args_for_step_executor["instance_count"] = 1
|
342
|
+
args_for_step_executor["sagemaker_session"] = session
|
343
|
+
args_for_step_executor["base_job_name"] = orchestrator_run_name
|
344
|
+
|
345
|
+
# Convert network_config to sagemaker.network.NetworkConfig if
|
346
|
+
# present
|
347
|
+
network_config = args_for_step_executor.get("network_config")
|
348
|
+
|
279
349
|
if network_config and isinstance(network_config, dict):
|
280
350
|
try:
|
281
|
-
|
351
|
+
args_for_step_executor["network_config"] = NetworkConfig(
|
282
352
|
**network_config
|
283
353
|
)
|
284
354
|
except TypeError:
|
285
|
-
# If the network_config passed is not compatible with the
|
286
|
-
# raise a more informative error.
|
355
|
+
# If the network_config passed is not compatible with the
|
356
|
+
# NetworkConfig class, raise a more informative error.
|
287
357
|
raise TypeError(
|
288
|
-
"Expected a sagemaker.network.NetworkConfig
|
358
|
+
"Expected a sagemaker.network.NetworkConfig "
|
359
|
+
"compatible object for the network_config argument, "
|
289
360
|
"but the network_config processor argument is invalid."
|
290
361
|
"See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
|
291
362
|
"for more information about the NetworkConfig class."
|
@@ -317,17 +388,21 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
317
388
|
|
318
389
|
# Construct S3 outputs from container for step
|
319
390
|
outputs = None
|
391
|
+
output_path = None
|
320
392
|
|
321
393
|
if step_settings.output_data_s3_uri is None:
|
322
394
|
pass
|
323
395
|
elif isinstance(step_settings.output_data_s3_uri, str):
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
396
|
+
if use_training_step:
|
397
|
+
output_path = step_settings.output_data_s3_uri
|
398
|
+
else:
|
399
|
+
outputs = [
|
400
|
+
ProcessingOutput(
|
401
|
+
source="/opt/ml/processing/output/data",
|
402
|
+
destination=step_settings.output_data_s3_uri,
|
403
|
+
s3_upload_mode=step_settings.output_data_s3_mode,
|
404
|
+
)
|
405
|
+
]
|
331
406
|
elif isinstance(step_settings.output_data_s3_uri, dict):
|
332
407
|
outputs = []
|
333
408
|
for (
|
@@ -342,17 +417,37 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
342
417
|
)
|
343
418
|
)
|
344
419
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
420
|
+
if use_training_step:
|
421
|
+
# Create Estimator and TrainingStep
|
422
|
+
estimator = sagemaker.estimator.Estimator(
|
423
|
+
keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
|
424
|
+
output_path=output_path,
|
425
|
+
environment=environment,
|
426
|
+
container_entry_point=entrypoint,
|
427
|
+
**args_for_step_executor,
|
428
|
+
)
|
429
|
+
sagemaker_step = TrainingStep(
|
430
|
+
name=step_name,
|
431
|
+
depends_on=step.spec.upstream_steps,
|
432
|
+
inputs=inputs,
|
433
|
+
estimator=estimator,
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
# Create Processor and ProcessingStep
|
437
|
+
processor = sagemaker.processing.Processor(
|
438
|
+
entrypoint=entrypoint,
|
439
|
+
env=environment,
|
440
|
+
**args_for_step_executor,
|
441
|
+
)
|
442
|
+
|
443
|
+
sagemaker_step = ProcessingStep(
|
444
|
+
name=step_name,
|
445
|
+
processor=processor,
|
446
|
+
depends_on=step.spec.upstream_steps,
|
447
|
+
inputs=inputs,
|
448
|
+
outputs=outputs,
|
449
|
+
)
|
450
|
+
|
356
451
|
sagemaker_steps.append(sagemaker_step)
|
357
452
|
|
358
453
|
# construct the pipeline from the sagemaker_steps
|
@@ -363,48 +458,37 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
363
458
|
)
|
364
459
|
|
365
460
|
pipeline.create(role_arn=self.config.execution_role)
|
366
|
-
|
461
|
+
execution = pipeline.start()
|
367
462
|
logger.warning(
|
368
463
|
"Steps can take 5-15 minutes to start running "
|
369
464
|
"when using the Sagemaker Orchestrator."
|
370
465
|
)
|
371
466
|
|
467
|
+
# Yield metadata based on the generated execution object
|
468
|
+
yield from self.compute_metadata(execution=execution)
|
469
|
+
|
372
470
|
# mainly for testing purposes, we wait for the pipeline to finish
|
373
471
|
if self.config.synchronous:
|
374
472
|
logger.info(
|
375
473
|
"Executing synchronously. Waiting for pipeline to finish... \n"
|
376
|
-
"At this point you can `Ctrl-C` out without cancelling the
|
474
|
+
"At this point you can `Ctrl-C` out without cancelling the "
|
475
|
+
"execution."
|
377
476
|
)
|
378
477
|
try:
|
379
|
-
|
478
|
+
execution.wait(
|
380
479
|
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
|
381
480
|
)
|
382
481
|
logger.info("Pipeline completed successfully.")
|
383
482
|
except WaiterError:
|
384
483
|
raise RuntimeError(
|
385
|
-
"Timed out while waiting for pipeline execution to
|
386
|
-
"pipelines we recommend
|
484
|
+
"Timed out while waiting for pipeline execution to "
|
485
|
+
"finish. For long-running pipelines we recommend "
|
486
|
+
"configuring your orchestrator for asynchronous execution. "
|
387
487
|
"The following command does this for you: \n"
|
388
|
-
f"`zenml orchestrator update {self.name}
|
488
|
+
f"`zenml orchestrator update {self.name} "
|
489
|
+
f"--synchronous=False`"
|
389
490
|
)
|
390
491
|
|
391
|
-
def _get_region_name(self) -> str:
|
392
|
-
"""Returns the AWS region name.
|
393
|
-
|
394
|
-
Returns:
|
395
|
-
The region name.
|
396
|
-
|
397
|
-
Raises:
|
398
|
-
RuntimeError: If the region name cannot be retrieved.
|
399
|
-
"""
|
400
|
-
try:
|
401
|
-
return cast(str, sagemaker.Session().boto_region_name)
|
402
|
-
except Exception as e:
|
403
|
-
raise RuntimeError(
|
404
|
-
"Unable to get region name. Please ensure that you have "
|
405
|
-
"configured your AWS credentials correctly."
|
406
|
-
) from e
|
407
|
-
|
408
492
|
def get_pipeline_run_metadata(
|
409
493
|
self, run_id: UUID
|
410
494
|
) -> Dict[str, "MetadataType"]:
|
@@ -416,16 +500,17 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
416
500
|
Returns:
|
417
501
|
A dictionary of metadata.
|
418
502
|
"""
|
503
|
+
pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
|
419
504
|
run_metadata: Dict[str, "MetadataType"] = {
|
420
|
-
"pipeline_execution_arn":
|
505
|
+
"pipeline_execution_arn": pipeline_execution_arn,
|
421
506
|
}
|
422
|
-
try:
|
423
|
-
region_name = self._get_region_name()
|
424
|
-
except RuntimeError:
|
425
|
-
logger.warning("Unable to get region name from AWS Sagemaker.")
|
426
|
-
return run_metadata
|
427
507
|
|
428
508
|
aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
|
509
|
+
|
510
|
+
region_name, _, _ = dissect_pipeline_execution_arn(
|
511
|
+
pipeline_execution_arn=pipeline_execution_arn
|
512
|
+
)
|
513
|
+
|
429
514
|
orchestrator_logs_url = (
|
430
515
|
f"https://{region_name}.console.aws.amazon.com/"
|
431
516
|
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
|
@@ -434,3 +519,173 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
434
519
|
)
|
435
520
|
run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
|
436
521
|
return run_metadata
|
522
|
+
|
523
|
+
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
|
524
|
+
"""Refreshes the status of a specific pipeline run.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
run: The run that was executed by this orchestrator.
|
528
|
+
|
529
|
+
Returns:
|
530
|
+
the actual status of the pipeline job.
|
531
|
+
|
532
|
+
Raises:
|
533
|
+
AssertionError: If the run was not executed by to this orchestrator.
|
534
|
+
ValueError: If it fetches an unknown state or if we can not fetch
|
535
|
+
the orchestrator run ID.
|
536
|
+
"""
|
537
|
+
# Make sure that the stack exists and is accessible
|
538
|
+
if run.stack is None:
|
539
|
+
raise ValueError(
|
540
|
+
"The stack that the run was executed on is not available "
|
541
|
+
"anymore."
|
542
|
+
)
|
543
|
+
|
544
|
+
# Make sure that the run belongs to this orchestrator
|
545
|
+
assert (
|
546
|
+
self.id
|
547
|
+
== run.stack.components[StackComponentType.ORCHESTRATOR][0].id
|
548
|
+
)
|
549
|
+
|
550
|
+
# Initialize the Sagemaker client
|
551
|
+
session = self._get_sagemaker_session()
|
552
|
+
sagemaker_client = session.sagemaker_client
|
553
|
+
|
554
|
+
# Fetch the status of the _PipelineExecution
|
555
|
+
if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
|
556
|
+
run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
|
557
|
+
elif run.orchestrator_run_id is not None:
|
558
|
+
run_id = run.orchestrator_run_id
|
559
|
+
else:
|
560
|
+
raise ValueError(
|
561
|
+
"Can not find the orchestrator run ID, thus can not fetch "
|
562
|
+
"the status."
|
563
|
+
)
|
564
|
+
status = sagemaker_client.describe_pipeline_execution(
|
565
|
+
PipelineExecutionArn=run_id
|
566
|
+
)["PipelineExecutionStatus"]
|
567
|
+
|
568
|
+
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
|
569
|
+
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
|
570
|
+
if status in ["Executing", "Stopping"]:
|
571
|
+
return ExecutionStatus.RUNNING
|
572
|
+
elif status in ["Stopped", "Failed"]:
|
573
|
+
return ExecutionStatus.FAILED
|
574
|
+
elif status in ["Succeeded"]:
|
575
|
+
return ExecutionStatus.COMPLETED
|
576
|
+
else:
|
577
|
+
raise ValueError("Unknown status for the pipeline execution.")
|
578
|
+
|
579
|
+
def compute_metadata(
|
580
|
+
self, execution: Any
|
581
|
+
) -> Iterator[Dict[str, MetadataType]]:
|
582
|
+
"""Generate run metadata based on the generated Sagemaker Execution.
|
583
|
+
|
584
|
+
Args:
|
585
|
+
execution: The corresponding _PipelineExecution object.
|
586
|
+
|
587
|
+
Yields:
|
588
|
+
A dictionary of metadata related to the pipeline run.
|
589
|
+
"""
|
590
|
+
# Metadata
|
591
|
+
metadata: Dict[str, MetadataType] = {}
|
592
|
+
|
593
|
+
# Orchestrator Run ID
|
594
|
+
if run_id := self._compute_orchestrator_run_id(execution):
|
595
|
+
metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
|
596
|
+
|
597
|
+
# URL to the Sagemaker's pipeline view
|
598
|
+
if orchestrator_url := self._compute_orchestrator_url(execution):
|
599
|
+
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
|
600
|
+
|
601
|
+
# URL to the corresponding CloudWatch page
|
602
|
+
if logs_url := self._compute_orchestrator_logs_url(execution):
|
603
|
+
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
|
604
|
+
|
605
|
+
yield metadata
|
606
|
+
|
607
|
+
@staticmethod
|
608
|
+
def _compute_orchestrator_url(
|
609
|
+
pipeline_execution: Any,
|
610
|
+
) -> Optional[str]:
|
611
|
+
"""Generate the Orchestrator Dashboard URL upon pipeline execution.
|
612
|
+
|
613
|
+
Args:
|
614
|
+
pipeline_execution: The corresponding _PipelineExecution object.
|
615
|
+
|
616
|
+
Returns:
|
617
|
+
the URL to the dashboard view in SageMaker.
|
618
|
+
"""
|
619
|
+
try:
|
620
|
+
region_name, pipeline_name, execution_id = (
|
621
|
+
dissect_pipeline_execution_arn(pipeline_execution.arn)
|
622
|
+
)
|
623
|
+
|
624
|
+
# Get the Sagemaker session
|
625
|
+
session = pipeline_execution.sagemaker_session
|
626
|
+
|
627
|
+
# List the Studio domains and get the Studio Domain ID
|
628
|
+
domains_response = session.sagemaker_client.list_domains()
|
629
|
+
studio_domain_id = domains_response["Domains"][0]["DomainId"]
|
630
|
+
|
631
|
+
return (
|
632
|
+
f"https://studio-{studio_domain_id}.studio.{region_name}."
|
633
|
+
f"sagemaker.aws/pipelines/view/{pipeline_name}/executions"
|
634
|
+
f"/{execution_id}/graph"
|
635
|
+
)
|
636
|
+
|
637
|
+
except Exception as e:
|
638
|
+
logger.warning(
|
639
|
+
f"There was an issue while extracting the pipeline url: {e}"
|
640
|
+
)
|
641
|
+
return None
|
642
|
+
|
643
|
+
@staticmethod
|
644
|
+
def _compute_orchestrator_logs_url(
|
645
|
+
pipeline_execution: Any,
|
646
|
+
) -> Optional[str]:
|
647
|
+
"""Generate the CloudWatch URL upon pipeline execution.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
pipeline_execution: The corresponding _PipelineExecution object.
|
651
|
+
|
652
|
+
Returns:
|
653
|
+
the URL querying the pipeline logs in CloudWatch on AWS.
|
654
|
+
"""
|
655
|
+
try:
|
656
|
+
region_name, _, execution_id = dissect_pipeline_execution_arn(
|
657
|
+
pipeline_execution.arn
|
658
|
+
)
|
659
|
+
|
660
|
+
return (
|
661
|
+
f"https://{region_name}.console.aws.amazon.com/"
|
662
|
+
f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
|
663
|
+
f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
|
664
|
+
f"$3Dpipelines-{execution_id}-"
|
665
|
+
)
|
666
|
+
except Exception as e:
|
667
|
+
logger.warning(
|
668
|
+
f"There was an issue while extracting the logs url: {e}"
|
669
|
+
)
|
670
|
+
return None
|
671
|
+
|
672
|
+
@staticmethod
|
673
|
+
def _compute_orchestrator_run_id(
|
674
|
+
pipeline_execution: Any,
|
675
|
+
) -> Optional[str]:
|
676
|
+
"""Fetch the Orchestrator Run ID upon pipeline execution.
|
677
|
+
|
678
|
+
Args:
|
679
|
+
pipeline_execution: The corresponding _PipelineExecution object.
|
680
|
+
|
681
|
+
Returns:
|
682
|
+
the Execution ID of the run in SageMaker.
|
683
|
+
"""
|
684
|
+
try:
|
685
|
+
return str(pipeline_execution.arn)
|
686
|
+
|
687
|
+
except Exception as e:
|
688
|
+
logger.warning(
|
689
|
+
f"There was an issue while extracting the pipeline run ID: {e}"
|
690
|
+
)
|
691
|
+
return None
|
@@ -48,9 +48,13 @@ class AzureIntegration(Integration):
|
|
48
48
|
"azure-mgmt-containerservice>=20.0.0",
|
49
49
|
"azure-storage-blob==12.17.0", # temporary fix for https://github.com/Azure/azure-sdk-for-python/issues/32056
|
50
50
|
"kubernetes",
|
51
|
-
"azure-ai-ml==1.18.0"
|
51
|
+
"azure-ai-ml==1.18.0",
|
52
|
+
# In azureml/core/_metrics.py:212 of azureml-core 1.56.0, they use
|
53
|
+
# an attribute that was removed in Numpy 2.0. However, AzureML itself
|
54
|
+
# does not have a limitation on numpy.
|
55
|
+
"numpy<2.0",
|
52
56
|
]
|
53
|
-
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes"]
|
57
|
+
REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kubernetes", "numpy"]
|
54
58
|
|
55
59
|
@classmethod
|
56
60
|
def activate(cls) -> None:
|