@aws/ml-container-creator 1.0.3 β†’ 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py β†’ custom/train.py} +0 -0
@@ -1,120 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Train Request Builder
6
- *
7
- * JavaScript module that replicates the Python helper's (.train_build_request.py)
8
- * logic for constructing a CreateTrainingJob JSON request from a parsed config.
9
- *
10
- * This module mirrors the behavior of the Python build_request() function,
11
- * providing a testable implementation of the config-to-API mapping logic.
12
- */
13
-
14
- /**
15
- * Build a CreateTrainingJob request from a parsed training config.
16
- *
17
- * Maps config fields to the SageMaker CreateTrainingJob API structure:
18
- * - image β†’ AlgorithmSpecification.TrainingImage
19
- * - instance_type β†’ ResourceConfig.InstanceType
20
- * - instance_count β†’ ResourceConfig.InstanceCount
21
- * - dataset β†’ InputDataConfig[0].DataSource.S3DataSource.S3Uri
22
- * - output_path β†’ OutputDataConfig.S3OutputPath
23
- * - hyperparameters β†’ HyperParameters (string key-value pairs)
24
- * - max_runtime_seconds β†’ StoppingCondition.MaxRuntimeInSeconds
25
- * - enable_spot=true β†’ EnableManagedSpotTraining = true
26
- * - enable_spot=true β†’ StoppingCondition.MaxWaitTimeInSeconds
27
- * - checkpoint_path β†’ CheckpointConfig.S3Uri
28
- * - metric_definitions β†’ AlgorithmSpecification.MetricDefinitions
29
- * - environment β†’ Environment
30
- * - tags β†’ Tags (converted from {k:v} to [{Key:k, Value:v}])
31
- *
32
- * @param {object} options - Build options
33
- * @param {string} options.jobName - Training job name
34
- * @param {string} options.roleArn - SageMaker execution role ARN
35
- * @param {object} options.config - Parsed training config (from parseTrainingConfig)
36
- * @returns {object} CreateTrainingJob request body
37
- */
38
- export function buildTrainingJobRequest({ jobName, roleArn, config }) {
39
- const request = {
40
- TrainingJobName: jobName,
41
- RoleArn: roleArn,
42
- AlgorithmSpecification: {
43
- TrainingImage: config.image,
44
- TrainingInputMode: 'File'
45
- },
46
- InputDataConfig: [
47
- {
48
- ChannelName: 'training',
49
- DataSource: {
50
- S3DataSource: {
51
- S3DataType: 'S3Prefix',
52
- S3Uri: config.dataset,
53
- S3DataDistributionType: 'FullyReplicated'
54
- }
55
- }
56
- }
57
- ],
58
- OutputDataConfig: {
59
- S3OutputPath: config.output_path
60
- },
61
- ResourceConfig: {
62
- InstanceType: config.instance_type,
63
- InstanceCount: parseInt(config.instance_count, 10),
64
- VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
65
- },
66
- StoppingCondition: {
67
- MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
68
- }
69
- };
70
-
71
- // Hyperparameters β€” ensure all values are strings (SageMaker requirement)
72
- const hyperparams = config.hyperparameters || {};
73
- if (Object.keys(hyperparams).length > 0) {
74
- request.HyperParameters = {};
75
- for (const [k, v] of Object.entries(hyperparams)) {
76
- request.HyperParameters[String(k)] = String(v);
77
- }
78
- }
79
-
80
- // Managed spot training
81
- const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
82
- if (enableSpot) {
83
- request.EnableManagedSpotTraining = true;
84
- request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
85
- }
86
-
87
- // Checkpoint configuration (for spot training resumption)
88
- const checkpointPath = config.checkpoint_path || '';
89
- if (checkpointPath) {
90
- request.CheckpointConfig = {
91
- S3Uri: checkpointPath
92
- };
93
- }
94
-
95
- // Metric definitions (custom CloudWatch metrics)
96
- const metricDefs = config.metric_definitions || [];
97
- if (Array.isArray(metricDefs) && metricDefs.length > 0) {
98
- request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
99
- Name: m.name,
100
- Regex: m.regex
101
- }));
102
- }
103
-
104
- // Environment variables for the container
105
- const environment = config.environment || {};
106
- if (Object.keys(environment).length > 0) {
107
- request.Environment = environment;
108
- }
109
-
110
- // Tags β€” convert from {key: value} map to [{Key: k, Value: v}] array
111
- const tags = config.tags || {};
112
- if (Object.keys(tags).length > 0) {
113
- request.Tags = Object.entries(tags).map(([k, v]) => ({
114
- Key: String(k),
115
- Value: String(v)
116
- }));
117
- }
118
-
119
- return request;
120
- }
@@ -1,279 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Tune Dataset Validator
6
- *
7
- * Parses dataset arguments (S3 URIs and Hugging Face references) and
8
- * validates JSONL dataset lines against catalog-driven schemas.
9
- *
10
- * Requirements: 3.1, 3.5, 3.6, 3.7, 3.8, 3.10, 3.11, 3.12
11
- */
12
-
13
- /**
14
- * Parse a dataset argument string into a structured object.
15
- * Accepts S3 URIs (`s3://bucket/key`) or Hugging Face references
16
- * (`hf://org/name` or `hf://org/name/split`).
17
- *
18
- * @param {string} datasetStr - The dataset argument string
19
- * @returns {{ valid: boolean, type?: string, bucket?: string, key?: string, org?: string, name?: string, split?: string, error?: string }}
20
- */
21
- export function parseDatasetArg(datasetStr) {
22
- if (!datasetStr || typeof datasetStr !== 'string') {
23
- return {
24
- valid: false,
25
- error: 'Dataset argument is required and must be a non-empty string.'
26
- };
27
- }
28
-
29
- const trimmed = datasetStr.trim();
30
-
31
- if (trimmed.startsWith('s3://')) {
32
- return _parseS3Uri(trimmed);
33
- }
34
-
35
- if (trimmed.startsWith('hf://')) {
36
- return _parseHfReference(trimmed);
37
- }
38
-
39
- return {
40
- valid: false,
41
- error: `Invalid dataset format: "${trimmed}". Expected s3://bucket/key or hf://org/name[/split].`
42
- };
43
- }
44
-
45
- /**
46
- * Validate JSONL lines against a dataset schema from the catalog.
47
- * Inspects only the first 10 lines per requirement.
48
- *
49
- * @param {string[]} lines - Array of JSONL line strings
50
- * @param {Object} schema - The datasetSchema object from the catalog
51
- * @param {string[]} schema.required - Array of required top-level keys
52
- * @param {Object} schema.types - Object mapping key to expected type ("string", "array", "object", "number")
53
- * @returns {{ valid: boolean, error: string|null, lineNumber: number|null, malformedLine: string|null, expectedFormat: string|null }}
54
- */
55
- export function validateDatasetFormat(lines, schema) {
56
- if (!lines || !Array.isArray(lines)) {
57
- return {
58
- valid: false,
59
- error: 'Lines must be provided as an array.',
60
- lineNumber: null,
61
- malformedLine: null,
62
- expectedFormat: _buildExpectedFormat(schema)
63
- };
64
- }
65
-
66
- if (!schema || !schema.required || !Array.isArray(schema.required)) {
67
- return {
68
- valid: false,
69
- error: 'Schema must include a "required" array of keys.',
70
- lineNumber: null,
71
- malformedLine: null,
72
- expectedFormat: null
73
- };
74
- }
75
-
76
- const linesToInspect = lines.slice(0, 10);
77
-
78
- for (let i = 0; i < linesToInspect.length; i++) {
79
- const line = linesToInspect[i];
80
- const lineNumber = i + 1;
81
-
82
- // Skip empty lines
83
- if (!line || line.trim() === '') {
84
- continue;
85
- }
86
-
87
- // Try to parse as JSON
88
- let parsed;
89
- try {
90
- parsed = JSON.parse(line);
91
- } catch (e) {
92
- return {
93
- valid: false,
94
- error: `Line ${lineNumber} is not valid JSON: ${e.message}`,
95
- lineNumber,
96
- malformedLine: line,
97
- expectedFormat: _buildExpectedFormat(schema)
98
- };
99
- }
100
-
101
- // Check that parsed value is an object
102
- if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
103
- return {
104
- valid: false,
105
- error: `Line ${lineNumber} must be a JSON object.`,
106
- lineNumber,
107
- malformedLine: line,
108
- expectedFormat: _buildExpectedFormat(schema)
109
- };
110
- }
111
-
112
- // Check required keys
113
- for (const key of schema.required) {
114
- if (!Object.hasOwn(parsed, key)) {
115
- return {
116
- valid: false,
117
- error: `Line ${lineNumber} is missing required key "${key}".`,
118
- lineNumber,
119
- malformedLine: line,
120
- expectedFormat: _buildExpectedFormat(schema)
121
- };
122
- }
123
- }
124
-
125
- // Check types if specified
126
- if (schema.types) {
127
- for (const [key, expectedType] of Object.entries(schema.types)) {
128
- if (!Object.hasOwn(parsed, key)) {
129
- continue;
130
- }
131
-
132
- const value = parsed[key];
133
- if (!_checkType(value, expectedType)) {
134
- return {
135
- valid: false,
136
- error: `Line ${lineNumber} has key "${key}" with wrong type. Expected "${expectedType}", got "${_getType(value)}".`,
137
- lineNumber,
138
- malformedLine: line,
139
- expectedFormat: _buildExpectedFormat(schema)
140
- };
141
- }
142
- }
143
- }
144
- }
145
-
146
- return {
147
- valid: true,
148
- error: null,
149
- lineNumber: null,
150
- malformedLine: null,
151
- expectedFormat: null
152
- };
153
- }
154
-
155
- /**
156
- * Parse an S3 URI into bucket and key components.
157
- * @param {string} uri - The S3 URI (e.g., "s3://bucket/path/to/file.jsonl")
158
- * @returns {Object} Parsed result
159
- * @private
160
- */
161
- function _parseS3Uri(uri) {
162
- const withoutScheme = uri.slice(5); // Remove "s3://"
163
- const slashIndex = withoutScheme.indexOf('/');
164
-
165
- if (slashIndex === -1 || slashIndex === 0) {
166
- return {
167
- valid: false,
168
- error: `Invalid S3 URI: "${uri}". Expected format: s3://bucket/key.`
169
- };
170
- }
171
-
172
- const bucket = withoutScheme.slice(0, slashIndex);
173
- const key = withoutScheme.slice(slashIndex + 1);
174
-
175
- if (!bucket) {
176
- return {
177
- valid: false,
178
- error: `Invalid S3 URI: "${uri}". Bucket name is empty.`
179
- };
180
- }
181
-
182
- if (!key) {
183
- return {
184
- valid: false,
185
- error: `Invalid S3 URI: "${uri}". Key path is empty.`
186
- };
187
- }
188
-
189
- return {
190
- valid: true,
191
- type: 's3',
192
- bucket,
193
- key
194
- };
195
- }
196
-
197
- /**
198
- * Parse a Hugging Face dataset reference into org, name, and split.
199
- * Defaults to 'train' split if not specified.
200
- * @param {string} ref - The HF reference (e.g., "hf://org/name" or "hf://org/name/split")
201
- * @returns {Object} Parsed result
202
- * @private
203
- */
204
- function _parseHfReference(ref) {
205
- const withoutScheme = ref.slice(5); // Remove "hf://"
206
- const parts = withoutScheme.split('/');
207
-
208
- if (parts.length < 2 || !parts[0] || !parts[1]) {
209
- return {
210
- valid: false,
211
- error: `Invalid Hugging Face reference: "${ref}". Expected format: hf://org/name[/split].`
212
- };
213
- }
214
-
215
- const org = parts[0];
216
- const name = parts[1];
217
- const split = parts.length >= 3 && parts[2] ? parts[2] : 'train';
218
-
219
- return {
220
- valid: true,
221
- type: 'hf',
222
- org,
223
- name,
224
- split
225
- };
226
- }
227
-
228
- /**
229
- * Check if a value matches the expected schema type.
230
- * @param {*} value - The value to check
231
- * @param {string} expectedType - One of "string", "array", "object", "number"
232
- * @returns {boolean} True if the value matches the expected type
233
- * @private
234
- */
235
- function _checkType(value, expectedType) {
236
- switch (expectedType) {
237
- case 'string':
238
- return typeof value === 'string';
239
- case 'number':
240
- return typeof value === 'number';
241
- case 'array':
242
- return Array.isArray(value);
243
- case 'object':
244
- return typeof value === 'object' && value !== null && !Array.isArray(value);
245
- default:
246
- return true;
247
- }
248
- }
249
-
250
- /**
251
- * Get a human-readable type name for a value.
252
- * @param {*} value - The value to describe
253
- * @returns {string} The type name
254
- * @private
255
- */
256
- function _getType(value) {
257
- if (value === null) return 'null';
258
- if (Array.isArray(value)) return 'array';
259
- return typeof value;
260
- }
261
-
262
- /**
263
- * Build a human-readable expected format description from a schema.
264
- * @param {Object} schema - The dataset schema
265
- * @returns {string|null} Description of expected format
266
- * @private
267
- */
268
- function _buildExpectedFormat(schema) {
269
- if (!schema || !schema.required) {
270
- return null;
271
- }
272
-
273
- const fields = schema.required.map(key => {
274
- const type = schema.types && schema.types[key] ? schema.types[key] : 'any';
275
- return `"${key}": <${type}>`;
276
- });
277
-
278
- return `Each line must be a JSON object with: {${fields.join(', ')}}`;
279
- }
@@ -1,66 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Tune Output Resolver
6
- *
7
- * Detects output type from training type and generates context-aware
8
- * next-step commands for deploying tune job artifacts.
9
- *
10
- * Requirements: 8.3, 8.11
11
- */
12
-
13
- /**
14
- * Detect the output type based on the training type used for the job.
15
- * LoRA training produces adapter weights; full-rank produces a full model.
16
- *
17
- * @param {string} trainingType - The training type ('lora' or 'full-rank')
18
- * @returns {string} The output type: 'adapter' for lora, 'full-model' for full-rank
19
- */
20
- export function detectOutputType(trainingType) {
21
- if (trainingType === 'lora') {
22
- return 'adapter';
23
- }
24
- if (trainingType === 'full-rank') {
25
- return 'full-model';
26
- }
27
- return 'adapter';
28
- }
29
-
30
- /**
31
- * Generate context-aware next-step commands based on the output type.
32
- *
33
- * For adapter output:
34
- * - Quick path: ./do/adapter add tuned-${technique} --from-tune
35
- * - Technique-specific: ./do/adapter add tuned-${technique} --from-tune ${technique}
36
- * - Explicit path: ./do/adapter add tuned-${technique} --weights ${artifactPath}
37
- *
38
- * For full-model output:
39
- * - Deploy as new IC: ./do/add-ic tuned-v1 --from-tune
40
- * - Explicit path: ./do/add-ic tuned-v1 --model-data ${artifactPath}
41
- * - Replace current base: ./do/deploy --force-ic --model-data ${artifactPath}
42
- *
43
- * @param {string} outputType - The output type ('adapter' or 'full-model')
44
- * @param {string} technique - The technique used (e.g., 'sft', 'dpo')
45
- * @param {string} artifactPath - The S3 path to the output artifact
46
- * @returns {string[]} Array of suggested next-step commands
47
- */
48
- export function generateNextStepCommands(outputType, technique, artifactPath) {
49
- if (outputType === 'adapter') {
50
- return [
51
- `./do/adapter add tuned-${technique} --from-tune`,
52
- `./do/adapter add tuned-${technique} --from-tune ${technique}`,
53
- `./do/adapter add tuned-${technique} --weights ${artifactPath}`
54
- ];
55
- }
56
-
57
- if (outputType === 'full-model') {
58
- return [
59
- './do/add-ic tuned-v1 --from-tune',
60
- `./do/add-ic tuned-v1 --model-data ${artifactPath}`,
61
- `./do/deploy --force-ic --model-data ${artifactPath}`
62
- ];
63
- }
64
-
65
- return [];
66
- }
@@ -1,135 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """
6
- Parse DescribeTrainingJob JSON for the polling loop in do/train.
7
-
8
- Reads JSON from stdin and outputs structured key=value lines for bash consumption:
9
- STATUS=<TrainingJobStatus>
10
- SECONDARY=<SecondaryStatus>
11
- FAILURE_REASON=<FailureReason or empty>
12
- DISPLAY=<formatted single-line status display>
13
-
14
- This keeps the bash poll loop simple while handling JSON parsing in Python.
15
- """
16
-
17
- import json
18
- import sys
19
- from datetime import datetime, timezone
20
-
21
-
22
- def format_duration(seconds):
23
- """Format seconds into a human-readable duration string."""
24
- if seconds is None or seconds < 0:
25
- return 'N/A'
26
- hours = int(seconds // 3600)
27
- minutes = int((seconds % 3600) // 60)
28
- secs = int(seconds % 60)
29
- if hours > 0:
30
- return f'{hours}h {minutes}m {secs}s'
31
- elif minutes > 0:
32
- return f'{minutes}m {secs}s'
33
- else:
34
- return f'{secs}s'
35
-
36
-
37
- def parse_iso_time(time_str):
38
- """Parse an ISO 8601 timestamp string to a datetime object."""
39
- if not time_str:
40
- return None
41
- try:
42
- time_str = time_str.replace('Z', '+00:00')
43
- return datetime.fromisoformat(time_str)
44
- except (ValueError, TypeError):
45
- return None
46
-
47
-
48
- def calculate_elapsed(start_time_str):
49
- """Calculate elapsed time from start to now."""
50
- start = parse_iso_time(start_time_str)
51
- if not start:
52
- return None
53
- now = datetime.now(timezone.utc)
54
- elapsed = (now - start).total_seconds()
55
- return max(0, elapsed)
56
-
57
-
58
- def format_metrics(final_metrics):
59
- """Format FinalMetricDataList into a compact string."""
60
- if not final_metrics:
61
- return ''
62
- parts = []
63
- for metric in final_metrics:
64
- name = metric.get('MetricName', 'unknown')
65
- value = metric.get('Value', 0)
66
- if isinstance(value, float):
67
- if abs(value) < 0.001:
68
- parts.append(f'{name}={value:.6f}')
69
- elif abs(value) < 1:
70
- parts.append(f'{name}={value:.4f}')
71
- else:
72
- parts.append(f'{name}={value:.2f}')
73
- else:
74
- parts.append(f'{name}={value}')
75
- return ', '.join(parts)
76
-
77
-
78
- # Status emoji mapping
79
- STATUS_EMOJI = {
80
- 'InProgress': 'πŸ”„',
81
- 'Completed': 'βœ…',
82
- 'Failed': '❌',
83
- 'Stopping': '⏸️',
84
- 'Stopped': '⏹️'
85
- }
86
-
87
-
88
- def main():
89
- """Parse DescribeTrainingJob JSON from stdin and output structured lines."""
90
- try:
91
- job_data = json.load(sys.stdin)
92
- except json.JSONDecodeError as e:
93
- print(f'Error parsing JSON: {e}', file=sys.stderr)
94
- sys.exit(1)
95
-
96
- status = job_data.get('TrainingJobStatus', 'Unknown')
97
- secondary_status = job_data.get('SecondaryStatus', '')
98
- failure_reason = job_data.get('FailureReason', '')
99
- training_start = job_data.get('TrainingStartTime', '')
100
- final_metrics = job_data.get('FinalMetricDataList', [])
101
-
102
- # Calculate elapsed time
103
- elapsed_str = ''
104
- if training_start:
105
- elapsed = calculate_elapsed(training_start)
106
- if elapsed is not None:
107
- elapsed_str = format_duration(elapsed)
108
-
109
- # Format metrics
110
- metrics_str = format_metrics(final_metrics)
111
-
112
- # Build display line
113
- emoji = STATUS_EMOJI.get(status, '❓')
114
- display_parts = [f' {emoji} {status}']
115
-
116
- if secondary_status:
117
- display_parts.append(f'| {secondary_status}')
118
-
119
- if elapsed_str:
120
- display_parts.append(f'| elapsed: {elapsed_str}')
121
-
122
- if metrics_str:
123
- display_parts.append(f'| {metrics_str}')
124
-
125
- display_line = ' '.join(display_parts)
126
-
127
- # Output structured lines for bash
128
- print(f'STATUS={status}')
129
- print(f'SECONDARY={secondary_status}')
130
- print(f'FAILURE_REASON={failure_reason}')
131
- print(f'DISPLAY={display_line}')
132
-
133
-
134
- if __name__ == '__main__':
135
- main()