sagemaker-mlp-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlp_sdk/__init__.py +18 -0
- mlp_sdk/config.py +630 -0
- mlp_sdk/exceptions.py +421 -0
- mlp_sdk/models.py +62 -0
- mlp_sdk/session.py +1160 -0
- mlp_sdk/wrappers/__init__.py +11 -0
- mlp_sdk/wrappers/deployment.py +459 -0
- mlp_sdk/wrappers/feature_store.py +308 -0
- mlp_sdk/wrappers/pipeline.py +452 -0
- mlp_sdk/wrappers/processing.py +381 -0
- mlp_sdk/wrappers/training.py +492 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/METADATA +569 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/RECORD +16 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Processing Job wrapper for mlp_sdk
|
|
3
|
+
Provides simplified processing job execution with configuration-driven defaults
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional, Dict, Any, List
|
|
7
|
+
from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
|
|
8
|
+
from ..config import ConfigurationManager
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProcessingWrapper:
|
|
12
|
+
"""
|
|
13
|
+
Wrapper for SageMaker Processing Job operations.
|
|
14
|
+
Applies default configurations from ConfigurationManager.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
|
|
18
|
+
"""
|
|
19
|
+
Initialize Processing wrapper.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
config_manager: Configuration manager instance
|
|
23
|
+
logger: Optional logger instance
|
|
24
|
+
"""
|
|
25
|
+
self.config_manager = config_manager
|
|
26
|
+
self.logger = logger or MLPLogger("mlp_sdk.processing")
|
|
27
|
+
|
|
28
|
+
def run_processing_job(self,
|
|
29
|
+
sagemaker_session,
|
|
30
|
+
job_name: str,
|
|
31
|
+
processing_script: Optional[str] = None,
|
|
32
|
+
inputs: Optional[List[Dict[str, Any]]] = None,
|
|
33
|
+
outputs: Optional[List[Dict[str, Any]]] = None,
|
|
34
|
+
**kwargs) -> Any:
|
|
35
|
+
"""
|
|
36
|
+
Execute processing job with default configurations.
|
|
37
|
+
|
|
38
|
+
Applies defaults from configuration for:
|
|
39
|
+
- Instance type and count
|
|
40
|
+
- IAM execution role
|
|
41
|
+
- VPC configuration (VPC ID, security groups, subnets)
|
|
42
|
+
- S3 input/output paths
|
|
43
|
+
- KMS encryption key
|
|
44
|
+
|
|
45
|
+
Runtime parameters override configuration defaults.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
sagemaker_session: SageMaker session object
|
|
49
|
+
job_name: Name of the processing job
|
|
50
|
+
processing_script: Optional path to custom processing script
|
|
51
|
+
inputs: Optional list of processing inputs
|
|
52
|
+
outputs: Optional list of processing outputs
|
|
53
|
+
**kwargs: Additional parameters that override defaults
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Processor object
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValidationError: If required parameters are missing or invalid
|
|
60
|
+
AWSServiceError: If processing job execution fails
|
|
61
|
+
"""
|
|
62
|
+
self.logger.info("Running processing job", name=job_name)
|
|
63
|
+
|
|
64
|
+
# Validate required parameters
|
|
65
|
+
if not job_name:
|
|
66
|
+
raise ValidationError("job_name is required")
|
|
67
|
+
|
|
68
|
+
# Validate runtime parameter overrides
|
|
69
|
+
self.validate_parameter_override(kwargs)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
from sagemaker.processing import Processor, ProcessingInput, ProcessingOutput
|
|
73
|
+
from sagemaker.network import NetworkConfig
|
|
74
|
+
except ImportError as e:
|
|
75
|
+
raise MLPSDKError(
|
|
76
|
+
"SageMaker SDK not installed. Install with: pip install sagemaker>=3.0.0"
|
|
77
|
+
) from e
|
|
78
|
+
|
|
79
|
+
# Build configuration with defaults
|
|
80
|
+
config = self._build_processing_config(kwargs)
|
|
81
|
+
|
|
82
|
+
# Create Processor object
|
|
83
|
+
processor_params = {
|
|
84
|
+
'role': config['role_arn'],
|
|
85
|
+
'instance_type': config['instance_type'],
|
|
86
|
+
'instance_count': config['instance_count'],
|
|
87
|
+
'sagemaker_session': sagemaker_session,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# Add base job name if provided
|
|
91
|
+
if config.get('base_job_name'):
|
|
92
|
+
processor_params['base_job_name'] = config['base_job_name']
|
|
93
|
+
|
|
94
|
+
# Add volume size if provided
|
|
95
|
+
if config.get('volume_size_in_gb'):
|
|
96
|
+
processor_params['volume_size_in_gb'] = config['volume_size_in_gb']
|
|
97
|
+
|
|
98
|
+
# Add volume KMS key if provided
|
|
99
|
+
if config.get('volume_kms_key'):
|
|
100
|
+
processor_params['volume_kms_key'] = config['volume_kms_key']
|
|
101
|
+
|
|
102
|
+
# Add output KMS key if provided
|
|
103
|
+
if config.get('output_kms_key'):
|
|
104
|
+
processor_params['output_kms_key'] = config['output_kms_key']
|
|
105
|
+
|
|
106
|
+
# Add max runtime if provided
|
|
107
|
+
if config.get('max_runtime_in_seconds'):
|
|
108
|
+
processor_params['max_runtime_in_seconds'] = config['max_runtime_in_seconds']
|
|
109
|
+
|
|
110
|
+
# Add environment variables if provided
|
|
111
|
+
if config.get('env'):
|
|
112
|
+
processor_params['env'] = config['env']
|
|
113
|
+
|
|
114
|
+
# Add tags if provided
|
|
115
|
+
if config.get('tags'):
|
|
116
|
+
processor_params['tags'] = config['tags']
|
|
117
|
+
|
|
118
|
+
# Add network config if available
|
|
119
|
+
if config.get('network_config'):
|
|
120
|
+
processor_params['network_config'] = config['network_config']
|
|
121
|
+
|
|
122
|
+
# Determine processor image
|
|
123
|
+
if config.get('image_uri'):
|
|
124
|
+
processor_params['image_uri'] = config['image_uri']
|
|
125
|
+
else:
|
|
126
|
+
# Use default SageMaker processing container
|
|
127
|
+
# This will use the SageMaker SDK default
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
# Create processor
|
|
131
|
+
processor = Processor(**processor_params)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
# Build run parameters
|
|
135
|
+
run_params = {
|
|
136
|
+
'job_name': job_name,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# Add processing script if provided
|
|
140
|
+
if processing_script:
|
|
141
|
+
run_params['code'] = processing_script
|
|
142
|
+
elif config.get('code'):
|
|
143
|
+
run_params['code'] = config['code']
|
|
144
|
+
|
|
145
|
+
# Add inputs
|
|
146
|
+
if inputs:
|
|
147
|
+
run_params['inputs'] = [ProcessingInput(**inp) for inp in inputs]
|
|
148
|
+
elif config.get('inputs'):
|
|
149
|
+
run_params['inputs'] = [ProcessingInput(**inp) for inp in config['inputs']]
|
|
150
|
+
|
|
151
|
+
# Add outputs
|
|
152
|
+
if outputs:
|
|
153
|
+
run_params['outputs'] = [ProcessingOutput(**out) for out in outputs]
|
|
154
|
+
elif config.get('outputs'):
|
|
155
|
+
run_params['outputs'] = [ProcessingOutput(**out) for out in config['outputs']]
|
|
156
|
+
|
|
157
|
+
# Add arguments if provided
|
|
158
|
+
if config.get('arguments'):
|
|
159
|
+
run_params['arguments'] = config['arguments']
|
|
160
|
+
|
|
161
|
+
# Add wait flag if provided
|
|
162
|
+
if 'wait' in config:
|
|
163
|
+
run_params['wait'] = config['wait']
|
|
164
|
+
|
|
165
|
+
# Add logs flag if provided
|
|
166
|
+
if 'logs' in config:
|
|
167
|
+
run_params['logs'] = config['logs']
|
|
168
|
+
|
|
169
|
+
self.logger.debug("Starting processing job",
|
|
170
|
+
name=job_name,
|
|
171
|
+
instance_type=config['instance_type'],
|
|
172
|
+
instance_count=config['instance_count'])
|
|
173
|
+
|
|
174
|
+
# Run the processing job
|
|
175
|
+
processor.run(**run_params)
|
|
176
|
+
|
|
177
|
+
self.logger.info("Processing job started successfully", name=job_name)
|
|
178
|
+
return processor
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
self.logger.error("Failed to run processing job",
|
|
182
|
+
name=job_name,
|
|
183
|
+
error=e)
|
|
184
|
+
raise AWSServiceError(
|
|
185
|
+
f"Failed to run processing job '{job_name}': {e}",
|
|
186
|
+
aws_error=e
|
|
187
|
+
) from e
|
|
188
|
+
|
|
189
|
+
def _build_processing_config(self, runtime_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
190
|
+
"""
|
|
191
|
+
Build processing job configuration by merging defaults with runtime parameters.
|
|
192
|
+
|
|
193
|
+
Parameter precedence: runtime > config > SageMaker defaults
|
|
194
|
+
|
|
195
|
+
This implements the parameter override behavior specified in Requirements 4.1, 4.2, 4.3:
|
|
196
|
+
- Runtime parameters always take precedence over configuration defaults
|
|
197
|
+
- Configuration defaults take precedence over SageMaker SDK defaults
|
|
198
|
+
- SageMaker SDK defaults are used when neither runtime nor config provide values
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
runtime_params: Runtime parameters provided by user
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Merged configuration dictionary
|
|
205
|
+
"""
|
|
206
|
+
config = {}
|
|
207
|
+
|
|
208
|
+
# Get configuration objects
|
|
209
|
+
compute_config = self.config_manager.get_compute_config()
|
|
210
|
+
networking_config = self.config_manager.get_networking_config()
|
|
211
|
+
iam_config = self.config_manager.get_iam_config()
|
|
212
|
+
s3_config = self.config_manager.get_s3_config()
|
|
213
|
+
kms_config = self.config_manager.get_kms_config()
|
|
214
|
+
|
|
215
|
+
# Apply compute defaults (runtime > config)
|
|
216
|
+
if 'instance_type' in runtime_params:
|
|
217
|
+
config['instance_type'] = runtime_params['instance_type']
|
|
218
|
+
self.logger.debug("Using runtime instance_type", value=runtime_params['instance_type'])
|
|
219
|
+
elif compute_config:
|
|
220
|
+
config['instance_type'] = compute_config.processing_instance_type
|
|
221
|
+
self.logger.debug("Using config instance_type", value=compute_config.processing_instance_type)
|
|
222
|
+
else:
|
|
223
|
+
# Will use SageMaker SDK default
|
|
224
|
+
config['instance_type'] = 'ml.m5.large'
|
|
225
|
+
self.logger.debug("Using default instance_type", value='ml.m5.large')
|
|
226
|
+
|
|
227
|
+
if 'instance_count' in runtime_params:
|
|
228
|
+
config['instance_count'] = runtime_params['instance_count']
|
|
229
|
+
self.logger.debug("Using runtime instance_count", value=runtime_params['instance_count'])
|
|
230
|
+
elif compute_config:
|
|
231
|
+
config['instance_count'] = compute_config.processing_instance_count
|
|
232
|
+
self.logger.debug("Using config instance_count", value=compute_config.processing_instance_count)
|
|
233
|
+
else:
|
|
234
|
+
# Will use SageMaker SDK default
|
|
235
|
+
config['instance_count'] = 1
|
|
236
|
+
self.logger.debug("Using default instance_count", value=1)
|
|
237
|
+
|
|
238
|
+
# Apply IAM role default (runtime > config)
|
|
239
|
+
if 'role_arn' in runtime_params:
|
|
240
|
+
config['role_arn'] = runtime_params['role_arn']
|
|
241
|
+
self.logger.debug("Using runtime role_arn")
|
|
242
|
+
elif iam_config:
|
|
243
|
+
config['role_arn'] = iam_config.execution_role
|
|
244
|
+
self.logger.debug("Using config role_arn", role=iam_config.execution_role)
|
|
245
|
+
else:
|
|
246
|
+
raise ValidationError("IAM execution role is required. Provide via runtime parameter or configuration.")
|
|
247
|
+
|
|
248
|
+
# Apply networking defaults (runtime > config)
|
|
249
|
+
if 'network_config' in runtime_params:
|
|
250
|
+
config['network_config'] = runtime_params['network_config']
|
|
251
|
+
self.logger.debug("Using runtime network_config")
|
|
252
|
+
elif networking_config:
|
|
253
|
+
try:
|
|
254
|
+
from sagemaker.network import NetworkConfig
|
|
255
|
+
config['network_config'] = NetworkConfig(
|
|
256
|
+
enable_network_isolation=False,
|
|
257
|
+
security_group_ids=networking_config.security_group_ids,
|
|
258
|
+
subnets=networking_config.subnets
|
|
259
|
+
)
|
|
260
|
+
self.logger.debug("Using config network_config",
|
|
261
|
+
vpc_id=networking_config.vpc_id,
|
|
262
|
+
security_groups=networking_config.security_group_ids,
|
|
263
|
+
subnets=networking_config.subnets)
|
|
264
|
+
except ImportError:
|
|
265
|
+
# SageMaker SDK not available, skip network config
|
|
266
|
+
self.logger.debug("SageMaker SDK not available, skipping network config")
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
# Apply KMS encryption defaults (runtime > config)
|
|
270
|
+
if 'volume_kms_key' in runtime_params:
|
|
271
|
+
config['volume_kms_key'] = runtime_params['volume_kms_key']
|
|
272
|
+
self.logger.debug("Using runtime volume_kms_key")
|
|
273
|
+
elif kms_config and kms_config.key_id:
|
|
274
|
+
config['volume_kms_key'] = kms_config.key_id
|
|
275
|
+
self.logger.debug("Using config volume_kms_key", key_id=kms_config.key_id)
|
|
276
|
+
|
|
277
|
+
if 'output_kms_key' in runtime_params:
|
|
278
|
+
config['output_kms_key'] = runtime_params['output_kms_key']
|
|
279
|
+
self.logger.debug("Using runtime output_kms_key")
|
|
280
|
+
elif kms_config and kms_config.key_id:
|
|
281
|
+
config['output_kms_key'] = kms_config.key_id
|
|
282
|
+
self.logger.debug("Using config output_kms_key", key_id=kms_config.key_id)
|
|
283
|
+
|
|
284
|
+
# Apply S3 defaults for inputs/outputs if not provided (runtime > config)
|
|
285
|
+
# Note: inputs and outputs are typically provided at runtime, but we can set defaults
|
|
286
|
+
if s3_config:
|
|
287
|
+
if 'inputs' not in runtime_params and 'inputs' not in config:
|
|
288
|
+
# Default input location
|
|
289
|
+
default_input_uri = f"s3://{s3_config.default_bucket}/{s3_config.input_prefix}"
|
|
290
|
+
self.logger.debug("Default input S3 URI available", uri=default_input_uri)
|
|
291
|
+
|
|
292
|
+
if 'outputs' not in runtime_params and 'outputs' not in config:
|
|
293
|
+
# Default output location
|
|
294
|
+
default_output_uri = f"s3://{s3_config.default_bucket}/{s3_config.output_prefix}"
|
|
295
|
+
self.logger.debug("Default output S3 URI available", uri=default_output_uri)
|
|
296
|
+
|
|
297
|
+
# Apply any remaining runtime parameters (these override everything)
|
|
298
|
+
for key, value in runtime_params.items():
|
|
299
|
+
if key not in ['instance_type', 'instance_count', 'role_arn',
|
|
300
|
+
'network_config', 'volume_kms_key', 'output_kms_key']:
|
|
301
|
+
config[key] = value
|
|
302
|
+
self.logger.debug(f"Using runtime parameter: {key}")
|
|
303
|
+
|
|
304
|
+
return config
|
|
305
|
+
|
|
306
|
+
def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
|
|
307
|
+
"""
|
|
308
|
+
Validate runtime parameter overrides.
|
|
309
|
+
|
|
310
|
+
This ensures that runtime parameters are valid and compatible with the configuration.
|
|
311
|
+
Implements validation requirements from Requirements 4.1, 4.2, 4.3, 4.4.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
runtime_params: Runtime parameters to validate
|
|
315
|
+
|
|
316
|
+
Raises:
|
|
317
|
+
ValidationError: If runtime parameters are invalid
|
|
318
|
+
"""
|
|
319
|
+
# Validate instance_type format if provided
|
|
320
|
+
if 'instance_type' in runtime_params:
|
|
321
|
+
instance_type = runtime_params['instance_type']
|
|
322
|
+
if not isinstance(instance_type, str):
|
|
323
|
+
raise ValidationError("instance_type must be a string")
|
|
324
|
+
|
|
325
|
+
if not instance_type.startswith('ml.'):
|
|
326
|
+
raise ValidationError(f"Invalid instance type format: {instance_type}. Must start with 'ml.'")
|
|
327
|
+
|
|
328
|
+
# Validate instance_count if provided
|
|
329
|
+
if 'instance_count' in runtime_params:
|
|
330
|
+
instance_count = runtime_params['instance_count']
|
|
331
|
+
if not isinstance(instance_count, int):
|
|
332
|
+
raise ValidationError("instance_count must be an integer")
|
|
333
|
+
|
|
334
|
+
if instance_count < 1:
|
|
335
|
+
raise ValidationError(f"instance_count must be at least 1, got {instance_count}")
|
|
336
|
+
|
|
337
|
+
# Validate role_arn format if provided
|
|
338
|
+
if 'role_arn' in runtime_params:
|
|
339
|
+
role_arn = runtime_params['role_arn']
|
|
340
|
+
if not isinstance(role_arn, str):
|
|
341
|
+
raise ValidationError("role_arn must be a string")
|
|
342
|
+
|
|
343
|
+
if not role_arn.startswith('arn:aws:iam::'):
|
|
344
|
+
raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
|
|
345
|
+
|
|
346
|
+
# Validate network_config if provided
|
|
347
|
+
if 'network_config' in runtime_params:
|
|
348
|
+
network_config = runtime_params['network_config']
|
|
349
|
+
# NetworkConfig should be a NetworkConfig object or dict
|
|
350
|
+
if not hasattr(network_config, 'security_group_ids') and not isinstance(network_config, dict):
|
|
351
|
+
raise ValidationError("network_config must be a NetworkConfig object or dictionary")
|
|
352
|
+
|
|
353
|
+
# Validate volume_size_in_gb if provided
|
|
354
|
+
if 'volume_size_in_gb' in runtime_params:
|
|
355
|
+
volume_size = runtime_params['volume_size_in_gb']
|
|
356
|
+
if not isinstance(volume_size, int):
|
|
357
|
+
raise ValidationError("volume_size_in_gb must be an integer")
|
|
358
|
+
|
|
359
|
+
if volume_size < 1:
|
|
360
|
+
raise ValidationError(f"volume_size_in_gb must be at least 1, got {volume_size}")
|
|
361
|
+
|
|
362
|
+
# Validate max_runtime_in_seconds if provided
|
|
363
|
+
if 'max_runtime_in_seconds' in runtime_params:
|
|
364
|
+
max_runtime = runtime_params['max_runtime_in_seconds']
|
|
365
|
+
if not isinstance(max_runtime, int):
|
|
366
|
+
raise ValidationError("max_runtime_in_seconds must be an integer")
|
|
367
|
+
|
|
368
|
+
if max_runtime < 1:
|
|
369
|
+
raise ValidationError(f"max_runtime_in_seconds must be at least 1, got {max_runtime}")
|
|
370
|
+
|
|
371
|
+
# Validate inputs if provided
|
|
372
|
+
if 'inputs' in runtime_params:
|
|
373
|
+
inputs = runtime_params['inputs']
|
|
374
|
+
if not isinstance(inputs, list):
|
|
375
|
+
raise ValidationError("inputs must be a list")
|
|
376
|
+
|
|
377
|
+
# Validate outputs if provided
|
|
378
|
+
if 'outputs' in runtime_params:
|
|
379
|
+
outputs = runtime_params['outputs']
|
|
380
|
+
if not isinstance(outputs, list):
|
|
381
|
+
raise ValidationError("outputs must be a list")
|