sagemaker-mlp-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlp_sdk/__init__.py +18 -0
- mlp_sdk/config.py +630 -0
- mlp_sdk/exceptions.py +421 -0
- mlp_sdk/models.py +62 -0
- mlp_sdk/session.py +1160 -0
- mlp_sdk/wrappers/__init__.py +11 -0
- mlp_sdk/wrappers/deployment.py +459 -0
- mlp_sdk/wrappers/feature_store.py +308 -0
- mlp_sdk/wrappers/pipeline.py +452 -0
- mlp_sdk/wrappers/processing.py +381 -0
- mlp_sdk/wrappers/training.py +492 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/METADATA +569 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/RECORD +16 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/WHEEL +5 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/licenses/LICENSE +21 -0
- sagemaker_mlp_sdk-0.1.0.dist-info/top_level.txt +1 -0
mlp_sdk/exceptions.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Custom exceptions for mlp_sdk
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional, Dict, Any
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
import json
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MLPSDKError(Exception):
|
|
12
|
+
"""Base exception for all mlp_sdk errors"""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConfigurationError(MLPSDKError):
|
|
17
|
+
"""Configuration loading or validation errors"""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ValidationError(MLPSDKError):
|
|
22
|
+
"""Parameter validation errors"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AWSServiceError(MLPSDKError):
|
|
27
|
+
"""AWS service operation errors with detailed error information"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, message: str, aws_error: Optional[Exception] = None):
|
|
30
|
+
"""
|
|
31
|
+
Initialize AWS service error with detailed error information.
|
|
32
|
+
|
|
33
|
+
Extracts and preserves AWS error details including:
|
|
34
|
+
- Error code
|
|
35
|
+
- Error message
|
|
36
|
+
- Request ID
|
|
37
|
+
- HTTP status code
|
|
38
|
+
- Operation name
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
message: Error message
|
|
42
|
+
aws_error: Original AWS exception (ClientError, BotoCoreError, etc.)
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(message)
|
|
45
|
+
self.aws_error = aws_error
|
|
46
|
+
self.error_code = None
|
|
47
|
+
self.error_message = None
|
|
48
|
+
self.request_id = None
|
|
49
|
+
self.http_status_code = None
|
|
50
|
+
self.operation_name = None
|
|
51
|
+
|
|
52
|
+
# Extract AWS error details if available
|
|
53
|
+
if aws_error:
|
|
54
|
+
self._extract_aws_error_details(aws_error)
|
|
55
|
+
|
|
56
|
+
def _extract_aws_error_details(self, aws_error: Exception) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Extract detailed error information from AWS exception.
|
|
59
|
+
|
|
60
|
+
Handles both boto3 ClientError and other AWS SDK exceptions.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
aws_error: AWS exception object
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
# Check if it's a boto3 ClientError
|
|
67
|
+
if hasattr(aws_error, 'response'):
|
|
68
|
+
response = aws_error.response
|
|
69
|
+
|
|
70
|
+
# Extract error details from response
|
|
71
|
+
if isinstance(response, dict):
|
|
72
|
+
# Get error code and message
|
|
73
|
+
error_info = response.get('Error', {})
|
|
74
|
+
self.error_code = error_info.get('Code')
|
|
75
|
+
self.error_message = error_info.get('Message')
|
|
76
|
+
|
|
77
|
+
# Get request ID
|
|
78
|
+
self.request_id = response.get('ResponseMetadata', {}).get('RequestId')
|
|
79
|
+
|
|
80
|
+
# Get HTTP status code
|
|
81
|
+
self.http_status_code = response.get('ResponseMetadata', {}).get('HTTPStatusCode')
|
|
82
|
+
|
|
83
|
+
# Check if it's a SageMaker SDK exception with operation_name
|
|
84
|
+
if hasattr(aws_error, 'operation_name'):
|
|
85
|
+
self.operation_name = aws_error.operation_name
|
|
86
|
+
|
|
87
|
+
except Exception:
|
|
88
|
+
# If extraction fails, just keep the original error
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
def get_error_details(self) -> Dict[str, Any]:
|
|
92
|
+
"""
|
|
93
|
+
Get structured error details.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Dictionary with error details including code, message, request ID, etc.
|
|
97
|
+
"""
|
|
98
|
+
details = {
|
|
99
|
+
'message': str(self),
|
|
100
|
+
'error_code': self.error_code,
|
|
101
|
+
'error_message': self.error_message,
|
|
102
|
+
'request_id': self.request_id,
|
|
103
|
+
'http_status_code': self.http_status_code,
|
|
104
|
+
'operation_name': self.operation_name,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
# Remove None values
|
|
108
|
+
return {k: v for k, v in details.items() if v is not None}
|
|
109
|
+
|
|
110
|
+
def __str__(self) -> str:
|
|
111
|
+
"""
|
|
112
|
+
String representation with AWS error details.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Formatted error message with AWS details
|
|
116
|
+
"""
|
|
117
|
+
base_message = super().__str__()
|
|
118
|
+
|
|
119
|
+
# Add AWS error details if available
|
|
120
|
+
details = []
|
|
121
|
+
if self.error_code:
|
|
122
|
+
details.append(f"ErrorCode: {self.error_code}")
|
|
123
|
+
if self.request_id:
|
|
124
|
+
details.append(f"RequestId: {self.request_id}")
|
|
125
|
+
if self.http_status_code:
|
|
126
|
+
details.append(f"HTTPStatus: {self.http_status_code}")
|
|
127
|
+
if self.operation_name:
|
|
128
|
+
details.append(f"Operation: {self.operation_name}")
|
|
129
|
+
|
|
130
|
+
if details:
|
|
131
|
+
return f"{base_message} [{', '.join(details)}]"
|
|
132
|
+
|
|
133
|
+
return base_message
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class SessionError(MLPSDKError):
|
|
137
|
+
"""Session initialization or lifecycle errors"""
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Logging infrastructure
|
|
142
|
+
class MLPLogger:
|
|
143
|
+
"""
|
|
144
|
+
Structured logging for mlp_sdk operations.
|
|
145
|
+
Provides configurable log levels and audit trail functionality.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(self, name: str = "mlp_sdk", level: int = logging.INFO):
|
|
149
|
+
"""
|
|
150
|
+
Initialize logger with configurable level.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
name: Logger name
|
|
154
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
155
|
+
"""
|
|
156
|
+
self.logger = logging.getLogger(name)
|
|
157
|
+
self.logger.setLevel(level)
|
|
158
|
+
|
|
159
|
+
# Add console handler if not already present
|
|
160
|
+
if not self.logger.handlers:
|
|
161
|
+
handler = logging.StreamHandler()
|
|
162
|
+
handler.setLevel(level)
|
|
163
|
+
formatter = logging.Formatter(
|
|
164
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
165
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
|
166
|
+
)
|
|
167
|
+
handler.setFormatter(formatter)
|
|
168
|
+
self.logger.addHandler(handler)
|
|
169
|
+
|
|
170
|
+
def set_level(self, level: int) -> None:
|
|
171
|
+
"""
|
|
172
|
+
Set logging level.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
176
|
+
"""
|
|
177
|
+
self.logger.setLevel(level)
|
|
178
|
+
for handler in self.logger.handlers:
|
|
179
|
+
handler.setLevel(level)
|
|
180
|
+
|
|
181
|
+
def debug(self, message: str, **kwargs: Any) -> None:
|
|
182
|
+
"""Log debug message with optional context"""
|
|
183
|
+
self.logger.debug(self._format_message(message, kwargs))
|
|
184
|
+
|
|
185
|
+
def info(self, message: str, **kwargs: Any) -> None:
|
|
186
|
+
"""Log info message with optional context"""
|
|
187
|
+
self.logger.info(self._format_message(message, kwargs))
|
|
188
|
+
|
|
189
|
+
def warning(self, message: str, **kwargs: Any) -> None:
|
|
190
|
+
"""Log warning message with optional context"""
|
|
191
|
+
self.logger.warning(self._format_message(message, kwargs))
|
|
192
|
+
|
|
193
|
+
def error(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None:
|
|
194
|
+
"""
|
|
195
|
+
Log error message with optional exception details.
|
|
196
|
+
|
|
197
|
+
Automatically extracts AWS error details if the error is an AWSServiceError.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
message: Error message
|
|
201
|
+
error: Optional exception object
|
|
202
|
+
**kwargs: Additional context
|
|
203
|
+
"""
|
|
204
|
+
if error:
|
|
205
|
+
kwargs['error_type'] = type(error).__name__
|
|
206
|
+
kwargs['error_details'] = str(error)
|
|
207
|
+
|
|
208
|
+
# Extract AWS error details if available
|
|
209
|
+
if hasattr(error, 'get_error_details'):
|
|
210
|
+
aws_details = error.get_error_details()
|
|
211
|
+
for key, value in aws_details.items():
|
|
212
|
+
if key != 'message': # Avoid duplicate message
|
|
213
|
+
kwargs[f'aws_{key}'] = value
|
|
214
|
+
elif hasattr(error, 'response'):
|
|
215
|
+
# Handle boto3 ClientError directly
|
|
216
|
+
try:
|
|
217
|
+
response = error.response
|
|
218
|
+
if isinstance(response, dict):
|
|
219
|
+
error_info = response.get('Error', {})
|
|
220
|
+
kwargs['aws_error_code'] = error_info.get('Code')
|
|
221
|
+
kwargs['aws_error_message'] = error_info.get('Message')
|
|
222
|
+
kwargs['aws_request_id'] = response.get('ResponseMetadata', {}).get('RequestId')
|
|
223
|
+
kwargs['aws_http_status'] = response.get('ResponseMetadata', {}).get('HTTPStatusCode')
|
|
224
|
+
except Exception:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
self.logger.error(self._format_message(message, kwargs))
|
|
228
|
+
|
|
229
|
+
def critical(self, message: str, error: Optional[Exception] = None, **kwargs: Any) -> None:
|
|
230
|
+
"""
|
|
231
|
+
Log critical message with optional exception details.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
message: Critical error message
|
|
235
|
+
error: Optional exception object
|
|
236
|
+
**kwargs: Additional context
|
|
237
|
+
"""
|
|
238
|
+
if error:
|
|
239
|
+
kwargs['error_type'] = type(error).__name__
|
|
240
|
+
kwargs['error_details'] = str(error)
|
|
241
|
+
self.logger.critical(self._format_message(message, kwargs))
|
|
242
|
+
|
|
243
|
+
def _format_message(self, message: str, context: Dict[str, Any]) -> str:
|
|
244
|
+
"""
|
|
245
|
+
Format log message with context.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
message: Base message
|
|
249
|
+
context: Additional context dictionary
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Formatted message string
|
|
253
|
+
"""
|
|
254
|
+
if not context:
|
|
255
|
+
return message
|
|
256
|
+
|
|
257
|
+
context_str = " | ".join(f"{k}={v}" for k, v in context.items())
|
|
258
|
+
return f"{message} | {context_str}"
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class AuditTrail:
|
|
262
|
+
"""
|
|
263
|
+
Maintains audit trail for mlp_sdk operations.
|
|
264
|
+
Records operation history for debugging and compliance.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
def __init__(self, max_entries: int = 1000):
|
|
268
|
+
"""
|
|
269
|
+
Initialize audit trail.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
max_entries: Maximum number of entries to keep in memory
|
|
273
|
+
"""
|
|
274
|
+
self._entries: list = []
|
|
275
|
+
self._max_entries = max_entries
|
|
276
|
+
|
|
277
|
+
def record(self, operation: str, status: str, **kwargs: Any) -> None:
|
|
278
|
+
"""
|
|
279
|
+
Record an operation in the audit trail.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
operation: Operation name (e.g., 'create_feature_group')
|
|
283
|
+
status: Operation status ('started', 'completed', 'failed')
|
|
284
|
+
**kwargs: Additional operation details
|
|
285
|
+
"""
|
|
286
|
+
entry = {
|
|
287
|
+
'timestamp': datetime.utcnow().isoformat(),
|
|
288
|
+
'operation': operation,
|
|
289
|
+
'status': status,
|
|
290
|
+
**kwargs
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
self._entries.append(entry)
|
|
294
|
+
|
|
295
|
+
# Maintain max entries limit
|
|
296
|
+
if len(self._entries) > self._max_entries:
|
|
297
|
+
self._entries.pop(0)
|
|
298
|
+
|
|
299
|
+
def get_entries(self, operation: Optional[str] = None,
|
|
300
|
+
status: Optional[str] = None,
|
|
301
|
+
limit: Optional[int] = None) -> list:
|
|
302
|
+
"""
|
|
303
|
+
Get audit trail entries with optional filtering.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
operation: Filter by operation name
|
|
307
|
+
status: Filter by status
|
|
308
|
+
limit: Maximum number of entries to return
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
List of audit trail entries
|
|
312
|
+
"""
|
|
313
|
+
entries = self._entries
|
|
314
|
+
|
|
315
|
+
if operation:
|
|
316
|
+
entries = [e for e in entries if e['operation'] == operation]
|
|
317
|
+
|
|
318
|
+
if status:
|
|
319
|
+
entries = [e for e in entries if e['status'] == status]
|
|
320
|
+
|
|
321
|
+
if limit:
|
|
322
|
+
entries = entries[-limit:]
|
|
323
|
+
|
|
324
|
+
return entries
|
|
325
|
+
|
|
326
|
+
def get_last_entry(self, operation: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
327
|
+
"""
|
|
328
|
+
Get the most recent audit trail entry.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
operation: Optional operation name filter
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Most recent entry or None
|
|
335
|
+
"""
|
|
336
|
+
entries = self.get_entries(operation=operation, limit=1)
|
|
337
|
+
return entries[0] if entries else None
|
|
338
|
+
|
|
339
|
+
def get_summary(self) -> Dict[str, Any]:
|
|
340
|
+
"""
|
|
341
|
+
Get summary statistics for the audit trail.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dictionary with summary statistics including:
|
|
345
|
+
- total_entries: Total number of entries
|
|
346
|
+
- operations: Count by operation type
|
|
347
|
+
- statuses: Count by status
|
|
348
|
+
- failed_operations: List of failed operations
|
|
349
|
+
"""
|
|
350
|
+
summary = {
|
|
351
|
+
'total_entries': len(self._entries),
|
|
352
|
+
'operations': {},
|
|
353
|
+
'statuses': {},
|
|
354
|
+
'failed_operations': []
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
for entry in self._entries:
|
|
358
|
+
# Count by operation
|
|
359
|
+
operation = entry.get('operation', 'unknown')
|
|
360
|
+
summary['operations'][operation] = summary['operations'].get(operation, 0) + 1
|
|
361
|
+
|
|
362
|
+
# Count by status
|
|
363
|
+
status = entry.get('status', 'unknown')
|
|
364
|
+
summary['statuses'][status] = summary['statuses'].get(status, 0) + 1
|
|
365
|
+
|
|
366
|
+
# Track failed operations
|
|
367
|
+
if status == 'failed':
|
|
368
|
+
summary['failed_operations'].append({
|
|
369
|
+
'timestamp': entry.get('timestamp'),
|
|
370
|
+
'operation': operation,
|
|
371
|
+
'error': entry.get('error', 'Unknown error')
|
|
372
|
+
})
|
|
373
|
+
|
|
374
|
+
return summary
|
|
375
|
+
|
|
376
|
+
def clear(self) -> None:
|
|
377
|
+
"""Clear all audit trail entries"""
|
|
378
|
+
self._entries.clear()
|
|
379
|
+
|
|
380
|
+
def export_json(self, file_path: str) -> None:
|
|
381
|
+
"""
|
|
382
|
+
Export audit trail to JSON file.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
file_path: Path to output JSON file
|
|
386
|
+
"""
|
|
387
|
+
with open(file_path, 'w', encoding='utf-8') as f:
|
|
388
|
+
json.dump(self._entries, f, indent=2)
|
|
389
|
+
|
|
390
|
+
def export_csv(self, file_path: str) -> None:
|
|
391
|
+
"""
|
|
392
|
+
Export audit trail to CSV file.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
file_path: Path to output CSV file
|
|
396
|
+
"""
|
|
397
|
+
import csv
|
|
398
|
+
|
|
399
|
+
if not self._entries:
|
|
400
|
+
# Create empty CSV with headers
|
|
401
|
+
with open(file_path, 'w', encoding='utf-8', newline='') as f:
|
|
402
|
+
writer = csv.writer(f)
|
|
403
|
+
writer.writerow(['timestamp', 'operation', 'status'])
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
# Get all unique keys from entries
|
|
407
|
+
all_keys = set()
|
|
408
|
+
for entry in self._entries:
|
|
409
|
+
all_keys.update(entry.keys())
|
|
410
|
+
|
|
411
|
+
# Sort keys for consistent column order
|
|
412
|
+
fieldnames = sorted(all_keys)
|
|
413
|
+
|
|
414
|
+
with open(file_path, 'w', encoding='utf-8', newline='') as f:
|
|
415
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
416
|
+
writer.writeheader()
|
|
417
|
+
writer.writerows(self._entries)
|
|
418
|
+
|
|
419
|
+
def __len__(self) -> int:
|
|
420
|
+
"""Return number of entries in audit trail"""
|
|
421
|
+
return len(self._entries)
|
mlp_sdk/models.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models and configuration schemas for mlp_sdk
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class S3Config:
|
|
11
|
+
"""S3 configuration settings"""
|
|
12
|
+
default_bucket: str
|
|
13
|
+
input_prefix: str = "input/"
|
|
14
|
+
output_prefix: str = "output/"
|
|
15
|
+
model_prefix: str = "models/"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class NetworkingConfig:
|
|
20
|
+
"""Networking configuration settings"""
|
|
21
|
+
vpc_id: str
|
|
22
|
+
security_group_ids: List[str]
|
|
23
|
+
subnets: List[str]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class ComputeConfig:
|
|
28
|
+
"""Compute configuration settings"""
|
|
29
|
+
processing_instance_type: str = "ml.m5.large"
|
|
30
|
+
training_instance_type: str = "ml.m5.xlarge"
|
|
31
|
+
processing_instance_count: int = 1
|
|
32
|
+
training_instance_count: int = 1
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class FeatureStoreConfig:
|
|
37
|
+
"""Feature Store configuration settings"""
|
|
38
|
+
offline_store_s3_uri: str
|
|
39
|
+
enable_online_store: bool = False
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class IAMConfig:
|
|
44
|
+
"""IAM configuration settings"""
|
|
45
|
+
execution_role: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class KMSConfig:
|
|
50
|
+
"""KMS configuration settings"""
|
|
51
|
+
key_id: Optional[str] = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class MLPConfig:
|
|
56
|
+
"""Main configuration container"""
|
|
57
|
+
s3_config: S3Config
|
|
58
|
+
networking_config: NetworkingConfig
|
|
59
|
+
compute_config: ComputeConfig
|
|
60
|
+
feature_store_config: FeatureStoreConfig
|
|
61
|
+
iam_config: IAMConfig
|
|
62
|
+
kms_config: Optional[KMSConfig] = None
|