@aws/ml-container-creator 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -1
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +5 -2
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/config +6 -1
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -1,187 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
|
|
5
|
-
"""
|
|
6
|
-
Parse DescribeTrainingJob JSON response and display formatted status.
|
|
7
|
-
|
|
8
|
-
This helper is called by do/train --status to parse the AWS CLI JSON output
|
|
9
|
-
from DescribeTrainingJob and display a user-friendly status summary.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
import json
|
|
13
|
-
import sys
|
|
14
|
-
import time
|
|
15
|
-
from datetime import datetime, timezone
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# Status emoji mapping
|
|
19
|
-
STATUS_EMOJI = {
|
|
20
|
-
'InProgress': '🔄',
|
|
21
|
-
'Completed': '✅',
|
|
22
|
-
'Failed': '❌',
|
|
23
|
-
'Stopping': '⏸️',
|
|
24
|
-
'Stopped': '⏹️'
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
# Secondary status descriptions
|
|
28
|
-
SECONDARY_DESCRIPTIONS = {
|
|
29
|
-
'Starting': 'Preparing training instance',
|
|
30
|
-
'LaunchingMLInstances': 'Launching ML instances',
|
|
31
|
-
'PreparingTrainingStack': 'Preparing training stack',
|
|
32
|
-
'Downloading': 'Downloading training data',
|
|
33
|
-
'DownloadingTrainingImage': 'Downloading training image',
|
|
34
|
-
'Training': 'Training in progress',
|
|
35
|
-
'Uploading': 'Uploading model artifacts',
|
|
36
|
-
'Completed': 'Training completed',
|
|
37
|
-
'MaxRuntimeExceeded': 'Max runtime exceeded',
|
|
38
|
-
'Stopped': 'Training stopped',
|
|
39
|
-
'MaxWaitTimeExceeded': 'Max wait time exceeded (spot)',
|
|
40
|
-
'Interrupted': 'Spot instance interrupted'
|
|
41
|
-
}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def format_duration(seconds):
|
|
45
|
-
"""Format seconds into a human-readable duration string."""
|
|
46
|
-
if seconds is None or seconds < 0:
|
|
47
|
-
return 'N/A'
|
|
48
|
-
hours = int(seconds // 3600)
|
|
49
|
-
minutes = int((seconds % 3600) // 60)
|
|
50
|
-
secs = int(seconds % 60)
|
|
51
|
-
if hours > 0:
|
|
52
|
-
return f'{hours}h {minutes}m {secs}s'
|
|
53
|
-
elif minutes > 0:
|
|
54
|
-
return f'{minutes}m {secs}s'
|
|
55
|
-
else:
|
|
56
|
-
return f'{secs}s'
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def parse_iso_time(time_str):
|
|
60
|
-
"""Parse an ISO 8601 timestamp string to a datetime object."""
|
|
61
|
-
if not time_str:
|
|
62
|
-
return None
|
|
63
|
-
try:
|
|
64
|
-
# Handle various AWS timestamp formats
|
|
65
|
-
# Remove trailing 'Z' and replace with +00:00 for fromisoformat
|
|
66
|
-
time_str = time_str.replace('Z', '+00:00')
|
|
67
|
-
return datetime.fromisoformat(time_str)
|
|
68
|
-
except (ValueError, TypeError):
|
|
69
|
-
return None
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def calculate_elapsed(start_time_str):
|
|
73
|
-
"""Calculate elapsed time from start to now."""
|
|
74
|
-
start = parse_iso_time(start_time_str)
|
|
75
|
-
if not start:
|
|
76
|
-
return None
|
|
77
|
-
now = datetime.now(timezone.utc)
|
|
78
|
-
elapsed = (now - start).total_seconds()
|
|
79
|
-
return max(0, elapsed)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def display_status(job_data):
|
|
83
|
-
"""Display formatted training job status."""
|
|
84
|
-
job_name = job_data.get('TrainingJobName', 'Unknown')
|
|
85
|
-
status = job_data.get('TrainingJobStatus', 'Unknown')
|
|
86
|
-
secondary_status = job_data.get('SecondaryStatus', '')
|
|
87
|
-
failure_reason = job_data.get('FailureReason', '')
|
|
88
|
-
training_start = job_data.get('TrainingStartTime', '')
|
|
89
|
-
training_end = job_data.get('TrainingEndTime', '')
|
|
90
|
-
billable_seconds = job_data.get('BillableTimeInSeconds')
|
|
91
|
-
training_seconds = job_data.get('TrainingTimeInSeconds')
|
|
92
|
-
final_metrics = job_data.get('FinalMetricDataList', [])
|
|
93
|
-
output_path = job_data.get('OutputDataConfig', {}).get('S3OutputPath', '')
|
|
94
|
-
model_artifacts = job_data.get('ModelArtifacts', {}).get('S3ModelArtifacts', '')
|
|
95
|
-
instance_type = job_data.get('ResourceConfig', {}).get('InstanceType', '')
|
|
96
|
-
instance_count = job_data.get('ResourceConfig', {}).get('InstanceCount', 1)
|
|
97
|
-
spot_enabled = job_data.get('EnableManagedSpotTraining', False)
|
|
98
|
-
|
|
99
|
-
emoji = STATUS_EMOJI.get(status, '❓')
|
|
100
|
-
|
|
101
|
-
print(f'')
|
|
102
|
-
print(f' {emoji} Status: {status}')
|
|
103
|
-
|
|
104
|
-
# Secondary status with description
|
|
105
|
-
if secondary_status:
|
|
106
|
-
desc = SECONDARY_DESCRIPTIONS.get(secondary_status, '')
|
|
107
|
-
if desc:
|
|
108
|
-
print(f' 📍 Phase: {secondary_status} ({desc})')
|
|
109
|
-
else:
|
|
110
|
-
print(f' 📍 Phase: {secondary_status}')
|
|
111
|
-
|
|
112
|
-
# Elapsed time
|
|
113
|
-
if status == 'InProgress' and training_start:
|
|
114
|
-
elapsed = calculate_elapsed(training_start)
|
|
115
|
-
if elapsed is not None:
|
|
116
|
-
print(f' ⏱️ Elapsed: {format_duration(elapsed)}')
|
|
117
|
-
elif training_seconds is not None:
|
|
118
|
-
print(f' ⏱️ Training time: {format_duration(training_seconds)}')
|
|
119
|
-
|
|
120
|
-
# Instance info
|
|
121
|
-
if instance_type:
|
|
122
|
-
instance_info = f'{instance_type}'
|
|
123
|
-
if instance_count and instance_count > 1:
|
|
124
|
-
instance_info += f' x {instance_count}'
|
|
125
|
-
if spot_enabled:
|
|
126
|
-
instance_info += ' (spot)'
|
|
127
|
-
print(f' 🖥️ Instance: {instance_info}')
|
|
128
|
-
|
|
129
|
-
# Billable time and cost savings (for completed spot jobs)
|
|
130
|
-
if status == 'Completed' and spot_enabled and billable_seconds is not None and training_seconds is not None:
|
|
131
|
-
savings_seconds = training_seconds - billable_seconds
|
|
132
|
-
if training_seconds > 0:
|
|
133
|
-
savings_pct = (savings_seconds / training_seconds) * 100
|
|
134
|
-
print(f' 💰 Spot savings: {format_duration(savings_seconds)} saved ({savings_pct:.0f}% discount)')
|
|
135
|
-
print(f' Billable: {format_duration(billable_seconds)} / Total: {format_duration(training_seconds)}')
|
|
136
|
-
|
|
137
|
-
# Training metrics
|
|
138
|
-
if final_metrics:
|
|
139
|
-
print(f' 📈 Metrics:')
|
|
140
|
-
for metric in final_metrics:
|
|
141
|
-
name = metric.get('MetricName', 'unknown')
|
|
142
|
-
value = metric.get('Value', 0)
|
|
143
|
-
# Format value nicely
|
|
144
|
-
if isinstance(value, float):
|
|
145
|
-
if abs(value) < 0.001:
|
|
146
|
-
print(f' {name}: {value:.6f}')
|
|
147
|
-
elif abs(value) < 1:
|
|
148
|
-
print(f' {name}: {value:.4f}')
|
|
149
|
-
else:
|
|
150
|
-
print(f' {name}: {value:.2f}')
|
|
151
|
-
else:
|
|
152
|
-
print(f' {name}: {value}')
|
|
153
|
-
|
|
154
|
-
# Output artifacts (for completed jobs)
|
|
155
|
-
if status == 'Completed' and model_artifacts:
|
|
156
|
-
print(f' 📦 Artifacts: {model_artifacts}')
|
|
157
|
-
elif status == 'Completed' and output_path:
|
|
158
|
-
print(f' 📦 Output: {output_path}')
|
|
159
|
-
|
|
160
|
-
# Failure reason
|
|
161
|
-
if status == 'Failed' and failure_reason:
|
|
162
|
-
print(f' 💥 Reason: {failure_reason}')
|
|
163
|
-
print(f'')
|
|
164
|
-
print(f' To start a new job: ./do/train --force')
|
|
165
|
-
|
|
166
|
-
# Spot interruption guidance
|
|
167
|
-
if secondary_status == 'Interrupted':
|
|
168
|
-
print(f'')
|
|
169
|
-
print(f' ℹ️ Spot instance was interrupted. The job will automatically')
|
|
170
|
-
print(f' resume from the last checkpoint. Re-run ./do/train to poll.')
|
|
171
|
-
|
|
172
|
-
print(f'')
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def main():
|
|
176
|
-
"""Main entry point — reads JSON from stdin."""
|
|
177
|
-
try:
|
|
178
|
-
job_data = json.load(sys.stdin)
|
|
179
|
-
except json.JSONDecodeError as e:
|
|
180
|
-
print(f'❌ Failed to parse DescribeTrainingJob response: {e}', file=sys.stderr)
|
|
181
|
-
sys.exit(1)
|
|
182
|
-
|
|
183
|
-
display_status(job_data)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
if __name__ == '__main__':
|
|
187
|
-
main()
|
|
File without changes
|