sagemaker-mlp-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlp_sdk/session.py ADDED
@@ -0,0 +1,1160 @@
1
+ """
2
+ Core MLP_Session class - main interface for all SDK operations
3
+ """
4
+
5
+ from typing import Optional, Dict, Any, List
6
+ import logging
7
+ from .config import ConfigurationManager
8
+ from .exceptions import MLPSDKError, SessionError, AWSServiceError, ValidationError, MLPLogger, AuditTrail
9
+ from .wrappers import FeatureStoreWrapper, ProcessingWrapper, TrainingWrapper, PipelineWrapper, DeploymentWrapper
10
+
11
+
12
+ class MLP_Session:
13
+ """
14
+ Main interface for all mlp_sdk operations.
15
+
16
+ Provides simplified SageMaker operations with configuration-driven defaults.
17
+ Built on top of SageMaker Python SDK v3.
18
+ """
19
+
20
+ def __init__(self,
21
+ config_path: Optional[str] = None,
22
+ log_level: int = logging.INFO,
23
+ enable_audit_trail: bool = True,
24
+ **kwargs):
25
+ """
26
+ Initialize session with optional custom config path.
27
+
28
+ Args:
29
+ config_path: Optional path to configuration file.
30
+ Defaults to /home/sagemaker-user/.config/admin-config.yaml
31
+ log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
32
+ enable_audit_trail: Whether to enable audit trail recording
33
+ **kwargs: Additional session parameters passed to SageMaker session
34
+ (e.g., boto_session, sagemaker_client, sagemaker_runtime_client)
35
+
36
+ Raises:
37
+ SessionError: If session initialization fails
38
+ ConfigurationError: If configuration loading fails
39
+ """
40
+ # Initialize logging
41
+ self.logger = MLPLogger("mlp_sdk.session", level=log_level)
42
+ self.logger.info("Initializing MLP_Session", config_path=config_path or "default")
43
+
44
+ # Initialize audit trail
45
+ self.audit_trail = AuditTrail() if enable_audit_trail else None
46
+ if self.audit_trail is not None:
47
+ self.audit_trail.record("session_init", "started", config_path=config_path)
48
+
49
+ try:
50
+ # Load configuration
51
+ self.config_manager = ConfigurationManager(config_path)
52
+
53
+ # Initialize wrappers
54
+ self._feature_store_wrapper = FeatureStoreWrapper(self.config_manager, self.logger)
55
+ self._processing_wrapper = ProcessingWrapper(self.config_manager, self.logger)
56
+ self._training_wrapper = TrainingWrapper(self.config_manager, self.logger)
57
+ self._pipeline_wrapper = PipelineWrapper(self.config_manager, self.logger)
58
+ self._deployment_wrapper = DeploymentWrapper(self.config_manager, self.logger)
59
+
60
+ # Initialize SageMaker session
61
+ self._sagemaker_session = None
62
+ self._session_kwargs = kwargs
63
+ self._initialize_sagemaker_session()
64
+
65
+ # Log successful initialization
66
+ self.logger.info("MLP_Session initialized successfully",
67
+ has_config=self.config_manager.has_config)
68
+
69
+ if self.audit_trail is not None:
70
+ self.audit_trail.record("session_init", "completed",
71
+ has_config=self.config_manager.has_config)
72
+
73
+ except Exception as e:
74
+ self.logger.error("Failed to initialize MLP_Session", error=e)
75
+ if self.audit_trail is not None:
76
+ self.audit_trail.record("session_init", "failed", error=str(e))
77
+ raise SessionError(f"Failed to initialize MLP_Session: {e}") from e
78
+
79
+ def _initialize_sagemaker_session(self) -> None:
80
+ """
81
+ Initialize underlying SageMaker session.
82
+
83
+ In SageMaker SDK v3, we use boto3 clients directly along with SessionSettings.
84
+
85
+ Raises:
86
+ SessionError: If SageMaker session initialization fails
87
+ """
88
+ try:
89
+ # SageMaker SDK v3 uses SessionSettings for configuration
90
+ from sagemaker.core.session_settings import SessionSettings
91
+ from botocore.exceptions import ClientError, NoCredentialsError
92
+ import boto3
93
+
94
+ # Get default bucket from config if available
95
+ default_bucket = None
96
+ if 'default_bucket' in self._session_kwargs:
97
+ default_bucket = self._session_kwargs.pop('default_bucket')
98
+ else:
99
+ s3_config = self.config_manager.get_s3_config()
100
+ if s3_config:
101
+ default_bucket = s3_config.default_bucket
102
+
103
+ # Create boto3 session if not provided
104
+ boto_session = self._session_kwargs.pop('boto_session', None)
105
+ if boto_session is None:
106
+ boto_session = boto3.Session()
107
+
108
+ # Store boto session and default bucket
109
+ self._boto_session = boto_session
110
+ self._default_bucket = default_bucket
111
+ self._region_name = boto_session.region_name
112
+
113
+ # Create SageMaker SessionSettings (SDK v3)
114
+ self._sagemaker_session = SessionSettings(**self._session_kwargs)
115
+
116
+ # Create boto3 clients
117
+ self._sagemaker_client = boto_session.client('sagemaker')
118
+ self._sagemaker_runtime_client = boto_session.client('sagemaker-runtime')
119
+
120
+ self.logger.debug("SageMaker SessionSettings initialized",
121
+ region=self._region_name,
122
+ default_bucket=default_bucket)
123
+
124
+ except ImportError as e:
125
+ raise SessionError(
126
+ "SageMaker SDK v3 not installed. Install with: pip install sagemaker>=3.0.0"
127
+ ) from e
128
+ except (ClientError, NoCredentialsError) as e:
129
+ raise SessionError(
130
+ f"Failed to initialize SageMaker session due to AWS credentials issue: {e}"
131
+ ) from e
132
+ except Exception as e:
133
+ raise SessionError(
134
+ f"Failed to initialize SageMaker session: {e}"
135
+ ) from e
136
+
137
+ @property
138
+ def sagemaker_session(self):
139
+ """
140
+ Get underlying SageMaker session for advanced use cases.
141
+
142
+ This property exposes the underlying SageMaker SDK session object,
143
+ allowing advanced users to access all SageMaker SDK functionality
144
+ directly while still benefiting from mlp_sdk configuration management.
145
+
146
+ Returns:
147
+ sagemaker.Session object
148
+
149
+ Example:
150
+ >>> session = MLP_Session()
151
+ >>> # Use underlying SageMaker session for advanced operations
152
+ >>> session.sagemaker_session.list_training_jobs()
153
+ """
154
+ return self._sagemaker_session
155
+
156
+ @property
157
+ def boto_session(self):
158
+ """
159
+ Get underlying boto3 session.
160
+
161
+ Provides access to the boto3 session used by SageMaker SDK,
162
+ enabling direct AWS service calls and custom boto3 operations.
163
+
164
+ Returns:
165
+ boto3.Session object or None
166
+
167
+ Example:
168
+ >>> session = MLP_Session()
169
+ >>> s3_client = session.boto_session.client('s3')
170
+ """
171
+ return getattr(self, '_boto_session', None)
172
+
173
+ @property
174
+ def sagemaker_client(self):
175
+ """
176
+ Get underlying SageMaker boto3 client.
177
+
178
+ Provides direct access to the SageMaker boto3 client for
179
+ low-level API operations not covered by the SageMaker SDK.
180
+
181
+ Returns:
182
+ boto3 SageMaker client or None
183
+
184
+ Example:
185
+ >>> session = MLP_Session()
186
+ >>> response = session.sagemaker_client.describe_training_job(
187
+ ... TrainingJobName='my-job'
188
+ ... )
189
+ """
190
+ return getattr(self, '_sagemaker_client', None)
191
+
192
+ @property
193
+ def sagemaker_runtime_client(self):
194
+ """
195
+ Get underlying SageMaker Runtime boto3 client.
196
+
197
+ Provides access to the SageMaker Runtime client for
198
+ invoking deployed endpoints.
199
+
200
+ Returns:
201
+ boto3 SageMaker Runtime client or None
202
+
203
+ Example:
204
+ >>> session = MLP_Session()
205
+ >>> response = session.sagemaker_runtime_client.invoke_endpoint(
206
+ ... EndpointName='my-endpoint',
207
+ ... Body=json.dumps(data)
208
+ ... )
209
+ """
210
+ return getattr(self, '_sagemaker_runtime_client', None)
211
+
212
+ @property
213
+ def region_name(self) -> Optional[str]:
214
+ """
215
+ Get AWS region name.
216
+
217
+ Returns:
218
+ AWS region name or None
219
+ """
220
+ return getattr(self, '_region_name', None)
221
+
222
+ @property
223
+ def default_bucket(self) -> Optional[str]:
224
+ """
225
+ Get default S3 bucket.
226
+
227
+ Returns:
228
+ Default S3 bucket name or None
229
+ """
230
+ return getattr(self, '_default_bucket', None)
231
+
232
+ @property
233
+ def account_id(self) -> Optional[str]:
234
+ """
235
+ Get AWS account ID.
236
+
237
+ Returns:
238
+ AWS account ID or None
239
+ """
240
+ boto_session = getattr(self, '_boto_session', None)
241
+ if boto_session:
242
+ try:
243
+ sts_client = boto_session.client('sts')
244
+ return sts_client.get_caller_identity()['Account']
245
+ except Exception:
246
+ return None
247
+ return None
248
+
249
+ def set_log_level(self, level: int) -> None:
250
+ """
251
+ Set logging level.
252
+
253
+ Args:
254
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
255
+ """
256
+ self.logger.set_level(level)
257
+ self.logger.info("Log level changed", level=logging.getLevelName(level))
258
+
259
+ def get_config(self) -> Optional[Dict[str, Any]]:
260
+ """
261
+ Get current configuration as dictionary.
262
+
263
+ Returns:
264
+ Configuration dictionary or None if no config loaded
265
+
266
+ Example:
267
+ >>> session = MLP_Session()
268
+ >>> config = session.get_config()
269
+ >>> print(config['defaults']['s3']['default_bucket'])
270
+ """
271
+ return self.config_manager._config if self.config_manager.has_config else None
272
+
273
+ def get_execution_role(self) -> Optional[str]:
274
+ """
275
+ Get IAM execution role ARN from configuration.
276
+
277
+ Returns:
278
+ IAM role ARN or None if not configured
279
+
280
+ Example:
281
+ >>> session = MLP_Session()
282
+ >>> role = session.get_execution_role()
283
+ """
284
+ iam_config = self.config_manager.get_iam_config()
285
+ return iam_config.execution_role if iam_config else None
286
+
287
+ def update_session_config(self, **kwargs) -> None:
288
+ """
289
+ Update session configuration at runtime.
290
+
291
+ Allows dynamic configuration updates for advanced use cases.
292
+ Note: This recreates the underlying SageMaker session.
293
+
294
+ Args:
295
+ **kwargs: Configuration parameters to update
296
+ (e.g., default_bucket, boto_session, sagemaker_client)
297
+
298
+ Raises:
299
+ SessionError: If session update fails
300
+
301
+ Example:
302
+ >>> session = MLP_Session()
303
+ >>> session.update_session_config(default_bucket='my-new-bucket')
304
+ """
305
+ self.logger.info("Updating session configuration", params=list(kwargs.keys()))
306
+
307
+ if self.audit_trail is not None:
308
+ self.audit_trail.record("update_session_config", "started", params=list(kwargs.keys()))
309
+
310
+ try:
311
+ # Merge with existing kwargs
312
+ self._session_kwargs.update(kwargs)
313
+
314
+ # Reinitialize SageMaker session
315
+ self._initialize_sagemaker_session()
316
+
317
+ self.logger.info("Session configuration updated successfully")
318
+
319
+ if self.audit_trail is not None:
320
+ self.audit_trail.record("update_session_config", "completed")
321
+
322
+ except Exception as e:
323
+ self.logger.error("Failed to update session configuration", error=e)
324
+ if self.audit_trail is not None:
325
+ self.audit_trail.record("update_session_config", "failed", error=str(e))
326
+ raise SessionError(f"Failed to update session configuration: {e}") from e
327
+
328
+ def get_audit_trail(self, operation: Optional[str] = None,
329
+ status: Optional[str] = None,
330
+ limit: Optional[int] = None) -> List[Dict[str, Any]]:
331
+ """
332
+ Get audit trail entries.
333
+
334
+ Args:
335
+ operation: Filter by operation name
336
+ status: Filter by status
337
+ limit: Maximum number of entries to return
338
+
339
+ Returns:
340
+ List of audit trail entries
341
+ """
342
+ if not self.audit_trail:
343
+ return []
344
+ return self.audit_trail.get_entries(operation=operation, status=status, limit=limit)
345
+
346
+ def get_audit_trail_summary(self) -> Dict[str, Any]:
347
+ """
348
+ Get audit trail summary statistics.
349
+
350
+ Returns:
351
+ Dictionary with summary statistics including operation counts,
352
+ status counts, and failed operations
353
+
354
+ Raises:
355
+ SessionError: If audit trail is not enabled
356
+ """
357
+ if not self.audit_trail:
358
+ raise SessionError("Audit trail is not enabled for this session")
359
+
360
+ return self.audit_trail.get_summary()
361
+
362
+ def export_audit_trail(self, file_path: str, format: str = 'json') -> None:
363
+ """
364
+ Export audit trail to file.
365
+
366
+ Args:
367
+ file_path: Path to output file
368
+ format: Export format ('json' or 'csv')
369
+
370
+ Raises:
371
+ SessionError: If audit trail is not enabled
372
+ ValidationError: If format is invalid
373
+ """
374
+ if not self.audit_trail:
375
+ raise SessionError("Audit trail is not enabled for this session")
376
+
377
+ if format not in ['json', 'csv']:
378
+ raise ValidationError(f"Invalid export format: {format}. Must be 'json' or 'csv'")
379
+
380
+ if format == 'json':
381
+ self.audit_trail.export_json(file_path)
382
+ else:
383
+ self.audit_trail.export_csv(file_path)
384
+
385
+ self.logger.info("Audit trail exported", file_path=file_path, format=format)
386
+
387
+ def create_feature_group(self,
388
+ feature_group_name: str,
389
+ record_identifier_name: str,
390
+ event_time_feature_name: str,
391
+ feature_definitions: List[Dict[str, str]],
392
+ **kwargs):
393
+ """
394
+ Create feature group with defaults.
395
+
396
+ Args:
397
+ feature_group_name: Name of the feature group
398
+ record_identifier_name: Name of the record identifier feature
399
+ event_time_feature_name: Name of the event time feature
400
+ feature_definitions: List of feature definitions with FeatureName and FeatureType
401
+ **kwargs: Additional feature group parameters that override defaults
402
+
403
+ Returns:
404
+ FeatureGroup object
405
+
406
+ Raises:
407
+ ValidationError: If required parameters are missing or invalid
408
+ SessionError: If session is not initialized
409
+ AWSServiceError: If feature group creation fails
410
+ """
411
+ # Validate session is initialized
412
+ if not self._sagemaker_session:
413
+ raise SessionError("SageMaker session is not initialized")
414
+
415
+ # Validate required parameters at session level
416
+ if not feature_group_name or not isinstance(feature_group_name, str):
417
+ raise ValidationError("feature_group_name must be a non-empty string")
418
+
419
+ if not record_identifier_name or not isinstance(record_identifier_name, str):
420
+ raise ValidationError("record_identifier_name must be a non-empty string")
421
+
422
+ if not event_time_feature_name or not isinstance(event_time_feature_name, str):
423
+ raise ValidationError("event_time_feature_name must be a non-empty string")
424
+
425
+ if not feature_definitions or not isinstance(feature_definitions, list):
426
+ raise ValidationError("feature_definitions must be a non-empty list")
427
+
428
+ self.logger.info("create_feature_group called", name=feature_group_name)
429
+
430
+ if self.audit_trail is not None:
431
+ self.audit_trail.record("create_feature_group", "started", name=feature_group_name)
432
+
433
+ try:
434
+ feature_group = self._feature_store_wrapper.create_feature_group(
435
+ sagemaker_session=self._sagemaker_session,
436
+ feature_group_name=feature_group_name,
437
+ record_identifier_name=record_identifier_name,
438
+ event_time_feature_name=event_time_feature_name,
439
+ feature_definitions=feature_definitions,
440
+ **kwargs
441
+ )
442
+
443
+ if self.audit_trail is not None:
444
+ self.audit_trail.record("create_feature_group", "completed", name=feature_group_name)
445
+
446
+ return feature_group
447
+
448
+ except (ValidationError, SessionError) as e:
449
+ # Re-raise validation and session errors without wrapping
450
+ if self.audit_trail is not None:
451
+ self.audit_trail.record("create_feature_group", "failed",
452
+ name=feature_group_name, error=str(e))
453
+ raise
454
+ except Exception as e:
455
+ if self.audit_trail is not None:
456
+ self.audit_trail.record("create_feature_group", "failed",
457
+ name=feature_group_name, error=str(e))
458
+ raise
459
+
460
+ def run_processing_job(self,
461
+ job_name: str,
462
+ processing_script: Optional[str] = None,
463
+ inputs: Optional[List[Dict[str, Any]]] = None,
464
+ outputs: Optional[List[Dict[str, Any]]] = None,
465
+ **kwargs):
466
+ """
467
+ Execute processing job with defaults.
468
+
469
+ Args:
470
+ job_name: Processing job name
471
+ processing_script: Optional path to custom processing script
472
+ inputs: Optional list of processing inputs
473
+ outputs: Optional list of processing outputs
474
+ **kwargs: Additional processing job parameters that override defaults
475
+
476
+ Returns:
477
+ Processor object
478
+
479
+ Raises:
480
+ ValidationError: If required parameters are missing or invalid
481
+ SessionError: If session is not initialized
482
+ AWSServiceError: If processing job execution fails
483
+ """
484
+ # Validate session is initialized
485
+ if not self._sagemaker_session:
486
+ raise SessionError("SageMaker session is not initialized")
487
+
488
+ # Validate required parameters at session level
489
+ if not job_name or not isinstance(job_name, str):
490
+ raise ValidationError("job_name must be a non-empty string")
491
+
492
+ # Validate optional parameters if provided
493
+ if inputs is not None and not isinstance(inputs, list):
494
+ raise ValidationError("inputs must be a list if provided")
495
+
496
+ if outputs is not None and not isinstance(outputs, list):
497
+ raise ValidationError("outputs must be a list if provided")
498
+
499
+ if processing_script is not None and not isinstance(processing_script, str):
500
+ raise ValidationError("processing_script must be a string if provided")
501
+
502
+ self.logger.info("run_processing_job called", name=job_name)
503
+
504
+ if self.audit_trail is not None:
505
+ self.audit_trail.record("run_processing_job", "started", name=job_name)
506
+
507
+ try:
508
+ processor = self._processing_wrapper.run_processing_job(
509
+ sagemaker_session=self._sagemaker_session,
510
+ job_name=job_name,
511
+ processing_script=processing_script,
512
+ inputs=inputs,
513
+ outputs=outputs,
514
+ **kwargs
515
+ )
516
+
517
+ if self.audit_trail is not None:
518
+ self.audit_trail.record("run_processing_job", "completed", name=job_name)
519
+
520
+ return processor
521
+
522
+ except (ValidationError, SessionError) as e:
523
+ # Re-raise validation and session errors without wrapping
524
+ if self.audit_trail is not None:
525
+ self.audit_trail.record("run_processing_job", "failed",
526
+ name=job_name, error=str(e))
527
+ raise
528
+ except Exception as e:
529
+ if self.audit_trail is not None:
530
+ self.audit_trail.record("run_processing_job", "failed",
531
+ name=job_name, error=str(e))
532
+ raise
533
+
534
+ def run_training_job(self,
535
+ job_name: str,
536
+ training_image: str,
537
+ source_code_dir: Optional[str] = None,
538
+ entry_script: Optional[str] = None,
539
+ requirements: Optional[str] = None,
540
+ inputs: Optional[Dict[str, Any]] = None,
541
+ **kwargs):
542
+ """
543
+ Execute training job with defaults using ModelTrainer (SDK v3).
544
+
545
+ Args:
546
+ job_name: Training job name (used as base_job_name)
547
+ training_image: Container image URI for training
548
+ source_code_dir: Directory containing training script and dependencies
549
+ entry_script: Entry point script for training (e.g., 'train.py')
550
+ requirements: Path to requirements.txt file for dependencies
551
+ inputs: Training data inputs (dict of channel_name: S3 path)
552
+ **kwargs: Additional training job parameters that override defaults
553
+ (e.g., hyperparameters, environment, distributed_runner)
554
+
555
+ Returns:
556
+ ModelTrainer object
557
+
558
+ Raises:
559
+ ValidationError: If required parameters are missing or invalid
560
+ SessionError: If session is not initialized
561
+ AWSServiceError: If training job execution fails
562
+
563
+ Example:
564
+ >>> session = MLP_Session()
565
+ >>> trainer = session.run_training_job(
566
+ ... job_name='my-training-job',
567
+ ... training_image='763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310',
568
+ ... source_code_dir='training-scripts',
569
+ ... entry_script='train.py',
570
+ ... inputs={'train': 's3://my-bucket/data/train/'}
571
+ ... )
572
+ """
573
+ # Validate session is initialized
574
+ if not self._sagemaker_session:
575
+ raise SessionError("SageMaker session is not initialized")
576
+
577
+ # Validate required parameters at session level
578
+ if not job_name or not isinstance(job_name, str):
579
+ raise ValidationError("job_name must be a non-empty string")
580
+
581
+ if not training_image or not isinstance(training_image, str):
582
+ raise ValidationError("training_image must be a non-empty string")
583
+
584
+ # Validate optional parameters if provided
585
+ if source_code_dir is not None and not isinstance(source_code_dir, str):
586
+ raise ValidationError("source_code_dir must be a string if provided")
587
+
588
+ if entry_script is not None and not isinstance(entry_script, str):
589
+ raise ValidationError("entry_script must be a string if provided")
590
+
591
+ if requirements is not None and not isinstance(requirements, str):
592
+ raise ValidationError("requirements must be a string if provided")
593
+
594
+ if inputs is not None and not isinstance(inputs, (dict, list)):
595
+ raise ValidationError("inputs must be a dictionary or list if provided")
596
+
597
+ self.logger.info("run_training_job called", name=job_name)
598
+
599
+ if self.audit_trail is not None:
600
+ self.audit_trail.record("run_training_job", "started", name=job_name)
601
+
602
+ try:
603
+ trainer = self._training_wrapper.run_training_job(
604
+ sagemaker_session=self._sagemaker_session,
605
+ job_name=job_name,
606
+ training_image=training_image,
607
+ source_code_dir=source_code_dir,
608
+ entry_script=entry_script,
609
+ requirements=requirements,
610
+ inputs=inputs,
611
+ **kwargs
612
+ )
613
+
614
+ if self.audit_trail is not None:
615
+ self.audit_trail.record("run_training_job", "completed", name=job_name)
616
+
617
+ return trainer
618
+
619
+ except (ValidationError, SessionError) as e:
620
+ # Re-raise validation and session errors without wrapping
621
+ if self.audit_trail is not None:
622
+ self.audit_trail.record("run_training_job", "failed",
623
+ name=job_name, error=str(e))
624
+ raise
625
+ except Exception as e:
626
+ if self.audit_trail is not None:
627
+ self.audit_trail.record("run_training_job", "failed",
628
+ name=job_name, error=str(e))
629
+ raise
630
+
631
+ def deploy_model(self,
632
+ model_data: str,
633
+ image_uri: str,
634
+ endpoint_name: str,
635
+ enable_vpc: bool = False,
636
+ **kwargs):
637
+ """
638
+ Deploy a trained model to a SageMaker endpoint with defaults.
639
+
640
+ Applies defaults from configuration for:
641
+ - Instance type and count (via Compute config)
642
+ - IAM execution role
643
+ - VPC configuration (optional, controlled by enable_vpc flag)
644
+ - KMS encryption key
645
+
646
+ Runtime parameters override configuration defaults.
647
+
648
+ Args:
649
+ model_data: S3 URI of the model artifacts (e.g., 's3://bucket/model.tar.gz')
650
+ image_uri: Container image URI for inference
651
+ endpoint_name: Name for the endpoint
652
+ enable_vpc: If True, applies VPC configuration from config (default: False)
653
+ **kwargs: Additional parameters that override defaults
654
+ (e.g., instance_type, instance_count, environment, subnets, security_group_ids)
655
+
656
+ Returns:
657
+ Predictor object for making predictions
658
+
659
+ Raises:
660
+ ValidationError: If required parameters are missing or invalid
661
+ SessionError: If session is not initialized
662
+ AWSServiceError: If deployment fails
663
+
664
+ Example:
665
+ >>> session = MLP_Session()
666
+ >>> # Deploy without VPC
667
+ >>> predictor = session.deploy_model(
668
+ ... model_data='s3://my-bucket/model.tar.gz',
669
+ ... image_uri='683313688378.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1',
670
+ ... endpoint_name='my-endpoint'
671
+ ... )
672
+ >>> # Deploy with VPC configuration
673
+ >>> predictor = session.deploy_model(
674
+ ... model_data='s3://my-bucket/model.tar.gz',
675
+ ... image_uri='683313688378.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1',
676
+ ... endpoint_name='my-endpoint',
677
+ ... enable_vpc=True
678
+ ... )
679
+ """
680
+ # Validate session is initialized
681
+ if not self._sagemaker_session:
682
+ raise SessionError("SageMaker session is not initialized")
683
+
684
+ # Validate required parameters at session level
685
+ if not model_data or not isinstance(model_data, str):
686
+ raise ValidationError("model_data must be a non-empty string")
687
+
688
+ if not image_uri or not isinstance(image_uri, str):
689
+ raise ValidationError("image_uri must be a non-empty string")
690
+
691
+ if not endpoint_name or not isinstance(endpoint_name, str):
692
+ raise ValidationError("endpoint_name must be a non-empty string")
693
+
694
+ self.logger.info("deploy_model called", endpoint_name=endpoint_name)
695
+
696
+ if self.audit_trail is not None:
697
+ self.audit_trail.record("deploy_model", "started", endpoint_name=endpoint_name)
698
+
699
+ try:
700
+ predictor = self._deployment_wrapper.deploy_model(
701
+ model_data=model_data,
702
+ image_uri=image_uri,
703
+ endpoint_name=endpoint_name,
704
+ enable_vpc=enable_vpc,
705
+ **kwargs
706
+ )
707
+
708
+ if self.audit_trail is not None:
709
+ self.audit_trail.record("deploy_model", "completed", endpoint_name=endpoint_name)
710
+
711
+ return predictor
712
+
713
+ except (ValidationError, SessionError) as e:
714
+ # Re-raise validation and session errors without wrapping
715
+ if self.audit_trail is not None:
716
+ self.audit_trail.record("deploy_model", "failed",
717
+ endpoint_name=endpoint_name, error=str(e))
718
+ raise
719
+ except Exception as e:
720
+ if self.audit_trail is not None:
721
+ self.audit_trail.record("deploy_model", "failed",
722
+ endpoint_name=endpoint_name, error=str(e))
723
+ raise
724
+
725
+ def delete_endpoint(self, endpoint_name: str) -> None:
726
+ """
727
+ Delete a SageMaker endpoint.
728
+
729
+ Args:
730
+ endpoint_name: Name of the endpoint to delete
731
+
732
+ Raises:
733
+ ValidationError: If endpoint_name is invalid
734
+ SessionError: If session is not initialized
735
+ AWSServiceError: If deletion fails
736
+
737
+ Example:
738
+ >>> session = MLP_Session()
739
+ >>> session.delete_endpoint('my-endpoint')
740
+ """
741
+ # Validate session is initialized
742
+ if not self._sagemaker_session:
743
+ raise SessionError("SageMaker session is not initialized")
744
+
745
+ # Validate required parameters
746
+ if not endpoint_name or not isinstance(endpoint_name, str):
747
+ raise ValidationError("endpoint_name must be a non-empty string")
748
+
749
+ self.logger.info("delete_endpoint called", endpoint_name=endpoint_name)
750
+
751
+ if self.audit_trail is not None:
752
+ self.audit_trail.record("delete_endpoint", "started", endpoint_name=endpoint_name)
753
+
754
+ try:
755
+ # Use boto3 client to delete endpoint
756
+ self._sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
757
+
758
+ self.logger.info("Endpoint deleted successfully", endpoint_name=endpoint_name)
759
+
760
+ if self.audit_trail is not None:
761
+ self.audit_trail.record("delete_endpoint", "completed", endpoint_name=endpoint_name)
762
+
763
+ except Exception as e:
764
+ self.logger.error("Failed to delete endpoint",
765
+ endpoint_name=endpoint_name,
766
+ error=e)
767
+ if self.audit_trail is not None:
768
+ self.audit_trail.record("delete_endpoint", "failed",
769
+ endpoint_name=endpoint_name, error=str(e))
770
+ raise AWSServiceError(
771
+ f"Failed to delete endpoint '{endpoint_name}': {e}",
772
+ aws_error=e
773
+ ) from e
774
+
775
+ def create_pipeline(self,
776
+ pipeline_name: str,
777
+ steps: List,
778
+ parameters: Optional[List] = None,
779
+ **kwargs):
780
+ """
781
+ Create pipeline with consistent defaults.
782
+
783
+ Applies default configurations across all pipeline steps and supports
784
+ parameter passing between steps.
785
+
786
+ Args:
787
+ pipeline_name: Pipeline name
788
+ steps: List of pipeline steps (ProcessingStep, TrainingStep, etc.)
789
+ parameters: Optional list of pipeline parameters for cross-step communication
790
+ **kwargs: Additional pipeline parameters that override defaults
791
+
792
+ Returns:
793
+ Pipeline object
794
+
795
+ Raises:
796
+ ValidationError: If required parameters are missing or invalid
797
+ SessionError: If session is not initialized
798
+ AWSServiceError: If pipeline creation fails
799
+ """
800
+ # Validate session is initialized
801
+ if not self._sagemaker_session:
802
+ raise SessionError("SageMaker session is not initialized")
803
+
804
+ # Validate required parameters at session level
805
+ if not pipeline_name or not isinstance(pipeline_name, str):
806
+ raise ValidationError("pipeline_name must be a non-empty string")
807
+
808
+ if not steps or not isinstance(steps, list):
809
+ raise ValidationError("steps must be a non-empty list")
810
+
811
+ # Validate optional parameters if provided
812
+ if parameters is not None and not isinstance(parameters, list):
813
+ raise ValidationError("parameters must be a list if provided")
814
+
815
+ self.logger.info("create_pipeline called", name=pipeline_name)
816
+
817
+ if self.audit_trail is not None:
818
+ self.audit_trail.record("create_pipeline", "started", name=pipeline_name)
819
+
820
+ try:
821
+ pipeline = self._pipeline_wrapper.create_pipeline(
822
+ sagemaker_session=self._sagemaker_session,
823
+ pipeline_name=pipeline_name,
824
+ steps=steps,
825
+ parameters=parameters,
826
+ **kwargs
827
+ )
828
+
829
+ if self.audit_trail is not None:
830
+ self.audit_trail.record("create_pipeline", "completed", name=pipeline_name)
831
+
832
+ return pipeline
833
+
834
+ except (ValidationError, SessionError) as e:
835
+ # Re-raise validation and session errors without wrapping
836
+ if self.audit_trail is not None:
837
+ self.audit_trail.record("create_pipeline", "failed",
838
+ name=pipeline_name, error=str(e))
839
+ raise
840
+ except Exception as e:
841
+ if self.audit_trail is not None:
842
+ self.audit_trail.record("create_pipeline", "failed",
843
+ name=pipeline_name, error=str(e))
844
+ raise
845
+
846
+ def upsert_pipeline(self, pipeline, **kwargs) -> Dict[str, Any]:
847
+ """
848
+ Create or update a pipeline definition.
849
+
850
+ Args:
851
+ pipeline: Pipeline object to upsert
852
+ **kwargs: Additional parameters for upsert operation
853
+
854
+ Returns:
855
+ Dictionary with pipeline ARN and other metadata
856
+
857
+ Raises:
858
+ ValidationError: If pipeline is invalid
859
+ SessionError: If session is not initialized
860
+ AWSServiceError: If upsert operation fails
861
+ """
862
+ # Validate session is initialized
863
+ if not self._sagemaker_session:
864
+ raise SessionError("SageMaker session is not initialized")
865
+
866
+ # Validate required parameters
867
+ if not pipeline:
868
+ raise ValidationError("pipeline is required")
869
+
870
+ if not hasattr(pipeline, 'name'):
871
+ raise ValidationError("pipeline must have a 'name' attribute")
872
+
873
+ self.logger.info("upsert_pipeline called", name=pipeline.name)
874
+
875
+ if self.audit_trail is not None:
876
+ self.audit_trail.record("upsert_pipeline", "started", name=pipeline.name)
877
+
878
+ try:
879
+ response = self._pipeline_wrapper.upsert_pipeline(
880
+ pipeline=pipeline,
881
+ **kwargs
882
+ )
883
+
884
+ if self.audit_trail is not None:
885
+ self.audit_trail.record("upsert_pipeline", "completed",
886
+ name=pipeline.name,
887
+ arn=response.get('PipelineArn'))
888
+
889
+ return response
890
+
891
+ except (ValidationError, SessionError) as e:
892
+ # Re-raise validation and session errors without wrapping
893
+ if self.audit_trail is not None:
894
+ self.audit_trail.record("upsert_pipeline", "failed",
895
+ name=pipeline.name, error=str(e))
896
+ raise
897
+ except Exception as e:
898
+ if self.audit_trail is not None:
899
+ self.audit_trail.record("upsert_pipeline", "failed",
900
+ name=pipeline.name, error=str(e))
901
+ raise
902
+
903
+ def start_pipeline_execution(self,
904
+ pipeline,
905
+ execution_display_name: Optional[str] = None,
906
+ execution_parameters: Optional[Dict[str, Any]] = None,
907
+ **kwargs):
908
+ """
909
+ Start pipeline execution with monitoring support.
910
+
911
+ Args:
912
+ pipeline: Pipeline object to execute
913
+ execution_display_name: Optional display name for the execution
914
+ execution_parameters: Optional parameters to override pipeline defaults
915
+ **kwargs: Additional execution parameters
916
+
917
+ Returns:
918
+ PipelineExecution object
919
+
920
+ Raises:
921
+ ValidationError: If pipeline is invalid
922
+ SessionError: If session is not initialized
923
+ AWSServiceError: If execution start fails
924
+ """
925
+ # Validate session is initialized
926
+ if not self._sagemaker_session:
927
+ raise SessionError("SageMaker session is not initialized")
928
+
929
+ # Validate required parameters
930
+ if not pipeline:
931
+ raise ValidationError("pipeline is required")
932
+
933
+ if not hasattr(pipeline, 'name'):
934
+ raise ValidationError("pipeline must have a 'name' attribute")
935
+
936
+ # Validate optional parameters if provided
937
+ if execution_display_name is not None and not isinstance(execution_display_name, str):
938
+ raise ValidationError("execution_display_name must be a string if provided")
939
+
940
+ if execution_parameters is not None and not isinstance(execution_parameters, dict):
941
+ raise ValidationError("execution_parameters must be a dictionary if provided")
942
+
943
+ self.logger.info("start_pipeline_execution called", name=pipeline.name)
944
+
945
+ if self.audit_trail is not None:
946
+ self.audit_trail.record("start_pipeline_execution", "started",
947
+ name=pipeline.name,
948
+ display_name=execution_display_name)
949
+
950
+ try:
951
+ execution = self._pipeline_wrapper.start_pipeline_execution(
952
+ pipeline=pipeline,
953
+ execution_display_name=execution_display_name,
954
+ execution_parameters=execution_parameters,
955
+ **kwargs
956
+ )
957
+
958
+ if self.audit_trail is not None:
959
+ self.audit_trail.record("start_pipeline_execution", "completed",
960
+ name=pipeline.name,
961
+ execution_arn=execution.arn)
962
+
963
+ return execution
964
+
965
+ except (ValidationError, SessionError) as e:
966
+ # Re-raise validation and session errors without wrapping
967
+ if self.audit_trail is not None:
968
+ self.audit_trail.record("start_pipeline_execution", "failed",
969
+ name=pipeline.name, error=str(e))
970
+ raise
971
+ except Exception as e:
972
+ if self.audit_trail is not None:
973
+ self.audit_trail.record("start_pipeline_execution", "failed",
974
+ name=pipeline.name, error=str(e))
975
+ raise
976
+
977
+ def describe_pipeline_execution(self, pipeline_execution) -> Dict[str, Any]:
978
+ """
979
+ Get pipeline execution status and details.
980
+
981
+ Args:
982
+ pipeline_execution: PipelineExecution object
983
+
984
+ Returns:
985
+ Dictionary with execution status and details
986
+
987
+ Raises:
988
+ ValidationError: If pipeline_execution is invalid
989
+ SessionError: If session is not initialized
990
+ AWSServiceError: If describe operation fails
991
+ """
992
+ # Validate session is initialized
993
+ if not self._sagemaker_session:
994
+ raise SessionError("SageMaker session is not initialized")
995
+
996
+ # Validate required parameters
997
+ if not pipeline_execution:
998
+ raise ValidationError("pipeline_execution is required")
999
+
1000
+ if not hasattr(pipeline_execution, 'arn'):
1001
+ raise ValidationError("pipeline_execution must have an 'arn' attribute")
1002
+
1003
+ self.logger.info("describe_pipeline_execution called", arn=pipeline_execution.arn)
1004
+
1005
+ if self.audit_trail is not None:
1006
+ self.audit_trail.record("describe_pipeline_execution", "started",
1007
+ arn=pipeline_execution.arn)
1008
+
1009
+ try:
1010
+ response = self._pipeline_wrapper.describe_pipeline_execution(
1011
+ pipeline_execution=pipeline_execution
1012
+ )
1013
+
1014
+ if self.audit_trail is not None:
1015
+ self.audit_trail.record("describe_pipeline_execution", "completed",
1016
+ arn=pipeline_execution.arn,
1017
+ status=response.get('PipelineExecutionStatus'))
1018
+
1019
+ return response
1020
+
1021
+ except (ValidationError, SessionError) as e:
1022
+ # Re-raise validation and session errors without wrapping
1023
+ if self.audit_trail is not None:
1024
+ self.audit_trail.record("describe_pipeline_execution", "failed",
1025
+ arn=pipeline_execution.arn, error=str(e))
1026
+ raise
1027
+ except Exception as e:
1028
+ if self.audit_trail is not None:
1029
+ self.audit_trail.record("describe_pipeline_execution", "failed",
1030
+ arn=pipeline_execution.arn, error=str(e))
1031
+ raise
1032
+
1033
+ def list_pipeline_execution_steps(self, pipeline_execution) -> List[Dict[str, Any]]:
1034
+ """
1035
+ List all steps in a pipeline execution with their status.
1036
+
1037
+ Args:
1038
+ pipeline_execution: PipelineExecution object
1039
+
1040
+ Returns:
1041
+ List of step details with status information
1042
+
1043
+ Raises:
1044
+ ValidationError: If pipeline_execution is invalid
1045
+ SessionError: If session is not initialized
1046
+ AWSServiceError: If list operation fails
1047
+ """
1048
+ # Validate session is initialized
1049
+ if not self._sagemaker_session:
1050
+ raise SessionError("SageMaker session is not initialized")
1051
+
1052
+ # Validate required parameters
1053
+ if not pipeline_execution:
1054
+ raise ValidationError("pipeline_execution is required")
1055
+
1056
+ if not hasattr(pipeline_execution, 'arn'):
1057
+ raise ValidationError("pipeline_execution must have an 'arn' attribute")
1058
+
1059
+ self.logger.info("list_pipeline_execution_steps called", arn=pipeline_execution.arn)
1060
+
1061
+ if self.audit_trail is not None:
1062
+ self.audit_trail.record("list_pipeline_execution_steps", "started",
1063
+ arn=pipeline_execution.arn)
1064
+
1065
+ try:
1066
+ steps = self._pipeline_wrapper.list_pipeline_execution_steps(
1067
+ pipeline_execution=pipeline_execution
1068
+ )
1069
+
1070
+ if self.audit_trail is not None:
1071
+ self.audit_trail.record("list_pipeline_execution_steps", "completed",
1072
+ arn=pipeline_execution.arn,
1073
+ step_count=len(steps))
1074
+
1075
+ return steps
1076
+
1077
+ except (ValidationError, SessionError) as e:
1078
+ # Re-raise validation and session errors without wrapping
1079
+ if self.audit_trail is not None:
1080
+ self.audit_trail.record("list_pipeline_execution_steps", "failed",
1081
+ arn=pipeline_execution.arn, error=str(e))
1082
+ raise
1083
+ except Exception as e:
1084
+ if self.audit_trail is not None:
1085
+ self.audit_trail.record("list_pipeline_execution_steps", "failed",
1086
+ arn=pipeline_execution.arn, error=str(e))
1087
+ raise
1088
+
1089
+ def wait_for_pipeline_execution(self,
1090
+ pipeline_execution,
1091
+ delay: int = 30,
1092
+ max_attempts: int = 60) -> Dict[str, Any]:
1093
+ """
1094
+ Wait for pipeline execution to complete.
1095
+
1096
+ Args:
1097
+ pipeline_execution: PipelineExecution object
1098
+ delay: Delay between status checks in seconds (default: 30)
1099
+ max_attempts: Maximum number of status checks (default: 60)
1100
+
1101
+ Returns:
1102
+ Final execution status dictionary
1103
+
1104
+ Raises:
1105
+ ValidationError: If pipeline_execution is invalid or parameters are invalid
1106
+ SessionError: If session is not initialized
1107
+ AWSServiceError: If wait operation fails
1108
+ """
1109
+ # Validate session is initialized
1110
+ if not self._sagemaker_session:
1111
+ raise SessionError("SageMaker session is not initialized")
1112
+
1113
+ # Validate required parameters
1114
+ if not pipeline_execution:
1115
+ raise ValidationError("pipeline_execution is required")
1116
+
1117
+ if not hasattr(pipeline_execution, 'arn'):
1118
+ raise ValidationError("pipeline_execution must have an 'arn' attribute")
1119
+
1120
+ # Validate delay and max_attempts
1121
+ if not isinstance(delay, int) or delay < 1:
1122
+ raise ValidationError("delay must be a positive integer")
1123
+
1124
+ if not isinstance(max_attempts, int) or max_attempts < 1:
1125
+ raise ValidationError("max_attempts must be a positive integer")
1126
+
1127
+ self.logger.info("wait_for_pipeline_execution called",
1128
+ arn=pipeline_execution.arn,
1129
+ delay=delay,
1130
+ max_attempts=max_attempts)
1131
+
1132
+ if self.audit_trail is not None:
1133
+ self.audit_trail.record("wait_for_pipeline_execution", "started",
1134
+ arn=pipeline_execution.arn)
1135
+
1136
+ try:
1137
+ final_status = self._pipeline_wrapper.wait_for_pipeline_execution(
1138
+ pipeline_execution=pipeline_execution,
1139
+ delay=delay,
1140
+ max_attempts=max_attempts
1141
+ )
1142
+
1143
+ if self.audit_trail is not None:
1144
+ self.audit_trail.record("wait_for_pipeline_execution", "completed",
1145
+ arn=pipeline_execution.arn,
1146
+ status=final_status.get('PipelineExecutionStatus'))
1147
+
1148
+ return final_status
1149
+
1150
+ except (ValidationError, SessionError) as e:
1151
+ # Re-raise validation and session errors without wrapping
1152
+ if self.audit_trail is not None:
1153
+ self.audit_trail.record("wait_for_pipeline_execution", "failed",
1154
+ arn=pipeline_execution.arn, error=str(e))
1155
+ raise
1156
+ except Exception as e:
1157
+ if self.audit_trail is not None:
1158
+ self.audit_trail.record("wait_for_pipeline_execution", "failed",
1159
+ arn=pipeline_execution.arn, error=str(e))
1160
+ raise