sagemaker-mlp-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,492 @@
1
+ """
2
+ Training Job wrapper for mlp_sdk
3
+ Provides simplified training job execution with configuration-driven defaults
4
+
5
+ Uses SageMaker SDK v3 ModelTrainer API for modern training job execution.
6
+ """
7
+
8
+ from typing import Optional, Dict, Any, List
9
+ from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
10
+ from ..config import ConfigurationManager
11
+
12
+
13
+ class TrainingWrapper:
14
+ """
15
+ Wrapper for SageMaker Training Job operations using ModelTrainer (SDK v3).
16
+ Applies default configurations from ConfigurationManager.
17
+
18
+ Note: This wrapper uses the modern ModelTrainer API introduced in SageMaker SDK v2.x
19
+ and recommended for SDK v3, which replaces the legacy Estimator class.
20
+ """
21
+
22
+ def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
23
+ """
24
+ Initialize Training wrapper.
25
+
26
+ Args:
27
+ config_manager: Configuration manager instance
28
+ logger: Optional logger instance
29
+ """
30
+ self.config_manager = config_manager
31
+ self.logger = logger or MLPLogger("mlp_sdk.training")
32
+
33
+ def run_training_job(self,
34
+ sagemaker_session,
35
+ job_name: str,
36
+ training_image: Optional[str] = None,
37
+ source_code_dir: Optional[str] = None,
38
+ entry_script: Optional[str] = None,
39
+ requirements: Optional[str] = None,
40
+ inputs: Optional[Dict[str, Any]] = None,
41
+ **kwargs) -> Any:
42
+ """
43
+ Execute training job with default configurations using ModelTrainer.
44
+
45
+ Applies defaults from configuration for:
46
+ - Instance type and count (via Compute config)
47
+ - IAM execution role
48
+ - VPC configuration (VPC ID, security groups, subnets)
49
+ - S3 input/output paths
50
+ - KMS encryption key
51
+
52
+ Runtime parameters override configuration defaults.
53
+
54
+ Supports both custom training scripts and container images.
55
+
56
+ Args:
57
+ sagemaker_session: SageMaker session object
58
+ job_name: Name of the training job (used as base_job_name)
59
+ training_image: Container image URI for training
60
+ source_code_dir: Directory containing training script and dependencies
61
+ entry_script: Entry point script for training (e.g., 'train.py')
62
+ requirements: Path to requirements.txt file for dependencies
63
+ inputs: Training data inputs (S3 paths or InputData objects)
64
+ **kwargs: Additional parameters that override defaults
65
+ (e.g., hyperparameters, environment, distributed_runner)
66
+
67
+ Returns:
68
+ ModelTrainer object
69
+
70
+ Raises:
71
+ ValidationError: If required parameters are missing or invalid
72
+ AWSServiceError: If training job execution fails
73
+ """
74
+ self.logger.info("Running training job with ModelTrainer", name=job_name)
75
+
76
+ # Validate required parameters
77
+ if not job_name:
78
+ raise ValidationError("job_name is required")
79
+
80
+ if not training_image:
81
+ raise ValidationError("training_image is required for ModelTrainer")
82
+
83
+ # Validate runtime parameter overrides
84
+ self.validate_parameter_override(kwargs)
85
+
86
+ try:
87
+ from sagemaker.train.model_trainer import ModelTrainer, SourceCode, Compute, InputData
88
+ from sagemaker.core.training.configs import Networking, StoppingCondition
89
+ from sagemaker.core.shapes.shapes import OutputDataConfig
90
+ except ImportError as e:
91
+ raise MLPSDKError(
92
+ "SageMaker SDK v3 not installed or ModelTrainer not available. "
93
+ "Install with: pip install sagemaker>=3.0.0"
94
+ ) from e
95
+
96
+ # Build configuration with defaults
97
+ config = self._build_training_config(kwargs)
98
+
99
+ # Create SourceCode configuration if script provided
100
+ source_code = None
101
+ if source_code_dir or entry_script:
102
+ source_code_params = {}
103
+ if source_code_dir:
104
+ source_code_params['source_dir'] = source_code_dir
105
+ if entry_script:
106
+ source_code_params['entry_script'] = entry_script
107
+ if requirements:
108
+ source_code_params['requirements'] = requirements
109
+
110
+ # Add any custom command if provided
111
+ if config.get('command'):
112
+ source_code_params['command'] = config['command']
113
+
114
+ source_code = SourceCode(**source_code_params)
115
+
116
+ # Create Compute configuration
117
+ compute_params = {
118
+ 'instance_type': config['instance_type'],
119
+ 'instance_count': config['instance_count'],
120
+ }
121
+
122
+ # Add volume size if provided
123
+ if config.get('volume_size_in_gb'):
124
+ compute_params['volume_size_in_gb'] = config['volume_size_in_gb']
125
+
126
+ # Add volume KMS key if provided (SDK v3 uses volume_kms_key_id)
127
+ if config.get('volume_kms_key'):
128
+ compute_params['volume_kms_key_id'] = config['volume_kms_key']
129
+
130
+ # Add keep alive period if provided
131
+ if config.get('keep_alive_period_in_seconds'):
132
+ compute_params['keep_alive_period_in_seconds'] = config['keep_alive_period_in_seconds']
133
+
134
+ compute = Compute(**compute_params)
135
+
136
+ # Create ModelTrainer parameters
137
+ trainer_params = {
138
+ 'training_image': training_image,
139
+ 'compute': compute,
140
+ 'base_job_name': job_name,
141
+ }
142
+
143
+ # Add source code if provided
144
+ if source_code:
145
+ trainer_params['source_code'] = source_code
146
+
147
+ # Add role if available
148
+ if config.get('role_arn'):
149
+ trainer_params['role'] = config['role_arn']
150
+
151
+ # Add hyperparameters if provided
152
+ if config.get('hyperparameters'):
153
+ trainer_params['hyperparameters'] = config['hyperparameters']
154
+
155
+ # Add environment variables if provided
156
+ if config.get('environment'):
157
+ trainer_params['environment'] = config['environment']
158
+
159
+ # Create OutputDataConfig if output path or KMS key provided
160
+ if config.get('output_path') or config.get('output_kms_key'):
161
+ output_config_params = {}
162
+ if config.get('output_path'):
163
+ output_config_params['s3_output_path'] = config['output_path']
164
+ if config.get('output_kms_key'):
165
+ output_config_params['kms_key_id'] = config['output_kms_key']
166
+
167
+ if output_config_params:
168
+ trainer_params['output_data_config'] = OutputDataConfig(**output_config_params)
169
+
170
+ # Add stopping condition if max runtime provided
171
+ if config.get('max_run_in_seconds'):
172
+ trainer_params['stopping_condition'] = StoppingCondition(
173
+ max_runtime_in_seconds=config['max_run_in_seconds']
174
+ )
175
+
176
+ # Add tags if provided
177
+ if config.get('tags'):
178
+ trainer_params['tags'] = config['tags']
179
+
180
+ # Add distributed runner if provided
181
+ if config.get('distributed_runner'):
182
+ trainer_params['distributed'] = config['distributed_runner']
183
+
184
+ # Add metric definitions if provided
185
+ if config.get('metric_definitions'):
186
+ trainer_params['metric_definitions'] = config['metric_definitions']
187
+
188
+ # Add checkpoint config if provided
189
+ if config.get('checkpoint_s3_uri'):
190
+ trainer_params['checkpoint_s3_uri'] = config['checkpoint_s3_uri']
191
+
192
+ # Create Networking config if subnets, security groups, or encryption settings provided
193
+ networking_params = {}
194
+ if config.get('subnets'):
195
+ networking_params['subnets'] = config['subnets']
196
+ if config.get('security_group_ids'):
197
+ networking_params['security_group_ids'] = config['security_group_ids']
198
+ if config.get('encrypt_inter_container_traffic') is not None:
199
+ networking_params['enable_inter_container_traffic_encryption'] = config['encrypt_inter_container_traffic']
200
+ if config.get('enable_network_isolation') is not None:
201
+ networking_params['enable_network_isolation'] = config['enable_network_isolation']
202
+
203
+ if networking_params:
204
+ trainer_params['networking'] = Networking(**networking_params)
205
+
206
+ # Create ModelTrainer
207
+ model_trainer = ModelTrainer(**trainer_params)
208
+
209
+ try:
210
+ # Prepare input data configuration
211
+ input_data_config = []
212
+ if inputs:
213
+ if isinstance(inputs, dict):
214
+ # Convert dict to list of InputData objects
215
+ for channel_name, data_source in inputs.items():
216
+ input_data_config.append(
217
+ InputData(
218
+ channel_name=channel_name,
219
+ data_source=data_source,
220
+ content_type='text/csv' # Default content type for CSV data
221
+ )
222
+ )
223
+ elif isinstance(inputs, list):
224
+ # Already a list of InputData objects
225
+ input_data_config = inputs
226
+ elif config.get('inputs'):
227
+ # Use inputs from config
228
+ if isinstance(config['inputs'], dict):
229
+ for channel_name, data_source in config['inputs'].items():
230
+ input_data_config.append(
231
+ InputData(
232
+ channel_name=channel_name,
233
+ data_source=data_source,
234
+ content_type='text/csv' # Default content type for CSV data
235
+ )
236
+ )
237
+ else:
238
+ input_data_config = config['inputs']
239
+
240
+ # Build train parameters
241
+ train_params = {}
242
+ if input_data_config:
243
+ train_params['input_data_config'] = input_data_config
244
+
245
+ # Add wait flag if provided (default is True in ModelTrainer)
246
+ if 'wait' in config:
247
+ train_params['wait'] = config['wait']
248
+
249
+ self.logger.debug("Starting training job with ModelTrainer",
250
+ name=job_name,
251
+ instance_type=config['instance_type'],
252
+ instance_count=config['instance_count'],
253
+ has_source_code=source_code is not None)
254
+
255
+ # Start the training job
256
+ model_trainer.train(**train_params)
257
+
258
+ self.logger.info("Training job started successfully", name=job_name)
259
+ return model_trainer
260
+
261
+ except Exception as e:
262
+ self.logger.error("Failed to run training job",
263
+ name=job_name,
264
+ error=e)
265
+ raise AWSServiceError(
266
+ f"Failed to run training job '{job_name}': {e}",
267
+ aws_error=e
268
+ ) from e
269
+
270
+ def _build_training_config(self, runtime_params: Dict[str, Any]) -> Dict[str, Any]:
271
+ """
272
+ Build training job configuration by merging defaults with runtime parameters.
273
+
274
+ Parameter precedence: runtime > config > SageMaker defaults
275
+
276
+ This implements the parameter override behavior specified in Requirements 5.1, 5.2, 5.3, 5.4:
277
+ - Runtime parameters always take precedence over configuration defaults
278
+ - Configuration defaults take precedence over SageMaker SDK defaults
279
+ - SageMaker SDK defaults are used when neither runtime nor config provide values
280
+
281
+ Args:
282
+ runtime_params: Runtime parameters provided by user
283
+
284
+ Returns:
285
+ Merged configuration dictionary
286
+ """
287
+ config = {}
288
+
289
+ # Get configuration objects
290
+ compute_config = self.config_manager.get_compute_config()
291
+ networking_config = self.config_manager.get_networking_config()
292
+ iam_config = self.config_manager.get_iam_config()
293
+ s3_config = self.config_manager.get_s3_config()
294
+ kms_config = self.config_manager.get_kms_config()
295
+
296
+ # Apply compute defaults (runtime > config)
297
+ if 'instance_type' in runtime_params:
298
+ config['instance_type'] = runtime_params['instance_type']
299
+ self.logger.debug("Using runtime instance_type", value=runtime_params['instance_type'])
300
+ elif compute_config:
301
+ config['instance_type'] = compute_config.training_instance_type
302
+ self.logger.debug("Using config instance_type", value=compute_config.training_instance_type)
303
+ else:
304
+ # Will use SageMaker SDK default
305
+ config['instance_type'] = 'ml.m5.xlarge'
306
+ self.logger.debug("Using default instance_type", value='ml.m5.xlarge')
307
+
308
+ if 'instance_count' in runtime_params:
309
+ config['instance_count'] = runtime_params['instance_count']
310
+ self.logger.debug("Using runtime instance_count", value=runtime_params['instance_count'])
311
+ elif compute_config:
312
+ config['instance_count'] = compute_config.training_instance_count
313
+ self.logger.debug("Using config instance_count", value=compute_config.training_instance_count)
314
+ else:
315
+ # Will use SageMaker SDK default
316
+ config['instance_count'] = 1
317
+ self.logger.debug("Using default instance_count", value=1)
318
+
319
+ # Apply IAM role default (runtime > config)
320
+ if 'role_arn' in runtime_params:
321
+ config['role_arn'] = runtime_params['role_arn']
322
+ self.logger.debug("Using runtime role_arn")
323
+ elif iam_config:
324
+ config['role_arn'] = iam_config.execution_role
325
+ self.logger.debug("Using config role_arn", role=iam_config.execution_role)
326
+ else:
327
+ raise ValidationError("IAM execution role is required. Provide via runtime parameter or configuration.")
328
+
329
+ # Apply networking defaults (runtime > config)
330
+ if 'subnets' in runtime_params:
331
+ config['subnets'] = runtime_params['subnets']
332
+ self.logger.debug("Using runtime subnets")
333
+ elif networking_config:
334
+ config['subnets'] = networking_config.subnets
335
+ self.logger.debug("Using config subnets", subnets=networking_config.subnets)
336
+
337
+ if 'security_group_ids' in runtime_params:
338
+ config['security_group_ids'] = runtime_params['security_group_ids']
339
+ self.logger.debug("Using runtime security_group_ids")
340
+ elif networking_config:
341
+ config['security_group_ids'] = networking_config.security_group_ids
342
+ self.logger.debug("Using config security_group_ids",
343
+ security_groups=networking_config.security_group_ids)
344
+
345
+ # Apply S3 output path defaults (runtime > config)
346
+ if 'output_path' in runtime_params:
347
+ config['output_path'] = runtime_params['output_path']
348
+ self.logger.debug("Using runtime output_path")
349
+ elif s3_config:
350
+ config['output_path'] = f"s3://{s3_config.default_bucket}/{s3_config.model_prefix}"
351
+ self.logger.debug("Using config output_path", path=config['output_path'])
352
+
353
+ # Apply KMS encryption defaults (runtime > config)
354
+ if 'volume_kms_key' in runtime_params:
355
+ config['volume_kms_key'] = runtime_params['volume_kms_key']
356
+ self.logger.debug("Using runtime volume_kms_key")
357
+ elif kms_config and kms_config.key_id:
358
+ config['volume_kms_key'] = kms_config.key_id
359
+ self.logger.debug("Using config volume_kms_key", key_id=kms_config.key_id)
360
+
361
+ if 'output_kms_key' in runtime_params:
362
+ config['output_kms_key'] = runtime_params['output_kms_key']
363
+ self.logger.debug("Using runtime output_kms_key")
364
+ elif kms_config and kms_config.key_id:
365
+ config['output_kms_key'] = kms_config.key_id
366
+ self.logger.debug("Using config output_kms_key", key_id=kms_config.key_id)
367
+
368
+ # Apply S3 defaults for inputs if not provided (runtime > config)
369
+ if s3_config:
370
+ if 'inputs' not in runtime_params and 'inputs' not in config:
371
+ # Default input location
372
+ default_input_uri = f"s3://{s3_config.default_bucket}/{s3_config.input_prefix}"
373
+ self.logger.debug("Default input S3 URI available", uri=default_input_uri)
374
+
375
+ # Apply any remaining runtime parameters (these override everything)
376
+ for key, value in runtime_params.items():
377
+ if key not in ['instance_type', 'instance_count', 'role_arn',
378
+ 'subnets', 'security_group_ids', 'output_path',
379
+ 'volume_kms_key', 'output_kms_key']:
380
+ config[key] = value
381
+ self.logger.debug(f"Using runtime parameter: {key}")
382
+
383
+ return config
384
+
385
+ def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
386
+ """
387
+ Validate runtime parameter overrides.
388
+
389
+ This ensures that runtime parameters are valid and compatible with the configuration.
390
+ Implements validation requirements from Requirements 5.1, 5.2, 5.3, 5.4, 5.5.
391
+
392
+ Args:
393
+ runtime_params: Runtime parameters to validate
394
+
395
+ Raises:
396
+ ValidationError: If runtime parameters are invalid
397
+ """
398
+ # Validate instance_type format if provided
399
+ if 'instance_type' in runtime_params:
400
+ instance_type = runtime_params['instance_type']
401
+ if not isinstance(instance_type, str):
402
+ raise ValidationError("instance_type must be a string")
403
+
404
+ if not instance_type.startswith('ml.'):
405
+ raise ValidationError(f"Invalid instance type format: {instance_type}. Must start with 'ml.'")
406
+
407
+ # Validate instance_count if provided
408
+ if 'instance_count' in runtime_params:
409
+ instance_count = runtime_params['instance_count']
410
+ if not isinstance(instance_count, int):
411
+ raise ValidationError("instance_count must be an integer")
412
+
413
+ if instance_count < 1:
414
+ raise ValidationError(f"instance_count must be at least 1, got {instance_count}")
415
+
416
+ # Validate role_arn format if provided
417
+ if 'role_arn' in runtime_params:
418
+ role_arn = runtime_params['role_arn']
419
+ if not isinstance(role_arn, str):
420
+ raise ValidationError("role_arn must be a string")
421
+
422
+ if not role_arn.startswith('arn:aws:iam::'):
423
+ raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
424
+
425
+ # Validate subnets if provided
426
+ if 'subnets' in runtime_params:
427
+ subnets = runtime_params['subnets']
428
+ if not isinstance(subnets, list):
429
+ raise ValidationError("subnets must be a list")
430
+
431
+ if not subnets:
432
+ raise ValidationError("subnets list cannot be empty")
433
+
434
+ # Validate security_group_ids if provided
435
+ if 'security_group_ids' in runtime_params:
436
+ security_groups = runtime_params['security_group_ids']
437
+ if not isinstance(security_groups, list):
438
+ raise ValidationError("security_group_ids must be a list")
439
+
440
+ if not security_groups:
441
+ raise ValidationError("security_group_ids list cannot be empty")
442
+
443
+ # Validate output_path if provided
444
+ if 'output_path' in runtime_params:
445
+ output_path = runtime_params['output_path']
446
+ if not isinstance(output_path, str):
447
+ raise ValidationError("output_path must be a string")
448
+
449
+ if not output_path.startswith('s3://'):
450
+ raise ValidationError(f"output_path must be an S3 URI starting with 's3://', got: {output_path}")
451
+
452
+ # Validate volume_size_in_gb if provided (ModelTrainer uses volume_size_in_gb)
453
+ if 'volume_size_in_gb' in runtime_params:
454
+ volume_size = runtime_params['volume_size_in_gb']
455
+ if not isinstance(volume_size, int):
456
+ raise ValidationError("volume_size_in_gb must be an integer")
457
+
458
+ if volume_size < 1:
459
+ raise ValidationError(f"volume_size_in_gb must be at least 1, got {volume_size}")
460
+
461
+ # Validate max_run_in_seconds if provided (ModelTrainer uses max_run_in_seconds)
462
+ if 'max_run_in_seconds' in runtime_params:
463
+ max_run = runtime_params['max_run_in_seconds']
464
+ if not isinstance(max_run, int):
465
+ raise ValidationError("max_run_in_seconds must be an integer")
466
+
467
+ if max_run < 1:
468
+ raise ValidationError(f"max_run_in_seconds must be at least 1, got {max_run}")
469
+
470
+ # Validate inputs if provided
471
+ if 'inputs' in runtime_params:
472
+ inputs = runtime_params['inputs']
473
+ if not isinstance(inputs, (dict, list)):
474
+ raise ValidationError("inputs must be a dictionary or list")
475
+
476
+ # Validate hyperparameters if provided
477
+ if 'hyperparameters' in runtime_params:
478
+ hyperparameters = runtime_params['hyperparameters']
479
+ if not isinstance(hyperparameters, dict):
480
+ raise ValidationError("hyperparameters must be a dictionary")
481
+
482
+ # Validate image_uri if provided (now called training_image in ModelTrainer)
483
+ if 'training_image' in runtime_params:
484
+ training_image = runtime_params['training_image']
485
+ if not isinstance(training_image, str):
486
+ raise ValidationError("training_image must be a string")
487
+
488
+ # Also support legacy image_uri parameter name for backward compatibility
489
+ if 'image_uri' in runtime_params:
490
+ image_uri = runtime_params['image_uri']
491
+ if not isinstance(image_uri, str):
492
+ raise ValidationError("image_uri must be a string")