sagemaker-mlp-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ """
2
+ Operation wrappers for SageMaker services
3
+ """
4
+
5
+ from .feature_store import FeatureStoreWrapper
6
+ from .processing import ProcessingWrapper
7
+ from .training import TrainingWrapper
8
+ from .pipeline import PipelineWrapper
9
+ from .deployment import DeploymentWrapper
10
+
11
+ __all__ = ['FeatureStoreWrapper', 'ProcessingWrapper', 'TrainingWrapper', 'PipelineWrapper', 'DeploymentWrapper']
@@ -0,0 +1,459 @@
1
+ """
2
+ Deployment wrapper for mlp_sdk
3
+ Provides simplified model deployment with configuration-driven defaults
4
+
5
+ Uses SageMaker SDK v3 ModelBuilder API for modern deployment.
6
+ """
7
+
8
+ from typing import Optional, Dict, Any
9
+ from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
10
+ from ..config import ConfigurationManager
11
+
12
+
13
+ class PredictorWrapper:
14
+ """
15
+ Wrapper around SDK v3 Endpoint object to provide a predict() method.
16
+
17
+ SDK v3's ModelBuilder.deploy() returns an Endpoint object with invoke() method.
18
+ This wrapper provides backward compatibility by adding a predict() method
19
+ that calls invoke() internally.
20
+ """
21
+
22
+ def __init__(self, endpoint, logger: Optional[MLPLogger] = None):
23
+ """
24
+ Initialize predictor wrapper.
25
+
26
+ Args:
27
+ endpoint: SDK v3 Endpoint object returned by ModelBuilder.deploy()
28
+ logger: Optional logger instance
29
+ """
30
+ self._endpoint = endpoint
31
+ self.logger = logger or MLPLogger("mlp_sdk.predictor")
32
+
33
+ # Expose endpoint attributes
34
+ self.endpoint_name = getattr(endpoint, 'endpoint_name', None)
35
+
36
+ def predict(self, data, content_type: str = 'text/csv'):
37
+ """
38
+ Make predictions using the endpoint.
39
+
40
+ This method provides backward compatibility with SDK v2 Predictor interface.
41
+ Uses boto3 sagemaker-runtime client for reliable invocation.
42
+
43
+ Args:
44
+ data: Input data for prediction (string or bytes)
45
+ content_type: Content type of the input data (default: 'text/csv')
46
+
47
+ Returns:
48
+ Prediction results as string
49
+
50
+ Raises:
51
+ AWSServiceError: If prediction fails
52
+ """
53
+ try:
54
+ import boto3
55
+
56
+ # Convert string to bytes if needed (boto3 expects bytes)
57
+ if isinstance(data, str):
58
+ body_data = data.encode('utf-8')
59
+ else:
60
+ body_data = data
61
+
62
+ # Use boto3 sagemaker-runtime client for reliable invocation
63
+ # This is more stable than SDK v3 Endpoint.invoke() which has serialization issues
64
+ runtime_client = boto3.client('sagemaker-runtime')
65
+
66
+ response = runtime_client.invoke_endpoint(
67
+ EndpointName=self.endpoint_name,
68
+ ContentType=content_type,
69
+ Body=body_data
70
+ )
71
+
72
+ # Read and decode the response body
73
+ result = response['Body'].read().decode('utf-8')
74
+
75
+ return result
76
+
77
+ except Exception as e:
78
+ self.logger.error("Prediction failed", error=e)
79
+ raise AWSServiceError(
80
+ f"Failed to make prediction: {e}",
81
+ aws_error=e
82
+ ) from e
83
+
84
+ def invoke(self, body, content_type: str = 'application/json'):
85
+ """
86
+ Invoke the endpoint directly using boto3 sagemaker-runtime client.
87
+
88
+ Args:
89
+ body: Request body (bytes or string)
90
+ content_type: Content type of the request
91
+
92
+ Returns:
93
+ Response from the endpoint
94
+
95
+ Raises:
96
+ AWSServiceError: If invocation fails
97
+ """
98
+ try:
99
+ import boto3
100
+
101
+ # Convert string to bytes if needed
102
+ if isinstance(body, str):
103
+ body_data = body.encode('utf-8')
104
+ else:
105
+ body_data = body
106
+
107
+ # Use boto3 sagemaker-runtime client
108
+ runtime_client = boto3.client('sagemaker-runtime')
109
+
110
+ response = runtime_client.invoke_endpoint(
111
+ EndpointName=self.endpoint_name,
112
+ ContentType=content_type,
113
+ Body=body_data
114
+ )
115
+
116
+ return response
117
+
118
+ except Exception as e:
119
+ self.logger.error("Invocation failed", error=e)
120
+ raise AWSServiceError(
121
+ f"Failed to invoke endpoint: {e}",
122
+ aws_error=e
123
+ ) from e
124
+
125
+ def delete_endpoint(self):
126
+ """
127
+ Delete the endpoint.
128
+
129
+ Raises:
130
+ AWSServiceError: If deletion fails
131
+ """
132
+ try:
133
+ if hasattr(self._endpoint, 'delete'):
134
+ self._endpoint.delete()
135
+ else:
136
+ raise MLPSDKError("Endpoint object does not support delete operation")
137
+ except Exception as e:
138
+ self.logger.error("Failed to delete endpoint", error=e)
139
+ raise AWSServiceError(
140
+ f"Failed to delete endpoint: {e}",
141
+ aws_error=e
142
+ ) from e
143
+
144
+
145
+ class DeploymentWrapper:
146
+ """
147
+ Wrapper for SageMaker Model Deployment operations (SDK v3).
148
+ Applies default configurations from ConfigurationManager.
149
+ """
150
+
151
+ def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
152
+ """
153
+ Initialize Deployment wrapper.
154
+
155
+ Args:
156
+ config_manager: Configuration manager instance
157
+ logger: Optional logger instance
158
+ """
159
+ self.config_manager = config_manager
160
+ self.logger = logger or MLPLogger("mlp_sdk.deployment")
161
+
162
+ def deploy_model(self,
163
+ model_data: str,
164
+ image_uri: str,
165
+ endpoint_name: str,
166
+ enable_vpc: bool = False,
167
+ **kwargs) -> Any:
168
+ """
169
+ Deploy a trained model to a SageMaker endpoint with default configurations.
170
+
171
+ Applies defaults from configuration for:
172
+ - Instance type and count (via Compute config)
173
+ - IAM execution role
174
+ - VPC configuration (optional, controlled by enable_vpc flag)
175
+ - KMS encryption key
176
+
177
+ Runtime parameters override configuration defaults.
178
+
179
+ Args:
180
+ model_data: S3 URI of the model artifacts (e.g., 's3://bucket/model.tar.gz')
181
+ image_uri: Container image URI for inference
182
+ endpoint_name: Name for the endpoint
183
+ enable_vpc: If True, applies VPC configuration from config (default: False)
184
+ **kwargs: Additional parameters that override defaults
185
+ (e.g., instance_type, instance_count, environment, subnets, security_group_ids)
186
+
187
+ Returns:
188
+ Predictor object for making predictions
189
+
190
+ Raises:
191
+ ValidationError: If required parameters are missing or invalid
192
+ AWSServiceError: If deployment fails
193
+ """
194
+ self.logger.info("Deploying model to endpoint", endpoint_name=endpoint_name)
195
+
196
+ # Validate required parameters
197
+ if not model_data:
198
+ raise ValidationError("model_data is required")
199
+
200
+ if not image_uri:
201
+ raise ValidationError("image_uri is required")
202
+
203
+ if not endpoint_name:
204
+ raise ValidationError("endpoint_name is required")
205
+
206
+ # Validate runtime parameter overrides
207
+ self.validate_parameter_override(kwargs)
208
+
209
+ try:
210
+ from sagemaker.serve.model_builder import ModelBuilder
211
+ except ImportError as e:
212
+ raise MLPSDKError(
213
+ "SageMaker SDK v3 not installed. "
214
+ "Install with: pip install sagemaker>=3.0.0"
215
+ ) from e
216
+
217
+ # Build configuration with defaults
218
+ config = self._build_deployment_config(kwargs, enable_vpc=enable_vpc)
219
+
220
+ try:
221
+ # Create ModelBuilder with model artifacts
222
+ model_builder_params = {
223
+ 's3_model_data_url': model_data, # SDK v3 uses s3_model_data_url
224
+ 'image_uri': image_uri,
225
+ }
226
+
227
+ # Add role if available
228
+ if config.get('role_arn'):
229
+ model_builder_params['role_arn'] = config['role_arn']
230
+
231
+ # Add environment variables if provided
232
+ if config.get('environment'):
233
+ model_builder_params['env_vars'] = config['environment']
234
+
235
+ # Add instance type if provided
236
+ if config.get('instance_type'):
237
+ model_builder_params['instance_type'] = config['instance_type']
238
+
239
+ self.logger.debug("Creating ModelBuilder",
240
+ model_data=model_data,
241
+ image_uri=image_uri)
242
+
243
+ # Create ModelBuilder
244
+ model_builder = ModelBuilder(**model_builder_params)
245
+
246
+ # Build the model
247
+ model = model_builder.build()
248
+
249
+ # Set VPC config on the built model if available and enabled
250
+ # ModelBuilder doesn't accept vpc_config, but the underlying Model object does
251
+ # SDK v3 Model expects lowercase keys: 'subnets' and 'security_group_ids'
252
+ if enable_vpc and config.get('subnets') and config.get('security_group_ids'):
253
+ vpc_config = {
254
+ 'subnets': config['subnets'],
255
+ 'security_group_ids': config['security_group_ids']
256
+ }
257
+ # Set vpc_config on the model object
258
+ if hasattr(model, 'vpc_config'):
259
+ model.vpc_config = vpc_config
260
+ self.logger.info("VPC configuration enabled for endpoint",
261
+ subnets=config['subnets'],
262
+ security_groups=config['security_group_ids'])
263
+ else:
264
+ self.logger.warning("Model object does not support vpc_config attribute")
265
+ elif enable_vpc:
266
+ self.logger.warning("VPC configuration requested but subnets or security groups not available in config")
267
+ else:
268
+ self.logger.info("VPC configuration disabled for endpoint deployment")
269
+
270
+ # Deploy model to endpoint using ModelBuilder.deploy()
271
+ # SDK v3: ModelBuilder.deploy() returns an Endpoint object (not Predictor)
272
+ deploy_params = {
273
+ 'model': model,
274
+ 'endpoint_name': endpoint_name,
275
+ }
276
+
277
+ # Add instance count
278
+ if config.get('instance_count'):
279
+ deploy_params['initial_instance_count'] = config['instance_count']
280
+
281
+ # Add wait flag (default is True)
282
+ if 'wait' in config:
283
+ deploy_params['wait'] = config['wait']
284
+
285
+ self.logger.debug("Deploying model to endpoint",
286
+ endpoint_name=endpoint_name,
287
+ instance_type=config.get('instance_type'),
288
+ instance_count=config.get('instance_count'))
289
+
290
+ # Deploy the model using ModelBuilder.deploy()
291
+ # SDK v3: ModelBuilder.deploy() returns an Endpoint object (not Predictor)
292
+ # The Endpoint object has invoke() method with PascalCase parameters
293
+ endpoint = model_builder.deploy(**deploy_params)
294
+
295
+ # Wrap the endpoint to provide predict() method for backward compatibility
296
+ predictor = PredictorWrapper(endpoint, self.logger)
297
+
298
+ self.logger.info("Model deployed successfully", endpoint_name=endpoint_name)
299
+ return predictor
300
+
301
+ except Exception as e:
302
+ self.logger.error("Failed to deploy model",
303
+ endpoint_name=endpoint_name,
304
+ error=e)
305
+ raise AWSServiceError(
306
+ f"Failed to deploy model to endpoint '{endpoint_name}': {e}",
307
+ aws_error=e
308
+ ) from e
309
+
310
+ def _build_deployment_config(self, runtime_params: Dict[str, Any], enable_vpc: bool = False) -> Dict[str, Any]:
311
+ """
312
+ Build deployment configuration by merging defaults with runtime parameters.
313
+
314
+ Parameter precedence: runtime > config > SageMaker defaults
315
+
316
+ Args:
317
+ runtime_params: Runtime parameters provided by user
318
+ enable_vpc: If True, includes VPC configuration from config
319
+
320
+ Returns:
321
+ Merged configuration dictionary
322
+ """
323
+ config = {}
324
+
325
+ # Get configuration objects
326
+ compute_config = self.config_manager.get_compute_config()
327
+ networking_config = self.config_manager.get_networking_config()
328
+ iam_config = self.config_manager.get_iam_config()
329
+ kms_config = self.config_manager.get_kms_config()
330
+
331
+ # Apply compute defaults (runtime > config)
332
+ if 'instance_type' in runtime_params:
333
+ config['instance_type'] = runtime_params['instance_type']
334
+ self.logger.debug("Using runtime instance_type", value=runtime_params['instance_type'])
335
+ elif compute_config and hasattr(compute_config, 'inference_instance_type'):
336
+ config['instance_type'] = compute_config.inference_instance_type
337
+ self.logger.debug("Using config instance_type", value=compute_config.inference_instance_type)
338
+ else:
339
+ # Use default for inference
340
+ config['instance_type'] = 'ml.m5.large'
341
+ self.logger.debug("Using default instance_type", value='ml.m5.large')
342
+
343
+ if 'instance_count' in runtime_params:
344
+ config['instance_count'] = runtime_params['instance_count']
345
+ self.logger.debug("Using runtime instance_count", value=runtime_params['instance_count'])
346
+ elif compute_config and hasattr(compute_config, 'inference_instance_count'):
347
+ config['instance_count'] = compute_config.inference_instance_count
348
+ self.logger.debug("Using config instance_count", value=compute_config.inference_instance_count)
349
+ else:
350
+ # Use default for inference
351
+ config['instance_count'] = 1
352
+ self.logger.debug("Using default instance_count", value=1)
353
+
354
+ # Apply IAM role default (runtime > config)
355
+ if 'role_arn' in runtime_params:
356
+ config['role_arn'] = runtime_params['role_arn']
357
+ self.logger.debug("Using runtime role_arn")
358
+ elif iam_config:
359
+ config['role_arn'] = iam_config.execution_role
360
+ self.logger.debug("Using config role_arn", role=iam_config.execution_role)
361
+ else:
362
+ raise ValidationError("IAM execution role is required. Provide via runtime parameter or configuration.")
363
+
364
+ # Apply networking defaults (runtime > config) - only if VPC is enabled
365
+ if enable_vpc:
366
+ if 'subnets' in runtime_params:
367
+ config['subnets'] = runtime_params['subnets']
368
+ self.logger.debug("Using runtime subnets")
369
+ elif networking_config:
370
+ config['subnets'] = networking_config.subnets
371
+ self.logger.debug("Using config subnets", subnets=networking_config.subnets)
372
+
373
+ if 'security_group_ids' in runtime_params:
374
+ config['security_group_ids'] = runtime_params['security_group_ids']
375
+ self.logger.debug("Using runtime security_group_ids")
376
+ elif networking_config:
377
+ config['security_group_ids'] = networking_config.security_group_ids
378
+ self.logger.debug("Using config security_group_ids",
379
+ security_groups=networking_config.security_group_ids)
380
+ else:
381
+ self.logger.debug("VPC configuration disabled, skipping network config")
382
+
383
+ # Apply KMS encryption defaults (runtime > config)
384
+ if 'kms_key' in runtime_params:
385
+ config['kms_key'] = runtime_params['kms_key']
386
+ self.logger.debug("Using runtime kms_key")
387
+ elif kms_config and kms_config.key_id:
388
+ config['kms_key'] = kms_config.key_id
389
+ self.logger.debug("Using config kms_key", key_id=kms_config.key_id)
390
+
391
+ # Apply any remaining runtime parameters
392
+ for key, value in runtime_params.items():
393
+ if key not in ['instance_type', 'instance_count', 'role_arn',
394
+ 'subnets', 'security_group_ids', 'kms_key']:
395
+ config[key] = value
396
+ self.logger.debug(f"Using runtime parameter: {key}")
397
+
398
+ return config
399
+
400
+ def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
401
+ """
402
+ Validate runtime parameter overrides.
403
+
404
+ Args:
405
+ runtime_params: Runtime parameters to validate
406
+
407
+ Raises:
408
+ ValidationError: If runtime parameters are invalid
409
+ """
410
+ # Validate instance_type format if provided
411
+ if 'instance_type' in runtime_params:
412
+ instance_type = runtime_params['instance_type']
413
+ if not isinstance(instance_type, str):
414
+ raise ValidationError("instance_type must be a string")
415
+
416
+ if not instance_type.startswith('ml.'):
417
+ raise ValidationError(f"Invalid instance type format: {instance_type}. Must start with 'ml.'")
418
+
419
+ # Validate instance_count if provided
420
+ if 'instance_count' in runtime_params:
421
+ instance_count = runtime_params['instance_count']
422
+ if not isinstance(instance_count, int):
423
+ raise ValidationError("instance_count must be an integer")
424
+
425
+ if instance_count < 1:
426
+ raise ValidationError(f"instance_count must be at least 1, got {instance_count}")
427
+
428
+ # Validate role_arn format if provided
429
+ if 'role_arn' in runtime_params:
430
+ role_arn = runtime_params['role_arn']
431
+ if not isinstance(role_arn, str):
432
+ raise ValidationError("role_arn must be a string")
433
+
434
+ if not role_arn.startswith('arn:aws:iam::'):
435
+ raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
436
+
437
+ # Validate subnets if provided
438
+ if 'subnets' in runtime_params:
439
+ subnets = runtime_params['subnets']
440
+ if not isinstance(subnets, list):
441
+ raise ValidationError("subnets must be a list")
442
+
443
+ if not subnets:
444
+ raise ValidationError("subnets list cannot be empty")
445
+
446
+ # Validate security_group_ids if provided
447
+ if 'security_group_ids' in runtime_params:
448
+ security_groups = runtime_params['security_group_ids']
449
+ if not isinstance(security_groups, list):
450
+ raise ValidationError("security_group_ids must be a list")
451
+
452
+ if not security_groups:
453
+ raise ValidationError("security_group_ids list cannot be empty")
454
+
455
+ # Validate environment if provided
456
+ if 'environment' in runtime_params:
457
+ environment = runtime_params['environment']
458
+ if not isinstance(environment, dict):
459
+ raise ValidationError("environment must be a dictionary")