sagemaker-mlp-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlp_sdk/__init__.py +18 -0
- mlp_sdk/config.py +630 -0
- mlp_sdk/exceptions.py +421 -0
- mlp_sdk/models.py +62 -0
- mlp_sdk/session.py +1160 -0
- mlp_sdk/wrappers/__init__.py +11 -0
- mlp_sdk/wrappers/deployment.py +459 -0
- mlp_sdk/wrappers/feature_store.py +308 -0
- mlp_sdk/wrappers/pipeline.py +452 -0
- mlp_sdk/wrappers/processing.py +381 -0
- mlp_sdk/wrappers/training.py +492 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/METADATA +569 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/RECORD +16 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Job wrapper for mlp_sdk
|
|
3
|
+
Provides simplified training job execution with configuration-driven defaults
|
|
4
|
+
|
|
5
|
+
Uses SageMaker SDK v3 ModelTrainer API for modern training job execution.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Dict, Any, List
|
|
9
|
+
from ..exceptions import MLPSDKError, ValidationError, AWSServiceError, MLPLogger
|
|
10
|
+
from ..config import ConfigurationManager
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TrainingWrapper:
|
|
14
|
+
"""
|
|
15
|
+
Wrapper for SageMaker Training Job operations using ModelTrainer (SDK v3).
|
|
16
|
+
Applies default configurations from ConfigurationManager.
|
|
17
|
+
|
|
18
|
+
Note: This wrapper uses the modern ModelTrainer API introduced in SageMaker SDK v2.x
|
|
19
|
+
and recommended for SDK v3, which replaces the legacy Estimator class.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config_manager: ConfigurationManager, logger: Optional[MLPLogger] = None):
|
|
23
|
+
"""
|
|
24
|
+
Initialize Training wrapper.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
config_manager: Configuration manager instance
|
|
28
|
+
logger: Optional logger instance
|
|
29
|
+
"""
|
|
30
|
+
self.config_manager = config_manager
|
|
31
|
+
self.logger = logger or MLPLogger("mlp_sdk.training")
|
|
32
|
+
|
|
33
|
+
def run_training_job(self,
|
|
34
|
+
sagemaker_session,
|
|
35
|
+
job_name: str,
|
|
36
|
+
training_image: Optional[str] = None,
|
|
37
|
+
source_code_dir: Optional[str] = None,
|
|
38
|
+
entry_script: Optional[str] = None,
|
|
39
|
+
requirements: Optional[str] = None,
|
|
40
|
+
inputs: Optional[Dict[str, Any]] = None,
|
|
41
|
+
**kwargs) -> Any:
|
|
42
|
+
"""
|
|
43
|
+
Execute training job with default configurations using ModelTrainer.
|
|
44
|
+
|
|
45
|
+
Applies defaults from configuration for:
|
|
46
|
+
- Instance type and count (via Compute config)
|
|
47
|
+
- IAM execution role
|
|
48
|
+
- VPC configuration (VPC ID, security groups, subnets)
|
|
49
|
+
- S3 input/output paths
|
|
50
|
+
- KMS encryption key
|
|
51
|
+
|
|
52
|
+
Runtime parameters override configuration defaults.
|
|
53
|
+
|
|
54
|
+
Supports both custom training scripts and container images.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
sagemaker_session: SageMaker session object
|
|
58
|
+
job_name: Name of the training job (used as base_job_name)
|
|
59
|
+
training_image: Container image URI for training
|
|
60
|
+
source_code_dir: Directory containing training script and dependencies
|
|
61
|
+
entry_script: Entry point script for training (e.g., 'train.py')
|
|
62
|
+
requirements: Path to requirements.txt file for dependencies
|
|
63
|
+
inputs: Training data inputs (S3 paths or InputData objects)
|
|
64
|
+
**kwargs: Additional parameters that override defaults
|
|
65
|
+
(e.g., hyperparameters, environment, distributed_runner)
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
ModelTrainer object
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValidationError: If required parameters are missing or invalid
|
|
72
|
+
AWSServiceError: If training job execution fails
|
|
73
|
+
"""
|
|
74
|
+
self.logger.info("Running training job with ModelTrainer", name=job_name)
|
|
75
|
+
|
|
76
|
+
# Validate required parameters
|
|
77
|
+
if not job_name:
|
|
78
|
+
raise ValidationError("job_name is required")
|
|
79
|
+
|
|
80
|
+
if not training_image:
|
|
81
|
+
raise ValidationError("training_image is required for ModelTrainer")
|
|
82
|
+
|
|
83
|
+
# Validate runtime parameter overrides
|
|
84
|
+
self.validate_parameter_override(kwargs)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
from sagemaker.train.model_trainer import ModelTrainer, SourceCode, Compute, InputData
|
|
88
|
+
from sagemaker.core.training.configs import Networking, StoppingCondition
|
|
89
|
+
from sagemaker.core.shapes.shapes import OutputDataConfig
|
|
90
|
+
except ImportError as e:
|
|
91
|
+
raise MLPSDKError(
|
|
92
|
+
"SageMaker SDK v3 not installed or ModelTrainer not available. "
|
|
93
|
+
"Install with: pip install sagemaker>=3.0.0"
|
|
94
|
+
) from e
|
|
95
|
+
|
|
96
|
+
# Build configuration with defaults
|
|
97
|
+
config = self._build_training_config(kwargs)
|
|
98
|
+
|
|
99
|
+
# Create SourceCode configuration if script provided
|
|
100
|
+
source_code = None
|
|
101
|
+
if source_code_dir or entry_script:
|
|
102
|
+
source_code_params = {}
|
|
103
|
+
if source_code_dir:
|
|
104
|
+
source_code_params['source_dir'] = source_code_dir
|
|
105
|
+
if entry_script:
|
|
106
|
+
source_code_params['entry_script'] = entry_script
|
|
107
|
+
if requirements:
|
|
108
|
+
source_code_params['requirements'] = requirements
|
|
109
|
+
|
|
110
|
+
# Add any custom command if provided
|
|
111
|
+
if config.get('command'):
|
|
112
|
+
source_code_params['command'] = config['command']
|
|
113
|
+
|
|
114
|
+
source_code = SourceCode(**source_code_params)
|
|
115
|
+
|
|
116
|
+
# Create Compute configuration
|
|
117
|
+
compute_params = {
|
|
118
|
+
'instance_type': config['instance_type'],
|
|
119
|
+
'instance_count': config['instance_count'],
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
# Add volume size if provided
|
|
123
|
+
if config.get('volume_size_in_gb'):
|
|
124
|
+
compute_params['volume_size_in_gb'] = config['volume_size_in_gb']
|
|
125
|
+
|
|
126
|
+
# Add volume KMS key if provided (SDK v3 uses volume_kms_key_id)
|
|
127
|
+
if config.get('volume_kms_key'):
|
|
128
|
+
compute_params['volume_kms_key_id'] = config['volume_kms_key']
|
|
129
|
+
|
|
130
|
+
# Add keep alive period if provided
|
|
131
|
+
if config.get('keep_alive_period_in_seconds'):
|
|
132
|
+
compute_params['keep_alive_period_in_seconds'] = config['keep_alive_period_in_seconds']
|
|
133
|
+
|
|
134
|
+
compute = Compute(**compute_params)
|
|
135
|
+
|
|
136
|
+
# Create ModelTrainer parameters
|
|
137
|
+
trainer_params = {
|
|
138
|
+
'training_image': training_image,
|
|
139
|
+
'compute': compute,
|
|
140
|
+
'base_job_name': job_name,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
# Add source code if provided
|
|
144
|
+
if source_code:
|
|
145
|
+
trainer_params['source_code'] = source_code
|
|
146
|
+
|
|
147
|
+
# Add role if available
|
|
148
|
+
if config.get('role_arn'):
|
|
149
|
+
trainer_params['role'] = config['role_arn']
|
|
150
|
+
|
|
151
|
+
# Add hyperparameters if provided
|
|
152
|
+
if config.get('hyperparameters'):
|
|
153
|
+
trainer_params['hyperparameters'] = config['hyperparameters']
|
|
154
|
+
|
|
155
|
+
# Add environment variables if provided
|
|
156
|
+
if config.get('environment'):
|
|
157
|
+
trainer_params['environment'] = config['environment']
|
|
158
|
+
|
|
159
|
+
# Create OutputDataConfig if output path or KMS key provided
|
|
160
|
+
if config.get('output_path') or config.get('output_kms_key'):
|
|
161
|
+
output_config_params = {}
|
|
162
|
+
if config.get('output_path'):
|
|
163
|
+
output_config_params['s3_output_path'] = config['output_path']
|
|
164
|
+
if config.get('output_kms_key'):
|
|
165
|
+
output_config_params['kms_key_id'] = config['output_kms_key']
|
|
166
|
+
|
|
167
|
+
if output_config_params:
|
|
168
|
+
trainer_params['output_data_config'] = OutputDataConfig(**output_config_params)
|
|
169
|
+
|
|
170
|
+
# Add stopping condition if max runtime provided
|
|
171
|
+
if config.get('max_run_in_seconds'):
|
|
172
|
+
trainer_params['stopping_condition'] = StoppingCondition(
|
|
173
|
+
max_runtime_in_seconds=config['max_run_in_seconds']
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Add tags if provided
|
|
177
|
+
if config.get('tags'):
|
|
178
|
+
trainer_params['tags'] = config['tags']
|
|
179
|
+
|
|
180
|
+
# Add distributed runner if provided
|
|
181
|
+
if config.get('distributed_runner'):
|
|
182
|
+
trainer_params['distributed'] = config['distributed_runner']
|
|
183
|
+
|
|
184
|
+
# Add metric definitions if provided
|
|
185
|
+
if config.get('metric_definitions'):
|
|
186
|
+
trainer_params['metric_definitions'] = config['metric_definitions']
|
|
187
|
+
|
|
188
|
+
# Add checkpoint config if provided
|
|
189
|
+
if config.get('checkpoint_s3_uri'):
|
|
190
|
+
trainer_params['checkpoint_s3_uri'] = config['checkpoint_s3_uri']
|
|
191
|
+
|
|
192
|
+
# Create Networking config if subnets, security groups, or encryption settings provided
|
|
193
|
+
networking_params = {}
|
|
194
|
+
if config.get('subnets'):
|
|
195
|
+
networking_params['subnets'] = config['subnets']
|
|
196
|
+
if config.get('security_group_ids'):
|
|
197
|
+
networking_params['security_group_ids'] = config['security_group_ids']
|
|
198
|
+
if config.get('encrypt_inter_container_traffic') is not None:
|
|
199
|
+
networking_params['enable_inter_container_traffic_encryption'] = config['encrypt_inter_container_traffic']
|
|
200
|
+
if config.get('enable_network_isolation') is not None:
|
|
201
|
+
networking_params['enable_network_isolation'] = config['enable_network_isolation']
|
|
202
|
+
|
|
203
|
+
if networking_params:
|
|
204
|
+
trainer_params['networking'] = Networking(**networking_params)
|
|
205
|
+
|
|
206
|
+
# Create ModelTrainer
|
|
207
|
+
model_trainer = ModelTrainer(**trainer_params)
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
# Prepare input data configuration
|
|
211
|
+
input_data_config = []
|
|
212
|
+
if inputs:
|
|
213
|
+
if isinstance(inputs, dict):
|
|
214
|
+
# Convert dict to list of InputData objects
|
|
215
|
+
for channel_name, data_source in inputs.items():
|
|
216
|
+
input_data_config.append(
|
|
217
|
+
InputData(
|
|
218
|
+
channel_name=channel_name,
|
|
219
|
+
data_source=data_source,
|
|
220
|
+
content_type='text/csv' # Default content type for CSV data
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
elif isinstance(inputs, list):
|
|
224
|
+
# Already a list of InputData objects
|
|
225
|
+
input_data_config = inputs
|
|
226
|
+
elif config.get('inputs'):
|
|
227
|
+
# Use inputs from config
|
|
228
|
+
if isinstance(config['inputs'], dict):
|
|
229
|
+
for channel_name, data_source in config['inputs'].items():
|
|
230
|
+
input_data_config.append(
|
|
231
|
+
InputData(
|
|
232
|
+
channel_name=channel_name,
|
|
233
|
+
data_source=data_source,
|
|
234
|
+
content_type='text/csv' # Default content type for CSV data
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
input_data_config = config['inputs']
|
|
239
|
+
|
|
240
|
+
# Build train parameters
|
|
241
|
+
train_params = {}
|
|
242
|
+
if input_data_config:
|
|
243
|
+
train_params['input_data_config'] = input_data_config
|
|
244
|
+
|
|
245
|
+
# Add wait flag if provided (default is True in ModelTrainer)
|
|
246
|
+
if 'wait' in config:
|
|
247
|
+
train_params['wait'] = config['wait']
|
|
248
|
+
|
|
249
|
+
self.logger.debug("Starting training job with ModelTrainer",
|
|
250
|
+
name=job_name,
|
|
251
|
+
instance_type=config['instance_type'],
|
|
252
|
+
instance_count=config['instance_count'],
|
|
253
|
+
has_source_code=source_code is not None)
|
|
254
|
+
|
|
255
|
+
# Start the training job
|
|
256
|
+
model_trainer.train(**train_params)
|
|
257
|
+
|
|
258
|
+
self.logger.info("Training job started successfully", name=job_name)
|
|
259
|
+
return model_trainer
|
|
260
|
+
|
|
261
|
+
except Exception as e:
|
|
262
|
+
self.logger.error("Failed to run training job",
|
|
263
|
+
name=job_name,
|
|
264
|
+
error=e)
|
|
265
|
+
raise AWSServiceError(
|
|
266
|
+
f"Failed to run training job '{job_name}': {e}",
|
|
267
|
+
aws_error=e
|
|
268
|
+
) from e
|
|
269
|
+
|
|
270
|
+
def _build_training_config(self, runtime_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
271
|
+
"""
|
|
272
|
+
Build training job configuration by merging defaults with runtime parameters.
|
|
273
|
+
|
|
274
|
+
Parameter precedence: runtime > config > SageMaker defaults
|
|
275
|
+
|
|
276
|
+
This implements the parameter override behavior specified in Requirements 5.1, 5.2, 5.3, 5.4:
|
|
277
|
+
- Runtime parameters always take precedence over configuration defaults
|
|
278
|
+
- Configuration defaults take precedence over SageMaker SDK defaults
|
|
279
|
+
- SageMaker SDK defaults are used when neither runtime nor config provide values
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
runtime_params: Runtime parameters provided by user
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Merged configuration dictionary
|
|
286
|
+
"""
|
|
287
|
+
config = {}
|
|
288
|
+
|
|
289
|
+
# Get configuration objects
|
|
290
|
+
compute_config = self.config_manager.get_compute_config()
|
|
291
|
+
networking_config = self.config_manager.get_networking_config()
|
|
292
|
+
iam_config = self.config_manager.get_iam_config()
|
|
293
|
+
s3_config = self.config_manager.get_s3_config()
|
|
294
|
+
kms_config = self.config_manager.get_kms_config()
|
|
295
|
+
|
|
296
|
+
# Apply compute defaults (runtime > config)
|
|
297
|
+
if 'instance_type' in runtime_params:
|
|
298
|
+
config['instance_type'] = runtime_params['instance_type']
|
|
299
|
+
self.logger.debug("Using runtime instance_type", value=runtime_params['instance_type'])
|
|
300
|
+
elif compute_config:
|
|
301
|
+
config['instance_type'] = compute_config.training_instance_type
|
|
302
|
+
self.logger.debug("Using config instance_type", value=compute_config.training_instance_type)
|
|
303
|
+
else:
|
|
304
|
+
# Will use SageMaker SDK default
|
|
305
|
+
config['instance_type'] = 'ml.m5.xlarge'
|
|
306
|
+
self.logger.debug("Using default instance_type", value='ml.m5.xlarge')
|
|
307
|
+
|
|
308
|
+
if 'instance_count' in runtime_params:
|
|
309
|
+
config['instance_count'] = runtime_params['instance_count']
|
|
310
|
+
self.logger.debug("Using runtime instance_count", value=runtime_params['instance_count'])
|
|
311
|
+
elif compute_config:
|
|
312
|
+
config['instance_count'] = compute_config.training_instance_count
|
|
313
|
+
self.logger.debug("Using config instance_count", value=compute_config.training_instance_count)
|
|
314
|
+
else:
|
|
315
|
+
# Will use SageMaker SDK default
|
|
316
|
+
config['instance_count'] = 1
|
|
317
|
+
self.logger.debug("Using default instance_count", value=1)
|
|
318
|
+
|
|
319
|
+
# Apply IAM role default (runtime > config)
|
|
320
|
+
if 'role_arn' in runtime_params:
|
|
321
|
+
config['role_arn'] = runtime_params['role_arn']
|
|
322
|
+
self.logger.debug("Using runtime role_arn")
|
|
323
|
+
elif iam_config:
|
|
324
|
+
config['role_arn'] = iam_config.execution_role
|
|
325
|
+
self.logger.debug("Using config role_arn", role=iam_config.execution_role)
|
|
326
|
+
else:
|
|
327
|
+
raise ValidationError("IAM execution role is required. Provide via runtime parameter or configuration.")
|
|
328
|
+
|
|
329
|
+
# Apply networking defaults (runtime > config)
|
|
330
|
+
if 'subnets' in runtime_params:
|
|
331
|
+
config['subnets'] = runtime_params['subnets']
|
|
332
|
+
self.logger.debug("Using runtime subnets")
|
|
333
|
+
elif networking_config:
|
|
334
|
+
config['subnets'] = networking_config.subnets
|
|
335
|
+
self.logger.debug("Using config subnets", subnets=networking_config.subnets)
|
|
336
|
+
|
|
337
|
+
if 'security_group_ids' in runtime_params:
|
|
338
|
+
config['security_group_ids'] = runtime_params['security_group_ids']
|
|
339
|
+
self.logger.debug("Using runtime security_group_ids")
|
|
340
|
+
elif networking_config:
|
|
341
|
+
config['security_group_ids'] = networking_config.security_group_ids
|
|
342
|
+
self.logger.debug("Using config security_group_ids",
|
|
343
|
+
security_groups=networking_config.security_group_ids)
|
|
344
|
+
|
|
345
|
+
# Apply S3 output path defaults (runtime > config)
|
|
346
|
+
if 'output_path' in runtime_params:
|
|
347
|
+
config['output_path'] = runtime_params['output_path']
|
|
348
|
+
self.logger.debug("Using runtime output_path")
|
|
349
|
+
elif s3_config:
|
|
350
|
+
config['output_path'] = f"s3://{s3_config.default_bucket}/{s3_config.model_prefix}"
|
|
351
|
+
self.logger.debug("Using config output_path", path=config['output_path'])
|
|
352
|
+
|
|
353
|
+
# Apply KMS encryption defaults (runtime > config)
|
|
354
|
+
if 'volume_kms_key' in runtime_params:
|
|
355
|
+
config['volume_kms_key'] = runtime_params['volume_kms_key']
|
|
356
|
+
self.logger.debug("Using runtime volume_kms_key")
|
|
357
|
+
elif kms_config and kms_config.key_id:
|
|
358
|
+
config['volume_kms_key'] = kms_config.key_id
|
|
359
|
+
self.logger.debug("Using config volume_kms_key", key_id=kms_config.key_id)
|
|
360
|
+
|
|
361
|
+
if 'output_kms_key' in runtime_params:
|
|
362
|
+
config['output_kms_key'] = runtime_params['output_kms_key']
|
|
363
|
+
self.logger.debug("Using runtime output_kms_key")
|
|
364
|
+
elif kms_config and kms_config.key_id:
|
|
365
|
+
config['output_kms_key'] = kms_config.key_id
|
|
366
|
+
self.logger.debug("Using config output_kms_key", key_id=kms_config.key_id)
|
|
367
|
+
|
|
368
|
+
# Apply S3 defaults for inputs if not provided (runtime > config)
|
|
369
|
+
if s3_config:
|
|
370
|
+
if 'inputs' not in runtime_params and 'inputs' not in config:
|
|
371
|
+
# Default input location
|
|
372
|
+
default_input_uri = f"s3://{s3_config.default_bucket}/{s3_config.input_prefix}"
|
|
373
|
+
self.logger.debug("Default input S3 URI available", uri=default_input_uri)
|
|
374
|
+
|
|
375
|
+
# Apply any remaining runtime parameters (these override everything)
|
|
376
|
+
for key, value in runtime_params.items():
|
|
377
|
+
if key not in ['instance_type', 'instance_count', 'role_arn',
|
|
378
|
+
'subnets', 'security_group_ids', 'output_path',
|
|
379
|
+
'volume_kms_key', 'output_kms_key']:
|
|
380
|
+
config[key] = value
|
|
381
|
+
self.logger.debug(f"Using runtime parameter: {key}")
|
|
382
|
+
|
|
383
|
+
return config
|
|
384
|
+
|
|
385
|
+
def validate_parameter_override(self, runtime_params: Dict[str, Any]) -> None:
|
|
386
|
+
"""
|
|
387
|
+
Validate runtime parameter overrides.
|
|
388
|
+
|
|
389
|
+
This ensures that runtime parameters are valid and compatible with the configuration.
|
|
390
|
+
Implements validation requirements from Requirements 5.1, 5.2, 5.3, 5.4, 5.5.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
runtime_params: Runtime parameters to validate
|
|
394
|
+
|
|
395
|
+
Raises:
|
|
396
|
+
ValidationError: If runtime parameters are invalid
|
|
397
|
+
"""
|
|
398
|
+
# Validate instance_type format if provided
|
|
399
|
+
if 'instance_type' in runtime_params:
|
|
400
|
+
instance_type = runtime_params['instance_type']
|
|
401
|
+
if not isinstance(instance_type, str):
|
|
402
|
+
raise ValidationError("instance_type must be a string")
|
|
403
|
+
|
|
404
|
+
if not instance_type.startswith('ml.'):
|
|
405
|
+
raise ValidationError(f"Invalid instance type format: {instance_type}. Must start with 'ml.'")
|
|
406
|
+
|
|
407
|
+
# Validate instance_count if provided
|
|
408
|
+
if 'instance_count' in runtime_params:
|
|
409
|
+
instance_count = runtime_params['instance_count']
|
|
410
|
+
if not isinstance(instance_count, int):
|
|
411
|
+
raise ValidationError("instance_count must be an integer")
|
|
412
|
+
|
|
413
|
+
if instance_count < 1:
|
|
414
|
+
raise ValidationError(f"instance_count must be at least 1, got {instance_count}")
|
|
415
|
+
|
|
416
|
+
# Validate role_arn format if provided
|
|
417
|
+
if 'role_arn' in runtime_params:
|
|
418
|
+
role_arn = runtime_params['role_arn']
|
|
419
|
+
if not isinstance(role_arn, str):
|
|
420
|
+
raise ValidationError("role_arn must be a string")
|
|
421
|
+
|
|
422
|
+
if not role_arn.startswith('arn:aws:iam::'):
|
|
423
|
+
raise ValidationError(f"Invalid IAM role ARN format: {role_arn}")
|
|
424
|
+
|
|
425
|
+
# Validate subnets if provided
|
|
426
|
+
if 'subnets' in runtime_params:
|
|
427
|
+
subnets = runtime_params['subnets']
|
|
428
|
+
if not isinstance(subnets, list):
|
|
429
|
+
raise ValidationError("subnets must be a list")
|
|
430
|
+
|
|
431
|
+
if not subnets:
|
|
432
|
+
raise ValidationError("subnets list cannot be empty")
|
|
433
|
+
|
|
434
|
+
# Validate security_group_ids if provided
|
|
435
|
+
if 'security_group_ids' in runtime_params:
|
|
436
|
+
security_groups = runtime_params['security_group_ids']
|
|
437
|
+
if not isinstance(security_groups, list):
|
|
438
|
+
raise ValidationError("security_group_ids must be a list")
|
|
439
|
+
|
|
440
|
+
if not security_groups:
|
|
441
|
+
raise ValidationError("security_group_ids list cannot be empty")
|
|
442
|
+
|
|
443
|
+
# Validate output_path if provided
|
|
444
|
+
if 'output_path' in runtime_params:
|
|
445
|
+
output_path = runtime_params['output_path']
|
|
446
|
+
if not isinstance(output_path, str):
|
|
447
|
+
raise ValidationError("output_path must be a string")
|
|
448
|
+
|
|
449
|
+
if not output_path.startswith('s3://'):
|
|
450
|
+
raise ValidationError(f"output_path must be an S3 URI starting with 's3://', got: {output_path}")
|
|
451
|
+
|
|
452
|
+
# Validate volume_size_in_gb if provided (ModelTrainer uses volume_size_in_gb)
|
|
453
|
+
if 'volume_size_in_gb' in runtime_params:
|
|
454
|
+
volume_size = runtime_params['volume_size_in_gb']
|
|
455
|
+
if not isinstance(volume_size, int):
|
|
456
|
+
raise ValidationError("volume_size_in_gb must be an integer")
|
|
457
|
+
|
|
458
|
+
if volume_size < 1:
|
|
459
|
+
raise ValidationError(f"volume_size_in_gb must be at least 1, got {volume_size}")
|
|
460
|
+
|
|
461
|
+
# Validate max_run_in_seconds if provided (ModelTrainer uses max_run_in_seconds)
|
|
462
|
+
if 'max_run_in_seconds' in runtime_params:
|
|
463
|
+
max_run = runtime_params['max_run_in_seconds']
|
|
464
|
+
if not isinstance(max_run, int):
|
|
465
|
+
raise ValidationError("max_run_in_seconds must be an integer")
|
|
466
|
+
|
|
467
|
+
if max_run < 1:
|
|
468
|
+
raise ValidationError(f"max_run_in_seconds must be at least 1, got {max_run}")
|
|
469
|
+
|
|
470
|
+
# Validate inputs if provided
|
|
471
|
+
if 'inputs' in runtime_params:
|
|
472
|
+
inputs = runtime_params['inputs']
|
|
473
|
+
if not isinstance(inputs, (dict, list)):
|
|
474
|
+
raise ValidationError("inputs must be a dictionary or list")
|
|
475
|
+
|
|
476
|
+
# Validate hyperparameters if provided
|
|
477
|
+
if 'hyperparameters' in runtime_params:
|
|
478
|
+
hyperparameters = runtime_params['hyperparameters']
|
|
479
|
+
if not isinstance(hyperparameters, dict):
|
|
480
|
+
raise ValidationError("hyperparameters must be a dictionary")
|
|
481
|
+
|
|
482
|
+
# Validate image_uri if provided (now called training_image in ModelTrainer)
|
|
483
|
+
if 'training_image' in runtime_params:
|
|
484
|
+
training_image = runtime_params['training_image']
|
|
485
|
+
if not isinstance(training_image, str):
|
|
486
|
+
raise ValidationError("training_image must be a string")
|
|
487
|
+
|
|
488
|
+
# Also support legacy image_uri parameter name for backward compatibility
|
|
489
|
+
if 'image_uri' in runtime_params:
|
|
490
|
+
image_uri = runtime_params['image_uri']
|
|
491
|
+
if not isinstance(image_uri, str):
|
|
492
|
+
raise ValidationError("image_uri must be a string")
|