sagemaker-mlp-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlp_sdk/__init__.py +18 -0
- mlp_sdk/config.py +630 -0
- mlp_sdk/exceptions.py +421 -0
- mlp_sdk/models.py +62 -0
- mlp_sdk/session.py +1160 -0
- mlp_sdk/wrappers/__init__.py +11 -0
- mlp_sdk/wrappers/deployment.py +459 -0
- mlp_sdk/wrappers/feature_store.py +308 -0
- mlp_sdk/wrappers/pipeline.py +452 -0
- mlp_sdk/wrappers/processing.py +381 -0
- mlp_sdk/wrappers/training.py +492 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/METADATA +569 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/RECORD +16 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/top_level.txt +1 -0
mlp_sdk/session.py
ADDED
|
@@ -0,0 +1,1160 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core MLP_Session class - main interface for all SDK operations
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Dict, Any, List
|
|
6
|
+
import logging
|
|
7
|
+
from .config import ConfigurationManager
|
|
8
|
+
from .exceptions import MLPSDKError, SessionError, AWSServiceError, ValidationError, MLPLogger, AuditTrail
|
|
9
|
+
from .wrappers import FeatureStoreWrapper, ProcessingWrapper, TrainingWrapper, PipelineWrapper, DeploymentWrapper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MLP_Session:
|
|
13
|
+
"""
|
|
14
|
+
Main interface for all mlp_sdk operations.
|
|
15
|
+
|
|
16
|
+
Provides simplified SageMaker operations with configuration-driven defaults.
|
|
17
|
+
Built on top of SageMaker Python SDK v3.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self,
|
|
21
|
+
config_path: Optional[str] = None,
|
|
22
|
+
log_level: int = logging.INFO,
|
|
23
|
+
enable_audit_trail: bool = True,
|
|
24
|
+
**kwargs):
|
|
25
|
+
"""
|
|
26
|
+
Initialize session with optional custom config path.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
config_path: Optional path to configuration file.
|
|
30
|
+
Defaults to /home/sagemaker-user/.config/admin-config.yaml
|
|
31
|
+
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
32
|
+
enable_audit_trail: Whether to enable audit trail recording
|
|
33
|
+
**kwargs: Additional session parameters passed to SageMaker session
|
|
34
|
+
(e.g., boto_session, sagemaker_client, sagemaker_runtime_client)
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
SessionError: If session initialization fails
|
|
38
|
+
ConfigurationError: If configuration loading fails
|
|
39
|
+
"""
|
|
40
|
+
# Initialize logging
|
|
41
|
+
self.logger = MLPLogger("mlp_sdk.session", level=log_level)
|
|
42
|
+
self.logger.info("Initializing MLP_Session", config_path=config_path or "default")
|
|
43
|
+
|
|
44
|
+
# Initialize audit trail
|
|
45
|
+
self.audit_trail = AuditTrail() if enable_audit_trail else None
|
|
46
|
+
if self.audit_trail is not None:
|
|
47
|
+
self.audit_trail.record("session_init", "started", config_path=config_path)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
# Load configuration
|
|
51
|
+
self.config_manager = ConfigurationManager(config_path)
|
|
52
|
+
|
|
53
|
+
# Initialize wrappers
|
|
54
|
+
self._feature_store_wrapper = FeatureStoreWrapper(self.config_manager, self.logger)
|
|
55
|
+
self._processing_wrapper = ProcessingWrapper(self.config_manager, self.logger)
|
|
56
|
+
self._training_wrapper = TrainingWrapper(self.config_manager, self.logger)
|
|
57
|
+
self._pipeline_wrapper = PipelineWrapper(self.config_manager, self.logger)
|
|
58
|
+
self._deployment_wrapper = DeploymentWrapper(self.config_manager, self.logger)
|
|
59
|
+
|
|
60
|
+
# Initialize SageMaker session
|
|
61
|
+
self._sagemaker_session = None
|
|
62
|
+
self._session_kwargs = kwargs
|
|
63
|
+
self._initialize_sagemaker_session()
|
|
64
|
+
|
|
65
|
+
# Log successful initialization
|
|
66
|
+
self.logger.info("MLP_Session initialized successfully",
|
|
67
|
+
has_config=self.config_manager.has_config)
|
|
68
|
+
|
|
69
|
+
if self.audit_trail is not None:
|
|
70
|
+
self.audit_trail.record("session_init", "completed",
|
|
71
|
+
has_config=self.config_manager.has_config)
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
self.logger.error("Failed to initialize MLP_Session", error=e)
|
|
75
|
+
if self.audit_trail is not None:
|
|
76
|
+
self.audit_trail.record("session_init", "failed", error=str(e))
|
|
77
|
+
raise SessionError(f"Failed to initialize MLP_Session: {e}") from e
|
|
78
|
+
|
|
79
|
+
def _initialize_sagemaker_session(self) -> None:
|
|
80
|
+
"""
|
|
81
|
+
Initialize underlying SageMaker session.
|
|
82
|
+
|
|
83
|
+
In SageMaker SDK v3, we use boto3 clients directly along with SessionSettings.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
SessionError: If SageMaker session initialization fails
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
# SageMaker SDK v3 uses SessionSettings for configuration
|
|
90
|
+
from sagemaker.core.session_settings import SessionSettings
|
|
91
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
|
92
|
+
import boto3
|
|
93
|
+
|
|
94
|
+
# Get default bucket from config if available
|
|
95
|
+
default_bucket = None
|
|
96
|
+
if 'default_bucket' in self._session_kwargs:
|
|
97
|
+
default_bucket = self._session_kwargs.pop('default_bucket')
|
|
98
|
+
else:
|
|
99
|
+
s3_config = self.config_manager.get_s3_config()
|
|
100
|
+
if s3_config:
|
|
101
|
+
default_bucket = s3_config.default_bucket
|
|
102
|
+
|
|
103
|
+
# Create boto3 session if not provided
|
|
104
|
+
boto_session = self._session_kwargs.pop('boto_session', None)
|
|
105
|
+
if boto_session is None:
|
|
106
|
+
boto_session = boto3.Session()
|
|
107
|
+
|
|
108
|
+
# Store boto session and default bucket
|
|
109
|
+
self._boto_session = boto_session
|
|
110
|
+
self._default_bucket = default_bucket
|
|
111
|
+
self._region_name = boto_session.region_name
|
|
112
|
+
|
|
113
|
+
# Create SageMaker SessionSettings (SDK v3)
|
|
114
|
+
self._sagemaker_session = SessionSettings(**self._session_kwargs)
|
|
115
|
+
|
|
116
|
+
# Create boto3 clients
|
|
117
|
+
self._sagemaker_client = boto_session.client('sagemaker')
|
|
118
|
+
self._sagemaker_runtime_client = boto_session.client('sagemaker-runtime')
|
|
119
|
+
|
|
120
|
+
self.logger.debug("SageMaker SessionSettings initialized",
|
|
121
|
+
region=self._region_name,
|
|
122
|
+
default_bucket=default_bucket)
|
|
123
|
+
|
|
124
|
+
except ImportError as e:
|
|
125
|
+
raise SessionError(
|
|
126
|
+
"SageMaker SDK v3 not installed. Install with: pip install sagemaker>=3.0.0"
|
|
127
|
+
) from e
|
|
128
|
+
except (ClientError, NoCredentialsError) as e:
|
|
129
|
+
raise SessionError(
|
|
130
|
+
f"Failed to initialize SageMaker session due to AWS credentials issue: {e}"
|
|
131
|
+
) from e
|
|
132
|
+
except Exception as e:
|
|
133
|
+
raise SessionError(
|
|
134
|
+
f"Failed to initialize SageMaker session: {e}"
|
|
135
|
+
) from e
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def sagemaker_session(self):
|
|
139
|
+
"""
|
|
140
|
+
Get underlying SageMaker session for advanced use cases.
|
|
141
|
+
|
|
142
|
+
This property exposes the underlying SageMaker SDK session object,
|
|
143
|
+
allowing advanced users to access all SageMaker SDK functionality
|
|
144
|
+
directly while still benefiting from mlp_sdk configuration management.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
sagemaker.Session object
|
|
148
|
+
|
|
149
|
+
Example:
|
|
150
|
+
>>> session = MLP_Session()
|
|
151
|
+
>>> # Use underlying SageMaker session for advanced operations
|
|
152
|
+
>>> session.sagemaker_session.list_training_jobs()
|
|
153
|
+
"""
|
|
154
|
+
return self._sagemaker_session
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def boto_session(self):
|
|
158
|
+
"""
|
|
159
|
+
Get underlying boto3 session.
|
|
160
|
+
|
|
161
|
+
Provides access to the boto3 session used by SageMaker SDK,
|
|
162
|
+
enabling direct AWS service calls and custom boto3 operations.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
boto3.Session object or None
|
|
166
|
+
|
|
167
|
+
Example:
|
|
168
|
+
>>> session = MLP_Session()
|
|
169
|
+
>>> s3_client = session.boto_session.client('s3')
|
|
170
|
+
"""
|
|
171
|
+
return getattr(self, '_boto_session', None)
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def sagemaker_client(self):
|
|
175
|
+
"""
|
|
176
|
+
Get underlying SageMaker boto3 client.
|
|
177
|
+
|
|
178
|
+
Provides direct access to the SageMaker boto3 client for
|
|
179
|
+
low-level API operations not covered by the SageMaker SDK.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
boto3 SageMaker client or None
|
|
183
|
+
|
|
184
|
+
Example:
|
|
185
|
+
>>> session = MLP_Session()
|
|
186
|
+
>>> response = session.sagemaker_client.describe_training_job(
|
|
187
|
+
... TrainingJobName='my-job'
|
|
188
|
+
... )
|
|
189
|
+
"""
|
|
190
|
+
return getattr(self, '_sagemaker_client', None)
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def sagemaker_runtime_client(self):
|
|
194
|
+
"""
|
|
195
|
+
Get underlying SageMaker Runtime boto3 client.
|
|
196
|
+
|
|
197
|
+
Provides access to the SageMaker Runtime client for
|
|
198
|
+
invoking deployed endpoints.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
boto3 SageMaker Runtime client or None
|
|
202
|
+
|
|
203
|
+
Example:
|
|
204
|
+
>>> session = MLP_Session()
|
|
205
|
+
>>> response = session.sagemaker_runtime_client.invoke_endpoint(
|
|
206
|
+
... EndpointName='my-endpoint',
|
|
207
|
+
... Body=json.dumps(data)
|
|
208
|
+
... )
|
|
209
|
+
"""
|
|
210
|
+
return getattr(self, '_sagemaker_runtime_client', None)
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def region_name(self) -> Optional[str]:
|
|
214
|
+
"""
|
|
215
|
+
Get AWS region name.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
AWS region name or None
|
|
219
|
+
"""
|
|
220
|
+
return getattr(self, '_region_name', None)
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def default_bucket(self) -> Optional[str]:
|
|
224
|
+
"""
|
|
225
|
+
Get default S3 bucket.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Default S3 bucket name or None
|
|
229
|
+
"""
|
|
230
|
+
return getattr(self, '_default_bucket', None)
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def account_id(self) -> Optional[str]:
|
|
234
|
+
"""
|
|
235
|
+
Get AWS account ID.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
AWS account ID or None
|
|
239
|
+
"""
|
|
240
|
+
boto_session = getattr(self, '_boto_session', None)
|
|
241
|
+
if boto_session:
|
|
242
|
+
try:
|
|
243
|
+
sts_client = boto_session.client('sts')
|
|
244
|
+
return sts_client.get_caller_identity()['Account']
|
|
245
|
+
except Exception:
|
|
246
|
+
return None
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
def set_log_level(self, level: int) -> None:
|
|
250
|
+
"""
|
|
251
|
+
Set logging level.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
255
|
+
"""
|
|
256
|
+
self.logger.set_level(level)
|
|
257
|
+
self.logger.info("Log level changed", level=logging.getLevelName(level))
|
|
258
|
+
|
|
259
|
+
def get_config(self) -> Optional[Dict[str, Any]]:
|
|
260
|
+
"""
|
|
261
|
+
Get current configuration as dictionary.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Configuration dictionary or None if no config loaded
|
|
265
|
+
|
|
266
|
+
Example:
|
|
267
|
+
>>> session = MLP_Session()
|
|
268
|
+
>>> config = session.get_config()
|
|
269
|
+
>>> print(config['defaults']['s3']['default_bucket'])
|
|
270
|
+
"""
|
|
271
|
+
return self.config_manager._config if self.config_manager.has_config else None
|
|
272
|
+
|
|
273
|
+
def get_execution_role(self) -> Optional[str]:
|
|
274
|
+
"""
|
|
275
|
+
Get IAM execution role ARN from configuration.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
IAM role ARN or None if not configured
|
|
279
|
+
|
|
280
|
+
Example:
|
|
281
|
+
>>> session = MLP_Session()
|
|
282
|
+
>>> role = session.get_execution_role()
|
|
283
|
+
"""
|
|
284
|
+
iam_config = self.config_manager.get_iam_config()
|
|
285
|
+
return iam_config.execution_role if iam_config else None
|
|
286
|
+
|
|
287
|
+
def update_session_config(self, **kwargs) -> None:
|
|
288
|
+
"""
|
|
289
|
+
Update session configuration at runtime.
|
|
290
|
+
|
|
291
|
+
Allows dynamic configuration updates for advanced use cases.
|
|
292
|
+
Note: This recreates the underlying SageMaker session.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
**kwargs: Configuration parameters to update
|
|
296
|
+
(e.g., default_bucket, boto_session, sagemaker_client)
|
|
297
|
+
|
|
298
|
+
Raises:
|
|
299
|
+
SessionError: If session update fails
|
|
300
|
+
|
|
301
|
+
Example:
|
|
302
|
+
>>> session = MLP_Session()
|
|
303
|
+
>>> session.update_session_config(default_bucket='my-new-bucket')
|
|
304
|
+
"""
|
|
305
|
+
self.logger.info("Updating session configuration", params=list(kwargs.keys()))
|
|
306
|
+
|
|
307
|
+
if self.audit_trail is not None:
|
|
308
|
+
self.audit_trail.record("update_session_config", "started", params=list(kwargs.keys()))
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
# Merge with existing kwargs
|
|
312
|
+
self._session_kwargs.update(kwargs)
|
|
313
|
+
|
|
314
|
+
# Reinitialize SageMaker session
|
|
315
|
+
self._initialize_sagemaker_session()
|
|
316
|
+
|
|
317
|
+
self.logger.info("Session configuration updated successfully")
|
|
318
|
+
|
|
319
|
+
if self.audit_trail is not None:
|
|
320
|
+
self.audit_trail.record("update_session_config", "completed")
|
|
321
|
+
|
|
322
|
+
except Exception as e:
|
|
323
|
+
self.logger.error("Failed to update session configuration", error=e)
|
|
324
|
+
if self.audit_trail is not None:
|
|
325
|
+
self.audit_trail.record("update_session_config", "failed", error=str(e))
|
|
326
|
+
raise SessionError(f"Failed to update session configuration: {e}") from e
|
|
327
|
+
|
|
328
|
+
def get_audit_trail(self, operation: Optional[str] = None,
|
|
329
|
+
status: Optional[str] = None,
|
|
330
|
+
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
331
|
+
"""
|
|
332
|
+
Get audit trail entries.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
operation: Filter by operation name
|
|
336
|
+
status: Filter by status
|
|
337
|
+
limit: Maximum number of entries to return
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
List of audit trail entries
|
|
341
|
+
"""
|
|
342
|
+
if not self.audit_trail:
|
|
343
|
+
return []
|
|
344
|
+
return self.audit_trail.get_entries(operation=operation, status=status, limit=limit)
|
|
345
|
+
|
|
346
|
+
def get_audit_trail_summary(self) -> Dict[str, Any]:
|
|
347
|
+
"""
|
|
348
|
+
Get audit trail summary statistics.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Dictionary with summary statistics including operation counts,
|
|
352
|
+
status counts, and failed operations
|
|
353
|
+
|
|
354
|
+
Raises:
|
|
355
|
+
SessionError: If audit trail is not enabled
|
|
356
|
+
"""
|
|
357
|
+
if not self.audit_trail:
|
|
358
|
+
raise SessionError("Audit trail is not enabled for this session")
|
|
359
|
+
|
|
360
|
+
return self.audit_trail.get_summary()
|
|
361
|
+
|
|
362
|
+
def export_audit_trail(self, file_path: str, format: str = 'json') -> None:
|
|
363
|
+
"""
|
|
364
|
+
Export audit trail to file.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
file_path: Path to output file
|
|
368
|
+
format: Export format ('json' or 'csv')
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
SessionError: If audit trail is not enabled
|
|
372
|
+
ValidationError: If format is invalid
|
|
373
|
+
"""
|
|
374
|
+
if not self.audit_trail:
|
|
375
|
+
raise SessionError("Audit trail is not enabled for this session")
|
|
376
|
+
|
|
377
|
+
if format not in ['json', 'csv']:
|
|
378
|
+
raise ValidationError(f"Invalid export format: {format}. Must be 'json' or 'csv'")
|
|
379
|
+
|
|
380
|
+
if format == 'json':
|
|
381
|
+
self.audit_trail.export_json(file_path)
|
|
382
|
+
else:
|
|
383
|
+
self.audit_trail.export_csv(file_path)
|
|
384
|
+
|
|
385
|
+
self.logger.info("Audit trail exported", file_path=file_path, format=format)
|
|
386
|
+
|
|
387
|
+
def create_feature_group(self,
|
|
388
|
+
feature_group_name: str,
|
|
389
|
+
record_identifier_name: str,
|
|
390
|
+
event_time_feature_name: str,
|
|
391
|
+
feature_definitions: List[Dict[str, str]],
|
|
392
|
+
**kwargs):
|
|
393
|
+
"""
|
|
394
|
+
Create feature group with defaults.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
feature_group_name: Name of the feature group
|
|
398
|
+
record_identifier_name: Name of the record identifier feature
|
|
399
|
+
event_time_feature_name: Name of the event time feature
|
|
400
|
+
feature_definitions: List of feature definitions with FeatureName and FeatureType
|
|
401
|
+
**kwargs: Additional feature group parameters that override defaults
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
FeatureGroup object
|
|
405
|
+
|
|
406
|
+
Raises:
|
|
407
|
+
ValidationError: If required parameters are missing or invalid
|
|
408
|
+
SessionError: If session is not initialized
|
|
409
|
+
AWSServiceError: If feature group creation fails
|
|
410
|
+
"""
|
|
411
|
+
# Validate session is initialized
|
|
412
|
+
if not self._sagemaker_session:
|
|
413
|
+
raise SessionError("SageMaker session is not initialized")
|
|
414
|
+
|
|
415
|
+
# Validate required parameters at session level
|
|
416
|
+
if not feature_group_name or not isinstance(feature_group_name, str):
|
|
417
|
+
raise ValidationError("feature_group_name must be a non-empty string")
|
|
418
|
+
|
|
419
|
+
if not record_identifier_name or not isinstance(record_identifier_name, str):
|
|
420
|
+
raise ValidationError("record_identifier_name must be a non-empty string")
|
|
421
|
+
|
|
422
|
+
if not event_time_feature_name or not isinstance(event_time_feature_name, str):
|
|
423
|
+
raise ValidationError("event_time_feature_name must be a non-empty string")
|
|
424
|
+
|
|
425
|
+
if not feature_definitions or not isinstance(feature_definitions, list):
|
|
426
|
+
raise ValidationError("feature_definitions must be a non-empty list")
|
|
427
|
+
|
|
428
|
+
self.logger.info("create_feature_group called", name=feature_group_name)
|
|
429
|
+
|
|
430
|
+
if self.audit_trail is not None:
|
|
431
|
+
self.audit_trail.record("create_feature_group", "started", name=feature_group_name)
|
|
432
|
+
|
|
433
|
+
try:
|
|
434
|
+
feature_group = self._feature_store_wrapper.create_feature_group(
|
|
435
|
+
sagemaker_session=self._sagemaker_session,
|
|
436
|
+
feature_group_name=feature_group_name,
|
|
437
|
+
record_identifier_name=record_identifier_name,
|
|
438
|
+
event_time_feature_name=event_time_feature_name,
|
|
439
|
+
feature_definitions=feature_definitions,
|
|
440
|
+
**kwargs
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
if self.audit_trail is not None:
|
|
444
|
+
self.audit_trail.record("create_feature_group", "completed", name=feature_group_name)
|
|
445
|
+
|
|
446
|
+
return feature_group
|
|
447
|
+
|
|
448
|
+
except (ValidationError, SessionError) as e:
|
|
449
|
+
# Re-raise validation and session errors without wrapping
|
|
450
|
+
if self.audit_trail is not None:
|
|
451
|
+
self.audit_trail.record("create_feature_group", "failed",
|
|
452
|
+
name=feature_group_name, error=str(e))
|
|
453
|
+
raise
|
|
454
|
+
except Exception as e:
|
|
455
|
+
if self.audit_trail is not None:
|
|
456
|
+
self.audit_trail.record("create_feature_group", "failed",
|
|
457
|
+
name=feature_group_name, error=str(e))
|
|
458
|
+
raise
|
|
459
|
+
|
|
460
|
+
def run_processing_job(self,
|
|
461
|
+
job_name: str,
|
|
462
|
+
processing_script: Optional[str] = None,
|
|
463
|
+
inputs: Optional[List[Dict[str, Any]]] = None,
|
|
464
|
+
outputs: Optional[List[Dict[str, Any]]] = None,
|
|
465
|
+
**kwargs):
|
|
466
|
+
"""
|
|
467
|
+
Execute processing job with defaults.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
job_name: Processing job name
|
|
471
|
+
processing_script: Optional path to custom processing script
|
|
472
|
+
inputs: Optional list of processing inputs
|
|
473
|
+
outputs: Optional list of processing outputs
|
|
474
|
+
**kwargs: Additional processing job parameters that override defaults
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
Processor object
|
|
478
|
+
|
|
479
|
+
Raises:
|
|
480
|
+
ValidationError: If required parameters are missing or invalid
|
|
481
|
+
SessionError: If session is not initialized
|
|
482
|
+
AWSServiceError: If processing job execution fails
|
|
483
|
+
"""
|
|
484
|
+
# Validate session is initialized
|
|
485
|
+
if not self._sagemaker_session:
|
|
486
|
+
raise SessionError("SageMaker session is not initialized")
|
|
487
|
+
|
|
488
|
+
# Validate required parameters at session level
|
|
489
|
+
if not job_name or not isinstance(job_name, str):
|
|
490
|
+
raise ValidationError("job_name must be a non-empty string")
|
|
491
|
+
|
|
492
|
+
# Validate optional parameters if provided
|
|
493
|
+
if inputs is not None and not isinstance(inputs, list):
|
|
494
|
+
raise ValidationError("inputs must be a list if provided")
|
|
495
|
+
|
|
496
|
+
if outputs is not None and not isinstance(outputs, list):
|
|
497
|
+
raise ValidationError("outputs must be a list if provided")
|
|
498
|
+
|
|
499
|
+
if processing_script is not None and not isinstance(processing_script, str):
|
|
500
|
+
raise ValidationError("processing_script must be a string if provided")
|
|
501
|
+
|
|
502
|
+
self.logger.info("run_processing_job called", name=job_name)
|
|
503
|
+
|
|
504
|
+
if self.audit_trail is not None:
|
|
505
|
+
self.audit_trail.record("run_processing_job", "started", name=job_name)
|
|
506
|
+
|
|
507
|
+
try:
|
|
508
|
+
processor = self._processing_wrapper.run_processing_job(
|
|
509
|
+
sagemaker_session=self._sagemaker_session,
|
|
510
|
+
job_name=job_name,
|
|
511
|
+
processing_script=processing_script,
|
|
512
|
+
inputs=inputs,
|
|
513
|
+
outputs=outputs,
|
|
514
|
+
**kwargs
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
if self.audit_trail is not None:
|
|
518
|
+
self.audit_trail.record("run_processing_job", "completed", name=job_name)
|
|
519
|
+
|
|
520
|
+
return processor
|
|
521
|
+
|
|
522
|
+
except (ValidationError, SessionError) as e:
|
|
523
|
+
# Re-raise validation and session errors without wrapping
|
|
524
|
+
if self.audit_trail is not None:
|
|
525
|
+
self.audit_trail.record("run_processing_job", "failed",
|
|
526
|
+
name=job_name, error=str(e))
|
|
527
|
+
raise
|
|
528
|
+
except Exception as e:
|
|
529
|
+
if self.audit_trail is not None:
|
|
530
|
+
self.audit_trail.record("run_processing_job", "failed",
|
|
531
|
+
name=job_name, error=str(e))
|
|
532
|
+
raise
|
|
533
|
+
|
|
534
|
+
def run_training_job(self,
|
|
535
|
+
job_name: str,
|
|
536
|
+
training_image: str,
|
|
537
|
+
source_code_dir: Optional[str] = None,
|
|
538
|
+
entry_script: Optional[str] = None,
|
|
539
|
+
requirements: Optional[str] = None,
|
|
540
|
+
inputs: Optional[Dict[str, Any]] = None,
|
|
541
|
+
**kwargs):
|
|
542
|
+
"""
|
|
543
|
+
Execute training job with defaults using ModelTrainer (SDK v3).
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
job_name: Training job name (used as base_job_name)
|
|
547
|
+
training_image: Container image URI for training
|
|
548
|
+
source_code_dir: Directory containing training script and dependencies
|
|
549
|
+
entry_script: Entry point script for training (e.g., 'train.py')
|
|
550
|
+
requirements: Path to requirements.txt file for dependencies
|
|
551
|
+
inputs: Training data inputs (dict of channel_name: S3 path)
|
|
552
|
+
**kwargs: Additional training job parameters that override defaults
|
|
553
|
+
(e.g., hyperparameters, environment, distributed_runner)
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
ModelTrainer object
|
|
557
|
+
|
|
558
|
+
Raises:
|
|
559
|
+
ValidationError: If required parameters are missing or invalid
|
|
560
|
+
SessionError: If session is not initialized
|
|
561
|
+
AWSServiceError: If training job execution fails
|
|
562
|
+
|
|
563
|
+
Example:
|
|
564
|
+
>>> session = MLP_Session()
|
|
565
|
+
>>> trainer = session.run_training_job(
|
|
566
|
+
... job_name='my-training-job',
|
|
567
|
+
... training_image='763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310',
|
|
568
|
+
... source_code_dir='training-scripts',
|
|
569
|
+
... entry_script='train.py',
|
|
570
|
+
... inputs={'train': 's3://my-bucket/data/train/'}
|
|
571
|
+
... )
|
|
572
|
+
"""
|
|
573
|
+
# Validate session is initialized
|
|
574
|
+
if not self._sagemaker_session:
|
|
575
|
+
raise SessionError("SageMaker session is not initialized")
|
|
576
|
+
|
|
577
|
+
# Validate required parameters at session level
|
|
578
|
+
if not job_name or not isinstance(job_name, str):
|
|
579
|
+
raise ValidationError("job_name must be a non-empty string")
|
|
580
|
+
|
|
581
|
+
if not training_image or not isinstance(training_image, str):
|
|
582
|
+
raise ValidationError("training_image must be a non-empty string")
|
|
583
|
+
|
|
584
|
+
# Validate optional parameters if provided
|
|
585
|
+
if source_code_dir is not None and not isinstance(source_code_dir, str):
|
|
586
|
+
raise ValidationError("source_code_dir must be a string if provided")
|
|
587
|
+
|
|
588
|
+
if entry_script is not None and not isinstance(entry_script, str):
|
|
589
|
+
raise ValidationError("entry_script must be a string if provided")
|
|
590
|
+
|
|
591
|
+
if requirements is not None and not isinstance(requirements, str):
|
|
592
|
+
raise ValidationError("requirements must be a string if provided")
|
|
593
|
+
|
|
594
|
+
if inputs is not None and not isinstance(inputs, (dict, list)):
|
|
595
|
+
raise ValidationError("inputs must be a dictionary or list if provided")
|
|
596
|
+
|
|
597
|
+
self.logger.info("run_training_job called", name=job_name)
|
|
598
|
+
|
|
599
|
+
if self.audit_trail is not None:
|
|
600
|
+
self.audit_trail.record("run_training_job", "started", name=job_name)
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
trainer = self._training_wrapper.run_training_job(
|
|
604
|
+
sagemaker_session=self._sagemaker_session,
|
|
605
|
+
job_name=job_name,
|
|
606
|
+
training_image=training_image,
|
|
607
|
+
source_code_dir=source_code_dir,
|
|
608
|
+
entry_script=entry_script,
|
|
609
|
+
requirements=requirements,
|
|
610
|
+
inputs=inputs,
|
|
611
|
+
**kwargs
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
if self.audit_trail is not None:
|
|
615
|
+
self.audit_trail.record("run_training_job", "completed", name=job_name)
|
|
616
|
+
|
|
617
|
+
return trainer
|
|
618
|
+
|
|
619
|
+
except (ValidationError, SessionError) as e:
|
|
620
|
+
# Re-raise validation and session errors without wrapping
|
|
621
|
+
if self.audit_trail is not None:
|
|
622
|
+
self.audit_trail.record("run_training_job", "failed",
|
|
623
|
+
name=job_name, error=str(e))
|
|
624
|
+
raise
|
|
625
|
+
except Exception as e:
|
|
626
|
+
if self.audit_trail is not None:
|
|
627
|
+
self.audit_trail.record("run_training_job", "failed",
|
|
628
|
+
name=job_name, error=str(e))
|
|
629
|
+
raise
|
|
630
|
+
|
|
631
|
+
def deploy_model(self,
|
|
632
|
+
model_data: str,
|
|
633
|
+
image_uri: str,
|
|
634
|
+
endpoint_name: str,
|
|
635
|
+
enable_vpc: bool = False,
|
|
636
|
+
**kwargs):
|
|
637
|
+
"""
|
|
638
|
+
Deploy a trained model to a SageMaker endpoint with defaults.
|
|
639
|
+
|
|
640
|
+
Applies defaults from configuration for:
|
|
641
|
+
- Instance type and count (via Compute config)
|
|
642
|
+
- IAM execution role
|
|
643
|
+
- VPC configuration (optional, controlled by enable_vpc flag)
|
|
644
|
+
- KMS encryption key
|
|
645
|
+
|
|
646
|
+
Runtime parameters override configuration defaults.
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
model_data: S3 URI of the model artifacts (e.g., 's3://bucket/model.tar.gz')
|
|
650
|
+
image_uri: Container image URI for inference
|
|
651
|
+
endpoint_name: Name for the endpoint
|
|
652
|
+
enable_vpc: If True, applies VPC configuration from config (default: False)
|
|
653
|
+
**kwargs: Additional parameters that override defaults
|
|
654
|
+
(e.g., instance_type, instance_count, environment, subnets, security_group_ids)
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
Predictor object for making predictions
|
|
658
|
+
|
|
659
|
+
Raises:
|
|
660
|
+
ValidationError: If required parameters are missing or invalid
|
|
661
|
+
SessionError: If session is not initialized
|
|
662
|
+
AWSServiceError: If deployment fails
|
|
663
|
+
|
|
664
|
+
Example:
|
|
665
|
+
>>> session = MLP_Session()
|
|
666
|
+
>>> # Deploy without VPC
|
|
667
|
+
>>> predictor = session.deploy_model(
|
|
668
|
+
... model_data='s3://my-bucket/model.tar.gz',
|
|
669
|
+
... image_uri='683313688378.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1',
|
|
670
|
+
... endpoint_name='my-endpoint'
|
|
671
|
+
... )
|
|
672
|
+
>>> # Deploy with VPC configuration
|
|
673
|
+
>>> predictor = session.deploy_model(
|
|
674
|
+
... model_data='s3://my-bucket/model.tar.gz',
|
|
675
|
+
... image_uri='683313688378.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1',
|
|
676
|
+
... endpoint_name='my-endpoint',
|
|
677
|
+
... enable_vpc=True
|
|
678
|
+
... )
|
|
679
|
+
"""
|
|
680
|
+
# Validate session is initialized
|
|
681
|
+
if not self._sagemaker_session:
|
|
682
|
+
raise SessionError("SageMaker session is not initialized")
|
|
683
|
+
|
|
684
|
+
# Validate required parameters at session level
|
|
685
|
+
if not model_data or not isinstance(model_data, str):
|
|
686
|
+
raise ValidationError("model_data must be a non-empty string")
|
|
687
|
+
|
|
688
|
+
if not image_uri or not isinstance(image_uri, str):
|
|
689
|
+
raise ValidationError("image_uri must be a non-empty string")
|
|
690
|
+
|
|
691
|
+
if not endpoint_name or not isinstance(endpoint_name, str):
|
|
692
|
+
raise ValidationError("endpoint_name must be a non-empty string")
|
|
693
|
+
|
|
694
|
+
self.logger.info("deploy_model called", endpoint_name=endpoint_name)
|
|
695
|
+
|
|
696
|
+
if self.audit_trail is not None:
|
|
697
|
+
self.audit_trail.record("deploy_model", "started", endpoint_name=endpoint_name)
|
|
698
|
+
|
|
699
|
+
try:
|
|
700
|
+
predictor = self._deployment_wrapper.deploy_model(
|
|
701
|
+
model_data=model_data,
|
|
702
|
+
image_uri=image_uri,
|
|
703
|
+
endpoint_name=endpoint_name,
|
|
704
|
+
enable_vpc=enable_vpc,
|
|
705
|
+
**kwargs
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
if self.audit_trail is not None:
|
|
709
|
+
self.audit_trail.record("deploy_model", "completed", endpoint_name=endpoint_name)
|
|
710
|
+
|
|
711
|
+
return predictor
|
|
712
|
+
|
|
713
|
+
except (ValidationError, SessionError) as e:
|
|
714
|
+
# Re-raise validation and session errors without wrapping
|
|
715
|
+
if self.audit_trail is not None:
|
|
716
|
+
self.audit_trail.record("deploy_model", "failed",
|
|
717
|
+
endpoint_name=endpoint_name, error=str(e))
|
|
718
|
+
raise
|
|
719
|
+
except Exception as e:
|
|
720
|
+
if self.audit_trail is not None:
|
|
721
|
+
self.audit_trail.record("deploy_model", "failed",
|
|
722
|
+
endpoint_name=endpoint_name, error=str(e))
|
|
723
|
+
raise
|
|
724
|
+
|
|
725
|
+
def delete_endpoint(self, endpoint_name: str) -> None:
|
|
726
|
+
"""
|
|
727
|
+
Delete a SageMaker endpoint.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
endpoint_name: Name of the endpoint to delete
|
|
731
|
+
|
|
732
|
+
Raises:
|
|
733
|
+
ValidationError: If endpoint_name is invalid
|
|
734
|
+
SessionError: If session is not initialized
|
|
735
|
+
AWSServiceError: If deletion fails
|
|
736
|
+
|
|
737
|
+
Example:
|
|
738
|
+
>>> session = MLP_Session()
|
|
739
|
+
>>> session.delete_endpoint('my-endpoint')
|
|
740
|
+
"""
|
|
741
|
+
# Validate session is initialized
|
|
742
|
+
if not self._sagemaker_session:
|
|
743
|
+
raise SessionError("SageMaker session is not initialized")
|
|
744
|
+
|
|
745
|
+
# Validate required parameters
|
|
746
|
+
if not endpoint_name or not isinstance(endpoint_name, str):
|
|
747
|
+
raise ValidationError("endpoint_name must be a non-empty string")
|
|
748
|
+
|
|
749
|
+
self.logger.info("delete_endpoint called", endpoint_name=endpoint_name)
|
|
750
|
+
|
|
751
|
+
if self.audit_trail is not None:
|
|
752
|
+
self.audit_trail.record("delete_endpoint", "started", endpoint_name=endpoint_name)
|
|
753
|
+
|
|
754
|
+
try:
|
|
755
|
+
# Use boto3 client to delete endpoint
|
|
756
|
+
self._sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
|
|
757
|
+
|
|
758
|
+
self.logger.info("Endpoint deleted successfully", endpoint_name=endpoint_name)
|
|
759
|
+
|
|
760
|
+
if self.audit_trail is not None:
|
|
761
|
+
self.audit_trail.record("delete_endpoint", "completed", endpoint_name=endpoint_name)
|
|
762
|
+
|
|
763
|
+
except Exception as e:
|
|
764
|
+
self.logger.error("Failed to delete endpoint",
|
|
765
|
+
endpoint_name=endpoint_name,
|
|
766
|
+
error=e)
|
|
767
|
+
if self.audit_trail is not None:
|
|
768
|
+
self.audit_trail.record("delete_endpoint", "failed",
|
|
769
|
+
endpoint_name=endpoint_name, error=str(e))
|
|
770
|
+
raise AWSServiceError(
|
|
771
|
+
f"Failed to delete endpoint '{endpoint_name}': {e}",
|
|
772
|
+
aws_error=e
|
|
773
|
+
) from e
|
|
774
|
+
|
|
775
|
+
def create_pipeline(self,
|
|
776
|
+
pipeline_name: str,
|
|
777
|
+
steps: List,
|
|
778
|
+
parameters: Optional[List] = None,
|
|
779
|
+
**kwargs):
|
|
780
|
+
"""
|
|
781
|
+
Create pipeline with consistent defaults.
|
|
782
|
+
|
|
783
|
+
Applies default configurations across all pipeline steps and supports
|
|
784
|
+
parameter passing between steps.
|
|
785
|
+
|
|
786
|
+
Args:
|
|
787
|
+
pipeline_name: Pipeline name
|
|
788
|
+
steps: List of pipeline steps (ProcessingStep, TrainingStep, etc.)
|
|
789
|
+
parameters: Optional list of pipeline parameters for cross-step communication
|
|
790
|
+
**kwargs: Additional pipeline parameters that override defaults
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
Pipeline object
|
|
794
|
+
|
|
795
|
+
Raises:
|
|
796
|
+
ValidationError: If required parameters are missing or invalid
|
|
797
|
+
SessionError: If session is not initialized
|
|
798
|
+
AWSServiceError: If pipeline creation fails
|
|
799
|
+
"""
|
|
800
|
+
# Validate session is initialized
|
|
801
|
+
if not self._sagemaker_session:
|
|
802
|
+
raise SessionError("SageMaker session is not initialized")
|
|
803
|
+
|
|
804
|
+
# Validate required parameters at session level
|
|
805
|
+
if not pipeline_name or not isinstance(pipeline_name, str):
|
|
806
|
+
raise ValidationError("pipeline_name must be a non-empty string")
|
|
807
|
+
|
|
808
|
+
if not steps or not isinstance(steps, list):
|
|
809
|
+
raise ValidationError("steps must be a non-empty list")
|
|
810
|
+
|
|
811
|
+
# Validate optional parameters if provided
|
|
812
|
+
if parameters is not None and not isinstance(parameters, list):
|
|
813
|
+
raise ValidationError("parameters must be a list if provided")
|
|
814
|
+
|
|
815
|
+
self.logger.info("create_pipeline called", name=pipeline_name)
|
|
816
|
+
|
|
817
|
+
if self.audit_trail is not None:
|
|
818
|
+
self.audit_trail.record("create_pipeline", "started", name=pipeline_name)
|
|
819
|
+
|
|
820
|
+
try:
|
|
821
|
+
pipeline = self._pipeline_wrapper.create_pipeline(
|
|
822
|
+
sagemaker_session=self._sagemaker_session,
|
|
823
|
+
pipeline_name=pipeline_name,
|
|
824
|
+
steps=steps,
|
|
825
|
+
parameters=parameters,
|
|
826
|
+
**kwargs
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
if self.audit_trail is not None:
|
|
830
|
+
self.audit_trail.record("create_pipeline", "completed", name=pipeline_name)
|
|
831
|
+
|
|
832
|
+
return pipeline
|
|
833
|
+
|
|
834
|
+
except (ValidationError, SessionError) as e:
|
|
835
|
+
# Re-raise validation and session errors without wrapping
|
|
836
|
+
if self.audit_trail is not None:
|
|
837
|
+
self.audit_trail.record("create_pipeline", "failed",
|
|
838
|
+
name=pipeline_name, error=str(e))
|
|
839
|
+
raise
|
|
840
|
+
except Exception as e:
|
|
841
|
+
if self.audit_trail is not None:
|
|
842
|
+
self.audit_trail.record("create_pipeline", "failed",
|
|
843
|
+
name=pipeline_name, error=str(e))
|
|
844
|
+
raise
|
|
845
|
+
|
|
846
|
+
def upsert_pipeline(self, pipeline, **kwargs) -> Dict[str, Any]:
|
|
847
|
+
"""
|
|
848
|
+
Create or update a pipeline definition.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
pipeline: Pipeline object to upsert
|
|
852
|
+
**kwargs: Additional parameters for upsert operation
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
Dictionary with pipeline ARN and other metadata
|
|
856
|
+
|
|
857
|
+
Raises:
|
|
858
|
+
ValidationError: If pipeline is invalid
|
|
859
|
+
SessionError: If session is not initialized
|
|
860
|
+
AWSServiceError: If upsert operation fails
|
|
861
|
+
"""
|
|
862
|
+
# Validate session is initialized
|
|
863
|
+
if not self._sagemaker_session:
|
|
864
|
+
raise SessionError("SageMaker session is not initialized")
|
|
865
|
+
|
|
866
|
+
# Validate required parameters
|
|
867
|
+
if not pipeline:
|
|
868
|
+
raise ValidationError("pipeline is required")
|
|
869
|
+
|
|
870
|
+
if not hasattr(pipeline, 'name'):
|
|
871
|
+
raise ValidationError("pipeline must have a 'name' attribute")
|
|
872
|
+
|
|
873
|
+
self.logger.info("upsert_pipeline called", name=pipeline.name)
|
|
874
|
+
|
|
875
|
+
if self.audit_trail is not None:
|
|
876
|
+
self.audit_trail.record("upsert_pipeline", "started", name=pipeline.name)
|
|
877
|
+
|
|
878
|
+
try:
|
|
879
|
+
response = self._pipeline_wrapper.upsert_pipeline(
|
|
880
|
+
pipeline=pipeline,
|
|
881
|
+
**kwargs
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
if self.audit_trail is not None:
|
|
885
|
+
self.audit_trail.record("upsert_pipeline", "completed",
|
|
886
|
+
name=pipeline.name,
|
|
887
|
+
arn=response.get('PipelineArn'))
|
|
888
|
+
|
|
889
|
+
return response
|
|
890
|
+
|
|
891
|
+
except (ValidationError, SessionError) as e:
|
|
892
|
+
# Re-raise validation and session errors without wrapping
|
|
893
|
+
if self.audit_trail is not None:
|
|
894
|
+
self.audit_trail.record("upsert_pipeline", "failed",
|
|
895
|
+
name=pipeline.name, error=str(e))
|
|
896
|
+
raise
|
|
897
|
+
except Exception as e:
|
|
898
|
+
if self.audit_trail is not None:
|
|
899
|
+
self.audit_trail.record("upsert_pipeline", "failed",
|
|
900
|
+
name=pipeline.name, error=str(e))
|
|
901
|
+
raise
|
|
902
|
+
|
|
903
|
+
def start_pipeline_execution(self,
|
|
904
|
+
pipeline,
|
|
905
|
+
execution_display_name: Optional[str] = None,
|
|
906
|
+
execution_parameters: Optional[Dict[str, Any]] = None,
|
|
907
|
+
**kwargs):
|
|
908
|
+
"""
|
|
909
|
+
Start pipeline execution with monitoring support.
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
pipeline: Pipeline object to execute
|
|
913
|
+
execution_display_name: Optional display name for the execution
|
|
914
|
+
execution_parameters: Optional parameters to override pipeline defaults
|
|
915
|
+
**kwargs: Additional execution parameters
|
|
916
|
+
|
|
917
|
+
Returns:
|
|
918
|
+
PipelineExecution object
|
|
919
|
+
|
|
920
|
+
Raises:
|
|
921
|
+
ValidationError: If pipeline is invalid
|
|
922
|
+
SessionError: If session is not initialized
|
|
923
|
+
AWSServiceError: If execution start fails
|
|
924
|
+
"""
|
|
925
|
+
# Validate session is initialized
|
|
926
|
+
if not self._sagemaker_session:
|
|
927
|
+
raise SessionError("SageMaker session is not initialized")
|
|
928
|
+
|
|
929
|
+
# Validate required parameters
|
|
930
|
+
if not pipeline:
|
|
931
|
+
raise ValidationError("pipeline is required")
|
|
932
|
+
|
|
933
|
+
if not hasattr(pipeline, 'name'):
|
|
934
|
+
raise ValidationError("pipeline must have a 'name' attribute")
|
|
935
|
+
|
|
936
|
+
# Validate optional parameters if provided
|
|
937
|
+
if execution_display_name is not None and not isinstance(execution_display_name, str):
|
|
938
|
+
raise ValidationError("execution_display_name must be a string if provided")
|
|
939
|
+
|
|
940
|
+
if execution_parameters is not None and not isinstance(execution_parameters, dict):
|
|
941
|
+
raise ValidationError("execution_parameters must be a dictionary if provided")
|
|
942
|
+
|
|
943
|
+
self.logger.info("start_pipeline_execution called", name=pipeline.name)
|
|
944
|
+
|
|
945
|
+
if self.audit_trail is not None:
|
|
946
|
+
self.audit_trail.record("start_pipeline_execution", "started",
|
|
947
|
+
name=pipeline.name,
|
|
948
|
+
display_name=execution_display_name)
|
|
949
|
+
|
|
950
|
+
try:
|
|
951
|
+
execution = self._pipeline_wrapper.start_pipeline_execution(
|
|
952
|
+
pipeline=pipeline,
|
|
953
|
+
execution_display_name=execution_display_name,
|
|
954
|
+
execution_parameters=execution_parameters,
|
|
955
|
+
**kwargs
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
if self.audit_trail is not None:
|
|
959
|
+
self.audit_trail.record("start_pipeline_execution", "completed",
|
|
960
|
+
name=pipeline.name,
|
|
961
|
+
execution_arn=execution.arn)
|
|
962
|
+
|
|
963
|
+
return execution
|
|
964
|
+
|
|
965
|
+
except (ValidationError, SessionError) as e:
|
|
966
|
+
# Re-raise validation and session errors without wrapping
|
|
967
|
+
if self.audit_trail is not None:
|
|
968
|
+
self.audit_trail.record("start_pipeline_execution", "failed",
|
|
969
|
+
name=pipeline.name, error=str(e))
|
|
970
|
+
raise
|
|
971
|
+
except Exception as e:
|
|
972
|
+
if self.audit_trail is not None:
|
|
973
|
+
self.audit_trail.record("start_pipeline_execution", "failed",
|
|
974
|
+
name=pipeline.name, error=str(e))
|
|
975
|
+
raise
|
|
976
|
+
|
|
977
|
+
def describe_pipeline_execution(self, pipeline_execution) -> Dict[str, Any]:
|
|
978
|
+
"""
|
|
979
|
+
Get pipeline execution status and details.
|
|
980
|
+
|
|
981
|
+
Args:
|
|
982
|
+
pipeline_execution: PipelineExecution object
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
Dictionary with execution status and details
|
|
986
|
+
|
|
987
|
+
Raises:
|
|
988
|
+
ValidationError: If pipeline_execution is invalid
|
|
989
|
+
SessionError: If session is not initialized
|
|
990
|
+
AWSServiceError: If describe operation fails
|
|
991
|
+
"""
|
|
992
|
+
# Validate session is initialized
|
|
993
|
+
if not self._sagemaker_session:
|
|
994
|
+
raise SessionError("SageMaker session is not initialized")
|
|
995
|
+
|
|
996
|
+
# Validate required parameters
|
|
997
|
+
if not pipeline_execution:
|
|
998
|
+
raise ValidationError("pipeline_execution is required")
|
|
999
|
+
|
|
1000
|
+
if not hasattr(pipeline_execution, 'arn'):
|
|
1001
|
+
raise ValidationError("pipeline_execution must have an 'arn' attribute")
|
|
1002
|
+
|
|
1003
|
+
self.logger.info("describe_pipeline_execution called", arn=pipeline_execution.arn)
|
|
1004
|
+
|
|
1005
|
+
if self.audit_trail is not None:
|
|
1006
|
+
self.audit_trail.record("describe_pipeline_execution", "started",
|
|
1007
|
+
arn=pipeline_execution.arn)
|
|
1008
|
+
|
|
1009
|
+
try:
|
|
1010
|
+
response = self._pipeline_wrapper.describe_pipeline_execution(
|
|
1011
|
+
pipeline_execution=pipeline_execution
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
if self.audit_trail is not None:
|
|
1015
|
+
self.audit_trail.record("describe_pipeline_execution", "completed",
|
|
1016
|
+
arn=pipeline_execution.arn,
|
|
1017
|
+
status=response.get('PipelineExecutionStatus'))
|
|
1018
|
+
|
|
1019
|
+
return response
|
|
1020
|
+
|
|
1021
|
+
except (ValidationError, SessionError) as e:
|
|
1022
|
+
# Re-raise validation and session errors without wrapping
|
|
1023
|
+
if self.audit_trail is not None:
|
|
1024
|
+
self.audit_trail.record("describe_pipeline_execution", "failed",
|
|
1025
|
+
arn=pipeline_execution.arn, error=str(e))
|
|
1026
|
+
raise
|
|
1027
|
+
except Exception as e:
|
|
1028
|
+
if self.audit_trail is not None:
|
|
1029
|
+
self.audit_trail.record("describe_pipeline_execution", "failed",
|
|
1030
|
+
arn=pipeline_execution.arn, error=str(e))
|
|
1031
|
+
raise
|
|
1032
|
+
|
|
1033
|
+
def list_pipeline_execution_steps(self, pipeline_execution) -> List[Dict[str, Any]]:
|
|
1034
|
+
"""
|
|
1035
|
+
List all steps in a pipeline execution with their status.
|
|
1036
|
+
|
|
1037
|
+
Args:
|
|
1038
|
+
pipeline_execution: PipelineExecution object
|
|
1039
|
+
|
|
1040
|
+
Returns:
|
|
1041
|
+
List of step details with status information
|
|
1042
|
+
|
|
1043
|
+
Raises:
|
|
1044
|
+
ValidationError: If pipeline_execution is invalid
|
|
1045
|
+
SessionError: If session is not initialized
|
|
1046
|
+
AWSServiceError: If list operation fails
|
|
1047
|
+
"""
|
|
1048
|
+
# Validate session is initialized
|
|
1049
|
+
if not self._sagemaker_session:
|
|
1050
|
+
raise SessionError("SageMaker session is not initialized")
|
|
1051
|
+
|
|
1052
|
+
# Validate required parameters
|
|
1053
|
+
if not pipeline_execution:
|
|
1054
|
+
raise ValidationError("pipeline_execution is required")
|
|
1055
|
+
|
|
1056
|
+
if not hasattr(pipeline_execution, 'arn'):
|
|
1057
|
+
raise ValidationError("pipeline_execution must have an 'arn' attribute")
|
|
1058
|
+
|
|
1059
|
+
self.logger.info("list_pipeline_execution_steps called", arn=pipeline_execution.arn)
|
|
1060
|
+
|
|
1061
|
+
if self.audit_trail is not None:
|
|
1062
|
+
self.audit_trail.record("list_pipeline_execution_steps", "started",
|
|
1063
|
+
arn=pipeline_execution.arn)
|
|
1064
|
+
|
|
1065
|
+
try:
|
|
1066
|
+
steps = self._pipeline_wrapper.list_pipeline_execution_steps(
|
|
1067
|
+
pipeline_execution=pipeline_execution
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
if self.audit_trail is not None:
|
|
1071
|
+
self.audit_trail.record("list_pipeline_execution_steps", "completed",
|
|
1072
|
+
arn=pipeline_execution.arn,
|
|
1073
|
+
step_count=len(steps))
|
|
1074
|
+
|
|
1075
|
+
return steps
|
|
1076
|
+
|
|
1077
|
+
except (ValidationError, SessionError) as e:
|
|
1078
|
+
# Re-raise validation and session errors without wrapping
|
|
1079
|
+
if self.audit_trail is not None:
|
|
1080
|
+
self.audit_trail.record("list_pipeline_execution_steps", "failed",
|
|
1081
|
+
arn=pipeline_execution.arn, error=str(e))
|
|
1082
|
+
raise
|
|
1083
|
+
except Exception as e:
|
|
1084
|
+
if self.audit_trail is not None:
|
|
1085
|
+
self.audit_trail.record("list_pipeline_execution_steps", "failed",
|
|
1086
|
+
arn=pipeline_execution.arn, error=str(e))
|
|
1087
|
+
raise
|
|
1088
|
+
|
|
1089
|
+
def wait_for_pipeline_execution(self,
|
|
1090
|
+
pipeline_execution,
|
|
1091
|
+
delay: int = 30,
|
|
1092
|
+
max_attempts: int = 60) -> Dict[str, Any]:
|
|
1093
|
+
"""
|
|
1094
|
+
Wait for pipeline execution to complete.
|
|
1095
|
+
|
|
1096
|
+
Args:
|
|
1097
|
+
pipeline_execution: PipelineExecution object
|
|
1098
|
+
delay: Delay between status checks in seconds (default: 30)
|
|
1099
|
+
max_attempts: Maximum number of status checks (default: 60)
|
|
1100
|
+
|
|
1101
|
+
Returns:
|
|
1102
|
+
Final execution status dictionary
|
|
1103
|
+
|
|
1104
|
+
Raises:
|
|
1105
|
+
ValidationError: If pipeline_execution is invalid or parameters are invalid
|
|
1106
|
+
SessionError: If session is not initialized
|
|
1107
|
+
AWSServiceError: If wait operation fails
|
|
1108
|
+
"""
|
|
1109
|
+
# Validate session is initialized
|
|
1110
|
+
if not self._sagemaker_session:
|
|
1111
|
+
raise SessionError("SageMaker session is not initialized")
|
|
1112
|
+
|
|
1113
|
+
# Validate required parameters
|
|
1114
|
+
if not pipeline_execution:
|
|
1115
|
+
raise ValidationError("pipeline_execution is required")
|
|
1116
|
+
|
|
1117
|
+
if not hasattr(pipeline_execution, 'arn'):
|
|
1118
|
+
raise ValidationError("pipeline_execution must have an 'arn' attribute")
|
|
1119
|
+
|
|
1120
|
+
# Validate delay and max_attempts
|
|
1121
|
+
if not isinstance(delay, int) or delay < 1:
|
|
1122
|
+
raise ValidationError("delay must be a positive integer")
|
|
1123
|
+
|
|
1124
|
+
if not isinstance(max_attempts, int) or max_attempts < 1:
|
|
1125
|
+
raise ValidationError("max_attempts must be a positive integer")
|
|
1126
|
+
|
|
1127
|
+
self.logger.info("wait_for_pipeline_execution called",
|
|
1128
|
+
arn=pipeline_execution.arn,
|
|
1129
|
+
delay=delay,
|
|
1130
|
+
max_attempts=max_attempts)
|
|
1131
|
+
|
|
1132
|
+
if self.audit_trail is not None:
|
|
1133
|
+
self.audit_trail.record("wait_for_pipeline_execution", "started",
|
|
1134
|
+
arn=pipeline_execution.arn)
|
|
1135
|
+
|
|
1136
|
+
try:
|
|
1137
|
+
final_status = self._pipeline_wrapper.wait_for_pipeline_execution(
|
|
1138
|
+
pipeline_execution=pipeline_execution,
|
|
1139
|
+
delay=delay,
|
|
1140
|
+
max_attempts=max_attempts
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
if self.audit_trail is not None:
|
|
1144
|
+
self.audit_trail.record("wait_for_pipeline_execution", "completed",
|
|
1145
|
+
arn=pipeline_execution.arn,
|
|
1146
|
+
status=final_status.get('PipelineExecutionStatus'))
|
|
1147
|
+
|
|
1148
|
+
return final_status
|
|
1149
|
+
|
|
1150
|
+
except (ValidationError, SessionError) as e:
|
|
1151
|
+
# Re-raise validation and session errors without wrapping
|
|
1152
|
+
if self.audit_trail is not None:
|
|
1153
|
+
self.audit_trail.record("wait_for_pipeline_execution", "failed",
|
|
1154
|
+
arn=pipeline_execution.arn, error=str(e))
|
|
1155
|
+
raise
|
|
1156
|
+
except Exception as e:
|
|
1157
|
+
if self.audit_trail is not None:
|
|
1158
|
+
self.audit_trail.record("wait_for_pipeline_execution", "failed",
|
|
1159
|
+
arn=pipeline_execution.arn, error=str(e))
|
|
1160
|
+
raise
|