zenml-nightly 0.83.1.dev20250709__py3-none-any.whl → 0.83.1.dev20250711__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifact_stores/base_artifact_store.py +51 -23
  3. zenml/artifacts/utils.py +3 -1
  4. zenml/cli/login.py +141 -18
  5. zenml/cli/pipeline.py +13 -2
  6. zenml/cli/project.py +8 -6
  7. zenml/cli/utils.py +63 -16
  8. zenml/client.py +4 -1
  9. zenml/config/compiler.py +1 -0
  10. zenml/config/retry_config.py +5 -3
  11. zenml/config/step_configurations.py +7 -1
  12. zenml/console.py +4 -1
  13. zenml/constants.py +3 -1
  14. zenml/container_registries/base_container_registry.py +17 -5
  15. zenml/enums.py +13 -4
  16. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +150 -117
  17. zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +43 -42
  18. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +16 -7
  19. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +18 -12
  20. zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py +7 -1
  21. zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +58 -23
  22. zenml/integrations/feast/flavors/feast_feature_store_flavor.py +18 -5
  23. zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py +10 -42
  24. zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +99 -92
  25. zenml/integrations/gcp/google_credentials_mixin.py +13 -8
  26. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +18 -9
  27. zenml/integrations/huggingface/__init__.py +1 -1
  28. zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +28 -30
  29. zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py +56 -40
  30. zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +59 -48
  31. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +189 -97
  32. zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +48 -33
  33. zenml/integrations/kubernetes/orchestrators/kube_utils.py +172 -0
  34. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +219 -24
  35. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +98 -24
  36. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +59 -0
  37. zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +41 -25
  38. zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +51 -44
  39. zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py +9 -4
  40. zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +13 -12
  41. zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +32 -7
  42. zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py +7 -1
  43. zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +34 -25
  44. zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +14 -11
  45. zenml/logger.py +6 -4
  46. zenml/logging/step_logging.py +8 -7
  47. zenml/login/web_login.py +13 -6
  48. zenml/models/v2/core/model_version.py +9 -1
  49. zenml/models/v2/core/pipeline_run.py +1 -59
  50. zenml/models/v2/core/step_run.py +35 -1
  51. zenml/orchestrators/base_orchestrator.py +70 -9
  52. zenml/orchestrators/dag_runner.py +3 -1
  53. zenml/orchestrators/publish_utils.py +4 -1
  54. zenml/orchestrators/step_launcher.py +77 -139
  55. zenml/orchestrators/step_run_utils.py +16 -0
  56. zenml/orchestrators/step_runner.py +1 -4
  57. zenml/pipelines/build_utils.py +2 -1
  58. zenml/pipelines/pipeline_decorator.py +6 -1
  59. zenml/pipelines/pipeline_definition.py +7 -0
  60. zenml/stack/authentication_mixin.py +6 -5
  61. zenml/stack/flavor.py +5 -1
  62. zenml/utils/code_utils.py +2 -1
  63. zenml/utils/docker_utils.py +22 -0
  64. zenml/utils/io_utils.py +18 -0
  65. zenml/utils/pipeline_docker_image_builder.py +4 -1
  66. zenml/utils/run_utils.py +101 -8
  67. zenml/zen_server/auth.py +0 -1
  68. zenml/zen_server/deploy/daemon/daemon_zen_server.py +4 -0
  69. zenml/zen_server/deploy/docker/docker_zen_server.py +2 -0
  70. zenml/zen_server/routers/runs_endpoints.py +20 -28
  71. zenml/zen_stores/migrations/versions/360fa84718bf_step_run_versioning.py +64 -0
  72. zenml/zen_stores/migrations/versions/85289fea86ff_adding_source_to_logs.py +1 -1
  73. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +21 -0
  74. zenml/zen_stores/schemas/pipeline_run_schemas.py +31 -2
  75. zenml/zen_stores/schemas/step_run_schemas.py +41 -17
  76. zenml/zen_stores/sql_zen_store.py +152 -32
  77. zenml/zen_stores/template_utils.py +29 -9
  78. zenml_nightly-0.83.1.dev20250711.dist-info/METADATA +486 -0
  79. {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250711.dist-info}/RECORD +82 -81
  80. zenml_nightly-0.83.1.dev20250709.dist-info/METADATA +0 -538
  81. {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250711.dist-info}/LICENSE +0 -0
  82. {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250711.dist-info}/WHEEL +0 -0
  83. {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250711.dist-info}/entry_points.txt +0 -0
@@ -13,6 +13,8 @@
13
13
  # permissions and limitations under the License.
14
14
  """Retry configuration for a step."""
15
15
 
16
+ from pydantic import NonNegativeInt, PositiveInt
17
+
16
18
  from zenml.config.strict_base_model import StrictBaseModel
17
19
 
18
20
 
@@ -22,6 +24,6 @@ class StepRetryConfig(StrictBaseModel):
22
24
  Delay is an integer (specified in seconds).
23
25
  """
24
26
 
25
- max_retries: int = 1
26
- delay: int = 0 # in seconds
27
- backoff: int = 0
27
+ max_retries: PositiveInt = 1
28
+ delay: NonNegativeInt = 0 # in seconds
29
+ backoff: NonNegativeInt = 0
@@ -289,6 +289,7 @@ class StepConfiguration(PartialStepConfiguration):
289
289
  "extra",
290
290
  "failure_hook_source",
291
291
  "success_hook_source",
292
+ "retry",
292
293
  "substitutions",
293
294
  },
294
295
  exclude_none=True,
@@ -300,12 +301,17 @@ class StepConfiguration(PartialStepConfiguration):
300
301
  "extra",
301
302
  "failure_hook_source",
302
303
  "success_hook_source",
304
+ "retry",
303
305
  "substitutions",
304
306
  },
305
307
  exclude_none=True,
306
308
  )
307
309
 
308
- updated_config = self.model_copy(update=pipeline_values, deep=True)
310
+ updated_config_dict = {
311
+ **self.model_dump(),
312
+ **pipeline_values,
313
+ }
314
+ updated_config = self.model_validate(updated_config_dict)
309
315
  return update_model(updated_config, original_values)
310
316
  else:
311
317
  return self.model_copy(deep=True)
zenml/console.py CHANGED
@@ -18,11 +18,14 @@ from rich.style import Style
18
18
  from rich.theme import Theme
19
19
 
20
20
  zenml_style_defaults = {
21
- "info": Style(color="cyan", dim=True),
21
+ "info": Style(color="white", dim=True),
22
22
  "warning": Style(color="yellow"),
23
23
  "danger": Style(color="red", bold=True),
24
24
  "title": Style(color="cyan", bold=True, underline=True),
25
25
  "error": Style(color="red"),
26
+ "success": Style(color="green"),
27
+ "repr.str": Style(color="white", dim=False),
28
+ "repr.uuid": Style(color="magenta", dim=False),
26
29
  }
27
30
 
28
31
  zenml_custom_theme = Theme(zenml_style_defaults)
zenml/constants.py CHANGED
@@ -166,12 +166,14 @@ ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING = (
166
166
  ENV_ZENML_SKIP_IMAGE_BUILDER_DEFAULT = "ZENML_SKIP_IMAGE_BUILDER_DEFAULT"
167
167
  ENV_ZENML_SKIP_STACK_VALIDATION = "ZENML_SKIP_STACK_VALIDATION"
168
168
  ENV_ZENML_SERVER = "ZENML_SERVER"
169
+ ENV_ZENML_SERVER_ALLOW_LOCAL_FILE_ACCESS = (
170
+ "ZENML_SERVER_ALLOW_LOCAL_FILE_ACCESS"
171
+ )
169
172
  ENV_ZENML_ENFORCE_TYPE_ANNOTATIONS = "ZENML_ENFORCE_TYPE_ANNOTATIONS"
170
173
  ENV_ZENML_ENABLE_IMPLICIT_AUTH_METHODS = "ZENML_ENABLE_IMPLICIT_AUTH_METHODS"
171
174
  ENV_ZENML_DISABLE_PIPELINE_LOGS_STORAGE = "ZENML_DISABLE_PIPELINE_LOGS_STORAGE"
172
175
  ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE"
173
176
  ENV_ZENML_DISABLE_STEP_NAMES_IN_LOGS = "ZENML_DISABLE_STEP_NAMES_IN_LOGS"
174
- ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
175
177
  ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT"
176
178
  ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
177
179
  "ZENML_PIPELINE_API_TOKEN_EXPIRATION"
@@ -16,7 +16,7 @@
16
16
  import re
17
17
  from typing import TYPE_CHECKING, Optional, Tuple, Type, cast
18
18
 
19
- from pydantic import field_validator
19
+ from pydantic import Field, field_validator
20
20
 
21
21
  from zenml.constants import DOCKER_REGISTRY_RESOURCE_TYPE
22
22
  from zenml.enums import StackComponentType
@@ -36,12 +36,24 @@ if TYPE_CHECKING:
36
36
  class BaseContainerRegistryConfig(AuthenticationConfigMixin):
37
37
  """Base config for a container registry.
38
38
 
39
- Attributes:
40
- uri: The URI of the container registry.
39
+ Configuration for connecting to container image registries.
40
+ Field descriptions are defined inline using Field() descriptors.
41
41
  """
42
42
 
43
- uri: str
44
- default_repository: Optional[str] = None
43
+ uri: str = Field(
44
+ description="Container registry URI (e.g., 'gcr.io' for Google Container "
45
+ "Registry, 'docker.io' for Docker Hub, 'registry.gitlab.com' for GitLab "
46
+ "Container Registry, 'ghcr.io' for GitHub Container Registry). This is "
47
+ "the base URL where container images will be pushed to and pulled from."
48
+ )
49
+ default_repository: Optional[str] = Field(
50
+ default=None,
51
+ description="Default repository namespace for image storage (e.g., "
52
+ "'username' for Docker Hub, 'project-id' for GCR, 'organization' for "
53
+ "GitHub Container Registry). If not specified, images will be stored at "
54
+ "the registry root. For Docker Hub this would mean only official images "
55
+ "can be pushed.",
56
+ )
45
57
 
46
58
  @field_validator("uri")
47
59
  @classmethod
zenml/enums.py CHANGED
@@ -78,6 +78,8 @@ class ExecutionStatus(StrEnum):
78
78
  COMPLETED = "completed"
79
79
  RUNNING = "running"
80
80
  CACHED = "cached"
81
+ RETRYING = "retrying"
82
+ RETRIED = "retried"
81
83
  STOPPED = "stopped"
82
84
  STOPPING = "stopping"
83
85
 
@@ -92,9 +94,19 @@ class ExecutionStatus(StrEnum):
92
94
  ExecutionStatus.FAILED,
93
95
  ExecutionStatus.COMPLETED,
94
96
  ExecutionStatus.CACHED,
97
+ ExecutionStatus.RETRIED,
95
98
  ExecutionStatus.STOPPED,
96
99
  }
97
100
 
101
+ @property
102
+ def is_successful(self) -> bool:
103
+ """Whether the execution status refers to a successful execution.
104
+
105
+ Returns:
106
+ Whether the execution status refers to a successful execution.
107
+ """
108
+ return self in {ExecutionStatus.COMPLETED, ExecutionStatus.CACHED}
109
+
98
110
 
99
111
  class LoggingLevels(Enum):
100
112
  """Enum for logging levels."""
@@ -411,9 +423,6 @@ class OnboardingStep(StrEnum):
411
423
  DEVICE_VERIFIED = "device_verified"
412
424
  PROJECT_CREATED = "project_created"
413
425
  PIPELINE_RUN = "pipeline_run"
414
- SECOND_PIPELINE_RUN = "second_pipeline_run"
415
- THIRD_PIPELINE_RUN = "third_pipeline_run"
416
- STARTER_SETUP_COMPLETED = "starter_setup_completed"
417
426
  STACK_WITH_REMOTE_ORCHESTRATOR_CREATED = (
418
427
  "stack_with_remote_orchestrator_created"
419
428
  )
@@ -426,7 +435,7 @@ class OnboardingStep(StrEnum):
426
435
  PIPELINE_RUN_WITH_REMOTE_ARTIFACT_STORE = (
427
436
  "pipeline_run_with_remote_artifact_store"
428
437
  )
429
- PRODUCTION_SETUP_COMPLETED = "production_setup_completed"
438
+ OSS_ONBOARDING_COMPLETED = "oss_onboarding_completed"
430
439
  PRO_ONBOARDING_COMPLETED = "pro_onboarding_completed"
431
440
 
432
441
 
@@ -37,100 +37,123 @@ DEFAULT_OUTPUT_DATA_S3_MODE = "EndOfJob"
37
37
 
38
38
 
39
39
  class SagemakerOrchestratorSettings(BaseSettings):
40
- """Settings for the Sagemaker orchestrator.
41
-
42
- Attributes:
43
- synchronous: If `True`, the client running a pipeline using this
44
- orchestrator waits until all steps finish running. If `False`,
45
- the client returns immediately and the pipeline is executed
46
- asynchronously. Defaults to `True`.
47
- instance_type: The instance type to use for the processing job.
48
- execution_role: The IAM role to use for the step execution.
49
- processor_role: DEPRECATED: use `execution_role` instead.
50
- volume_size_in_gb: The size of the EBS volume to use for the processing
51
- job.
52
- max_runtime_in_seconds: The maximum runtime in seconds for the
53
- processing job.
54
- tags: Tags to apply to the Processor/Estimator assigned to the step.
55
- pipeline_tags: Tags to apply to the pipeline via the
56
- sagemaker.workflow.pipeline.Pipeline.create method.
57
- processor_tags: DEPRECATED: use `tags` instead.
58
- keep_alive_period_in_seconds: The time in seconds after which the
59
- provisioned instance will be terminated if not used. This is only
60
- applicable for TrainingStep type and it is not possible to use
61
- TrainingStep type if the `output_data_s3_uri` is set to Dict[str, str].
62
- use_training_step: Whether to use the TrainingStep type.
63
- It is not possible to use TrainingStep type
64
- if the `output_data_s3_uri` is set to Dict[str, str] or if the
65
- `output_data_s3_mode` != "EndOfJob".
66
- processor_args: Arguments that are directly passed to the SageMaker
67
- Processor for a specific step, allowing for overriding the default
68
- settings provided when configuring the component. See
69
- https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor
70
- for a full list of arguments.
71
- For processor_args.instance_type, check
72
- https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
73
- for a list of available instance types.
74
- environment: Environment variables to pass to the container.
75
- estimator_args: Arguments that are directly passed to the SageMaker
76
- Estimator for a specific step, allowing for overriding the default
77
- settings provided when configuring the component. See
78
- https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
79
- for a full list of arguments.
80
- For a list of available instance types, check
81
- https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html.
82
- input_data_s3_mode: How data is made available to the container.
83
- Two possible input modes: File, Pipe.
84
- input_data_s3_uri: S3 URI where data is located if not locally,
85
- e.g. s3://my-bucket/my-data/train. How data will be made available
86
- to the container is configured with input_data_s3_mode. Two possible
87
- input types:
88
- - str: S3 location where training data is saved.
89
- - Dict[str, str]: (ChannelName, S3Location) which represent
90
- - Dict[str, str]: (ChannelName, S3Location) which represent
91
- channels (e.g. training, validation, testing) where
92
- specific parts of the data are saved in S3.
93
- output_data_s3_mode: How data is uploaded to the S3 bucket.
94
- Two possible output modes: EndOfJob, Continuous.
95
- output_data_s3_uri: S3 URI where data is uploaded after or during processing run.
96
- e.g. s3://my-bucket/my-data/output. How data will be made available
97
- to the container is configured with output_data_s3_mode. Two possible
98
- input types:
99
- - str: S3 location where data will be uploaded from a local folder
100
- named /opt/ml/processing/output/data.
101
- - Dict[str, str]: (ChannelName, S3Location) which represent
102
- channels (e.g. output_one, output_two) where
103
- specific parts of the data are stored locally for S3 upload.
104
- Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
105
- """
106
-
107
- synchronous: bool = True
40
+ """Settings for the Sagemaker orchestrator."""
41
+
42
+ synchronous: bool = Field(
43
+ True,
44
+ description="Controls whether pipeline execution blocks the client. If True, "
45
+ "the client waits until all steps complete before returning. If False, "
46
+ "returns immediately and executes asynchronously. Useful for long-running "
47
+ "production pipelines where you don't want to maintain a connection",
48
+ )
108
49
 
109
- instance_type: Optional[str] = None
110
- execution_role: Optional[str] = None
111
- volume_size_in_gb: int = 30
112
- max_runtime_in_seconds: int = 86400
113
- tags: Dict[str, str] = {}
114
- pipeline_tags: Dict[str, str] = {}
115
- keep_alive_period_in_seconds: Optional[int] = 300 # 5 minutes
116
- use_training_step: Optional[bool] = None
50
+ instance_type: Optional[str] = Field(
51
+ None,
52
+ description="AWS EC2 instance type for step execution. Must be a valid "
53
+ "SageMaker-supported instance type. Examples: 'ml.t3.medium' (2 vCPU, 4GB RAM), "
54
+ "'ml.m5.xlarge' (4 vCPU, 16GB RAM), 'ml.p3.2xlarge' (8 vCPU, 61GB RAM, 1 GPU). "
55
+ "Defaults to ml.m5.xlarge for training steps or ml.t3.medium for processing steps",
56
+ )
57
+ execution_role: Optional[str] = Field(
58
+ None,
59
+ description="IAM role ARN for SageMaker step execution permissions. Must have "
60
+ "necessary policies attached (SageMakerFullAccess, S3 access, etc.). "
61
+ "Example: 'arn:aws:iam::123456789012:role/SageMakerExecutionRole'. "
62
+ "If not provided, uses the default SageMaker execution role",
63
+ )
64
+ volume_size_in_gb: int = Field(
65
+ 30,
66
+ description="EBS volume size in GB for step execution storage. Must be between "
67
+ "1-16384 GB. Used for temporary files, model artifacts, and data processing. "
68
+ "Larger volumes needed for big datasets or model training. Example: 30 for "
69
+ "small jobs, 100+ for large ML training jobs",
70
+ )
71
+ max_runtime_in_seconds: int = Field(
72
+ 86400, # 24 hours
73
+ description="Maximum execution time in seconds before job termination. Must be "
74
+ "between 1-432000 seconds (5 days). Used to prevent runaway jobs and control costs. "
75
+ "Examples: 3600 (1 hour), 86400 (24 hours), 259200 (3 days). "
76
+ "Consider your longest expected step duration",
77
+ )
78
+ tags: Dict[str, str] = Field(
79
+ default_factory=dict,
80
+ description="Tags to apply to the Processor/Estimator assigned to the step. "
81
+ "Example: {'Environment': 'Production', 'Project': 'MLOps'}",
82
+ )
83
+ pipeline_tags: Dict[str, str] = Field(
84
+ default_factory=dict,
85
+ description="Tags to apply to the pipeline via the "
86
+ "sagemaker.workflow.pipeline.Pipeline.create method. Example: "
87
+ "{'Environment': 'Production', 'Project': 'MLOps'}",
88
+ )
89
+ keep_alive_period_in_seconds: Optional[int] = Field(
90
+ 300, # 5 minutes
91
+ description="The time in seconds after which the provisioned instance "
92
+ "will be terminated if not used. This is only applicable for "
93
+ "TrainingStep type.",
94
+ )
95
+ use_training_step: Optional[bool] = Field(
96
+ None,
97
+ description="Whether to use the TrainingStep type. It is not possible "
98
+ "to use TrainingStep type if the `output_data_s3_uri` is set to "
99
+ "Dict[str, str] or if the `output_data_s3_mode` != 'EndOfJob'.",
100
+ )
117
101
 
118
- processor_args: Dict[str, Any] = {}
119
- estimator_args: Dict[str, Any] = {}
120
- environment: Dict[str, str] = {}
102
+ processor_args: Dict[str, Any] = Field(
103
+ default_factory=dict,
104
+ description="Arguments that are directly passed to the SageMaker "
105
+ "Processor for a specific step, allowing for overriding the default "
106
+ "settings provided when configuring the component. Example: "
107
+ "{'instance_count': 2, 'base_job_name': 'my-processing-job'}",
108
+ )
109
+ estimator_args: Dict[str, Any] = Field(
110
+ default_factory=dict,
111
+ description="Arguments that are directly passed to the SageMaker "
112
+ "Estimator for a specific step, allowing for overriding the default "
113
+ "settings provided when configuring the component. Example: "
114
+ "{'train_instance_count': 2, 'train_max_run': 3600}",
115
+ )
116
+ environment: Dict[str, str] = Field(
117
+ default_factory=dict,
118
+ description="Environment variables to pass to the container. "
119
+ "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
120
+ )
121
121
 
122
- input_data_s3_mode: str = "File"
122
+ input_data_s3_mode: str = Field(
123
+ "File",
124
+ description="How data is made available to the container. "
125
+ "Two possible input modes: File, Pipe.",
126
+ )
123
127
  input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
124
- default=None, union_mode="left_to_right"
128
+ default=None,
129
+ union_mode="left_to_right",
130
+ description="S3 URI where data is located if not locally. Example string: "
131
+ "'s3://my-bucket/my-data/train'. Example dict: "
132
+ "{'training': 's3://bucket/train', 'validation': 's3://bucket/val'}",
125
133
  )
126
134
 
127
- output_data_s3_mode: str = DEFAULT_OUTPUT_DATA_S3_MODE
135
+ output_data_s3_mode: str = Field(
136
+ DEFAULT_OUTPUT_DATA_S3_MODE,
137
+ description="How data is uploaded to the S3 bucket. "
138
+ "Two possible output modes: EndOfJob, Continuous.",
139
+ )
128
140
  output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
129
- default=None, union_mode="left_to_right"
141
+ default=None,
142
+ union_mode="left_to_right",
143
+ description="S3 URI where data is uploaded after or during processing run. "
144
+ "Example string: 's3://my-bucket/my-data/output'. Example dict: "
145
+ "{'output_one': 's3://bucket/out1', 'output_two': 's3://bucket/out2'}",
146
+ )
147
+ processor_role: Optional[str] = Field(
148
+ None,
149
+ description="DEPRECATED: use `execution_role` instead. "
150
+ "The IAM role to use for the step execution.",
151
+ )
152
+ processor_tags: Optional[Dict[str, str]] = Field(
153
+ None,
154
+ description="DEPRECATED: use `tags` instead. "
155
+ "Tags to apply to the Processor assigned to the step.",
130
156
  )
131
-
132
- processor_role: Optional[str] = None
133
- processor_tags: Optional[Dict[str, str]] = None
134
157
  _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
135
158
  ("processor_role", "execution_role"), ("processor_tags", "tags")
136
159
  )
@@ -184,39 +207,49 @@ class SagemakerOrchestratorConfig(
184
207
  `aws_secret_access_key`, and optional `aws_auth_role_arn`,
185
208
  - If none of the above are provided, unspecified credentials will be
186
209
  loaded from the default AWS config.
187
-
188
- Attributes:
189
- execution_role: The IAM role ARN to use for the pipeline.
190
- scheduler_role: The ARN of the IAM role that will be assumed by
191
- the EventBridge service to launch Sagemaker pipelines
192
- (For more details regarding the required permissions, please check:
193
- https://docs.zenml.io/stacks/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules)
194
- aws_access_key_id: The AWS access key ID to use to authenticate to AWS.
195
- If not provided, the value from the default AWS config will be used.
196
- aws_secret_access_key: The AWS secret access key to use to authenticate
197
- to AWS. If not provided, the value from the default AWS config will
198
- be used.
199
- aws_profile: The AWS profile to use for authentication if not using
200
- service connectors or explicit credentials. If not provided, the
201
- default profile will be used.
202
- aws_auth_role_arn: The ARN of an intermediate IAM role to assume when
203
- authenticating to AWS.
204
- region: The AWS region where the processing job will be run. If not
205
- provided, the value from the default AWS config will be used.
206
- bucket: Name of the S3 bucket to use for storing artifacts
207
- from the job run. If not provided, a default bucket will be created
208
- based on the following format:
209
- "sagemaker-{region}-{aws-account-id}".
210
210
  """
211
211
 
212
- execution_role: str
213
- scheduler_role: Optional[str] = None
214
- aws_access_key_id: Optional[str] = SecretField(default=None)
215
- aws_secret_access_key: Optional[str] = SecretField(default=None)
216
- aws_profile: Optional[str] = None
217
- aws_auth_role_arn: Optional[str] = None
218
- region: Optional[str] = None
219
- bucket: Optional[str] = None
212
+ execution_role: str = Field(
213
+ ..., description="The IAM role ARN to use for the pipeline."
214
+ )
215
+ scheduler_role: Optional[str] = Field(
216
+ None,
217
+ description="The ARN of the IAM role that will be assumed by "
218
+ "the EventBridge service to launch Sagemaker pipelines. "
219
+ "Required for scheduled pipelines.",
220
+ )
221
+ aws_access_key_id: Optional[str] = SecretField(
222
+ default=None,
223
+ description="The AWS access key ID to use to authenticate to AWS. "
224
+ "If not provided, the value from the default AWS config will be used.",
225
+ )
226
+ aws_secret_access_key: Optional[str] = SecretField(
227
+ default=None,
228
+ description="The AWS secret access key to use to authenticate to AWS. "
229
+ "If not provided, the value from the default AWS config will be used.",
230
+ )
231
+ aws_profile: Optional[str] = Field(
232
+ None,
233
+ description="The AWS profile to use for authentication if not using "
234
+ "service connectors or explicit credentials. If not provided, the "
235
+ "default profile will be used.",
236
+ )
237
+ aws_auth_role_arn: Optional[str] = Field(
238
+ None,
239
+ description="The ARN of an intermediate IAM role to assume when "
240
+ "authenticating to AWS.",
241
+ )
242
+ region: Optional[str] = Field(
243
+ None,
244
+ description="The AWS region where the processing job will be run. "
245
+ "If not provided, the value from the default AWS config will be used.",
246
+ )
247
+ bucket: Optional[str] = Field(
248
+ None,
249
+ description="Name of the S3 bucket to use for storing artifacts "
250
+ "from the job run. If not provided, a default bucket will be created "
251
+ "based on the following format: 'sagemaker-{region}-{aws-account-id}'.",
252
+ )
220
253
 
221
254
  @property
222
255
  def is_remote(self) -> bool:
@@ -34,38 +34,37 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class SagemakerStepOperatorSettings(BaseSettings):
37
- """Settings for the Sagemaker step operator.
38
-
39
- Attributes:
40
- experiment_name: The name for the experiment to which the job
41
- will be associated. If not provided, the job runs would be
42
- independent.
43
- input_data_s3_uri: S3 URI where training data is located if not locally,
44
- e.g. s3://my-bucket/my-data/train. How data will be made available
45
- to the container is configured with estimator_args.input_mode. Two possible
46
- input types:
47
- - str: S3 location where training data is saved.
48
- - Dict[str, str]: (ChannelName, S3Location) which represent
49
- channels (e.g. training, validation, testing) where
50
- specific parts of the data are saved in S3.
51
- estimator_args: Arguments that are directly passed to the SageMaker
52
- Estimator. See
53
- https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
54
- for a full list of arguments.
55
- For estimator_args.instance_type, check
56
- https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
57
- for a list of available instance types.
58
- environment: Environment variables to pass to the container.
59
-
60
- """
61
-
62
- instance_type: Optional[str] = None
63
- experiment_name: Optional[str] = None
37
+ """Settings for the Sagemaker step operator."""
38
+
39
+ instance_type: Optional[str] = Field(
40
+ None,
41
+ description="DEPRECATED: The instance type to use for the step execution. "
42
+ "Use estimator_args instead. Example: 'ml.m5.xlarge'",
43
+ )
44
+ experiment_name: Optional[str] = Field(
45
+ None,
46
+ description="The name for the experiment to which the job will be associated. "
47
+ "If not provided, the job runs would be independent. Example: 'my-training-experiment'",
48
+ )
64
49
  input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
65
- default=None, union_mode="left_to_right"
50
+ default=None,
51
+ union_mode="left_to_right",
52
+ description="S3 URI where training data is located if not locally. "
53
+ "Example string: 's3://my-bucket/my-data/train'. Example dict: "
54
+ "{'training': 's3://bucket/train', 'validation': 's3://bucket/val'}",
55
+ )
56
+ estimator_args: Dict[str, Any] = Field(
57
+ default_factory=dict,
58
+ description="Arguments that are directly passed to the SageMaker Estimator. "
59
+ "See SageMaker documentation for available arguments and instance types. Example: "
60
+ "{'instance_type': 'ml.m5.xlarge', 'instance_count': 1, "
61
+ "'train_max_run': 3600, 'input_mode': 'File'}",
62
+ )
63
+ environment: Dict[str, str] = Field(
64
+ default_factory=dict,
65
+ description="Environment variables to pass to the container during execution. "
66
+ "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
66
67
  )
67
- estimator_args: Dict[str, Any] = {}
68
- environment: Dict[str, str] = {}
69
68
 
70
69
  _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
71
70
  "instance_type"
@@ -75,18 +74,20 @@ class SagemakerStepOperatorSettings(BaseSettings):
75
74
  class SagemakerStepOperatorConfig(
76
75
  BaseStepOperatorConfig, SagemakerStepOperatorSettings
77
76
  ):
78
- """Config for the Sagemaker step operator.
79
-
80
- Attributes:
81
- role: The role that has to be assigned to the jobs which are
82
- running in Sagemaker.
83
- bucket: Name of the S3 bucket to use for storing artifacts
84
- from the job run. If not provided, a default bucket will be created
85
- based on the following format: "sagemaker-{region}-{aws-account-id}".
86
- """
87
-
88
- role: str
89
- bucket: Optional[str] = None
77
+ """Config for the Sagemaker step operator."""
78
+
79
+ role: str = Field(
80
+ ...,
81
+ description="The IAM role ARN that has to be assigned to the jobs "
82
+ "running in SageMaker. This role must have the necessary permissions "
83
+ "to access SageMaker and S3 resources.",
84
+ )
85
+ bucket: Optional[str] = Field(
86
+ None,
87
+ description="Name of the S3 bucket to use for storing artifacts from the job run. "
88
+ "If not provided, a default bucket will be created based on the format: "
89
+ "'sagemaker-{region}-{aws-account-id}'.",
90
+ )
90
91
 
91
92
  @property
92
93
  def is_remote(self) -> bool:
@@ -807,14 +807,20 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
807
807
  settings=settings,
808
808
  )
809
809
 
810
- def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
810
+ def fetch_status(
811
+ self, run: "PipelineRunResponse", include_steps: bool = False
812
+ ) -> Tuple[
813
+ Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]]
814
+ ]:
811
815
  """Refreshes the status of a specific pipeline run.
812
816
 
813
817
  Args:
814
818
  run: The run that was executed by this orchestrator.
819
+ include_steps: Whether to fetch steps
815
820
 
816
821
  Returns:
817
- the actual status of the pipeline job.
822
+ A tuple of (pipeline_status, step_statuses_dict).
823
+ Step statuses are not supported for SageMaker, so step_statuses_dict will always be None.
818
824
 
819
825
  Raises:
820
826
  AssertionError: If the run was not executed by to this orchestrator.
@@ -855,18 +861,21 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
855
861
  # Map the potential outputs to ZenML ExecutionStatus. Potential values:
856
862
  # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribePipelineExecution.html
857
863
  if status == "Executing":
858
- return ExecutionStatus.RUNNING
864
+ pipeline_status = ExecutionStatus.RUNNING
859
865
  elif status == "Stopping":
860
- return ExecutionStatus.STOPPING
866
+ pipeline_status = ExecutionStatus.STOPPING
861
867
  elif status == "Stopped":
862
- return ExecutionStatus.STOPPED
868
+ pipeline_status = ExecutionStatus.STOPPED
863
869
  elif status == "Failed":
864
- return ExecutionStatus.FAILED
870
+ pipeline_status = ExecutionStatus.FAILED
865
871
  elif status == "Succeeded":
866
- return ExecutionStatus.COMPLETED
872
+ pipeline_status = ExecutionStatus.COMPLETED
867
873
  else:
868
874
  raise ValueError("Unknown status for the pipeline execution.")
869
875
 
876
+ # SageMaker doesn't support step-level status fetching yet
877
+ return pipeline_status, None
878
+
870
879
  def compute_metadata(
871
880
  self,
872
881
  execution_arn: str,