@aws/ml-container-creator 0.7.1 â 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/bin/cli.js +1 -1
- package/infra/ci-harness/buildspec.yml +4 -0
- package/package.json +3 -1
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/model-servers.json +80 -0
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/servers/model-picker/index.js +27 -16
- package/src/app.js +113 -23
- package/src/lib/cli-handler.js +1 -1
- package/src/lib/config-manager.js +39 -2
- package/src/lib/cross-cutting-checker.js +146 -33
- package/src/lib/deployment-config-resolver.js +10 -4
- package/src/lib/e2e-bootstrap.js +227 -0
- package/src/lib/e2e-catalog-validator.js +103 -0
- package/src/lib/e2e-quota-validator.js +135 -0
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +306 -24
- package/src/lib/prompts.js +9 -3
- package/src/lib/template-manager.js +10 -4
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/src/lib/tune-catalog-validator.js +5 -5
- package/templates/code/serve +2 -2
- package/templates/code/serving.properties +2 -2
- package/templates/diffusors/serve +3 -3
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/.tune_helper.py +2 -2
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/register +8 -2
- package/templates/do/test +5 -5
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
- package/templates/do/tune +2 -2
- package/templates/marketplace/config +118 -0
- package/templates/marketplace/deploy +890 -0
- package/templates/marketplace/test +453 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
|
|
14
14
|
/**
|
|
15
15
|
* Look up a model entry in the catalog by model ID.
|
|
16
|
-
* @param {string} modelId - The
|
|
16
|
+
* @param {string} modelId - The model ID to look up
|
|
17
17
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
18
18
|
* @returns {Object|null} The catalog entry for the model, or null if not found
|
|
19
19
|
*/
|
|
@@ -29,7 +29,7 @@ export function lookupModel(modelId, catalog) {
|
|
|
29
29
|
|
|
30
30
|
/**
|
|
31
31
|
* Check whether a model ID is present in the Supported Model Catalog.
|
|
32
|
-
* @param {string} modelId - The
|
|
32
|
+
* @param {string} modelId - The model ID to check
|
|
33
33
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
34
34
|
* @returns {boolean} True if the model is in the catalog
|
|
35
35
|
*/
|
|
@@ -41,7 +41,7 @@ export function isTuneSupported(modelId, catalog) {
|
|
|
41
41
|
* Validate that a model ID exists in the catalog.
|
|
42
42
|
* Returns a descriptive error when the model is not supported, including
|
|
43
43
|
* the model name, supported families, and a reference to `do/train`.
|
|
44
|
-
* @param {string} modelId - The
|
|
44
|
+
* @param {string} modelId - The model ID to validate
|
|
45
45
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
46
46
|
* @returns {{ valid: boolean, error?: string }}
|
|
47
47
|
*/
|
|
@@ -65,7 +65,7 @@ export function validateModel(modelId, catalog) {
|
|
|
65
65
|
* Validate that a technique is supported for the given model.
|
|
66
66
|
* Returns a descriptive error listing the supported techniques when
|
|
67
67
|
* the requested technique is not available.
|
|
68
|
-
* @param {string} modelId - The
|
|
68
|
+
* @param {string} modelId - The model ID
|
|
69
69
|
* @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
|
|
70
70
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
71
71
|
* @returns {{ valid: boolean, error?: string }}
|
|
@@ -92,7 +92,7 @@ export function validateTechnique(modelId, technique, catalog) {
|
|
|
92
92
|
* Validate that a training type is supported for the given model and technique.
|
|
93
93
|
* Returns a descriptive error listing the supported training types when
|
|
94
94
|
* the requested type is not available.
|
|
95
|
-
* @param {string} modelId - The
|
|
95
|
+
* @param {string} modelId - The model ID
|
|
96
96
|
* @param {string} technique - The technique (e.g., 'sft', 'dpo')
|
|
97
97
|
* @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
|
|
98
98
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
package/templates/code/serve
CHANGED
|
@@ -113,7 +113,7 @@ resolve_model() {
|
|
|
113
113
|
echo "${!_MODEL_VAR}"
|
|
114
114
|
return
|
|
115
115
|
;;
|
|
116
|
-
s3|
|
|
116
|
+
s3|registry)
|
|
117
117
|
# Check for pre-mounted artifacts first
|
|
118
118
|
if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
|
|
119
119
|
echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
|
|
@@ -245,7 +245,7 @@ ARG_PREFIX="--"
|
|
|
245
245
|
|
|
246
246
|
# Define environment variables to exclude (internal variables set by base images)
|
|
247
247
|
<% if (modelServer === 'vllm') { %>
|
|
248
|
-
EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
|
|
248
|
+
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
|
|
249
249
|
<% } else if (modelServer === 'sglang') { %>
|
|
250
250
|
EXCLUDE_VARS=()
|
|
251
251
|
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
@@ -15,7 +15,7 @@ option.model_id=<%= modelName %>
|
|
|
15
15
|
option.model_id=<%= artifactUri %>
|
|
16
16
|
<% } else { %>
|
|
17
17
|
# Model will be loaded from /opt/ml/model at runtime
|
|
18
|
-
# (
|
|
18
|
+
# (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
|
|
19
19
|
# option.model_id=/opt/ml/model
|
|
20
20
|
<% } %>
|
|
21
21
|
|
|
@@ -71,7 +71,7 @@ option.model_id=<%= modelName %>
|
|
|
71
71
|
option.model_id=<%= artifactUri %>
|
|
72
72
|
<% } else { %>
|
|
73
73
|
# Model will be loaded from /opt/ml/model at runtime
|
|
74
|
-
# (
|
|
74
|
+
# (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
|
|
75
75
|
# option.model_id=/opt/ml/model
|
|
76
76
|
<% } %>
|
|
77
77
|
|
|
@@ -9,10 +9,10 @@ echo "Starting vLLM-Omni server (diffusion model serving)"
|
|
|
9
9
|
|
|
10
10
|
# Resolve model URI prefixes that engines cannot handle natively.
|
|
11
11
|
# The generator's model-picker may store provider-specific URIs
|
|
12
|
-
# (e.g.
|
|
13
|
-
#
|
|
12
|
+
# (e.g. registry://my-model-group/1) as the model identifier.
|
|
13
|
+
# vLLM expects a HuggingFace repo ID or local path.
|
|
14
14
|
_RAW_MODEL="${VLLM_MODEL:-}"
|
|
15
|
-
if [[ "$_RAW_MODEL" ==
|
|
15
|
+
if [[ "$_RAW_MODEL" == registry://* ]]; then
|
|
16
16
|
if [ -d /opt/ml/model ] && [ "$(ls -A /opt/ml/model 2>/dev/null)" ]; then
|
|
17
17
|
echo "Resolved VLLM_MODEL='${_RAW_MODEL}' â /opt/ml/model (local artifacts found)"
|
|
18
18
|
export VLLM_MODEL="/opt/ml/model"
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Build the CreateTrainingJob JSON request for SageMaker.
|
|
7
|
+
|
|
8
|
+
This helper is called by do/train to construct the full API request body.
|
|
9
|
+
It handles conditional fields (spot training, metric definitions, environment,
|
|
10
|
+
tags) and writes the result to a JSON file for use with:
|
|
11
|
+
aws sagemaker create-training-job --cli-input-json file://path.json
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def parse_args():
|
|
20
|
+
"""Parse command-line arguments."""
|
|
21
|
+
parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
|
|
22
|
+
parser.add_argument('--job-name', required=True, help='Training job name')
|
|
23
|
+
parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
|
|
24
|
+
parser.add_argument('--image', required=True, help='Training container image URI')
|
|
25
|
+
parser.add_argument('--instance-type', required=True, help='Instance type')
|
|
26
|
+
parser.add_argument('--instance-count', required=True, help='Instance count')
|
|
27
|
+
parser.add_argument('--volume-size', required=True, help='Volume size in GB')
|
|
28
|
+
parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
|
|
29
|
+
parser.add_argument('--output-path', required=True, help='S3 URI for output')
|
|
30
|
+
parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
|
|
31
|
+
parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
|
|
32
|
+
parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
|
|
33
|
+
parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
|
|
34
|
+
parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
|
|
35
|
+
parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
|
|
36
|
+
parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
|
|
37
|
+
parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
|
|
38
|
+
parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
|
|
39
|
+
return parser.parse_args()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def build_request(args):
|
|
43
|
+
"""Construct the CreateTrainingJob request dictionary."""
|
|
44
|
+
# Parse JSON inputs
|
|
45
|
+
hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
|
|
46
|
+
metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
|
|
47
|
+
environment = json.loads(args.environment) if args.environment else {}
|
|
48
|
+
tags = json.loads(args.tags) if args.tags else {}
|
|
49
|
+
|
|
50
|
+
# Base request structure
|
|
51
|
+
request = {
|
|
52
|
+
'TrainingJobName': args.job_name,
|
|
53
|
+
'RoleArn': args.role_arn,
|
|
54
|
+
'AlgorithmSpecification': {
|
|
55
|
+
'TrainingImage': args.image,
|
|
56
|
+
'TrainingInputMode': 'File'
|
|
57
|
+
},
|
|
58
|
+
'InputDataConfig': [
|
|
59
|
+
{
|
|
60
|
+
'ChannelName': 'training',
|
|
61
|
+
'DataSource': {
|
|
62
|
+
'S3DataSource': {
|
|
63
|
+
'S3DataType': 'S3Prefix',
|
|
64
|
+
'S3Uri': args.dataset,
|
|
65
|
+
'S3DataDistributionType': 'FullyReplicated'
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
],
|
|
70
|
+
'OutputDataConfig': {
|
|
71
|
+
'S3OutputPath': args.output_path
|
|
72
|
+
},
|
|
73
|
+
'ResourceConfig': {
|
|
74
|
+
'InstanceType': args.instance_type,
|
|
75
|
+
'InstanceCount': int(args.instance_count),
|
|
76
|
+
'VolumeSizeInGB': int(args.volume_size)
|
|
77
|
+
},
|
|
78
|
+
'StoppingCondition': {
|
|
79
|
+
'MaxRuntimeInSeconds': int(args.max_runtime)
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# Hyperparameters â ensure all values are strings (SageMaker requirement)
|
|
84
|
+
if hyperparams:
|
|
85
|
+
request['HyperParameters'] = {
|
|
86
|
+
str(k): str(v) for k, v in hyperparams.items()
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Managed spot training
|
|
90
|
+
if args.enable_spot == 'true':
|
|
91
|
+
request['EnableManagedSpotTraining'] = True
|
|
92
|
+
request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
|
|
93
|
+
|
|
94
|
+
# Checkpoint configuration (for spot training resumption)
|
|
95
|
+
if args.checkpoint_path:
|
|
96
|
+
request['CheckpointConfig'] = {
|
|
97
|
+
'S3Uri': args.checkpoint_path
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Metric definitions (custom CloudWatch metrics)
|
|
101
|
+
if metric_definitions and metric_definitions != []:
|
|
102
|
+
request['AlgorithmSpecification']['MetricDefinitions'] = [
|
|
103
|
+
{'Name': m['name'], 'Regex': m['regex']}
|
|
104
|
+
for m in metric_definitions
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# Environment variables for the container
|
|
108
|
+
if environment and environment != {}:
|
|
109
|
+
request['Environment'] = environment
|
|
110
|
+
|
|
111
|
+
# Tags â convert from {key: value} map to [{Key: k, Value: v}] array
|
|
112
|
+
if tags and tags != {}:
|
|
113
|
+
request['Tags'] = [
|
|
114
|
+
{'Key': str(k), 'Value': str(v)}
|
|
115
|
+
for k, v in tags.items()
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
return request
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def main():
|
|
122
|
+
"""Main entry point."""
|
|
123
|
+
args = parse_args()
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
request = build_request(args)
|
|
127
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
128
|
+
print(f'â Failed to build request: {e}', file=sys.stderr)
|
|
129
|
+
sys.exit(1)
|
|
130
|
+
|
|
131
|
+
# Write the JSON request to the output file
|
|
132
|
+
try:
|
|
133
|
+
with open(args.output_file, 'w') as f:
|
|
134
|
+
json.dump(request, f, indent=2)
|
|
135
|
+
except IOError as e:
|
|
136
|
+
print(f'â Failed to write request file: {e}', file=sys.stderr)
|
|
137
|
+
sys.exit(1)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == '__main__':
|
|
141
|
+
main()
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Parse DescribeTrainingJob JSON for the polling loop in do/train.
|
|
7
|
+
|
|
8
|
+
Reads JSON from stdin and outputs structured key=value lines for bash consumption:
|
|
9
|
+
STATUS=<TrainingJobStatus>
|
|
10
|
+
SECONDARY=<SecondaryStatus>
|
|
11
|
+
FAILURE_REASON=<FailureReason or empty>
|
|
12
|
+
DISPLAY=<formatted single-line status display>
|
|
13
|
+
|
|
14
|
+
This keeps the bash poll loop simple while handling JSON parsing in Python.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import sys
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def format_duration(seconds):
|
|
23
|
+
"""Format seconds into a human-readable duration string."""
|
|
24
|
+
if seconds is None or seconds < 0:
|
|
25
|
+
return 'N/A'
|
|
26
|
+
hours = int(seconds // 3600)
|
|
27
|
+
minutes = int((seconds % 3600) // 60)
|
|
28
|
+
secs = int(seconds % 60)
|
|
29
|
+
if hours > 0:
|
|
30
|
+
return f'{hours}h {minutes}m {secs}s'
|
|
31
|
+
elif minutes > 0:
|
|
32
|
+
return f'{minutes}m {secs}s'
|
|
33
|
+
else:
|
|
34
|
+
return f'{secs}s'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def parse_iso_time(time_str):
|
|
38
|
+
"""Parse an ISO 8601 timestamp string to a datetime object."""
|
|
39
|
+
if not time_str:
|
|
40
|
+
return None
|
|
41
|
+
try:
|
|
42
|
+
time_str = time_str.replace('Z', '+00:00')
|
|
43
|
+
return datetime.fromisoformat(time_str)
|
|
44
|
+
except (ValueError, TypeError):
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def calculate_elapsed(start_time_str):
|
|
49
|
+
"""Calculate elapsed time from start to now."""
|
|
50
|
+
start = parse_iso_time(start_time_str)
|
|
51
|
+
if not start:
|
|
52
|
+
return None
|
|
53
|
+
now = datetime.now(timezone.utc)
|
|
54
|
+
elapsed = (now - start).total_seconds()
|
|
55
|
+
return max(0, elapsed)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def format_metrics(final_metrics):
|
|
59
|
+
"""Format FinalMetricDataList into a compact string."""
|
|
60
|
+
if not final_metrics:
|
|
61
|
+
return ''
|
|
62
|
+
parts = []
|
|
63
|
+
for metric in final_metrics:
|
|
64
|
+
name = metric.get('MetricName', 'unknown')
|
|
65
|
+
value = metric.get('Value', 0)
|
|
66
|
+
if isinstance(value, float):
|
|
67
|
+
if abs(value) < 0.001:
|
|
68
|
+
parts.append(f'{name}={value:.6f}')
|
|
69
|
+
elif abs(value) < 1:
|
|
70
|
+
parts.append(f'{name}={value:.4f}')
|
|
71
|
+
else:
|
|
72
|
+
parts.append(f'{name}={value:.2f}')
|
|
73
|
+
else:
|
|
74
|
+
parts.append(f'{name}={value}')
|
|
75
|
+
return ', '.join(parts)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Status emoji mapping
|
|
79
|
+
STATUS_EMOJI = {
|
|
80
|
+
'InProgress': 'đ',
|
|
81
|
+
'Completed': 'â
',
|
|
82
|
+
'Failed': 'â',
|
|
83
|
+
'Stopping': 'â¸ī¸',
|
|
84
|
+
'Stopped': 'âšī¸'
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def main():
|
|
89
|
+
"""Parse DescribeTrainingJob JSON from stdin and output structured lines."""
|
|
90
|
+
try:
|
|
91
|
+
job_data = json.load(sys.stdin)
|
|
92
|
+
except json.JSONDecodeError as e:
|
|
93
|
+
print(f'Error parsing JSON: {e}', file=sys.stderr)
|
|
94
|
+
sys.exit(1)
|
|
95
|
+
|
|
96
|
+
status = job_data.get('TrainingJobStatus', 'Unknown')
|
|
97
|
+
secondary_status = job_data.get('SecondaryStatus', '')
|
|
98
|
+
failure_reason = job_data.get('FailureReason', '')
|
|
99
|
+
training_start = job_data.get('TrainingStartTime', '')
|
|
100
|
+
final_metrics = job_data.get('FinalMetricDataList', [])
|
|
101
|
+
|
|
102
|
+
# Calculate elapsed time
|
|
103
|
+
elapsed_str = ''
|
|
104
|
+
if training_start:
|
|
105
|
+
elapsed = calculate_elapsed(training_start)
|
|
106
|
+
if elapsed is not None:
|
|
107
|
+
elapsed_str = format_duration(elapsed)
|
|
108
|
+
|
|
109
|
+
# Format metrics
|
|
110
|
+
metrics_str = format_metrics(final_metrics)
|
|
111
|
+
|
|
112
|
+
# Build display line
|
|
113
|
+
emoji = STATUS_EMOJI.get(status, 'â')
|
|
114
|
+
display_parts = [f' {emoji} {status}']
|
|
115
|
+
|
|
116
|
+
if secondary_status:
|
|
117
|
+
display_parts.append(f'| {secondary_status}')
|
|
118
|
+
|
|
119
|
+
if elapsed_str:
|
|
120
|
+
display_parts.append(f'| elapsed: {elapsed_str}')
|
|
121
|
+
|
|
122
|
+
if metrics_str:
|
|
123
|
+
display_parts.append(f'| {metrics_str}')
|
|
124
|
+
|
|
125
|
+
display_line = ' '.join(display_parts)
|
|
126
|
+
|
|
127
|
+
# Output structured lines for bash
|
|
128
|
+
print(f'STATUS={status}')
|
|
129
|
+
print(f'SECONDARY={secondary_status}')
|
|
130
|
+
print(f'FAILURE_REASON={failure_reason}')
|
|
131
|
+
print(f'DISPLAY={display_line}')
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == '__main__':
|
|
135
|
+
main()
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Parse DescribeTrainingJob JSON response and display formatted status.
|
|
7
|
+
|
|
8
|
+
This helper is called by do/train --status to parse the AWS CLI JSON output
|
|
9
|
+
from DescribeTrainingJob and display a user-friendly status summary.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import sys
|
|
14
|
+
import time
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Status emoji mapping
|
|
19
|
+
STATUS_EMOJI = {
|
|
20
|
+
'InProgress': 'đ',
|
|
21
|
+
'Completed': 'â
',
|
|
22
|
+
'Failed': 'â',
|
|
23
|
+
'Stopping': 'â¸ī¸',
|
|
24
|
+
'Stopped': 'âšī¸'
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
# Secondary status descriptions
|
|
28
|
+
SECONDARY_DESCRIPTIONS = {
|
|
29
|
+
'Starting': 'Preparing training instance',
|
|
30
|
+
'LaunchingMLInstances': 'Launching ML instances',
|
|
31
|
+
'PreparingTrainingStack': 'Preparing training stack',
|
|
32
|
+
'Downloading': 'Downloading training data',
|
|
33
|
+
'DownloadingTrainingImage': 'Downloading training image',
|
|
34
|
+
'Training': 'Training in progress',
|
|
35
|
+
'Uploading': 'Uploading model artifacts',
|
|
36
|
+
'Completed': 'Training completed',
|
|
37
|
+
'MaxRuntimeExceeded': 'Max runtime exceeded',
|
|
38
|
+
'Stopped': 'Training stopped',
|
|
39
|
+
'MaxWaitTimeExceeded': 'Max wait time exceeded (spot)',
|
|
40
|
+
'Interrupted': 'Spot instance interrupted'
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def format_duration(seconds):
|
|
45
|
+
"""Format seconds into a human-readable duration string."""
|
|
46
|
+
if seconds is None or seconds < 0:
|
|
47
|
+
return 'N/A'
|
|
48
|
+
hours = int(seconds // 3600)
|
|
49
|
+
minutes = int((seconds % 3600) // 60)
|
|
50
|
+
secs = int(seconds % 60)
|
|
51
|
+
if hours > 0:
|
|
52
|
+
return f'{hours}h {minutes}m {secs}s'
|
|
53
|
+
elif minutes > 0:
|
|
54
|
+
return f'{minutes}m {secs}s'
|
|
55
|
+
else:
|
|
56
|
+
return f'{secs}s'
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def parse_iso_time(time_str):
|
|
60
|
+
"""Parse an ISO 8601 timestamp string to a datetime object."""
|
|
61
|
+
if not time_str:
|
|
62
|
+
return None
|
|
63
|
+
try:
|
|
64
|
+
# Handle various AWS timestamp formats
|
|
65
|
+
# Remove trailing 'Z' and replace with +00:00 for fromisoformat
|
|
66
|
+
time_str = time_str.replace('Z', '+00:00')
|
|
67
|
+
return datetime.fromisoformat(time_str)
|
|
68
|
+
except (ValueError, TypeError):
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def calculate_elapsed(start_time_str):
|
|
73
|
+
"""Calculate elapsed time from start to now."""
|
|
74
|
+
start = parse_iso_time(start_time_str)
|
|
75
|
+
if not start:
|
|
76
|
+
return None
|
|
77
|
+
now = datetime.now(timezone.utc)
|
|
78
|
+
elapsed = (now - start).total_seconds()
|
|
79
|
+
return max(0, elapsed)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def display_status(job_data):
|
|
83
|
+
"""Display formatted training job status."""
|
|
84
|
+
job_name = job_data.get('TrainingJobName', 'Unknown')
|
|
85
|
+
status = job_data.get('TrainingJobStatus', 'Unknown')
|
|
86
|
+
secondary_status = job_data.get('SecondaryStatus', '')
|
|
87
|
+
failure_reason = job_data.get('FailureReason', '')
|
|
88
|
+
training_start = job_data.get('TrainingStartTime', '')
|
|
89
|
+
training_end = job_data.get('TrainingEndTime', '')
|
|
90
|
+
billable_seconds = job_data.get('BillableTimeInSeconds')
|
|
91
|
+
training_seconds = job_data.get('TrainingTimeInSeconds')
|
|
92
|
+
final_metrics = job_data.get('FinalMetricDataList', [])
|
|
93
|
+
output_path = job_data.get('OutputDataConfig', {}).get('S3OutputPath', '')
|
|
94
|
+
model_artifacts = job_data.get('ModelArtifacts', {}).get('S3ModelArtifacts', '')
|
|
95
|
+
instance_type = job_data.get('ResourceConfig', {}).get('InstanceType', '')
|
|
96
|
+
instance_count = job_data.get('ResourceConfig', {}).get('InstanceCount', 1)
|
|
97
|
+
spot_enabled = job_data.get('EnableManagedSpotTraining', False)
|
|
98
|
+
|
|
99
|
+
emoji = STATUS_EMOJI.get(status, 'â')
|
|
100
|
+
|
|
101
|
+
print(f'')
|
|
102
|
+
print(f' {emoji} Status: {status}')
|
|
103
|
+
|
|
104
|
+
# Secondary status with description
|
|
105
|
+
if secondary_status:
|
|
106
|
+
desc = SECONDARY_DESCRIPTIONS.get(secondary_status, '')
|
|
107
|
+
if desc:
|
|
108
|
+
print(f' đ Phase: {secondary_status} ({desc})')
|
|
109
|
+
else:
|
|
110
|
+
print(f' đ Phase: {secondary_status}')
|
|
111
|
+
|
|
112
|
+
# Elapsed time
|
|
113
|
+
if status == 'InProgress' and training_start:
|
|
114
|
+
elapsed = calculate_elapsed(training_start)
|
|
115
|
+
if elapsed is not None:
|
|
116
|
+
print(f' âąī¸ Elapsed: {format_duration(elapsed)}')
|
|
117
|
+
elif training_seconds is not None:
|
|
118
|
+
print(f' âąī¸ Training time: {format_duration(training_seconds)}')
|
|
119
|
+
|
|
120
|
+
# Instance info
|
|
121
|
+
if instance_type:
|
|
122
|
+
instance_info = f'{instance_type}'
|
|
123
|
+
if instance_count and instance_count > 1:
|
|
124
|
+
instance_info += f' x {instance_count}'
|
|
125
|
+
if spot_enabled:
|
|
126
|
+
instance_info += ' (spot)'
|
|
127
|
+
print(f' đĨī¸ Instance: {instance_info}')
|
|
128
|
+
|
|
129
|
+
# Billable time and cost savings (for completed spot jobs)
|
|
130
|
+
if status == 'Completed' and spot_enabled and billable_seconds is not None and training_seconds is not None:
|
|
131
|
+
savings_seconds = training_seconds - billable_seconds
|
|
132
|
+
if training_seconds > 0:
|
|
133
|
+
savings_pct = (savings_seconds / training_seconds) * 100
|
|
134
|
+
print(f' đ° Spot savings: {format_duration(savings_seconds)} saved ({savings_pct:.0f}% discount)')
|
|
135
|
+
print(f' Billable: {format_duration(billable_seconds)} / Total: {format_duration(training_seconds)}')
|
|
136
|
+
|
|
137
|
+
# Training metrics
|
|
138
|
+
if final_metrics:
|
|
139
|
+
print(f' đ Metrics:')
|
|
140
|
+
for metric in final_metrics:
|
|
141
|
+
name = metric.get('MetricName', 'unknown')
|
|
142
|
+
value = metric.get('Value', 0)
|
|
143
|
+
# Format value nicely
|
|
144
|
+
if isinstance(value, float):
|
|
145
|
+
if abs(value) < 0.001:
|
|
146
|
+
print(f' {name}: {value:.6f}')
|
|
147
|
+
elif abs(value) < 1:
|
|
148
|
+
print(f' {name}: {value:.4f}')
|
|
149
|
+
else:
|
|
150
|
+
print(f' {name}: {value:.2f}')
|
|
151
|
+
else:
|
|
152
|
+
print(f' {name}: {value}')
|
|
153
|
+
|
|
154
|
+
# Output artifacts (for completed jobs)
|
|
155
|
+
if status == 'Completed' and model_artifacts:
|
|
156
|
+
print(f' đĻ Artifacts: {model_artifacts}')
|
|
157
|
+
elif status == 'Completed' and output_path:
|
|
158
|
+
print(f' đĻ Output: {output_path}')
|
|
159
|
+
|
|
160
|
+
# Failure reason
|
|
161
|
+
if status == 'Failed' and failure_reason:
|
|
162
|
+
print(f' đĨ Reason: {failure_reason}')
|
|
163
|
+
print(f'')
|
|
164
|
+
print(f' To start a new job: ./do/train --force')
|
|
165
|
+
|
|
166
|
+
# Spot interruption guidance
|
|
167
|
+
if secondary_status == 'Interrupted':
|
|
168
|
+
print(f'')
|
|
169
|
+
print(f' âšī¸ Spot instance was interrupted. The job will automatically')
|
|
170
|
+
print(f' resume from the last checkpoint. Re-run ./do/train to poll.')
|
|
171
|
+
|
|
172
|
+
print(f'')
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def main():
|
|
176
|
+
"""Main entry point â reads JSON from stdin."""
|
|
177
|
+
try:
|
|
178
|
+
job_data = json.load(sys.stdin)
|
|
179
|
+
except json.JSONDecodeError as e:
|
|
180
|
+
print(f'â Failed to parse DescribeTrainingJob response: {e}', file=sys.stderr)
|
|
181
|
+
sys.exit(1)
|
|
182
|
+
|
|
183
|
+
display_status(job_data)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
if __name__ == '__main__':
|
|
187
|
+
main()
|
|
@@ -176,7 +176,7 @@ def cmd_submit(args):
|
|
|
176
176
|
)
|
|
177
177
|
elif "ValidationException" in error_msg and "license" in error_msg.lower():
|
|
178
178
|
_error_exit(
|
|
179
|
-
f"Model license not accepted. Accept the license
|
|
179
|
+
f"Model license not accepted. Accept the model license before "
|
|
180
180
|
f"using this model for customization. Details: {error_msg}"
|
|
181
181
|
)
|
|
182
182
|
else:
|
|
@@ -660,7 +660,7 @@ def main():
|
|
|
660
660
|
|
|
661
661
|
# ââ submit ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ
|
|
662
662
|
submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
|
|
663
|
-
submit_parser.add_argument("--model-id", required=True, help="
|
|
663
|
+
submit_parser.add_argument("--model-id", required=True, help="Model ID")
|
|
664
664
|
submit_parser.add_argument("--technique", required=True,
|
|
665
665
|
choices=["sft", "dpo", "rlaif", "rlvr"],
|
|
666
666
|
help="Customization technique")
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Shared helper: post-completion feedback loop for training and tuning jobs.
|
|
6
|
+
# Sourced by do/tune and do/train â prints artifact locations and deployment suggestions.
|
|
7
|
+
|
|
8
|
+
# print_completion_feedback()
|
|
9
|
+
# Display completion summary with artifact path and next-step deployment commands.
|
|
10
|
+
# Tailors suggestions based on the detected artifact type (adapter vs full model).
|
|
11
|
+
#
|
|
12
|
+
# Arguments:
|
|
13
|
+
# $1 - output_path: S3 URI to the output artifacts
|
|
14
|
+
# $2 - output_type: "adapter" or "full-model"
|
|
15
|
+
# $3 - job_name: Job name for reference
|
|
16
|
+
# $4 - model_package_arn: (optional) Model package ARN if registered
|
|
17
|
+
print_completion_feedback() {
|
|
18
|
+
local output_path="$1"
|
|
19
|
+
local output_type="$2"
|
|
20
|
+
local job_name="$3"
|
|
21
|
+
local model_package_arn="${4:-}"
|
|
22
|
+
|
|
23
|
+
echo ""
|
|
24
|
+
echo "â
Training complete: ${job_name}"
|
|
25
|
+
echo ""
|
|
26
|
+
echo " Artifacts: ${output_path}"
|
|
27
|
+
if [ -n "${model_package_arn}" ]; then
|
|
28
|
+
echo " Model Package: ${model_package_arn}"
|
|
29
|
+
fi
|
|
30
|
+
echo ""
|
|
31
|
+
echo " Next steps:"
|
|
32
|
+
|
|
33
|
+
if [ "${output_type}" = "adapter" ]; then
|
|
34
|
+
echo " âĸ Deploy as LoRA adapter: ./do/adapter add my-adapter --weights ${output_path}"
|
|
35
|
+
echo " âĸ (Requires running endpoint with LoRA enabled)"
|
|
36
|
+
elif [ "${output_type}" = "full-model" ]; then
|
|
37
|
+
echo " âĸ Deploy as new IC: ./do/add-ic my-model --model-data ${output_path}"
|
|
38
|
+
echo " âĸ Replace current base: ./do/deploy --force-ic --model-data ${output_path}"
|
|
39
|
+
fi
|
|
40
|
+
echo ""
|
|
41
|
+
}
|
package/templates/do/register
CHANGED
|
@@ -191,8 +191,14 @@ fi
|
|
|
191
191
|
# ============================================================
|
|
192
192
|
|
|
193
193
|
# DEPLOYMENT_CONFIG format: <architecture>-<backend> (e.g., transformers-vllm, http-flask, triton-fil)
|
|
194
|
-
|
|
195
|
-
|
|
194
|
+
# Special case: marketplace has no backend
|
|
195
|
+
if [ "${DEPLOYMENT_CONFIG}" = "marketplace" ]; then
|
|
196
|
+
ARCHITECTURE="marketplace"
|
|
197
|
+
BACKEND=""
|
|
198
|
+
else
|
|
199
|
+
ARCHITECTURE="${DEPLOYMENT_CONFIG%%-*}"
|
|
200
|
+
BACKEND="${DEPLOYMENT_CONFIG#*-}"
|
|
201
|
+
fi
|
|
196
202
|
|
|
197
203
|
echo "đ Registering deployment to registry"
|
|
198
204
|
echo " Project: ${PROJECT_NAME}"
|