sagemaker-mlp-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,381 @@
1
+ """
2
+ Processing Job wrapper for mlp_sdk
3
+ Provides simplified processing job execution with configuration-driven defaults
4
+ """
5
+
6
+ from typing import Optional, Dict, Any, List
7
+ from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
8
+ from ..config import ConfigurationManager
9
+
10
+
11
+ class ProcessingWrapper:
12
+ """
13
+ Wrapper for SageMaker Processing Job operations.
14
+ Applies default configurations from ConfigurationManager.
15
+ """
16
+
17
+ def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
18
+ """
19
+ Initialize Processing wrapper.
20
+
21
+ Args:
22
+ config_manager: Configuration manager instance
23
+ logger: Optional logger instance
24
+ """
25
+ self.config_manager = config_manager
26
+ self.logger = logger or MLPLogger("mlp_sdk.processing")
27
+
28
+ def run_processing_job(self,
29
+ sagemaker_session,
30
+ job_name: str,
31
+ processing_script: Optional[str] = None,
32
+ inputs: Optional[List[Dict[str, Any]]] = None,
33
+ outputs: Optional[List[Dict[str, Any]]] = None,
34
+ **kwargs) -> Any:
35
+ """
36
+ Execute processing job with default configurations.
37
+
38
+ Applies defaults from configuration for:
39
+ - Instance type and count
40
+ - IAM execution role
41
+ - VPC configuration (VPC ID, security groups, subnets)
42
+ - S3 input/output paths
43
+ - KMS encryption key
44
+
45
+ Runtime parameters override configuration defaults.
46
+
47
+ Args:
48
+ sagemaker_session: SageMaker session object
49
+ job_name: Name of the processing job
50
+ processing_script: Optional path to custom processing script
51
+ inputs: Optional list of processing inputs
52
+ outputs: Optional list of processing outputs
53
+ **kwargs: Additional parameters that override defaults
54
+
55
+ Returns:
56
+ Processor object
57
+
58
+ Raises:
59
+ ValidationError: If required parameters are missing or invalid
60
+ AWSServiceError: If processing job execution fails
61
+ """
62
+ self.logger.info("Running processing job", name=job_name)
63
+
64
+ # Validate required parameters
65
+ if not job_name:
66
+ raise ValidationError("job_name is required")
67
+
68
+ # Validate runtime parameter overrides
69
+ self.validate_parameter_override(kwargs)
70
+
71
+ try:
72
+ from sagemaker.processing import Processor, ProcessingInput, ProcessingOutput
73
+ from sagemaker.network import NetworkConfig
74
+ except ImportError as e:
75
+ raise MLPSDKError(
76
+ "SageMaker SDK not installed. Install with: pip install sagemaker>=3.0.0"
77
+ ) from e
78
+
79
+ # Build configuration with defaults
80
+ config = self._build_processing_config(kwargs)
81
+
82
+ # Create Processor object
83
+ processor_params = {
84
+ 'role': config['role_arn'],
85
+ 'instance_type': config['instance_type'],
86
+ 'instance_count': config['instance_count'],
87
+ 'sagemaker_session': sagemaker_session,
88
+ }
89
+
90
+ # Add base job name if provided
91
+ if config.get('base_job_name'):
92
+ processor_params['base_job_name'] = config['base_job_name']
93
+
94
+ # Add volume size if provided
95
+ if config.get('volume_size_in_gb'):
96
+ processor_params['volume_size_in_gb'] = config['volume_size_in_gb']
97
+
98
+ # Add volume KMS key if provided
99
+ if config.get('volume_kms_key'):
100
+ processor_params['volume_kms_key'] = config['volume_kms_key']
101
+
102
+ # Add output KMS key if provided
103
+ if config.get('output_kms_key'):
104
+ processor_params['output_kms_key'] = config['output_kms_key']
105
+
106
+ # Add max runtime if provided
107
+ if config.get('max_runtime_in_seconds'):
108
+ processor_params['max_runtime_in_seconds'] = config['max_runtime_in_seconds']
109
+
110
+ # Add environment variables if provided
111
+ if config.get('env'):
112
+ processor_params['env'] = config['env']
113
+
114
+ # Add tags if provided
115
+ if config.get('tags'):
116
+ processor_params['tags'] = config['tags']
117
+
118
+ # Add network config if available
119
+ if config.get('network_config'):
120
+ processor_params['network_config'] = config['network_config']
121
+
122
+ # Determine processor image
123
+ if config.get('image_uri'):
124
+ processor_params['image_uri'] = config['image_uri']
125
+ else:
126
+ # Use default SageMaker processing container
127
+ # This will use the SageMaker SDK default
128
+ pass
129
+
130
+ # Create processor
131
+ processor = Processor(**processor_params)
132
+
133
+ try:
134
+ # Build run parameters
135
+ run_params = {
136
+ 'job_name': job_name,
137
+ }
138
+
139
+ # Add processing script if provided
140
+ if processing_script:
141
+ run_params['code'] = processing_script
142
+ elif config.get('code'):
143
+ run_params['code'] = config['code']
144
+
145
+ # Add inputs
146
+ if inputs:
147
+ run_params['inputs'] = [ProcessingInput(**inp) for inp in inputs]
148
+ elif config.get('inputs'):
149
+ run_params['inputs'] = [ProcessingInput(**inp) for inp in config['inputs']]
150
+
151
+ # Add outputs
152
+ if outputs:
153
+ run_params['outputs'] = [ProcessingOutput(**out) for out in outputs]
154
+ elif config.get('outputs'):
155
+ run_params['outputs'] = [ProcessingOutput(**out) for out in config['outputs']]
156
+
157
+ # Add arguments if provided
158
+ if config.get('arguments'):
159
+ run_params['arguments'] = config['arguments']
160
+
161
+ # Add wait flag if provided
162
+ if 'wait' in config:
163
+ run_params['wait'] = config['wait']
164
+
165
+ # Add logs flag if provided
166
+ if 'logs' in config:
167
+ run_params['logs'] = config['logs']
168
+
169
+ self.logger.debug("Starting processing job",
170
+ name=job_name,
171
+ instance_type=config['instance_type'],
172
+ instance_count=config['instance_count'])
173
+
174
+ # Run the processing job
175
+ processor.run(**run_params)
176
+
177
+ self.logger.info("Processing job started successfully", name=job_name)
178
+ return processor
179
+
180
+ except Exception as e:
181
+ self.logger.error("Failed to run processing job",
182
+ name=job_name,
183
+ error=e)
184
+ raise AWSServiceError(
185
+ f"Failed to run processing job '{job_name}': {e}",
186
+ aws_error=e
187
+ ) from e
188
+
189
+ def _build_processing_config(self, runtime_params: Dict[str, Any]) -> Dict[str, Any]:
190
+ """
191
+ Build processing job configuration by merging defaults with runtime parameters.
192
+
193
+ Parameter precedence: runtime > config > SageMaker defaults
194
+
195
+ This implements the parameter override behavior specified in Requirements 4.1, 4.2, 4.3:
196
+ - Runtime parameters always take precedence over configuration defaults
197
+ - Configuration defaults take precedence over SageMaker SDK defaults
198
+ - SageMaker SDK defaults are used when neither runtime nor config provide values
199
+
200
+ Args:
201
+ runtime_params: Runtime parameters provided by user
202
+
203
+ Returns:
204
+ Merged configuration dictionary
205
+ """
206
+ config = {}
207
+
208
+ # Get configuration objects
209
+ compute_config = self.config_manager.get_compute_config()
210
+ networking_config = self.config_manager.get_networking_config()
211
+ iam_config = self.config_manager.get_iam_config()
212
+ s3_config = self.config_manager.get_s3_config()
213
+ kms_config = self.config_manager.get_kms_config()
214
+
215
+ # Apply compute defaults (runtime > config)
216
+ if 'instance_type' in runtime_params:
217
+ config['instance_type'] = runtime_params['instance_type']
218
+ self.logger.debug("Using runtime instance_type", value=runtime_params['instance_type'])
219
+ elif compute_config:
220
+ config['instance_type'] = compute_config.processing_instance_type
221
+ self.logger.debug("Using config instance_type", value=compute_config.processing_instance_type)
222
+ else:
223
+ # Will use SageMaker SDK default
224
+ config['instance_type'] = 'ml.m5.large'
225
+ self.logger.debug("Using default instance_type", value='ml.m5.large')
226
+
227
+ if 'instance_count' in runtime_params:
228
+ config['instance_count'] = runtime_params['instance_count']
229
+ self.logger.debug("Using runtime instance_count", value=runtime_params['instance_count'])
230
+ elif compute_config:
231
+ config['instance_count'] = compute_config.processing_instance_count
232
+ self.logger.debug("Using config instance_count", value=compute_config.processing_instance_count)
233
+ else:
234
+ # Will use SageMaker SDK default
235
+ config['instance_count'] = 1
236
+ self.logger.debug("Using default instance_count", value=1)
237
+
238
+ # Apply IAM role default (runtime > config)
239
+ if 'role_arn' in runtime_params:
240
+ config['role_arn'] = runtime_params['role_arn']
241
+ self.logger.debug("Using runtime role_arn")
242
+ elif iam_config:
243
+ config['role_arn'] = iam_config.execution_role
244
+ self.logger.debug("Using config role_arn", role=iam_config.execution_role)
245
+ else:
246
+ raise ValidationError("IAM execution role is required. Provide via runtime parameter or configuration.")
247
+
248
+ # Apply networking defaults (runtime > config)
249
+ if 'network_config' in runtime_params:
250
+ config['network_config'] = runtime_params['network_config']
251
+ self.logger.debug("Using runtime network_config")
252
+ elif networking_config:
253
+ try:
254
+ from sagemaker.network import NetworkConfig
255
+ config['network_config'] = NetworkConfig(
256
+ enable_network_isolation=False,
257
+ security_group_ids=networking_config.security_group_ids,
258
+ subnets=networking_config.subnets
259
+ )
260
+ self.logger.debug("Using config network_config",
261
+ vpc_id=networking_config.vpc_id,
262
+ security_groups=networking_config.security_group_ids,
263
+ subnets=networking_config.subnets)
264
+ except ImportError:
265
+ # SageMaker SDK not available, skip network config
266
+ self.logger.debug("SageMaker SDK not available, skipping network config")
267
+ pass
268
+
269
+ # Apply KMS encryption defaults (runtime > config)
270
+ if 'volume_kms_key' in runtime_params:
271
+ config['volume_kms_key'] = runtime_params['volume_kms_key']
272
+ self.logger.debug("Using runtime volume_kms_key")
273
+ elif kms_config and kms_config.key_id:
274
+ config['volume_kms_key'] = kms_config.key_id
275
+ self.logger.debug("Using config volume_kms_key", key_id=kms_config.key_id)
276
+
277
+ if 'output_kms_key' in runtime_params:
278
+ config['output_kms_key'] = runtime_params['output_kms_key']
279
+ self.logger.debug("Using runtime output_kms_key")
280
+ elif kms_config and kms_config.key_id:
281
+ config['output_kms_key'] = kms_config.key_id
282
+ self.logger.debug("Using config output_kms_key", key_id=kms_config.key_id)
283
+
284
+ # Apply S3 defaults for inputs/outputs if not provided (runtime > config)
285
+ # Note: inputs and outputs are typically provided at runtime, but we can set defaults
286
+ if s3_config:
287
+ if 'inputs' not in runtime_params and 'inputs' not in config:
288
+ # Default input location
289
+ default_input_uri = f"s3://{s3_config.default_bucket}/{s3_config.input_prefix}"
290
+ self.logger.debug("Default input S3 URI available", uri=default_input_uri)
291
+
292
+ if 'outputs' not in runtime_params and 'outputs' not in config:
293
+ # Default output location
294
+ default_output_uri = f"s3://{s3_config.default_bucket}/{s3_config.output_prefix}"
295
+ self.logger.debug("Default output S3 URI available", uri=default_output_uri)
296
+
297
+ # Apply any remaining runtime parameters (these override everything)
298
+ for key, value in runtime_params.items():
299
+ if key not in ['instance_type', 'instance_count', 'role_arn',
300
+ 'network_config', 'volume_kms_key', 'output_kms_key']:
301
+ config[key] = value
302
+ self.logger.debug(f"Using runtime parameter: {key}")
303
+
304
+ return config
305
+
306
+ def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
307
+ """
308
+ Validate runtime parameter overrides.
309
+
310
+ This ensures that runtime parameters are valid and compatible with the configuration.
311
+ Implements validation requirements from Requirements 4.1, 4.2, 4.3, 4.4.
312
+
313
+ Args:
314
+ runtime_params: Runtime parameters to validate
315
+
316
+ Raises:
317
+ ValidationError: If runtime parameters are invalid
318
+ """
319
+ # Validate instance_type format if provided
320
+ if 'instance_type' in runtime_params:
321
+ instance_type = runtime_params['instance_type']
322
+ if not isinstance(instance_type, str):
323
+ raise ValidationError("instance_type must be a string")
324
+
325
+ if not instance_type.startswith('ml.'):
326
+ raise ValidationError(f"Invalid instance type format: {instance_type}. Must start with 'ml.'")
327
+
328
+ # Validate instance_count if provided
329
+ if 'instance_count' in runtime_params:
330
+ instance_count = runtime_params['instance_count']
331
+ if not isinstance(instance_count, int):
332
+ raise ValidationError("instance_count must be an integer")
333
+
334
+ if instance_count < 1:
335
+ raise ValidationError(f"instance_count must be at least 1, got {instance_count}")
336
+
337
+ # Validate role_arn format if provided
338
+ if 'role_arn' in runtime_params:
339
+ role_arn = runtime_params['role_arn']
340
+ if not isinstance(role_arn, str):
341
+ raise ValidationError("role_arn must be a string")
342
+
343
+ if not role_arn.startswith('arn:aws:iam::'):
344
+ raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
345
+
346
+ # Validate network_config if provided
347
+ if 'network_config' in runtime_params:
348
+ network_config = runtime_params['network_config']
349
+ # NetworkConfig should be a NetworkConfig object or dict
350
+ if not hasattr(network_config, 'security_group_ids') and not isinstance(network_config, dict):
351
+ raise ValidationError("network_config must be a NetworkConfig object or dictionary")
352
+
353
+ # Validate volume_size_in_gb if provided
354
+ if 'volume_size_in_gb' in runtime_params:
355
+ volume_size = runtime_params['volume_size_in_gb']
356
+ if not isinstance(volume_size, int):
357
+ raise ValidationError("volume_size_in_gb must be an integer")
358
+
359
+ if volume_size < 1:
360
+ raise ValidationError(f"volume_size_in_gb must be at least 1, got {volume_size}")
361
+
362
+ # Validate max_runtime_in_seconds if provided
363
+ if 'max_runtime_in_seconds' in runtime_params:
364
+ max_runtime = runtime_params['max_runtime_in_seconds']
365
+ if not isinstance(max_runtime, int):
366
+ raise ValidationError("max_runtime_in_seconds must be an integer")
367
+
368
+ if max_runtime < 1:
369
+ raise ValidationError(f"max_runtime_in_seconds must be at least 1, got {max_runtime}")
370
+
371
+ # Validate inputs if provided
372
+ if 'inputs' in runtime_params:
373
+ inputs = runtime_params['inputs']
374
+ if not isinstance(inputs, list):
375
+ raise ValidationError("inputs must be a list")
376
+
377
+ # Validate outputs if provided
378
+ if 'outputs' in runtime_params:
379
+ outputs = runtime_params['outputs']
380
+ if not isinstance(outputs, list):
381
+ raise ValidationError("outputs must be a list")