sagemaker-mlp-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,452 @@
1
+ """
2
+ Pipeline wrapper for mlp_sdk
3
+ Provides simplified pipeline creation with configuration-driven defaults
4
+ """
5
+
6
+ from typing import Optional, Dict, Any, List, Union
7
+ from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
8
+ from ..config import ConfigurationManager
9
+
10
+
11
+ class PipelineWrapper:
12
+ """
13
+ Wrapper for SageMaker Pipeline operations.
14
+ Applies default configurations from ConfigurationManager.
15
+ """
16
+
17
+ def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
18
+ """
19
+ Initialize Pipeline wrapper.
20
+
21
+ Args:
22
+ config_manager: Configuration manager instance
23
+ logger: Optional logger instance
24
+ """
25
+ self.config_manager = config_manager
26
+ self.logger = logger or MLPLogger("mlp_sdk.pipeline")
27
+
28
+ def create_pipeline(self,
29
+ sagemaker_session,
30
+ pipeline_name: str,
31
+ steps: List,
32
+ parameters: Optional[List] = None,
33
+ **kwargs) -> Any:
34
+ """
35
+ Create pipeline with step connection and consistent default configurations.
36
+
37
+ Applies defaults from configuration for:
38
+ - IAM execution role
39
+ - Pipeline parameters
40
+ - Step-level configurations (inherited from processing/training defaults)
41
+
42
+ Runtime parameters override configuration defaults.
43
+
44
+ Supports parameter passing between pipeline steps.
45
+
46
+ Args:
47
+ sagemaker_session: SageMaker session object
48
+ pipeline_name: Name of the pipeline
49
+ steps: List of pipeline steps (ProcessingStep, TrainingStep, etc.)
50
+ parameters: Optional list of pipeline parameters for cross-step communication
51
+ **kwargs: Additional parameters that override defaults
52
+
53
+ Returns:
54
+ Pipeline object
55
+
56
+ Raises:
57
+ ValidationError: If required parameters are missing or invalid
58
+ AWSServiceError: If pipeline creation fails
59
+ """
60
+ self.logger.info("Creating pipeline", name=pipeline_name)
61
+
62
+ # Validate required parameters
63
+ if not pipeline_name:
64
+ raise ValidationError("pipeline_name is required")
65
+
66
+ if not steps:
67
+ raise ValidationError("steps is required and cannot be empty")
68
+
69
+ if not isinstance(steps, list):
70
+ raise ValidationError("steps must be a list")
71
+
72
+ # Validate runtime parameter overrides
73
+ self.validate_parameter_override(kwargs)
74
+
75
+ try:
76
+ from sagemaker.workflow.pipeline import Pipeline
77
+ except ImportError as e:
78
+ raise MLPSDKError(
79
+ "SageMaker SDK not installed. Install with: pip install sagemaker>=3.0.0"
80
+ ) from e
81
+
82
+ # Build configuration with defaults
83
+ config = self._build_pipeline_config(kwargs)
84
+
85
+ # Build pipeline parameters
86
+ pipeline_params = {
87
+ 'name': pipeline_name,
88
+ 'steps': steps,
89
+ 'sagemaker_session': sagemaker_session,
90
+ }
91
+
92
+ # Add parameters if provided
93
+ if parameters:
94
+ pipeline_params['parameters'] = parameters
95
+ elif config.get('parameters'):
96
+ pipeline_params['parameters'] = config['parameters']
97
+
98
+ # Add role ARN if available
99
+ if config.get('role_arn'):
100
+ pipeline_params['role_arn'] = config['role_arn']
101
+
102
+ # Add pipeline definition config if provided
103
+ if config.get('pipeline_definition_config'):
104
+ pipeline_params['pipeline_definition_config'] = config['pipeline_definition_config']
105
+
106
+ try:
107
+ self.logger.debug("Creating pipeline with config",
108
+ name=pipeline_name,
109
+ step_count=len(steps),
110
+ has_parameters=bool(parameters or config.get('parameters')))
111
+
112
+ # Create the pipeline
113
+ pipeline = Pipeline(**pipeline_params)
114
+
115
+ self.logger.info("Pipeline created successfully", name=pipeline_name)
116
+ return pipeline
117
+
118
+ except Exception as e:
119
+ self.logger.error("Failed to create pipeline",
120
+ name=pipeline_name,
121
+ error=e)
122
+ raise AWSServiceError(
123
+ f"Failed to create pipeline '{pipeline_name}': {e}",
124
+ aws_error=e
125
+ ) from e
126
+
127
+ def upsert_pipeline(self,
128
+ pipeline,
129
+ **kwargs) -> Dict[str, Any]:
130
+ """
131
+ Create or update a pipeline definition.
132
+
133
+ Args:
134
+ pipeline: Pipeline object to upsert
135
+ **kwargs: Additional parameters for upsert operation
136
+
137
+ Returns:
138
+ Dictionary with pipeline ARN and other metadata
139
+
140
+ Raises:
141
+ ValidationError: If pipeline is invalid
142
+ AWSServiceError: If upsert operation fails
143
+ """
144
+ self.logger.info("Upserting pipeline", name=pipeline.name)
145
+
146
+ if not pipeline:
147
+ raise ValidationError("pipeline is required")
148
+
149
+ try:
150
+ # Build upsert parameters
151
+ upsert_params = {}
152
+
153
+ # Add role ARN if provided
154
+ if 'role_arn' in kwargs:
155
+ upsert_params['role_arn'] = kwargs['role_arn']
156
+
157
+ # Add description if provided
158
+ if 'description' in kwargs:
159
+ upsert_params['description'] = kwargs['description']
160
+
161
+ # Add tags if provided
162
+ if 'tags' in kwargs:
163
+ upsert_params['tags'] = kwargs['tags']
164
+
165
+ # Add parallelism config if provided
166
+ if 'parallelism_config' in kwargs:
167
+ upsert_params['parallelism_config'] = kwargs['parallelism_config']
168
+
169
+ self.logger.debug("Upserting pipeline", name=pipeline.name)
170
+
171
+ # Upsert the pipeline
172
+ response = pipeline.upsert(**upsert_params)
173
+
174
+ self.logger.info("Pipeline upserted successfully",
175
+ name=pipeline.name,
176
+ arn=response.get('PipelineArn'))
177
+
178
+ return response
179
+
180
+ except Exception as e:
181
+ self.logger.error("Failed to upsert pipeline",
182
+ name=pipeline.name,
183
+ error=e)
184
+ raise AWSServiceError(
185
+ f"Failed to upsert pipeline '{pipeline.name}': {e}",
186
+ aws_error=e
187
+ ) from e
188
+
189
+ def start_pipeline_execution(self,
190
+ pipeline,
191
+ execution_display_name: Optional[str] = None,
192
+ execution_parameters: Optional[Dict[str, Any]] = None,
193
+ **kwargs) -> Any:
194
+ """
195
+ Start pipeline execution with monitoring support.
196
+
197
+ Args:
198
+ pipeline: Pipeline object to execute
199
+ execution_display_name: Optional display name for the execution
200
+ execution_parameters: Optional parameters to override pipeline defaults
201
+ **kwargs: Additional execution parameters
202
+
203
+ Returns:
204
+ PipelineExecution object
205
+
206
+ Raises:
207
+ ValidationError: If pipeline is invalid
208
+ AWSServiceError: If execution start fails
209
+ """
210
+ self.logger.info("Starting pipeline execution",
211
+ name=pipeline.name,
212
+ display_name=execution_display_name)
213
+
214
+ if not pipeline:
215
+ raise ValidationError("pipeline is required")
216
+
217
+ try:
218
+ # Build execution parameters
219
+ exec_params = {}
220
+
221
+ # Add display name if provided
222
+ if execution_display_name:
223
+ exec_params['execution_display_name'] = execution_display_name
224
+
225
+ # Add execution parameters if provided
226
+ if execution_parameters:
227
+ exec_params['parameters'] = execution_parameters
228
+
229
+ # Add parallelism config if provided
230
+ if 'parallelism_config' in kwargs:
231
+ exec_params['parallelism_config'] = kwargs['parallelism_config']
232
+
233
+ self.logger.debug("Starting pipeline execution",
234
+ name=pipeline.name,
235
+ has_parameters=bool(execution_parameters))
236
+
237
+ # Start the execution
238
+ execution = pipeline.start(**exec_params)
239
+
240
+ self.logger.info("Pipeline execution started successfully",
241
+ name=pipeline.name,
242
+ execution_arn=execution.arn)
243
+
244
+ return execution
245
+
246
+ except Exception as e:
247
+ self.logger.error("Failed to start pipeline execution",
248
+ name=pipeline.name,
249
+ error=e)
250
+ raise AWSServiceError(
251
+ f"Failed to start pipeline execution for '{pipeline.name}': {e}",
252
+ aws_error=e
253
+ ) from e
254
+
255
+ def describe_pipeline_execution(self,
256
+ pipeline_execution) -> Dict[str, Any]:
257
+ """
258
+ Get pipeline execution status and details.
259
+
260
+ Args:
261
+ pipeline_execution: PipelineExecution object
262
+
263
+ Returns:
264
+ Dictionary with execution status and details
265
+
266
+ Raises:
267
+ ValidationError: If pipeline_execution is invalid
268
+ AWSServiceError: If describe operation fails
269
+ """
270
+ self.logger.info("Describing pipeline execution", arn=pipeline_execution.arn)
271
+
272
+ if not pipeline_execution:
273
+ raise ValidationError("pipeline_execution is required")
274
+
275
+ try:
276
+ # Describe the execution
277
+ response = pipeline_execution.describe()
278
+
279
+ self.logger.debug("Pipeline execution described",
280
+ arn=pipeline_execution.arn,
281
+ status=response.get('PipelineExecutionStatus'))
282
+
283
+ return response
284
+
285
+ except Exception as e:
286
+ self.logger.error("Failed to describe pipeline execution",
287
+ arn=pipeline_execution.arn,
288
+ error=e)
289
+ raise AWSServiceError(
290
+ f"Failed to describe pipeline execution: {e}",
291
+ aws_error=e
292
+ ) from e
293
+
294
+ def list_pipeline_execution_steps(self,
295
+ pipeline_execution) -> List[Dict[str, Any]]:
296
+ """
297
+ List all steps in a pipeline execution with their status.
298
+
299
+ Args:
300
+ pipeline_execution: PipelineExecution object
301
+
302
+ Returns:
303
+ List of step details with status information
304
+
305
+ Raises:
306
+ ValidationError: If pipeline_execution is invalid
307
+ AWSServiceError: If list operation fails
308
+ """
309
+ self.logger.info("Listing pipeline execution steps", arn=pipeline_execution.arn)
310
+
311
+ if not pipeline_execution:
312
+ raise ValidationError("pipeline_execution is required")
313
+
314
+ try:
315
+ # List execution steps
316
+ steps = pipeline_execution.list_steps()
317
+
318
+ self.logger.debug("Pipeline execution steps listed",
319
+ arn=pipeline_execution.arn,
320
+ step_count=len(steps))
321
+
322
+ return steps
323
+
324
+ except Exception as e:
325
+ self.logger.error("Failed to list pipeline execution steps",
326
+ arn=pipeline_execution.arn,
327
+ error=e)
328
+ raise AWSServiceError(
329
+ f"Failed to list pipeline execution steps: {e}",
330
+ aws_error=e
331
+ ) from e
332
+
333
+ def wait_for_pipeline_execution(self,
334
+ pipeline_execution,
335
+ delay: int = 30,
336
+ max_attempts: int = 60) -> Dict[str, Any]:
337
+ """
338
+ Wait for pipeline execution to complete.
339
+
340
+ Args:
341
+ pipeline_execution: PipelineExecution object
342
+ delay: Delay between status checks in seconds (default: 30)
343
+ max_attempts: Maximum number of status checks (default: 60)
344
+
345
+ Returns:
346
+ Final execution status dictionary
347
+
348
+ Raises:
349
+ ValidationError: If pipeline_execution is invalid
350
+ AWSServiceError: If wait operation fails
351
+ """
352
+ self.logger.info("Waiting for pipeline execution",
353
+ arn=pipeline_execution.arn,
354
+ delay=delay,
355
+ max_attempts=max_attempts)
356
+
357
+ if not pipeline_execution:
358
+ raise ValidationError("pipeline_execution is required")
359
+
360
+ try:
361
+ # Wait for completion
362
+ pipeline_execution.wait(delay=delay, max_attempts=max_attempts)
363
+
364
+ # Get final status
365
+ final_status = pipeline_execution.describe()
366
+
367
+ self.logger.info("Pipeline execution completed",
368
+ arn=pipeline_execution.arn,
369
+ status=final_status.get('PipelineExecutionStatus'))
370
+
371
+ return final_status
372
+
373
+ except Exception as e:
374
+ self.logger.error("Failed while waiting for pipeline execution",
375
+ arn=pipeline_execution.arn,
376
+ error=e)
377
+ raise AWSServiceError(
378
+ f"Failed while waiting for pipeline execution: {e}",
379
+ aws_error=e
380
+ ) from e
381
+
382
+ def _build_pipeline_config(self, runtime_params: Dict[str, Any]) -> Dict[str, Any]:
383
+ """
384
+ Build pipeline configuration by merging defaults with runtime parameters.
385
+
386
+ Parameter precedence: runtime > config > SageMaker defaults
387
+
388
+ This implements the parameter override behavior specified in Requirements 6.1, 6.2, 6.3:
389
+ - Runtime parameters always take precedence over configuration defaults
390
+ - Configuration defaults take precedence over SageMaker SDK defaults
391
+ - Individual step configurations can be overridden while maintaining pipeline-level defaults
392
+
393
+ Args:
394
+ runtime_params: Runtime parameters provided by user
395
+
396
+ Returns:
397
+ Merged configuration dictionary
398
+ """
399
+ config = {}
400
+
401
+ # Get configuration objects
402
+ iam_config = self.config_manager.get_iam_config()
403
+
404
+ # Apply IAM role default (runtime > config)
405
+ if 'role_arn' in runtime_params:
406
+ config['role_arn'] = runtime_params['role_arn']
407
+ self.logger.debug("Using runtime role_arn")
408
+ elif iam_config:
409
+ config['role_arn'] = iam_config.execution_role
410
+ self.logger.debug("Using config role_arn", role=iam_config.execution_role)
411
+
412
+ # Apply any remaining runtime parameters (these override everything)
413
+ for key, value in runtime_params.items():
414
+ if key not in ['role_arn']:
415
+ config[key] = value
416
+ self.logger.debug(f"Using runtime parameter: {key}")
417
+
418
+ return config
419
+
420
+ def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
421
+ """
422
+ Validate runtime parameter overrides.
423
+
424
+ This ensures that runtime parameters are valid and compatible with the configuration.
425
+ Implements validation requirements from Requirements 6.1, 6.2, 6.3, 6.5.
426
+
427
+ Args:
428
+ runtime_params: Runtime parameters to validate
429
+
430
+ Raises:
431
+ ValidationError: If runtime parameters are invalid
432
+ """
433
+ # Validate role_arn format if provided
434
+ if 'role_arn' in runtime_params:
435
+ role_arn = runtime_params['role_arn']
436
+ if not isinstance(role_arn, str):
437
+ raise ValidationError("role_arn must be a string")
438
+
439
+ if not role_arn.startswith('arn:aws:iam::'):
440
+ raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
441
+
442
+ # Validate parameters if provided
443
+ if 'parameters' in runtime_params:
444
+ parameters = runtime_params['parameters']
445
+ if not isinstance(parameters, list):
446
+ raise ValidationError("parameters must be a list")
447
+
448
+ # Validate pipeline_definition_config if provided
449
+ if 'pipeline_definition_config' in runtime_params:
450
+ pipeline_def_config = runtime_params['pipeline_definition_config']
451
+ if not isinstance(pipeline_def_config, dict):
452
+ raise ValidationError("pipeline_definition_config must be a dictionary")