@aws/ml-container-creator 0.7.1 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/LICENSE-THIRD-PARTY +50760 -16218
  2. package/bin/cli.js +1 -1
  3. package/infra/ci-harness/buildspec.yml +4 -0
  4. package/package.json +3 -1
  5. package/servers/lib/catalogs/instances.json +52 -1275
  6. package/servers/lib/catalogs/model-servers.json +80 -0
  7. package/servers/lib/catalogs/models.json +0 -132
  8. package/servers/lib/catalogs/popular-diffusors.json +1 -110
  9. package/servers/model-picker/index.js +27 -16
  10. package/src/app.js +113 -23
  11. package/src/lib/cli-handler.js +1 -1
  12. package/src/lib/config-manager.js +39 -2
  13. package/src/lib/cross-cutting-checker.js +146 -33
  14. package/src/lib/deployment-config-resolver.js +10 -4
  15. package/src/lib/e2e-bootstrap.js +227 -0
  16. package/src/lib/e2e-catalog-validator.js +103 -0
  17. package/src/lib/e2e-quota-validator.js +135 -0
  18. package/src/lib/mcp-client.js +16 -1
  19. package/src/lib/mcp-command-handler.js +10 -2
  20. package/src/lib/prompt-runner.js +306 -24
  21. package/src/lib/prompts.js +9 -3
  22. package/src/lib/template-manager.js +10 -4
  23. package/src/lib/train-config-parser.js +136 -0
  24. package/src/lib/train-config-persistence.js +143 -0
  25. package/src/lib/train-config-validator.js +112 -0
  26. package/src/lib/train-feedback.js +46 -0
  27. package/src/lib/train-idempotency.js +97 -0
  28. package/src/lib/train-request-builder.js +120 -0
  29. package/src/lib/tune-catalog-validator.js +5 -5
  30. package/templates/code/serve +2 -2
  31. package/templates/code/serving.properties +2 -2
  32. package/templates/diffusors/serve +3 -3
  33. package/templates/do/.train_build_request.py +141 -0
  34. package/templates/do/.train_poll_parser.py +135 -0
  35. package/templates/do/.train_status_parser.py +187 -0
  36. package/templates/do/.tune_helper.py +2 -2
  37. package/templates/do/lib/feedback.sh +41 -0
  38. package/templates/do/register +8 -2
  39. package/templates/do/test +5 -5
  40. package/templates/do/train +786 -0
  41. package/templates/do/training/config.yaml +140 -0
  42. package/templates/do/training/train.py +463 -0
  43. package/templates/do/tune +2 -2
  44. package/templates/marketplace/config +118 -0
  45. package/templates/marketplace/deploy +890 -0
  46. package/templates/marketplace/test +453 -0
@@ -13,7 +13,7 @@
13
13
 
14
14
  /**
15
15
  * Look up a model entry in the catalog by model ID.
16
- * @param {string} modelId - The JumpStart model ID to look up
16
+ * @param {string} modelId - The model ID to look up
17
17
  * @param {Object} catalog - The tune catalog object with a `models` map
18
18
  * @returns {Object|null} The catalog entry for the model, or null if not found
19
19
  */
@@ -29,7 +29,7 @@ export function lookupModel(modelId, catalog) {
29
29
 
30
30
  /**
31
31
  * Check whether a model ID is present in the Supported Model Catalog.
32
- * @param {string} modelId - The JumpStart model ID to check
32
+ * @param {string} modelId - The model ID to check
33
33
  * @param {Object} catalog - The tune catalog object with a `models` map
34
34
  * @returns {boolean} True if the model is in the catalog
35
35
  */
@@ -41,7 +41,7 @@ export function isTuneSupported(modelId, catalog) {
41
41
  * Validate that a model ID exists in the catalog.
42
42
  * Returns a descriptive error when the model is not supported, including
43
43
  * the model name, supported families, and a reference to `do/train`.
44
- * @param {string} modelId - The JumpStart model ID to validate
44
+ * @param {string} modelId - The model ID to validate
45
45
  * @param {Object} catalog - The tune catalog object with a `models` map
46
46
  * @returns {{ valid: boolean, error?: string }}
47
47
  */
@@ -65,7 +65,7 @@ export function validateModel(modelId, catalog) {
65
65
  * Validate that a technique is supported for the given model.
66
66
  * Returns a descriptive error listing the supported techniques when
67
67
  * the requested technique is not available.
68
- * @param {string} modelId - The JumpStart model ID
68
+ * @param {string} modelId - The model ID
69
69
  * @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
70
70
  * @param {Object} catalog - The tune catalog object with a `models` map
71
71
  * @returns {{ valid: boolean, error?: string }}
@@ -92,7 +92,7 @@ export function validateTechnique(modelId, technique, catalog) {
92
92
  * Validate that a training type is supported for the given model and technique.
93
93
  * Returns a descriptive error listing the supported training types when
94
94
  * the requested type is not available.
95
- * @param {string} modelId - The JumpStart model ID
95
+ * @param {string} modelId - The model ID
96
96
  * @param {string} technique - The technique (e.g., 'sft', 'dpo')
97
97
  * @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
98
98
  * @param {Object} catalog - The tune catalog object with a `models` map
@@ -113,7 +113,7 @@ resolve_model() {
113
113
  echo "${!_MODEL_VAR}"
114
114
  return
115
115
  ;;
116
- s3|jumpstart|jumpstart-hub|registry)
116
+ s3|registry)
117
117
  # Check for pre-mounted artifacts first
118
118
  if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
119
119
  echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
@@ -245,7 +245,7 @@ ARG_PREFIX="--"
245
245
 
246
246
  # Define environment variables to exclude (internal variables set by base images)
247
247
  <% if (modelServer === 'vllm') { %>
248
- EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
248
+ EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
249
249
  <% } else if (modelServer === 'sglang') { %>
250
250
  EXCLUDE_VARS=()
251
251
  <% } else if (modelServer === 'tensorrt-llm') { %>
@@ -15,7 +15,7 @@ option.model_id=<%= modelName %>
15
15
  option.model_id=<%= artifactUri %>
16
16
  <% } else { %>
17
17
  # Model will be loaded from /opt/ml/model at runtime
18
- # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
18
+ # (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
19
19
  # option.model_id=/opt/ml/model
20
20
  <% } %>
21
21
 
@@ -71,7 +71,7 @@ option.model_id=<%= modelName %>
71
71
  option.model_id=<%= artifactUri %>
72
72
  <% } else { %>
73
73
  # Model will be loaded from /opt/ml/model at runtime
74
- # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
74
+ # (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
75
75
  # option.model_id=/opt/ml/model
76
76
  <% } %>
77
77
 
@@ -9,10 +9,10 @@ echo "Starting vLLM-Omni server (diffusion model serving)"
9
9
 
10
10
  # Resolve model URI prefixes that engines cannot handle natively.
11
11
  # The generator's model-picker may store provider-specific URIs
12
- # (e.g. jumpstart://model-txt2img-stabilityai-stable-diffusion-v2-1-base)
13
- # as the model identifier. vLLM expects a HuggingFace repo ID or local path.
12
+ # (e.g. registry://my-model-group/1) as the model identifier.
13
+ # vLLM expects a HuggingFace repo ID or local path.
14
14
  _RAW_MODEL="${VLLM_MODEL:-}"
15
- if [[ "$_RAW_MODEL" == jumpstart://* ]] || [[ "$_RAW_MODEL" == jumpstart-hub://* ]] || [[ "$_RAW_MODEL" == registry://* ]]; then
15
+ if [[ "$_RAW_MODEL" == registry://* ]]; then
16
16
  if [ -d /opt/ml/model ] && [ "$(ls -A /opt/ml/model 2>/dev/null)" ]; then
17
17
  echo "Resolved VLLM_MODEL='${_RAW_MODEL}' → /opt/ml/model (local artifacts found)"
18
18
  export VLLM_MODEL="/opt/ml/model"
@@ -0,0 +1,141 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ Build the CreateTrainingJob JSON request for SageMaker.
7
+
8
+ This helper is called by do/train to construct the full API request body.
9
+ It handles conditional fields (spot training, metric definitions, environment,
10
+ tags) and writes the result to a JSON file for use with:
11
+ aws sagemaker create-training-job --cli-input-json file://path.json
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import sys
17
+
18
+
19
+ def parse_args():
20
+ """Parse command-line arguments."""
21
+ parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
22
+ parser.add_argument('--job-name', required=True, help='Training job name')
23
+ parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
24
+ parser.add_argument('--image', required=True, help='Training container image URI')
25
+ parser.add_argument('--instance-type', required=True, help='Instance type')
26
+ parser.add_argument('--instance-count', required=True, help='Instance count')
27
+ parser.add_argument('--volume-size', required=True, help='Volume size in GB')
28
+ parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
29
+ parser.add_argument('--output-path', required=True, help='S3 URI for output')
30
+ parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
31
+ parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
32
+ parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
33
+ parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
34
+ parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
35
+ parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
36
+ parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
37
+ parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
38
+ parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
39
+ return parser.parse_args()
40
+
41
+
42
+ def build_request(args):
43
+ """Construct the CreateTrainingJob request dictionary."""
44
+ # Parse JSON inputs
45
+ hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
46
+ metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
47
+ environment = json.loads(args.environment) if args.environment else {}
48
+ tags = json.loads(args.tags) if args.tags else {}
49
+
50
+ # Base request structure
51
+ request = {
52
+ 'TrainingJobName': args.job_name,
53
+ 'RoleArn': args.role_arn,
54
+ 'AlgorithmSpecification': {
55
+ 'TrainingImage': args.image,
56
+ 'TrainingInputMode': 'File'
57
+ },
58
+ 'InputDataConfig': [
59
+ {
60
+ 'ChannelName': 'training',
61
+ 'DataSource': {
62
+ 'S3DataSource': {
63
+ 'S3DataType': 'S3Prefix',
64
+ 'S3Uri': args.dataset,
65
+ 'S3DataDistributionType': 'FullyReplicated'
66
+ }
67
+ }
68
+ }
69
+ ],
70
+ 'OutputDataConfig': {
71
+ 'S3OutputPath': args.output_path
72
+ },
73
+ 'ResourceConfig': {
74
+ 'InstanceType': args.instance_type,
75
+ 'InstanceCount': int(args.instance_count),
76
+ 'VolumeSizeInGB': int(args.volume_size)
77
+ },
78
+ 'StoppingCondition': {
79
+ 'MaxRuntimeInSeconds': int(args.max_runtime)
80
+ }
81
+ }
82
+
83
+ # Hyperparameters — ensure all values are strings (SageMaker requirement)
84
+ if hyperparams:
85
+ request['HyperParameters'] = {
86
+ str(k): str(v) for k, v in hyperparams.items()
87
+ }
88
+
89
+ # Managed spot training
90
+ if args.enable_spot == 'true':
91
+ request['EnableManagedSpotTraining'] = True
92
+ request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
93
+
94
+ # Checkpoint configuration (for spot training resumption)
95
+ if args.checkpoint_path:
96
+ request['CheckpointConfig'] = {
97
+ 'S3Uri': args.checkpoint_path
98
+ }
99
+
100
+ # Metric definitions (custom CloudWatch metrics)
101
+ if metric_definitions and metric_definitions != []:
102
+ request['AlgorithmSpecification']['MetricDefinitions'] = [
103
+ {'Name': m['name'], 'Regex': m['regex']}
104
+ for m in metric_definitions
105
+ ]
106
+
107
+ # Environment variables for the container
108
+ if environment and environment != {}:
109
+ request['Environment'] = environment
110
+
111
+ # Tags — convert from {key: value} map to [{Key: k, Value: v}] array
112
+ if tags and tags != {}:
113
+ request['Tags'] = [
114
+ {'Key': str(k), 'Value': str(v)}
115
+ for k, v in tags.items()
116
+ ]
117
+
118
+ return request
119
+
120
+
121
+ def main():
122
+ """Main entry point."""
123
+ args = parse_args()
124
+
125
+ try:
126
+ request = build_request(args)
127
+ except (json.JSONDecodeError, ValueError) as e:
128
+ print(f'❌ Failed to build request: {e}', file=sys.stderr)
129
+ sys.exit(1)
130
+
131
+ # Write the JSON request to the output file
132
+ try:
133
+ with open(args.output_file, 'w') as f:
134
+ json.dump(request, f, indent=2)
135
+ except IOError as e:
136
+ print(f'❌ Failed to write request file: {e}', file=sys.stderr)
137
+ sys.exit(1)
138
+
139
+
140
+ if __name__ == '__main__':
141
+ main()
@@ -0,0 +1,135 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ Parse DescribeTrainingJob JSON for the polling loop in do/train.
7
+
8
+ Reads JSON from stdin and outputs structured key=value lines for bash consumption:
9
+ STATUS=<TrainingJobStatus>
10
+ SECONDARY=<SecondaryStatus>
11
+ FAILURE_REASON=<FailureReason or empty>
12
+ DISPLAY=<formatted single-line status display>
13
+
14
+ This keeps the bash poll loop simple while handling JSON parsing in Python.
15
+ """
16
+
17
+ import json
18
+ import sys
19
+ from datetime import datetime, timezone
20
+
21
+
22
+ def format_duration(seconds):
23
+ """Format seconds into a human-readable duration string."""
24
+ if seconds is None or seconds < 0:
25
+ return 'N/A'
26
+ hours = int(seconds // 3600)
27
+ minutes = int((seconds % 3600) // 60)
28
+ secs = int(seconds % 60)
29
+ if hours > 0:
30
+ return f'{hours}h {minutes}m {secs}s'
31
+ elif minutes > 0:
32
+ return f'{minutes}m {secs}s'
33
+ else:
34
+ return f'{secs}s'
35
+
36
+
37
+ def parse_iso_time(time_str):
38
+ """Parse an ISO 8601 timestamp string to a datetime object."""
39
+ if not time_str:
40
+ return None
41
+ try:
42
+ time_str = time_str.replace('Z', '+00:00')
43
+ return datetime.fromisoformat(time_str)
44
+ except (ValueError, TypeError):
45
+ return None
46
+
47
+
48
+ def calculate_elapsed(start_time_str):
49
+ """Calculate elapsed time from start to now."""
50
+ start = parse_iso_time(start_time_str)
51
+ if not start:
52
+ return None
53
+ now = datetime.now(timezone.utc)
54
+ elapsed = (now - start).total_seconds()
55
+ return max(0, elapsed)
56
+
57
+
58
+ def format_metrics(final_metrics):
59
+ """Format FinalMetricDataList into a compact string."""
60
+ if not final_metrics:
61
+ return ''
62
+ parts = []
63
+ for metric in final_metrics:
64
+ name = metric.get('MetricName', 'unknown')
65
+ value = metric.get('Value', 0)
66
+ if isinstance(value, float):
67
+ if abs(value) < 0.001:
68
+ parts.append(f'{name}={value:.6f}')
69
+ elif abs(value) < 1:
70
+ parts.append(f'{name}={value:.4f}')
71
+ else:
72
+ parts.append(f'{name}={value:.2f}')
73
+ else:
74
+ parts.append(f'{name}={value}')
75
+ return ', '.join(parts)
76
+
77
+
78
+ # Status emoji mapping
79
+ STATUS_EMOJI = {
80
+ 'InProgress': '🔄',
81
+ 'Completed': '✅',
82
+ 'Failed': '❌',
83
+ 'Stopping': 'â¸ī¸',
84
+ 'Stopped': 'âšī¸'
85
+ }
86
+
87
+
88
+ def main():
89
+ """Parse DescribeTrainingJob JSON from stdin and output structured lines."""
90
+ try:
91
+ job_data = json.load(sys.stdin)
92
+ except json.JSONDecodeError as e:
93
+ print(f'Error parsing JSON: {e}', file=sys.stderr)
94
+ sys.exit(1)
95
+
96
+ status = job_data.get('TrainingJobStatus', 'Unknown')
97
+ secondary_status = job_data.get('SecondaryStatus', '')
98
+ failure_reason = job_data.get('FailureReason', '')
99
+ training_start = job_data.get('TrainingStartTime', '')
100
+ final_metrics = job_data.get('FinalMetricDataList', [])
101
+
102
+ # Calculate elapsed time
103
+ elapsed_str = ''
104
+ if training_start:
105
+ elapsed = calculate_elapsed(training_start)
106
+ if elapsed is not None:
107
+ elapsed_str = format_duration(elapsed)
108
+
109
+ # Format metrics
110
+ metrics_str = format_metrics(final_metrics)
111
+
112
+ # Build display line
113
+ emoji = STATUS_EMOJI.get(status, '❓')
114
+ display_parts = [f' {emoji} {status}']
115
+
116
+ if secondary_status:
117
+ display_parts.append(f'| {secondary_status}')
118
+
119
+ if elapsed_str:
120
+ display_parts.append(f'| elapsed: {elapsed_str}')
121
+
122
+ if metrics_str:
123
+ display_parts.append(f'| {metrics_str}')
124
+
125
+ display_line = ' '.join(display_parts)
126
+
127
+ # Output structured lines for bash
128
+ print(f'STATUS={status}')
129
+ print(f'SECONDARY={secondary_status}')
130
+ print(f'FAILURE_REASON={failure_reason}')
131
+ print(f'DISPLAY={display_line}')
132
+
133
+
134
+ if __name__ == '__main__':
135
+ main()
@@ -0,0 +1,187 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ Parse DescribeTrainingJob JSON response and display formatted status.
7
+
8
+ This helper is called by do/train --status to parse the AWS CLI JSON output
9
+ from DescribeTrainingJob and display a user-friendly status summary.
10
+ """
11
+
12
+ import json
13
+ import sys
14
+ import time
15
+ from datetime import datetime, timezone
16
+
17
+
18
+ # Status emoji mapping
19
+ STATUS_EMOJI = {
20
+ 'InProgress': '🔄',
21
+ 'Completed': '✅',
22
+ 'Failed': '❌',
23
+ 'Stopping': 'â¸ī¸',
24
+ 'Stopped': 'âšī¸'
25
+ }
26
+
27
+ # Secondary status descriptions
28
+ SECONDARY_DESCRIPTIONS = {
29
+ 'Starting': 'Preparing training instance',
30
+ 'LaunchingMLInstances': 'Launching ML instances',
31
+ 'PreparingTrainingStack': 'Preparing training stack',
32
+ 'Downloading': 'Downloading training data',
33
+ 'DownloadingTrainingImage': 'Downloading training image',
34
+ 'Training': 'Training in progress',
35
+ 'Uploading': 'Uploading model artifacts',
36
+ 'Completed': 'Training completed',
37
+ 'MaxRuntimeExceeded': 'Max runtime exceeded',
38
+ 'Stopped': 'Training stopped',
39
+ 'MaxWaitTimeExceeded': 'Max wait time exceeded (spot)',
40
+ 'Interrupted': 'Spot instance interrupted'
41
+ }
42
+
43
+
44
+ def format_duration(seconds):
45
+ """Format seconds into a human-readable duration string."""
46
+ if seconds is None or seconds < 0:
47
+ return 'N/A'
48
+ hours = int(seconds // 3600)
49
+ minutes = int((seconds % 3600) // 60)
50
+ secs = int(seconds % 60)
51
+ if hours > 0:
52
+ return f'{hours}h {minutes}m {secs}s'
53
+ elif minutes > 0:
54
+ return f'{minutes}m {secs}s'
55
+ else:
56
+ return f'{secs}s'
57
+
58
+
59
+ def parse_iso_time(time_str):
60
+ """Parse an ISO 8601 timestamp string to a datetime object."""
61
+ if not time_str:
62
+ return None
63
+ try:
64
+ # Handle various AWS timestamp formats
65
+ # Remove trailing 'Z' and replace with +00:00 for fromisoformat
66
+ time_str = time_str.replace('Z', '+00:00')
67
+ return datetime.fromisoformat(time_str)
68
+ except (ValueError, TypeError):
69
+ return None
70
+
71
+
72
+ def calculate_elapsed(start_time_str):
73
+ """Calculate elapsed time from start to now."""
74
+ start = parse_iso_time(start_time_str)
75
+ if not start:
76
+ return None
77
+ now = datetime.now(timezone.utc)
78
+ elapsed = (now - start).total_seconds()
79
+ return max(0, elapsed)
80
+
81
+
82
+ def display_status(job_data):
83
+ """Display formatted training job status."""
84
+ job_name = job_data.get('TrainingJobName', 'Unknown')
85
+ status = job_data.get('TrainingJobStatus', 'Unknown')
86
+ secondary_status = job_data.get('SecondaryStatus', '')
87
+ failure_reason = job_data.get('FailureReason', '')
88
+ training_start = job_data.get('TrainingStartTime', '')
89
+ training_end = job_data.get('TrainingEndTime', '')
90
+ billable_seconds = job_data.get('BillableTimeInSeconds')
91
+ training_seconds = job_data.get('TrainingTimeInSeconds')
92
+ final_metrics = job_data.get('FinalMetricDataList', [])
93
+ output_path = job_data.get('OutputDataConfig', {}).get('S3OutputPath', '')
94
+ model_artifacts = job_data.get('ModelArtifacts', {}).get('S3ModelArtifacts', '')
95
+ instance_type = job_data.get('ResourceConfig', {}).get('InstanceType', '')
96
+ instance_count = job_data.get('ResourceConfig', {}).get('InstanceCount', 1)
97
+ spot_enabled = job_data.get('EnableManagedSpotTraining', False)
98
+
99
+ emoji = STATUS_EMOJI.get(status, '❓')
100
+
101
+ print(f'')
102
+ print(f' {emoji} Status: {status}')
103
+
104
+ # Secondary status with description
105
+ if secondary_status:
106
+ desc = SECONDARY_DESCRIPTIONS.get(secondary_status, '')
107
+ if desc:
108
+ print(f' 📍 Phase: {secondary_status} ({desc})')
109
+ else:
110
+ print(f' 📍 Phase: {secondary_status}')
111
+
112
+ # Elapsed time
113
+ if status == 'InProgress' and training_start:
114
+ elapsed = calculate_elapsed(training_start)
115
+ if elapsed is not None:
116
+ print(f' âąī¸ Elapsed: {format_duration(elapsed)}')
117
+ elif training_seconds is not None:
118
+ print(f' âąī¸ Training time: {format_duration(training_seconds)}')
119
+
120
+ # Instance info
121
+ if instance_type:
122
+ instance_info = f'{instance_type}'
123
+ if instance_count and instance_count > 1:
124
+ instance_info += f' x {instance_count}'
125
+ if spot_enabled:
126
+ instance_info += ' (spot)'
127
+ print(f' đŸ–Ĩī¸ Instance: {instance_info}')
128
+
129
+ # Billable time and cost savings (for completed spot jobs)
130
+ if status == 'Completed' and spot_enabled and billable_seconds is not None and training_seconds is not None:
131
+ savings_seconds = training_seconds - billable_seconds
132
+ if training_seconds > 0:
133
+ savings_pct = (savings_seconds / training_seconds) * 100
134
+ print(f' 💰 Spot savings: {format_duration(savings_seconds)} saved ({savings_pct:.0f}% discount)')
135
+ print(f' Billable: {format_duration(billable_seconds)} / Total: {format_duration(training_seconds)}')
136
+
137
+ # Training metrics
138
+ if final_metrics:
139
+ print(f' 📈 Metrics:')
140
+ for metric in final_metrics:
141
+ name = metric.get('MetricName', 'unknown')
142
+ value = metric.get('Value', 0)
143
+ # Format value nicely
144
+ if isinstance(value, float):
145
+ if abs(value) < 0.001:
146
+ print(f' {name}: {value:.6f}')
147
+ elif abs(value) < 1:
148
+ print(f' {name}: {value:.4f}')
149
+ else:
150
+ print(f' {name}: {value:.2f}')
151
+ else:
152
+ print(f' {name}: {value}')
153
+
154
+ # Output artifacts (for completed jobs)
155
+ if status == 'Completed' and model_artifacts:
156
+ print(f' đŸ“Ļ Artifacts: {model_artifacts}')
157
+ elif status == 'Completed' and output_path:
158
+ print(f' đŸ“Ļ Output: {output_path}')
159
+
160
+ # Failure reason
161
+ if status == 'Failed' and failure_reason:
162
+ print(f' đŸ’Ĩ Reason: {failure_reason}')
163
+ print(f'')
164
+ print(f' To start a new job: ./do/train --force')
165
+
166
+ # Spot interruption guidance
167
+ if secondary_status == 'Interrupted':
168
+ print(f'')
169
+ print(f' â„šī¸ Spot instance was interrupted. The job will automatically')
170
+ print(f' resume from the last checkpoint. Re-run ./do/train to poll.')
171
+
172
+ print(f'')
173
+
174
+
175
+ def main():
176
+ """Main entry point — reads JSON from stdin."""
177
+ try:
178
+ job_data = json.load(sys.stdin)
179
+ except json.JSONDecodeError as e:
180
+ print(f'❌ Failed to parse DescribeTrainingJob response: {e}', file=sys.stderr)
181
+ sys.exit(1)
182
+
183
+ display_status(job_data)
184
+
185
+
186
+ if __name__ == '__main__':
187
+ main()
@@ -176,7 +176,7 @@ def cmd_submit(args):
176
176
  )
177
177
  elif "ValidationException" in error_msg and "license" in error_msg.lower():
178
178
  _error_exit(
179
- f"Model license not accepted. Accept the license in JumpStart before "
179
+ f"Model license not accepted. Accept the model license before "
180
180
  f"using this model for customization. Details: {error_msg}"
181
181
  )
182
182
  else:
@@ -660,7 +660,7 @@ def main():
660
660
 
661
661
  # ── submit ────────────────────────────────────────────────────────────────
662
662
  submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
663
- submit_parser.add_argument("--model-id", required=True, help="JumpStart model ID")
663
+ submit_parser.add_argument("--model-id", required=True, help="Model ID")
664
664
  submit_parser.add_argument("--technique", required=True,
665
665
  choices=["sft", "dpo", "rlaif", "rlvr"],
666
666
  help="Customization technique")
@@ -0,0 +1,41 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Shared helper: post-completion feedback loop for training and tuning jobs.
6
+ # Sourced by do/tune and do/train — prints artifact locations and deployment suggestions.
7
+
8
+ # print_completion_feedback()
9
+ # Display completion summary with artifact path and next-step deployment commands.
10
+ # Tailors suggestions based on the detected artifact type (adapter vs full model).
11
+ #
12
+ # Arguments:
13
+ # $1 - output_path: S3 URI to the output artifacts
14
+ # $2 - output_type: "adapter" or "full-model"
15
+ # $3 - job_name: Job name for reference
16
+ # $4 - model_package_arn: (optional) Model package ARN if registered
17
+ print_completion_feedback() {
18
+ local output_path="$1"
19
+ local output_type="$2"
20
+ local job_name="$3"
21
+ local model_package_arn="${4:-}"
22
+
23
+ echo ""
24
+ echo "✅ Training complete: ${job_name}"
25
+ echo ""
26
+ echo " Artifacts: ${output_path}"
27
+ if [ -n "${model_package_arn}" ]; then
28
+ echo " Model Package: ${model_package_arn}"
29
+ fi
30
+ echo ""
31
+ echo " Next steps:"
32
+
33
+ if [ "${output_type}" = "adapter" ]; then
34
+ echo " â€ĸ Deploy as LoRA adapter: ./do/adapter add my-adapter --weights ${output_path}"
35
+ echo " â€ĸ (Requires running endpoint with LoRA enabled)"
36
+ elif [ "${output_type}" = "full-model" ]; then
37
+ echo " â€ĸ Deploy as new IC: ./do/add-ic my-model --model-data ${output_path}"
38
+ echo " â€ĸ Replace current base: ./do/deploy --force-ic --model-data ${output_path}"
39
+ fi
40
+ echo ""
41
+ }
@@ -191,8 +191,14 @@ fi
191
191
  # ============================================================
192
192
 
193
193
  # DEPLOYMENT_CONFIG format: <architecture>-<backend> (e.g., transformers-vllm, http-flask, triton-fil)
194
- ARCHITECTURE="${DEPLOYMENT_CONFIG%%-*}"
195
- BACKEND="${DEPLOYMENT_CONFIG#*-}"
194
+ # Special case: marketplace has no backend
195
+ if [ "${DEPLOYMENT_CONFIG}" = "marketplace" ]; then
196
+ ARCHITECTURE="marketplace"
197
+ BACKEND=""
198
+ else
199
+ ARCHITECTURE="${DEPLOYMENT_CONFIG%%-*}"
200
+ BACKEND="${DEPLOYMENT_CONFIG#*-}"
201
+ fi
196
202
 
197
203
  echo "📋 Registering deployment to registry"
198
204
  echo " Project: ${PROJECT_NAME}"