sagemaker-mlp-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlp_sdk/__init__.py +18 -0
- mlp_sdk/config.py +630 -0
- mlp_sdk/exceptions.py +421 -0
- mlp_sdk/models.py +62 -0
- mlp_sdk/session.py +1160 -0
- mlp_sdk/wrappers/__init__.py +11 -0
- mlp_sdk/wrappers/deployment.py +459 -0
- mlp_sdk/wrappers/feature_store.py +308 -0
- mlp_sdk/wrappers/pipeline.py +452 -0
- mlp_sdk/wrappers/processing.py +381 -0
- mlp_sdk/wrappers/training.py +492 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/METADATA +569 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/RECORD +16 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,452 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline wrapper for mlp_sdk
|
|
3
|
+
Provides simplified pipeline creation with configuration-driven defaults
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional, Dict, Any, List, Union
|
|
7
|
+
from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
|
|
8
|
+
from ..config import ConfigurationManager
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PipelineWrapper:
|
|
12
|
+
"""
|
|
13
|
+
Wrapper for SageMaker Pipeline operations.
|
|
14
|
+
Applies default configurations from ConfigurationManager.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
|
|
18
|
+
"""
|
|
19
|
+
Initialize Pipeline wrapper.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
config_manager: Configuration manager instance
|
|
23
|
+
logger: Optional logger instance
|
|
24
|
+
"""
|
|
25
|
+
self.config_manager = config_manager
|
|
26
|
+
self.logger = logger or MLPLogger("mlp_sdk.pipeline")
|
|
27
|
+
|
|
28
|
+
def create_pipeline(self,
|
|
29
|
+
sagemaker_session,
|
|
30
|
+
pipeline_name: str,
|
|
31
|
+
steps: List,
|
|
32
|
+
parameters: Optional[List] = None,
|
|
33
|
+
**kwargs) -> Any:
|
|
34
|
+
"""
|
|
35
|
+
Create pipeline with step connection and consistent default configurations.
|
|
36
|
+
|
|
37
|
+
Applies defaults from configuration for:
|
|
38
|
+
- IAM execution role
|
|
39
|
+
- Pipeline parameters
|
|
40
|
+
- Step-level configurations (inherited from processing/training defaults)
|
|
41
|
+
|
|
42
|
+
Runtime parameters override configuration defaults.
|
|
43
|
+
|
|
44
|
+
Supports parameter passing between pipeline steps.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
sagemaker_session: SageMaker session object
|
|
48
|
+
pipeline_name: Name of the pipeline
|
|
49
|
+
steps: List of pipeline steps (ProcessingStep, TrainingStep, etc.)
|
|
50
|
+
parameters: Optional list of pipeline parameters for cross-step communication
|
|
51
|
+
**kwargs: Additional parameters that override defaults
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Pipeline object
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValidationError: If required parameters are missing or invalid
|
|
58
|
+
AWSServiceError: If pipeline creation fails
|
|
59
|
+
"""
|
|
60
|
+
self.logger.info("Creating pipeline", name=pipeline_name)
|
|
61
|
+
|
|
62
|
+
# Validate required parameters
|
|
63
|
+
if not pipeline_name:
|
|
64
|
+
raise ValidationError("pipeline_name is required")
|
|
65
|
+
|
|
66
|
+
if not steps:
|
|
67
|
+
raise ValidationError("steps is required and cannot be empty")
|
|
68
|
+
|
|
69
|
+
if not isinstance(steps, list):
|
|
70
|
+
raise ValidationError("steps must be a list")
|
|
71
|
+
|
|
72
|
+
# Validate runtime parameter overrides
|
|
73
|
+
self.validate_parameter_override(kwargs)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
from sagemaker.workflow.pipeline import Pipeline
|
|
77
|
+
except ImportError as e:
|
|
78
|
+
raise MLPSDKError(
|
|
79
|
+
"SageMaker SDK not installed. Install with: pip install sagemaker>=3.0.0"
|
|
80
|
+
) from e
|
|
81
|
+
|
|
82
|
+
# Build configuration with defaults
|
|
83
|
+
config = self._build_pipeline_config(kwargs)
|
|
84
|
+
|
|
85
|
+
# Build pipeline parameters
|
|
86
|
+
pipeline_params = {
|
|
87
|
+
'name': pipeline_name,
|
|
88
|
+
'steps': steps,
|
|
89
|
+
'sagemaker_session': sagemaker_session,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Add parameters if provided
|
|
93
|
+
if parameters:
|
|
94
|
+
pipeline_params['parameters'] = parameters
|
|
95
|
+
elif config.get('parameters'):
|
|
96
|
+
pipeline_params['parameters'] = config['parameters']
|
|
97
|
+
|
|
98
|
+
# Add role ARN if available
|
|
99
|
+
if config.get('role_arn'):
|
|
100
|
+
pipeline_params['role_arn'] = config['role_arn']
|
|
101
|
+
|
|
102
|
+
# Add pipeline definition config if provided
|
|
103
|
+
if config.get('pipeline_definition_config'):
|
|
104
|
+
pipeline_params['pipeline_definition_config'] = config['pipeline_definition_config']
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
self.logger.debug("Creating pipeline with config",
|
|
108
|
+
name=pipeline_name,
|
|
109
|
+
step_count=len(steps),
|
|
110
|
+
has_parameters=bool(parameters or config.get('parameters')))
|
|
111
|
+
|
|
112
|
+
# Create the pipeline
|
|
113
|
+
pipeline = Pipeline(**pipeline_params)
|
|
114
|
+
|
|
115
|
+
self.logger.info("Pipeline created successfully", name=pipeline_name)
|
|
116
|
+
return pipeline
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
self.logger.error("Failed to create pipeline",
|
|
120
|
+
name=pipeline_name,
|
|
121
|
+
error=e)
|
|
122
|
+
raise AWSServiceError(
|
|
123
|
+
f"Failed to create pipeline '{pipeline_name}': {e}",
|
|
124
|
+
aws_error=e
|
|
125
|
+
) from e
|
|
126
|
+
|
|
127
|
+
def upsert_pipeline(self,
|
|
128
|
+
pipeline,
|
|
129
|
+
**kwargs) -> Dict[str, Any]:
|
|
130
|
+
"""
|
|
131
|
+
Create or update a pipeline definition.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
pipeline: Pipeline object to upsert
|
|
135
|
+
**kwargs: Additional parameters for upsert operation
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Dictionary with pipeline ARN and other metadata
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
ValidationError: If pipeline is invalid
|
|
142
|
+
AWSServiceError: If upsert operation fails
|
|
143
|
+
"""
|
|
144
|
+
self.logger.info("Upserting pipeline", name=pipeline.name)
|
|
145
|
+
|
|
146
|
+
if not pipeline:
|
|
147
|
+
raise ValidationError("pipeline is required")
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
# Build upsert parameters
|
|
151
|
+
upsert_params = {}
|
|
152
|
+
|
|
153
|
+
# Add role ARN if provided
|
|
154
|
+
if 'role_arn' in kwargs:
|
|
155
|
+
upsert_params['role_arn'] = kwargs['role_arn']
|
|
156
|
+
|
|
157
|
+
# Add description if provided
|
|
158
|
+
if 'description' in kwargs:
|
|
159
|
+
upsert_params['description'] = kwargs['description']
|
|
160
|
+
|
|
161
|
+
# Add tags if provided
|
|
162
|
+
if 'tags' in kwargs:
|
|
163
|
+
upsert_params['tags'] = kwargs['tags']
|
|
164
|
+
|
|
165
|
+
# Add parallelism config if provided
|
|
166
|
+
if 'parallelism_config' in kwargs:
|
|
167
|
+
upsert_params['parallelism_config'] = kwargs['parallelism_config']
|
|
168
|
+
|
|
169
|
+
self.logger.debug("Upserting pipeline", name=pipeline.name)
|
|
170
|
+
|
|
171
|
+
# Upsert the pipeline
|
|
172
|
+
response = pipeline.upsert(**upsert_params)
|
|
173
|
+
|
|
174
|
+
self.logger.info("Pipeline upserted successfully",
|
|
175
|
+
name=pipeline.name,
|
|
176
|
+
arn=response.get('PipelineArn'))
|
|
177
|
+
|
|
178
|
+
return response
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
self.logger.error("Failed to upsert pipeline",
|
|
182
|
+
name=pipeline.name,
|
|
183
|
+
error=e)
|
|
184
|
+
raise AWSServiceError(
|
|
185
|
+
f"Failed to upsert pipeline '{pipeline.name}': {e}",
|
|
186
|
+
aws_error=e
|
|
187
|
+
) from e
|
|
188
|
+
|
|
189
|
+
def start_pipeline_execution(self,
|
|
190
|
+
pipeline,
|
|
191
|
+
execution_display_name: Optional[str] = None,
|
|
192
|
+
execution_parameters: Optional[Dict[str, Any]] = None,
|
|
193
|
+
**kwargs) -> Any:
|
|
194
|
+
"""
|
|
195
|
+
Start pipeline execution with monitoring support.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
pipeline: Pipeline object to execute
|
|
199
|
+
execution_display_name: Optional display name for the execution
|
|
200
|
+
execution_parameters: Optional parameters to override pipeline defaults
|
|
201
|
+
**kwargs: Additional execution parameters
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
PipelineExecution object
|
|
205
|
+
|
|
206
|
+
Raises:
|
|
207
|
+
ValidationError: If pipeline is invalid
|
|
208
|
+
AWSServiceError: If execution start fails
|
|
209
|
+
"""
|
|
210
|
+
self.logger.info("Starting pipeline execution",
|
|
211
|
+
name=pipeline.name,
|
|
212
|
+
display_name=execution_display_name)
|
|
213
|
+
|
|
214
|
+
if not pipeline:
|
|
215
|
+
raise ValidationError("pipeline is required")
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
# Build execution parameters
|
|
219
|
+
exec_params = {}
|
|
220
|
+
|
|
221
|
+
# Add display name if provided
|
|
222
|
+
if execution_display_name:
|
|
223
|
+
exec_params['execution_display_name'] = execution_display_name
|
|
224
|
+
|
|
225
|
+
# Add execution parameters if provided
|
|
226
|
+
if execution_parameters:
|
|
227
|
+
exec_params['parameters'] = execution_parameters
|
|
228
|
+
|
|
229
|
+
# Add parallelism config if provided
|
|
230
|
+
if 'parallelism_config' in kwargs:
|
|
231
|
+
exec_params['parallelism_config'] = kwargs['parallelism_config']
|
|
232
|
+
|
|
233
|
+
self.logger.debug("Starting pipeline execution",
|
|
234
|
+
name=pipeline.name,
|
|
235
|
+
has_parameters=bool(execution_parameters))
|
|
236
|
+
|
|
237
|
+
# Start the execution
|
|
238
|
+
execution = pipeline.start(**exec_params)
|
|
239
|
+
|
|
240
|
+
self.logger.info("Pipeline execution started successfully",
|
|
241
|
+
name=pipeline.name,
|
|
242
|
+
execution_arn=execution.arn)
|
|
243
|
+
|
|
244
|
+
return execution
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
self.logger.error("Failed to start pipeline execution",
|
|
248
|
+
name=pipeline.name,
|
|
249
|
+
error=e)
|
|
250
|
+
raise AWSServiceError(
|
|
251
|
+
f"Failed to start pipeline execution for '{pipeline.name}': {e}",
|
|
252
|
+
aws_error=e
|
|
253
|
+
) from e
|
|
254
|
+
|
|
255
|
+
def describe_pipeline_execution(self,
|
|
256
|
+
pipeline_execution) -> Dict[str, Any]:
|
|
257
|
+
"""
|
|
258
|
+
Get pipeline execution status and details.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
pipeline_execution: PipelineExecution object
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Dictionary with execution status and details
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
ValidationError: If pipeline_execution is invalid
|
|
268
|
+
AWSServiceError: If describe operation fails
|
|
269
|
+
"""
|
|
270
|
+
self.logger.info("Describing pipeline execution", arn=pipeline_execution.arn)
|
|
271
|
+
|
|
272
|
+
if not pipeline_execution:
|
|
273
|
+
raise ValidationError("pipeline_execution is required")
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
# Describe the execution
|
|
277
|
+
response = pipeline_execution.describe()
|
|
278
|
+
|
|
279
|
+
self.logger.debug("Pipeline execution described",
|
|
280
|
+
arn=pipeline_execution.arn,
|
|
281
|
+
status=response.get('PipelineExecutionStatus'))
|
|
282
|
+
|
|
283
|
+
return response
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
self.logger.error("Failed to describe pipeline execution",
|
|
287
|
+
arn=pipeline_execution.arn,
|
|
288
|
+
error=e)
|
|
289
|
+
raise AWSServiceError(
|
|
290
|
+
f"Failed to describe pipeline execution: {e}",
|
|
291
|
+
aws_error=e
|
|
292
|
+
) from e
|
|
293
|
+
|
|
294
|
+
def list_pipeline_execution_steps(self,
|
|
295
|
+
pipeline_execution) -> List[Dict[str, Any]]:
|
|
296
|
+
"""
|
|
297
|
+
List all steps in a pipeline execution with their status.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
pipeline_execution: PipelineExecution object
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
List of step details with status information
|
|
304
|
+
|
|
305
|
+
Raises:
|
|
306
|
+
ValidationError: If pipeline_execution is invalid
|
|
307
|
+
AWSServiceError: If list operation fails
|
|
308
|
+
"""
|
|
309
|
+
self.logger.info("Listing pipeline execution steps", arn=pipeline_execution.arn)
|
|
310
|
+
|
|
311
|
+
if not pipeline_execution:
|
|
312
|
+
raise ValidationError("pipeline_execution is required")
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# List execution steps
|
|
316
|
+
steps = pipeline_execution.list_steps()
|
|
317
|
+
|
|
318
|
+
self.logger.debug("Pipeline execution steps listed",
|
|
319
|
+
arn=pipeline_execution.arn,
|
|
320
|
+
step_count=len(steps))
|
|
321
|
+
|
|
322
|
+
return steps
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
self.logger.error("Failed to list pipeline execution steps",
|
|
326
|
+
arn=pipeline_execution.arn,
|
|
327
|
+
error=e)
|
|
328
|
+
raise AWSServiceError(
|
|
329
|
+
f"Failed to list pipeline execution steps: {e}",
|
|
330
|
+
aws_error=e
|
|
331
|
+
) from e
|
|
332
|
+
|
|
333
|
+
def wait_for_pipeline_execution(self,
|
|
334
|
+
pipeline_execution,
|
|
335
|
+
delay: int = 30,
|
|
336
|
+
max_attempts: int = 60) -> Dict[str, Any]:
|
|
337
|
+
"""
|
|
338
|
+
Wait for pipeline execution to complete.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
pipeline_execution: PipelineExecution object
|
|
342
|
+
delay: Delay between status checks in seconds (default: 30)
|
|
343
|
+
max_attempts: Maximum number of status checks (default: 60)
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
Final execution status dictionary
|
|
347
|
+
|
|
348
|
+
Raises:
|
|
349
|
+
ValidationError: If pipeline_execution is invalid
|
|
350
|
+
AWSServiceError: If wait operation fails
|
|
351
|
+
"""
|
|
352
|
+
self.logger.info("Waiting for pipeline execution",
|
|
353
|
+
arn=pipeline_execution.arn,
|
|
354
|
+
delay=delay,
|
|
355
|
+
max_attempts=max_attempts)
|
|
356
|
+
|
|
357
|
+
if not pipeline_execution:
|
|
358
|
+
raise ValidationError("pipeline_execution is required")
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
# Wait for completion
|
|
362
|
+
pipeline_execution.wait(delay=delay, max_attempts=max_attempts)
|
|
363
|
+
|
|
364
|
+
# Get final status
|
|
365
|
+
final_status = pipeline_execution.describe()
|
|
366
|
+
|
|
367
|
+
self.logger.info("Pipeline execution completed",
|
|
368
|
+
arn=pipeline_execution.arn,
|
|
369
|
+
status=final_status.get('PipelineExecutionStatus'))
|
|
370
|
+
|
|
371
|
+
return final_status
|
|
372
|
+
|
|
373
|
+
except Exception as e:
|
|
374
|
+
self.logger.error("Failed while waiting for pipeline execution",
|
|
375
|
+
arn=pipeline_execution.arn,
|
|
376
|
+
error=e)
|
|
377
|
+
raise AWSServiceError(
|
|
378
|
+
f"Failed while waiting for pipeline execution: {e}",
|
|
379
|
+
aws_error=e
|
|
380
|
+
) from e
|
|
381
|
+
|
|
382
|
+
def _build_pipeline_config(self, runtime_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
383
|
+
"""
|
|
384
|
+
Build pipeline configuration by merging defaults with runtime parameters.
|
|
385
|
+
|
|
386
|
+
Parameter precedence: runtime > config > SageMaker defaults
|
|
387
|
+
|
|
388
|
+
This implements the parameter override behavior specified in Requirements 6.1, 6.2, 6.3:
|
|
389
|
+
- Runtime parameters always take precedence over configuration defaults
|
|
390
|
+
- Configuration defaults take precedence over SageMaker SDK defaults
|
|
391
|
+
- Individual step configurations can be overridden while maintaining pipeline-level defaults
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
runtime_params: Runtime parameters provided by user
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
Merged configuration dictionary
|
|
398
|
+
"""
|
|
399
|
+
config = {}
|
|
400
|
+
|
|
401
|
+
# Get configuration objects
|
|
402
|
+
iam_config = self.config_manager.get_iam_config()
|
|
403
|
+
|
|
404
|
+
# Apply IAM role default (runtime > config)
|
|
405
|
+
if 'role_arn' in runtime_params:
|
|
406
|
+
config['role_arn'] = runtime_params['role_arn']
|
|
407
|
+
self.logger.debug("Using runtime role_arn")
|
|
408
|
+
elif iam_config:
|
|
409
|
+
config['role_arn'] = iam_config.execution_role
|
|
410
|
+
self.logger.debug("Using config role_arn", role=iam_config.execution_role)
|
|
411
|
+
|
|
412
|
+
# Apply any remaining runtime parameters (these override everything)
|
|
413
|
+
for key, value in runtime_params.items():
|
|
414
|
+
if key not in ['role_arn']:
|
|
415
|
+
config[key] = value
|
|
416
|
+
self.logger.debug(f"Using runtime parameter: {key}")
|
|
417
|
+
|
|
418
|
+
return config
|
|
419
|
+
|
|
420
|
+
def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
|
|
421
|
+
"""
|
|
422
|
+
Validate runtime parameter overrides.
|
|
423
|
+
|
|
424
|
+
This ensures that runtime parameters are valid and compatible with the configuration.
|
|
425
|
+
Implements validation requirements from Requirements 6.1, 6.2, 6.3, 6.5.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
runtime_params: Runtime parameters to validate
|
|
429
|
+
|
|
430
|
+
Raises:
|
|
431
|
+
ValidationError: If runtime parameters are invalid
|
|
432
|
+
"""
|
|
433
|
+
# Validate role_arn format if provided
|
|
434
|
+
if 'role_arn' in runtime_params:
|
|
435
|
+
role_arn = runtime_params['role_arn']
|
|
436
|
+
if not isinstance(role_arn, str):
|
|
437
|
+
raise ValidationError("role_arn must be a string")
|
|
438
|
+
|
|
439
|
+
if not role_arn.startswith('arn:aws:iam::'):
|
|
440
|
+
raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
|
|
441
|
+
|
|
442
|
+
# Validate parameters if provided
|
|
443
|
+
if 'parameters' in runtime_params:
|
|
444
|
+
parameters = runtime_params['parameters']
|
|
445
|
+
if not isinstance(parameters, list):
|
|
446
|
+
raise ValidationError("parameters must be a list")
|
|
447
|
+
|
|
448
|
+
# Validate pipeline_definition_config if provided
|
|
449
|
+
if 'pipeline_definition_config' in runtime_params:
|
|
450
|
+
pipeline_def_config = runtime_params['pipeline_definition_config']
|
|
451
|
+
if not isinstance(pipeline_def_config, dict):
|
|
452
|
+
raise ValidationError("pipeline_definition_config must be a dictionary")
|