sagemaker-mlp-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlp_sdk/__init__.py +18 -0
- mlp_sdk/config.py +630 -0
- mlp_sdk/exceptions.py +421 -0
- mlp_sdk/models.py +62 -0
- mlp_sdk/session.py +1160 -0
- mlp_sdk/wrappers/__init__.py +11 -0
- mlp_sdk/wrappers/deployment.py +459 -0
- mlp_sdk/wrappers/feature_store.py +308 -0
- mlp_sdk/wrappers/pipeline.py +452 -0
- mlp_sdk/wrappers/processing.py +381 -0
- mlp_sdk/wrappers/training.py +492 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/METADATA +569 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/RECORD +16 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/top_level.txt +1 -0
mlp_sdk/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
mlp_sdk - A Python wrapper library for SageMaker SDK v3
|
|
3
|
+
|
|
4
|
+
This package provides simplified SageMaker operations with configuration-driven defaults.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
|
|
9
|
+
from .session import MLP_Session
|
|
10
|
+
from .exceptions import MLPSDKError, ConfigurationError, ValidationError, AWSServiceError
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"MLP_Session",
|
|
14
|
+
"MLPSDKError",
|
|
15
|
+
"ConfigurationError",
|
|
16
|
+
"ValidationError",
|
|
17
|
+
"AWSServiceError"
|
|
18
|
+
]
|
mlp_sdk/config.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration management for mlp_sdk
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Dict, Any, List, Union
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pydantic import BaseModel, Field, field_validator
|
|
9
|
+
import yaml
|
|
10
|
+
import os
|
|
11
|
+
import base64
|
|
12
|
+
import json
|
|
13
|
+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
14
|
+
from cryptography.hazmat.backends import default_backend
|
|
15
|
+
from .exceptions import ConfigurationError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class S3Config:
|
|
20
|
+
"""S3 configuration settings"""
|
|
21
|
+
default_bucket: str
|
|
22
|
+
input_prefix: str = "input/"
|
|
23
|
+
output_prefix: str = "output/"
|
|
24
|
+
model_prefix: str = "models/"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class NetworkingConfig:
|
|
29
|
+
"""Networking configuration settings"""
|
|
30
|
+
vpc_id: str
|
|
31
|
+
security_group_ids: List[str]
|
|
32
|
+
subnets: List[str]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ComputeConfig:
|
|
37
|
+
"""Compute configuration settings"""
|
|
38
|
+
processing_instance_type: str = "ml.m5.large"
|
|
39
|
+
training_instance_type: str = "ml.m5.xlarge"
|
|
40
|
+
processing_instance_count: int = 1
|
|
41
|
+
training_instance_count: int = 1
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class FeatureStoreConfig:
|
|
46
|
+
"""Feature Store configuration settings"""
|
|
47
|
+
offline_store_s3_uri: str
|
|
48
|
+
enable_online_store: bool = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class IAMConfig:
|
|
53
|
+
"""IAM configuration settings"""
|
|
54
|
+
execution_role: str
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class KMSConfig:
|
|
59
|
+
"""KMS configuration settings"""
|
|
60
|
+
key_id: Optional[str] = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class MLPConfig:
|
|
65
|
+
"""Main configuration container"""
|
|
66
|
+
s3_config: S3Config
|
|
67
|
+
networking_config: NetworkingConfig
|
|
68
|
+
compute_config: ComputeConfig
|
|
69
|
+
feature_store_config: FeatureStoreConfig
|
|
70
|
+
iam_config: IAMConfig
|
|
71
|
+
kms_config: Optional[KMSConfig] = None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Pydantic models for YAML schema validation
|
|
75
|
+
class S3ConfigSchema(BaseModel):
|
|
76
|
+
"""Pydantic schema for S3 configuration validation"""
|
|
77
|
+
default_bucket: str = Field(..., min_length=3, max_length=63)
|
|
78
|
+
input_prefix: str = Field(default="input/", pattern=r"^[a-zA-Z0-9\-_/]*/$")
|
|
79
|
+
output_prefix: str = Field(default="output/", pattern=r"^[a-zA-Z0-9\-_/]*/$")
|
|
80
|
+
model_prefix: str = Field(default="models/", pattern=r"^[a-zA-Z0-9\-_/]*/$")
|
|
81
|
+
|
|
82
|
+
@field_validator('default_bucket')
|
|
83
|
+
@classmethod
|
|
84
|
+
def validate_bucket_name(cls, v):
|
|
85
|
+
"""Validate S3 bucket name format"""
|
|
86
|
+
if not v.replace('-', '').replace('.', '').isalnum():
|
|
87
|
+
raise ValueError('Bucket name must contain only alphanumeric characters, hyphens, and dots')
|
|
88
|
+
return v
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class NetworkingConfigSchema(BaseModel):
|
|
92
|
+
"""Pydantic schema for networking configuration validation"""
|
|
93
|
+
vpc_id: str = Field(..., pattern=r"^vpc-[a-f0-9]{8,17}$")
|
|
94
|
+
security_group_ids: List[str] = Field(..., min_length=1)
|
|
95
|
+
subnets: List[str] = Field(..., min_length=1)
|
|
96
|
+
|
|
97
|
+
@field_validator('security_group_ids')
|
|
98
|
+
@classmethod
|
|
99
|
+
def validate_security_group_id(cls, v):
|
|
100
|
+
"""Validate security group ID format"""
|
|
101
|
+
for sg_id in v:
|
|
102
|
+
if not sg_id.startswith('sg-') or len(sg_id) < 11:
|
|
103
|
+
raise ValueError(f'Invalid security group ID format: {sg_id}')
|
|
104
|
+
return v
|
|
105
|
+
|
|
106
|
+
@field_validator('subnets')
|
|
107
|
+
@classmethod
|
|
108
|
+
def validate_subnet_id(cls, v):
|
|
109
|
+
"""Validate subnet ID format"""
|
|
110
|
+
for subnet_id in v:
|
|
111
|
+
if not subnet_id.startswith('subnet-') or len(subnet_id) < 15:
|
|
112
|
+
raise ValueError(f'Invalid subnet ID format: {subnet_id}')
|
|
113
|
+
return v
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ComputeConfigSchema(BaseModel):
|
|
117
|
+
"""Pydantic schema for compute configuration validation"""
|
|
118
|
+
processing_instance_type: str = Field(default="ml.m5.large", pattern=r"^ml\.[a-z0-9]+\.[a-z0-9]+$")
|
|
119
|
+
training_instance_type: str = Field(default="ml.m5.xlarge", pattern=r"^ml\.[a-z0-9]+\.[a-z0-9]+$")
|
|
120
|
+
processing_instance_count: int = Field(default=1, ge=1, le=100)
|
|
121
|
+
training_instance_count: int = Field(default=1, ge=1, le=100)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class FeatureStoreConfigSchema(BaseModel):
|
|
125
|
+
"""Pydantic schema for feature store configuration validation"""
|
|
126
|
+
offline_store_s3_uri: str = Field(..., pattern=r"^s3://[a-zA-Z0-9\-_./]+$")
|
|
127
|
+
enable_online_store: bool = Field(default=False)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class IAMConfigSchema(BaseModel):
|
|
131
|
+
"""Pydantic schema for IAM configuration validation"""
|
|
132
|
+
execution_role: str = Field(..., pattern=r"^arn:aws:iam::\d{12}:role/[a-zA-Z0-9+=,.@\-_/]+$")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class KMSConfigSchema(BaseModel):
|
|
136
|
+
"""Pydantic schema for KMS configuration validation"""
|
|
137
|
+
key_id: Optional[str] = Field(None, pattern=r"^(arn:aws:kms:[a-z0-9\-]+:\d{12}:key/)?[a-f0-9\-]{36}$")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class DefaultsConfigSchema(BaseModel):
|
|
141
|
+
"""Pydantic schema for the defaults section of YAML config"""
|
|
142
|
+
s3: S3ConfigSchema
|
|
143
|
+
networking: NetworkingConfigSchema
|
|
144
|
+
compute: ComputeConfigSchema = Field(default_factory=ComputeConfigSchema)
|
|
145
|
+
feature_store: FeatureStoreConfigSchema
|
|
146
|
+
iam: IAMConfigSchema
|
|
147
|
+
kms: Optional[KMSConfigSchema] = None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class MLPConfigSchema(BaseModel):
|
|
151
|
+
"""Pydantic schema for the complete YAML configuration"""
|
|
152
|
+
defaults: DefaultsConfigSchema
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class ConfigurationManager:
|
|
156
|
+
"""
|
|
157
|
+
Handles loading and merging configuration from multiple sources.
|
|
158
|
+
Supports encryption/decryption of sensitive configuration values.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
DEFAULT_CONFIG_PATH = "/home/sagemaker-user/.config/admin-config.yaml"
|
|
162
|
+
|
|
163
|
+
def __init__(self, config_path: Optional[str] = None, encryption_key: Optional[Union[str, bytes]] = None):
|
|
164
|
+
"""
|
|
165
|
+
Load config from specified path or default location.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
config_path: Optional custom configuration file path
|
|
169
|
+
encryption_key: Optional encryption key for decrypting sensitive values.
|
|
170
|
+
Can be a base64-encoded string or raw bytes.
|
|
171
|
+
"""
|
|
172
|
+
self.config_path = config_path or self.DEFAULT_CONFIG_PATH
|
|
173
|
+
self._config = {}
|
|
174
|
+
self._MLP_config: Optional[MLPConfig] = None
|
|
175
|
+
self._encryption_key = self._process_encryption_key(encryption_key) if encryption_key else None
|
|
176
|
+
self._load_configuration()
|
|
177
|
+
|
|
178
|
+
def _load_configuration(self) -> None:
|
|
179
|
+
"""
|
|
180
|
+
Load and validate configuration from YAML file.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
ConfigurationError: If configuration loading or validation fails
|
|
184
|
+
"""
|
|
185
|
+
config_file = Path(self.config_path)
|
|
186
|
+
|
|
187
|
+
# If config file doesn't exist, use empty config (SageMaker SDK defaults will be used)
|
|
188
|
+
if not config_file.exists():
|
|
189
|
+
self._config = {}
|
|
190
|
+
self._MLP_config = None
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
with open(config_file, 'r', encoding='utf-8') as f:
|
|
195
|
+
raw_config = yaml.safe_load(f)
|
|
196
|
+
|
|
197
|
+
if not raw_config:
|
|
198
|
+
self._config = {}
|
|
199
|
+
self._MLP_config = None
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
# Validate configuration using Pydantic schema
|
|
203
|
+
validated_config = MLPConfigSchema(**raw_config)
|
|
204
|
+
self._config = validated_config.model_dump()
|
|
205
|
+
|
|
206
|
+
# Convert to dataclass structure
|
|
207
|
+
defaults = self._config['defaults']
|
|
208
|
+
self._MLP_config = MLPConfig(
|
|
209
|
+
s3_config=S3Config(**defaults['s3']),
|
|
210
|
+
networking_config=NetworkingConfig(**defaults['networking']),
|
|
211
|
+
compute_config=ComputeConfig(**defaults['compute']),
|
|
212
|
+
feature_store_config=FeatureStoreConfig(**defaults['feature_store']),
|
|
213
|
+
iam_config=IAMConfig(**defaults['iam']),
|
|
214
|
+
kms_config=KMSConfig(**defaults['kms']) if defaults.get('kms') else None
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
except yaml.YAMLError as e:
|
|
218
|
+
raise ConfigurationError(f"Invalid YAML syntax in config file {self.config_path}: {e}")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
raise ConfigurationError(f"Failed to load configuration from {self.config_path}: {e}")
|
|
221
|
+
|
|
222
|
+
def get_default(self, key: str, fallback: Any = None) -> Any:
|
|
223
|
+
"""
|
|
224
|
+
Get configuration value with fallback.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
key: Configuration key (supports dot notation like 's3.default_bucket')
|
|
228
|
+
fallback: Fallback value if key not found
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Configuration value or fallback
|
|
232
|
+
"""
|
|
233
|
+
if not self._config:
|
|
234
|
+
return fallback
|
|
235
|
+
|
|
236
|
+
# Navigate through nested dictionary using dot notation
|
|
237
|
+
keys = key.split('.')
|
|
238
|
+
current = self._config.get('defaults', {})
|
|
239
|
+
|
|
240
|
+
for k in keys:
|
|
241
|
+
if isinstance(current, dict) and k in current:
|
|
242
|
+
current = current[k]
|
|
243
|
+
else:
|
|
244
|
+
return fallback
|
|
245
|
+
|
|
246
|
+
return current
|
|
247
|
+
|
|
248
|
+
def merge_with_runtime(self, runtime_config: Dict) -> Dict:
|
|
249
|
+
"""
|
|
250
|
+
Merge runtime parameters with defaults.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
runtime_config: Runtime configuration parameters
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Merged configuration dictionary with runtime values taking precedence
|
|
257
|
+
"""
|
|
258
|
+
if not self._config:
|
|
259
|
+
return runtime_config
|
|
260
|
+
|
|
261
|
+
# Start with defaults
|
|
262
|
+
merged = self._config.get('defaults', {}).copy()
|
|
263
|
+
|
|
264
|
+
# Deep merge runtime config (runtime takes precedence)
|
|
265
|
+
def deep_merge(base: Dict, override: Dict) -> Dict:
|
|
266
|
+
"""Recursively merge dictionaries"""
|
|
267
|
+
result = base.copy()
|
|
268
|
+
for key, value in override.items():
|
|
269
|
+
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
|
270
|
+
result[key] = deep_merge(result[key], value)
|
|
271
|
+
else:
|
|
272
|
+
result[key] = value
|
|
273
|
+
return result
|
|
274
|
+
|
|
275
|
+
return deep_merge(merged, runtime_config)
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def MLP_config(self) -> Optional[MLPConfig]:
|
|
279
|
+
"""Get the parsed MLPConfig object"""
|
|
280
|
+
return self._MLP_config
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def has_config(self) -> bool:
|
|
284
|
+
"""Check if configuration was successfully loaded"""
|
|
285
|
+
return bool(self._config)
|
|
286
|
+
|
|
287
|
+
def get_s3_config(self) -> Optional[S3Config]:
|
|
288
|
+
"""Get S3 configuration"""
|
|
289
|
+
return self._MLP_config.s3_config if self._MLP_config else None
|
|
290
|
+
|
|
291
|
+
def get_networking_config(self) -> Optional[NetworkingConfig]:
|
|
292
|
+
"""Get networking configuration"""
|
|
293
|
+
return self._MLP_config.networking_config if self._MLP_config else None
|
|
294
|
+
|
|
295
|
+
def get_compute_config(self) -> Optional[ComputeConfig]:
|
|
296
|
+
"""Get compute configuration"""
|
|
297
|
+
return self._MLP_config.compute_config if self._MLP_config else None
|
|
298
|
+
|
|
299
|
+
def get_feature_store_config(self) -> Optional[FeatureStoreConfig]:
|
|
300
|
+
"""Get feature store configuration"""
|
|
301
|
+
return self._MLP_config.feature_store_config if self._MLP_config else None
|
|
302
|
+
|
|
303
|
+
def get_iam_config(self) -> Optional[IAMConfig]:
|
|
304
|
+
"""Get IAM configuration"""
|
|
305
|
+
return self._MLP_config.iam_config if self._MLP_config else None
|
|
306
|
+
|
|
307
|
+
def get_kms_config(self) -> Optional[KMSConfig]:
|
|
308
|
+
"""Get KMS configuration"""
|
|
309
|
+
return self._MLP_config.kms_config if self._MLP_config else None
|
|
310
|
+
|
|
311
|
+
def _process_encryption_key(self, key: Union[str, bytes]) -> bytes:
|
|
312
|
+
"""
|
|
313
|
+
Process encryption key from various formats.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
key: Encryption key as base64 string or raw bytes
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
32-byte encryption key for AES-256
|
|
320
|
+
|
|
321
|
+
Raises:
|
|
322
|
+
ConfigurationError: If key format is invalid
|
|
323
|
+
"""
|
|
324
|
+
try:
|
|
325
|
+
if isinstance(key, str):
|
|
326
|
+
# Try to decode as base64
|
|
327
|
+
decoded = base64.b64decode(key)
|
|
328
|
+
if len(decoded) != 32:
|
|
329
|
+
raise ConfigurationError(f"Encryption key must be 32 bytes for AES-256, got {len(decoded)} bytes")
|
|
330
|
+
return decoded
|
|
331
|
+
elif isinstance(key, bytes):
|
|
332
|
+
if len(key) != 32:
|
|
333
|
+
raise ConfigurationError(f"Encryption key must be 32 bytes for AES-256, got {len(key)} bytes")
|
|
334
|
+
return key
|
|
335
|
+
else:
|
|
336
|
+
raise ConfigurationError(f"Encryption key must be string or bytes, got {type(key)}")
|
|
337
|
+
except Exception as e:
|
|
338
|
+
if isinstance(e, ConfigurationError):
|
|
339
|
+
raise
|
|
340
|
+
raise ConfigurationError(f"Failed to process encryption key: {e}")
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def load_key_from_env(env_var: str = "MLP_SDK_ENCRYPTION_KEY") -> Optional[bytes]:
|
|
344
|
+
"""
|
|
345
|
+
Load encryption key from environment variable.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
env_var: Environment variable name containing the base64-encoded key
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
32-byte encryption key or None if not found
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
ConfigurationError: If key format is invalid
|
|
355
|
+
"""
|
|
356
|
+
key_str = os.environ.get(env_var)
|
|
357
|
+
if not key_str:
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
decoded = base64.b64decode(key_str)
|
|
362
|
+
if len(decoded) != 32:
|
|
363
|
+
raise ConfigurationError(f"Encryption key from {env_var} must be 32 bytes for AES-256, got {len(decoded)} bytes")
|
|
364
|
+
return decoded
|
|
365
|
+
except Exception as e:
|
|
366
|
+
if isinstance(e, ConfigurationError):
|
|
367
|
+
raise
|
|
368
|
+
raise ConfigurationError(f"Failed to load encryption key from environment variable {env_var}: {e}")
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def load_key_from_file(file_path: str) -> bytes:
|
|
372
|
+
"""
|
|
373
|
+
Load encryption key from file.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
file_path: Path to file containing base64-encoded encryption key
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
32-byte encryption key
|
|
380
|
+
|
|
381
|
+
Raises:
|
|
382
|
+
ConfigurationError: If file not found or key format is invalid
|
|
383
|
+
"""
|
|
384
|
+
key_file = Path(file_path)
|
|
385
|
+
if not key_file.exists():
|
|
386
|
+
raise ConfigurationError(f"Encryption key file not found: {file_path}")
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
with open(key_file, 'r', encoding='utf-8') as f:
|
|
390
|
+
key_str = f.read().strip()
|
|
391
|
+
|
|
392
|
+
decoded = base64.b64decode(key_str)
|
|
393
|
+
if len(decoded) != 32:
|
|
394
|
+
raise ConfigurationError(f"Encryption key from {file_path} must be 32 bytes for AES-256, got {len(decoded)} bytes")
|
|
395
|
+
return decoded
|
|
396
|
+
except Exception as e:
|
|
397
|
+
if isinstance(e, ConfigurationError):
|
|
398
|
+
raise
|
|
399
|
+
raise ConfigurationError(f"Failed to load encryption key from file {file_path}: {e}")
|
|
400
|
+
|
|
401
|
+
@staticmethod
|
|
402
|
+
def load_key_from_kms(key_id: str, region: Optional[str] = None) -> bytes:
|
|
403
|
+
"""
|
|
404
|
+
Load encryption key from AWS KMS.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
key_id: KMS key ID or ARN
|
|
408
|
+
region: AWS region (optional, uses default if not specified)
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
32-byte encryption key generated by KMS
|
|
412
|
+
|
|
413
|
+
Raises:
|
|
414
|
+
ConfigurationError: If KMS operation fails
|
|
415
|
+
"""
|
|
416
|
+
try:
|
|
417
|
+
import boto3
|
|
418
|
+
from botocore.exceptions import ClientError
|
|
419
|
+
|
|
420
|
+
# Create KMS client
|
|
421
|
+
kms_client = boto3.client('kms', region_name=region) if region else boto3.client('kms')
|
|
422
|
+
|
|
423
|
+
# Generate a data key using KMS
|
|
424
|
+
response = kms_client.generate_data_key(
|
|
425
|
+
KeyId=key_id,
|
|
426
|
+
KeySpec='AES_256'
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Return the plaintext key (32 bytes for AES-256)
|
|
430
|
+
return response['Plaintext']
|
|
431
|
+
|
|
432
|
+
except ImportError:
|
|
433
|
+
raise ConfigurationError("boto3 is required for KMS key loading. Install with: pip install boto3")
|
|
434
|
+
except Exception as e:
|
|
435
|
+
raise ConfigurationError(f"Failed to load encryption key from KMS {key_id}: {e}")
|
|
436
|
+
|
|
437
|
+
@staticmethod
|
|
438
|
+
def generate_key() -> str:
|
|
439
|
+
"""
|
|
440
|
+
Generate a new random 32-byte encryption key for AES-256-GCM.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
Base64-encoded 32-byte encryption key
|
|
444
|
+
"""
|
|
445
|
+
key = os.urandom(32)
|
|
446
|
+
return base64.b64encode(key).decode('utf-8')
|
|
447
|
+
|
|
448
|
+
def encrypt_value(self, plaintext: str, key: Optional[Union[str, bytes]] = None) -> str:
|
|
449
|
+
"""
|
|
450
|
+
Encrypt a configuration value using AES-256-GCM.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
plaintext: Value to encrypt
|
|
454
|
+
key: Optional encryption key (uses instance key if not provided)
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Base64-encoded encrypted value with format: nonce:ciphertext:tag
|
|
458
|
+
|
|
459
|
+
Raises:
|
|
460
|
+
ConfigurationError: If encryption fails or no key available
|
|
461
|
+
"""
|
|
462
|
+
encryption_key = self._process_encryption_key(key) if key else self._encryption_key
|
|
463
|
+
|
|
464
|
+
if not encryption_key:
|
|
465
|
+
raise ConfigurationError("No encryption key available. Provide key during initialization or as parameter.")
|
|
466
|
+
|
|
467
|
+
try:
|
|
468
|
+
# Create AESGCM cipher
|
|
469
|
+
aesgcm = AESGCM(encryption_key)
|
|
470
|
+
|
|
471
|
+
# Generate random nonce (96 bits / 12 bytes recommended for GCM)
|
|
472
|
+
nonce = os.urandom(12)
|
|
473
|
+
|
|
474
|
+
# Encrypt the plaintext
|
|
475
|
+
plaintext_bytes = plaintext.encode('utf-8')
|
|
476
|
+
ciphertext = aesgcm.encrypt(nonce, plaintext_bytes, None)
|
|
477
|
+
|
|
478
|
+
# Combine nonce and ciphertext (ciphertext includes auth tag)
|
|
479
|
+
combined = nonce + ciphertext
|
|
480
|
+
|
|
481
|
+
# Return base64-encoded result
|
|
482
|
+
return base64.b64encode(combined).decode('utf-8')
|
|
483
|
+
|
|
484
|
+
except Exception as e:
|
|
485
|
+
raise ConfigurationError(f"Failed to encrypt value: {e}")
|
|
486
|
+
|
|
487
|
+
def decrypt_value(self, encrypted: str, key: Optional[Union[str, bytes]] = None) -> str:
|
|
488
|
+
"""
|
|
489
|
+
Decrypt a configuration value using AES-256-GCM.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
encrypted: Base64-encoded encrypted value with format: nonce:ciphertext:tag
|
|
493
|
+
key: Optional encryption key (uses instance key if not provided)
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
Decrypted plaintext value
|
|
497
|
+
|
|
498
|
+
Raises:
|
|
499
|
+
ConfigurationError: If decryption fails or no key available
|
|
500
|
+
"""
|
|
501
|
+
encryption_key = self._process_encryption_key(key) if key else self._encryption_key
|
|
502
|
+
|
|
503
|
+
if not encryption_key:
|
|
504
|
+
raise ConfigurationError("No encryption key available. Provide key during initialization or as parameter.")
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
# Decode base64
|
|
508
|
+
combined = base64.b64decode(encrypted)
|
|
509
|
+
|
|
510
|
+
# Extract nonce (first 12 bytes) and ciphertext (rest includes auth tag)
|
|
511
|
+
nonce = combined[:12]
|
|
512
|
+
ciphertext = combined[12:]
|
|
513
|
+
|
|
514
|
+
# Create AESGCM cipher
|
|
515
|
+
aesgcm = AESGCM(encryption_key)
|
|
516
|
+
|
|
517
|
+
# Decrypt the ciphertext
|
|
518
|
+
plaintext_bytes = aesgcm.decrypt(nonce, ciphertext, None)
|
|
519
|
+
|
|
520
|
+
# Return decoded string
|
|
521
|
+
return plaintext_bytes.decode('utf-8')
|
|
522
|
+
|
|
523
|
+
except Exception as e:
|
|
524
|
+
raise ConfigurationError(f"Failed to decrypt value: {e}")
|
|
525
|
+
|
|
526
|
+
def encrypt_config_file(self, input_path: str, output_path: str,
|
|
527
|
+
fields_to_encrypt: List[str],
|
|
528
|
+
key: Optional[Union[str, bytes]] = None) -> None:
|
|
529
|
+
"""
|
|
530
|
+
Encrypt specific fields in a YAML configuration file.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
input_path: Path to input YAML file
|
|
534
|
+
output_path: Path to output encrypted YAML file
|
|
535
|
+
fields_to_encrypt: List of field paths to encrypt (dot notation, e.g., 'defaults.iam.execution_role')
|
|
536
|
+
key: Optional encryption key (uses instance key if not provided)
|
|
537
|
+
|
|
538
|
+
Raises:
|
|
539
|
+
ConfigurationError: If encryption fails
|
|
540
|
+
"""
|
|
541
|
+
encryption_key = self._process_encryption_key(key) if key else self._encryption_key
|
|
542
|
+
|
|
543
|
+
if not encryption_key:
|
|
544
|
+
raise ConfigurationError("No encryption key available. Provide key during initialization or as parameter.")
|
|
545
|
+
|
|
546
|
+
try:
|
|
547
|
+
# Load input file
|
|
548
|
+
with open(input_path, 'r', encoding='utf-8') as f:
|
|
549
|
+
config = yaml.safe_load(f)
|
|
550
|
+
|
|
551
|
+
# Encrypt specified fields
|
|
552
|
+
for field_path in fields_to_encrypt:
|
|
553
|
+
keys = field_path.split('.')
|
|
554
|
+
current = config
|
|
555
|
+
|
|
556
|
+
# Navigate to parent of target field
|
|
557
|
+
for key in keys[:-1]:
|
|
558
|
+
if key not in current:
|
|
559
|
+
raise ConfigurationError(f"Field path not found: {field_path}")
|
|
560
|
+
current = current[key]
|
|
561
|
+
|
|
562
|
+
# Encrypt the target field
|
|
563
|
+
final_key = keys[-1]
|
|
564
|
+
if final_key not in current:
|
|
565
|
+
raise ConfigurationError(f"Field not found: {field_path}")
|
|
566
|
+
|
|
567
|
+
plaintext = str(current[final_key])
|
|
568
|
+
current[final_key] = self.encrypt_value(plaintext, encryption_key)
|
|
569
|
+
|
|
570
|
+
# Write encrypted config
|
|
571
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
572
|
+
yaml.safe_dump(config, f, default_flow_style=False)
|
|
573
|
+
|
|
574
|
+
except Exception as e:
|
|
575
|
+
if isinstance(e, ConfigurationError):
|
|
576
|
+
raise
|
|
577
|
+
raise ConfigurationError(f"Failed to encrypt config file: {e}")
|
|
578
|
+
|
|
579
|
+
def decrypt_config_file(self, input_path: str, output_path: str,
|
|
580
|
+
fields_to_decrypt: List[str],
|
|
581
|
+
key: Optional[Union[str, bytes]] = None) -> None:
|
|
582
|
+
"""
|
|
583
|
+
Decrypt specific fields in a YAML configuration file.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
input_path: Path to input encrypted YAML file
|
|
587
|
+
output_path: Path to output decrypted YAML file
|
|
588
|
+
fields_to_decrypt: List of field paths to decrypt (dot notation)
|
|
589
|
+
key: Optional encryption key (uses instance key if not provided)
|
|
590
|
+
|
|
591
|
+
Raises:
|
|
592
|
+
ConfigurationError: If decryption fails
|
|
593
|
+
"""
|
|
594
|
+
encryption_key = self._process_encryption_key(key) if key else self._encryption_key
|
|
595
|
+
|
|
596
|
+
if not encryption_key:
|
|
597
|
+
raise ConfigurationError("No encryption key available. Provide key during initialization or as parameter.")
|
|
598
|
+
|
|
599
|
+
try:
|
|
600
|
+
# Load input file
|
|
601
|
+
with open(input_path, 'r', encoding='utf-8') as f:
|
|
602
|
+
config = yaml.safe_load(f)
|
|
603
|
+
|
|
604
|
+
# Decrypt specified fields
|
|
605
|
+
for field_path in fields_to_decrypt:
|
|
606
|
+
keys = field_path.split('.')
|
|
607
|
+
current = config
|
|
608
|
+
|
|
609
|
+
# Navigate to parent of target field
|
|
610
|
+
for key in keys[:-1]:
|
|
611
|
+
if key not in current:
|
|
612
|
+
raise ConfigurationError(f"Field path not found: {field_path}")
|
|
613
|
+
current = current[key]
|
|
614
|
+
|
|
615
|
+
# Decrypt the target field
|
|
616
|
+
final_key = keys[-1]
|
|
617
|
+
if final_key not in current:
|
|
618
|
+
raise ConfigurationError(f"Field not found: {field_path}")
|
|
619
|
+
|
|
620
|
+
encrypted = str(current[final_key])
|
|
621
|
+
current[final_key] = self.decrypt_value(encrypted, encryption_key)
|
|
622
|
+
|
|
623
|
+
# Write decrypted config
|
|
624
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
625
|
+
yaml.safe_dump(config, f, default_flow_style=False)
|
|
626
|
+
|
|
627
|
+
except Exception as e:
|
|
628
|
+
if isinstance(e, ConfigurationError):
|
|
629
|
+
raise
|
|
630
|
+
raise ConfigurationError(f"Failed to decrypt config file: {e}")
|