@aws/ml-container-creator 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -1,187 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """
6
- Parse DescribeTrainingJob JSON response and display formatted status.
7
-
8
- This helper is called by do/train --status to parse the AWS CLI JSON output
9
- from DescribeTrainingJob and display a user-friendly status summary.
10
- """
11
-
12
- import json
13
- import sys
14
- import time
15
- from datetime import datetime, timezone
16
-
17
-
18
- # Status emoji mapping
19
- STATUS_EMOJI = {
20
- 'InProgress': '🔄',
21
- 'Completed': '✅',
22
- 'Failed': '❌',
23
- 'Stopping': '⏸️',
24
- 'Stopped': '⏹️'
25
- }
26
-
27
- # Secondary status descriptions
28
- SECONDARY_DESCRIPTIONS = {
29
- 'Starting': 'Preparing training instance',
30
- 'LaunchingMLInstances': 'Launching ML instances',
31
- 'PreparingTrainingStack': 'Preparing training stack',
32
- 'Downloading': 'Downloading training data',
33
- 'DownloadingTrainingImage': 'Downloading training image',
34
- 'Training': 'Training in progress',
35
- 'Uploading': 'Uploading model artifacts',
36
- 'Completed': 'Training completed',
37
- 'MaxRuntimeExceeded': 'Max runtime exceeded',
38
- 'Stopped': 'Training stopped',
39
- 'MaxWaitTimeExceeded': 'Max wait time exceeded (spot)',
40
- 'Interrupted': 'Spot instance interrupted'
41
- }
42
-
43
-
44
- def format_duration(seconds):
45
- """Format seconds into a human-readable duration string."""
46
- if seconds is None or seconds < 0:
47
- return 'N/A'
48
- hours = int(seconds // 3600)
49
- minutes = int((seconds % 3600) // 60)
50
- secs = int(seconds % 60)
51
- if hours > 0:
52
- return f'{hours}h {minutes}m {secs}s'
53
- elif minutes > 0:
54
- return f'{minutes}m {secs}s'
55
- else:
56
- return f'{secs}s'
57
-
58
-
59
- def parse_iso_time(time_str):
60
- """Parse an ISO 8601 timestamp string to a datetime object."""
61
- if not time_str:
62
- return None
63
- try:
64
- # Handle various AWS timestamp formats
65
- # Remove trailing 'Z' and replace with +00:00 for fromisoformat
66
- time_str = time_str.replace('Z', '+00:00')
67
- return datetime.fromisoformat(time_str)
68
- except (ValueError, TypeError):
69
- return None
70
-
71
-
72
- def calculate_elapsed(start_time_str):
73
- """Calculate elapsed time from start to now."""
74
- start = parse_iso_time(start_time_str)
75
- if not start:
76
- return None
77
- now = datetime.now(timezone.utc)
78
- elapsed = (now - start).total_seconds()
79
- return max(0, elapsed)
80
-
81
-
82
- def display_status(job_data):
83
- """Display formatted training job status."""
84
- job_name = job_data.get('TrainingJobName', 'Unknown')
85
- status = job_data.get('TrainingJobStatus', 'Unknown')
86
- secondary_status = job_data.get('SecondaryStatus', '')
87
- failure_reason = job_data.get('FailureReason', '')
88
- training_start = job_data.get('TrainingStartTime', '')
89
- training_end = job_data.get('TrainingEndTime', '')
90
- billable_seconds = job_data.get('BillableTimeInSeconds')
91
- training_seconds = job_data.get('TrainingTimeInSeconds')
92
- final_metrics = job_data.get('FinalMetricDataList', [])
93
- output_path = job_data.get('OutputDataConfig', {}).get('S3OutputPath', '')
94
- model_artifacts = job_data.get('ModelArtifacts', {}).get('S3ModelArtifacts', '')
95
- instance_type = job_data.get('ResourceConfig', {}).get('InstanceType', '')
96
- instance_count = job_data.get('ResourceConfig', {}).get('InstanceCount', 1)
97
- spot_enabled = job_data.get('EnableManagedSpotTraining', False)
98
-
99
- emoji = STATUS_EMOJI.get(status, '❓')
100
-
101
- print(f'')
102
- print(f' {emoji} Status: {status}')
103
-
104
- # Secondary status with description
105
- if secondary_status:
106
- desc = SECONDARY_DESCRIPTIONS.get(secondary_status, '')
107
- if desc:
108
- print(f' 📍 Phase: {secondary_status} ({desc})')
109
- else:
110
- print(f' 📍 Phase: {secondary_status}')
111
-
112
- # Elapsed time
113
- if status == 'InProgress' and training_start:
114
- elapsed = calculate_elapsed(training_start)
115
- if elapsed is not None:
116
- print(f' ⏱️ Elapsed: {format_duration(elapsed)}')
117
- elif training_seconds is not None:
118
- print(f' ⏱️ Training time: {format_duration(training_seconds)}')
119
-
120
- # Instance info
121
- if instance_type:
122
- instance_info = f'{instance_type}'
123
- if instance_count and instance_count > 1:
124
- instance_info += f' x {instance_count}'
125
- if spot_enabled:
126
- instance_info += ' (spot)'
127
- print(f' 🖥️ Instance: {instance_info}')
128
-
129
- # Billable time and cost savings (for completed spot jobs)
130
- if status == 'Completed' and spot_enabled and billable_seconds is not None and training_seconds is not None:
131
- savings_seconds = training_seconds - billable_seconds
132
- if training_seconds > 0:
133
- savings_pct = (savings_seconds / training_seconds) * 100
134
- print(f' 💰 Spot savings: {format_duration(savings_seconds)} saved ({savings_pct:.0f}% discount)')
135
- print(f' Billable: {format_duration(billable_seconds)} / Total: {format_duration(training_seconds)}')
136
-
137
- # Training metrics
138
- if final_metrics:
139
- print(f' 📈 Metrics:')
140
- for metric in final_metrics:
141
- name = metric.get('MetricName', 'unknown')
142
- value = metric.get('Value', 0)
143
- # Format value nicely
144
- if isinstance(value, float):
145
- if abs(value) < 0.001:
146
- print(f' {name}: {value:.6f}')
147
- elif abs(value) < 1:
148
- print(f' {name}: {value:.4f}')
149
- else:
150
- print(f' {name}: {value:.2f}')
151
- else:
152
- print(f' {name}: {value}')
153
-
154
- # Output artifacts (for completed jobs)
155
- if status == 'Completed' and model_artifacts:
156
- print(f' 📦 Artifacts: {model_artifacts}')
157
- elif status == 'Completed' and output_path:
158
- print(f' 📦 Output: {output_path}')
159
-
160
- # Failure reason
161
- if status == 'Failed' and failure_reason:
162
- print(f' 💥 Reason: {failure_reason}')
163
- print(f'')
164
- print(f' To start a new job: ./do/train --force')
165
-
166
- # Spot interruption guidance
167
- if secondary_status == 'Interrupted':
168
- print(f'')
169
- print(f' ℹ️ Spot instance was interrupted. The job will automatically')
170
- print(f' resume from the last checkpoint. Re-run ./do/train to poll.')
171
-
172
- print(f'')
173
-
174
-
175
- def main():
176
- """Main entry point — reads JSON from stdin."""
177
- try:
178
- job_data = json.load(sys.stdin)
179
- except json.JSONDecodeError as e:
180
- print(f'❌ Failed to parse DescribeTrainingJob response: {e}', file=sys.stderr)
181
- sys.exit(1)
182
-
183
- display_status(job_data)
184
-
185
-
186
- if __name__ == '__main__':
187
- main()