@aws/ml-container-creator 0.8.0 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/bin/cli.js +31 -137
- package/package.json +7 -2
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/src/app.js +29 -2
- package/src/lib/config-manager.js +17 -0
- package/src/lib/generated/cli-options.js +467 -0
- package/src/lib/generated/validation-rules.js +202 -0
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +16 -2
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/templates/code/serve +5 -134
- package/templates/code/serve.d/lmi.ejs +19 -0
- package/templates/code/serve.d/sglang.ejs +47 -0
- package/templates/code/serve.d/tensorrt-llm.ejs +53 -0
- package/templates/code/serve.d/vllm.ejs +48 -0
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/clean +1 -1387
- package/templates/do/clean.d/async-inference.ejs +508 -0
- package/templates/do/clean.d/batch-transform.ejs +512 -0
- package/templates/do/clean.d/hyperpod-eks.ejs +481 -0
- package/templates/do/clean.d/managed-inference.ejs +1043 -0
- package/templates/do/deploy +1 -1766
- package/templates/do/deploy.d/async-inference.ejs +501 -0
- package/templates/do/deploy.d/batch-transform.ejs +529 -0
- package/templates/do/deploy.d/hyperpod-eks.ejs +339 -0
- package/templates/do/deploy.d/managed-inference.ejs +726 -0
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Request Builder
|
|
6
|
+
*
|
|
7
|
+
* JavaScript module that replicates the Python helper's (.train_build_request.py)
|
|
8
|
+
* logic for constructing a CreateTrainingJob JSON request from a parsed config.
|
|
9
|
+
*
|
|
10
|
+
* This module mirrors the behavior of the Python build_request() function,
|
|
11
|
+
* providing a testable implementation of the config-to-API mapping logic.
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Build a CreateTrainingJob request from a parsed training config.
|
|
16
|
+
*
|
|
17
|
+
* Maps config fields to the SageMaker CreateTrainingJob API structure:
|
|
18
|
+
* - image → AlgorithmSpecification.TrainingImage
|
|
19
|
+
* - instance_type → ResourceConfig.InstanceType
|
|
20
|
+
* - instance_count → ResourceConfig.InstanceCount
|
|
21
|
+
* - dataset → InputDataConfig[0].DataSource.S3DataSource.S3Uri
|
|
22
|
+
* - output_path → OutputDataConfig.S3OutputPath
|
|
23
|
+
* - hyperparameters → HyperParameters (string key-value pairs)
|
|
24
|
+
* - max_runtime_seconds → StoppingCondition.MaxRuntimeInSeconds
|
|
25
|
+
* - enable_spot=true → EnableManagedSpotTraining = true
|
|
26
|
+
* - enable_spot=true → StoppingCondition.MaxWaitTimeInSeconds
|
|
27
|
+
* - checkpoint_path → CheckpointConfig.S3Uri
|
|
28
|
+
* - metric_definitions → AlgorithmSpecification.MetricDefinitions
|
|
29
|
+
* - environment → Environment
|
|
30
|
+
* - tags → Tags (converted from {k:v} to [{Key:k, Value:v}])
|
|
31
|
+
*
|
|
32
|
+
* @param {object} options - Build options
|
|
33
|
+
* @param {string} options.jobName - Training job name
|
|
34
|
+
* @param {string} options.roleArn - SageMaker execution role ARN
|
|
35
|
+
* @param {object} options.config - Parsed training config (from parseTrainingConfig)
|
|
36
|
+
* @returns {object} CreateTrainingJob request body
|
|
37
|
+
*/
|
|
38
|
+
export function buildTrainingJobRequest({ jobName, roleArn, config }) {
|
|
39
|
+
const request = {
|
|
40
|
+
TrainingJobName: jobName,
|
|
41
|
+
RoleArn: roleArn,
|
|
42
|
+
AlgorithmSpecification: {
|
|
43
|
+
TrainingImage: config.image,
|
|
44
|
+
TrainingInputMode: 'File'
|
|
45
|
+
},
|
|
46
|
+
InputDataConfig: [
|
|
47
|
+
{
|
|
48
|
+
ChannelName: 'training',
|
|
49
|
+
DataSource: {
|
|
50
|
+
S3DataSource: {
|
|
51
|
+
S3DataType: 'S3Prefix',
|
|
52
|
+
S3Uri: config.dataset,
|
|
53
|
+
S3DataDistributionType: 'FullyReplicated'
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
],
|
|
58
|
+
OutputDataConfig: {
|
|
59
|
+
S3OutputPath: config.output_path
|
|
60
|
+
},
|
|
61
|
+
ResourceConfig: {
|
|
62
|
+
InstanceType: config.instance_type,
|
|
63
|
+
InstanceCount: parseInt(config.instance_count, 10),
|
|
64
|
+
VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
|
|
65
|
+
},
|
|
66
|
+
StoppingCondition: {
|
|
67
|
+
MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
|
|
68
|
+
}
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
// Hyperparameters — ensure all values are strings (SageMaker requirement)
|
|
72
|
+
const hyperparams = config.hyperparameters || {};
|
|
73
|
+
if (Object.keys(hyperparams).length > 0) {
|
|
74
|
+
request.HyperParameters = {};
|
|
75
|
+
for (const [k, v] of Object.entries(hyperparams)) {
|
|
76
|
+
request.HyperParameters[String(k)] = String(v);
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
// Managed spot training
|
|
81
|
+
const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
|
|
82
|
+
if (enableSpot) {
|
|
83
|
+
request.EnableManagedSpotTraining = true;
|
|
84
|
+
request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// Checkpoint configuration (for spot training resumption)
|
|
88
|
+
const checkpointPath = config.checkpoint_path || '';
|
|
89
|
+
if (checkpointPath) {
|
|
90
|
+
request.CheckpointConfig = {
|
|
91
|
+
S3Uri: checkpointPath
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
// Metric definitions (custom CloudWatch metrics)
|
|
96
|
+
const metricDefs = config.metric_definitions || [];
|
|
97
|
+
if (Array.isArray(metricDefs) && metricDefs.length > 0) {
|
|
98
|
+
request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
|
|
99
|
+
Name: m.name,
|
|
100
|
+
Regex: m.regex
|
|
101
|
+
}));
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// Environment variables for the container
|
|
105
|
+
const environment = config.environment || {};
|
|
106
|
+
if (Object.keys(environment).length > 0) {
|
|
107
|
+
request.Environment = environment;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Tags — convert from {key: value} map to [{Key: k, Value: v}] array
|
|
111
|
+
const tags = config.tags || {};
|
|
112
|
+
if (Object.keys(tags).length > 0) {
|
|
113
|
+
request.Tags = Object.entries(tags).map(([k, v]) => ({
|
|
114
|
+
Key: String(k),
|
|
115
|
+
Value: String(v)
|
|
116
|
+
}));
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
return request;
|
|
120
|
+
}
|
package/templates/code/serve
CHANGED
|
@@ -10,35 +10,10 @@ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
|
|
|
10
10
|
# CUDA compatibility setup (required for newer SageMaker inference AMIs)
|
|
11
11
|
source /usr/bin/cuda_compat.sh 2>/dev/null || true
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
echo "Starting vLLM server"
|
|
15
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
16
|
-
echo "Starting SGLang server"
|
|
17
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
18
|
-
echo "Starting TensorRT-LLM server"
|
|
19
|
-
<% } else if (modelServer === 'lmi') { %>
|
|
20
|
-
echo "Starting LMI (Large Model Inference) server"
|
|
21
|
-
<% } else if (modelServer === 'djl') { %>
|
|
22
|
-
echo "Starting DJL Serving server"
|
|
23
|
-
<% } %>
|
|
13
|
+
echo "Starting <%= modelServer %> server"
|
|
24
14
|
|
|
25
15
|
<% if (modelServer === 'lmi' || modelServer === 'djl') { %>
|
|
26
|
-
|
|
27
|
-
# The configuration file should be at /opt/ml/model/serving.properties
|
|
28
|
-
# DJL Serving will automatically start with this configuration
|
|
29
|
-
|
|
30
|
-
if [ ! -f /opt/ml/model/serving.properties ]; then
|
|
31
|
-
echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
|
|
32
|
-
exit 1
|
|
33
|
-
fi
|
|
34
|
-
|
|
35
|
-
echo "Using configuration from /opt/ml/model/serving.properties"
|
|
36
|
-
cat /opt/ml/model/serving.properties
|
|
37
|
-
|
|
38
|
-
# DJL Serving is already configured in the base image
|
|
39
|
-
# This script is not typically needed for LMI/DJL as they have their own entrypoint
|
|
40
|
-
# But we provide it for consistency with other model servers
|
|
41
|
-
exit 0
|
|
16
|
+
<%- include('serve.d/lmi') %>
|
|
42
17
|
<% } else { %>
|
|
43
18
|
|
|
44
19
|
<% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
|
|
@@ -60,7 +35,6 @@ download_model_from_s3() {
|
|
|
60
35
|
mkdir -p "${dest_path}"
|
|
61
36
|
|
|
62
37
|
if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
|
|
63
|
-
# Tarball: download and extract
|
|
64
38
|
if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
|
|
65
39
|
echo "Error: Failed to download tarball from ${s3_uri}" >&2
|
|
66
40
|
return 1
|
|
@@ -72,13 +46,11 @@ download_model_from_s3() {
|
|
|
72
46
|
fi
|
|
73
47
|
rm -f /tmp/model_archive.tar.gz
|
|
74
48
|
elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
|
|
75
|
-
# Directory prefix: sync
|
|
76
49
|
if ! aws s3 sync "$s3_uri" "$dest_path"; then
|
|
77
50
|
echo "Error: Failed to sync from ${s3_uri}" >&2
|
|
78
51
|
return 1
|
|
79
52
|
fi
|
|
80
53
|
else
|
|
81
|
-
# Single file: copy
|
|
82
54
|
if ! aws s3 cp "$s3_uri" "$dest_path/"; then
|
|
83
55
|
echo "Error: Failed to copy ${s3_uri}" >&2
|
|
84
56
|
return 1
|
|
@@ -109,19 +81,16 @@ _MODEL_VAR="TRTLLM_MODEL"
|
|
|
109
81
|
resolve_model() {
|
|
110
82
|
case "$MODEL_SOURCE" in
|
|
111
83
|
huggingface)
|
|
112
|
-
# Pass model name directly — server fetches from HF Hub
|
|
113
84
|
echo "${!_MODEL_VAR}"
|
|
114
85
|
return
|
|
115
86
|
;;
|
|
116
87
|
s3|registry)
|
|
117
|
-
# Check for pre-mounted artifacts first
|
|
118
88
|
if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
|
|
119
89
|
echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
|
|
120
90
|
echo "$LOCAL_MODEL_PATH"
|
|
121
91
|
return
|
|
122
92
|
fi
|
|
123
93
|
|
|
124
|
-
# For registry:// models, resolve artifact URI at runtime via SageMaker API
|
|
125
94
|
if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
|
|
126
95
|
local model_uri="${!_MODEL_VAR}"
|
|
127
96
|
local registry_prefix="registry://"
|
|
@@ -131,7 +100,6 @@ resolve_model() {
|
|
|
131
100
|
local version="${registry_path#*/}"
|
|
132
101
|
local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"
|
|
133
102
|
|
|
134
|
-
# Get account ID for ARN construction
|
|
135
103
|
local account_id
|
|
136
104
|
account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
|
|
137
105
|
echo "Error: Failed to get AWS account ID for model package ARN" >&2
|
|
@@ -151,38 +119,22 @@ resolve_model() {
|
|
|
151
119
|
exit 1
|
|
152
120
|
}
|
|
153
121
|
|
|
154
|
-
# Try ModelDataUrl first, then S3DataSource.S3Uri, then description
|
|
155
122
|
MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
|
|
156
123
|
import sys, json, re
|
|
157
124
|
try:
|
|
158
125
|
pkg = json.load(sys.stdin)
|
|
159
126
|
uri = ''
|
|
160
|
-
# Check InferenceSpecification.Containers[0]
|
|
161
127
|
containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
|
|
162
128
|
if containers:
|
|
163
129
|
c = containers[0]
|
|
164
130
|
uri = c.get('ModelDataUrl', '')
|
|
165
131
|
if not uri:
|
|
166
132
|
uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
|
|
167
|
-
# Fallback: extract S3 URI from ModelPackageDescription
|
|
168
133
|
if not uri:
|
|
169
134
|
desc = pkg.get('ModelPackageDescription', '')
|
|
170
135
|
m = re.search(r's3://[^\s]+', desc)
|
|
171
136
|
if m:
|
|
172
137
|
uri = m.group(0)
|
|
173
|
-
# Fallback: check ModelCard hyperparameters for model_artifacts_s3
|
|
174
|
-
if not uri:
|
|
175
|
-
try:
|
|
176
|
-
card = pkg.get('ModelCard', {})
|
|
177
|
-
content = card.get('ModelCardContent', '{}')
|
|
178
|
-
card_data = json.loads(content) if isinstance(content, str) else content
|
|
179
|
-
params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
|
|
180
|
-
for p in params:
|
|
181
|
-
if p.get('name') == 'model_artifacts_s3':
|
|
182
|
-
uri = p.get('value', '')
|
|
183
|
-
break
|
|
184
|
-
except:
|
|
185
|
-
pass
|
|
186
138
|
print(uri)
|
|
187
139
|
except:
|
|
188
140
|
print('')
|
|
@@ -192,19 +144,15 @@ except:
|
|
|
192
144
|
echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
|
|
193
145
|
else
|
|
194
146
|
echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
|
|
195
|
-
echo " Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
|
|
196
|
-
echo " Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
|
|
197
147
|
exit 1
|
|
198
148
|
fi
|
|
199
149
|
fi
|
|
200
150
|
fi
|
|
201
151
|
|
|
202
|
-
# Need artifact URI for download
|
|
203
152
|
if [ -z "$MODEL_ARTIFACT_URI" ]; then
|
|
204
153
|
echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
|
|
205
154
|
exit 1
|
|
206
155
|
fi
|
|
207
|
-
# Download from S3
|
|
208
156
|
if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
|
|
209
157
|
echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
|
|
210
158
|
exit 1
|
|
@@ -212,7 +160,6 @@ except:
|
|
|
212
160
|
echo "$LOCAL_MODEL_PATH"
|
|
213
161
|
;;
|
|
214
162
|
*)
|
|
215
|
-
# Unrecognized source — treat as huggingface
|
|
216
163
|
echo "${!_MODEL_VAR}"
|
|
217
164
|
return
|
|
218
165
|
;;
|
|
@@ -226,89 +173,13 @@ unset _MODEL_VAR _RESOLVED_MODEL
|
|
|
226
173
|
|
|
227
174
|
# Initialize server arguments
|
|
228
175
|
<% if (modelServer === 'tensorrt-llm') { %>
|
|
229
|
-
# port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
|
|
230
176
|
SERVER_ARGS=(--host 0.0.0.0 --port 8081)
|
|
231
177
|
<% } else { %>
|
|
232
|
-
# port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
|
|
233
178
|
SERVER_ARGS=(--host 0.0.0.0 --port 8080)
|
|
234
179
|
<% } %>
|
|
235
180
|
|
|
236
|
-
#
|
|
237
|
-
<% if (
|
|
238
|
-
|
|
239
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
240
|
-
PREFIX="SGLANG_"
|
|
241
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
242
|
-
PREFIX="TRTLLM_"
|
|
243
|
-
<% } %>
|
|
244
|
-
ARG_PREFIX="--"
|
|
245
|
-
|
|
246
|
-
# Define environment variables to exclude (internal variables set by base images)
|
|
247
|
-
<% if (modelServer === 'vllm') { %>
|
|
248
|
-
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
|
|
249
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
250
|
-
EXCLUDE_VARS=()
|
|
251
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
252
|
-
# Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
|
|
253
|
-
EXCLUDE_VARS=("TRTLLM_MODEL")
|
|
254
|
-
<% } %>
|
|
255
|
-
|
|
256
|
-
# Declare and populate array of matching environment variables
|
|
257
|
-
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|
|
258
|
-
|
|
259
|
-
# Loop through the array and convert to command-line arguments
|
|
260
|
-
for var in "${env_vars[@]}"; do
|
|
261
|
-
IFS='=' read -r key value <<< "$var"
|
|
262
|
-
|
|
263
|
-
# Skip excluded variables
|
|
264
|
-
skip=false
|
|
265
|
-
for exclude in "${EXCLUDE_VARS[@]}"; do
|
|
266
|
-
if [ "$key" = "$exclude" ]; then
|
|
267
|
-
skip=true
|
|
268
|
-
break
|
|
269
|
-
fi
|
|
270
|
-
done
|
|
271
|
-
|
|
272
|
-
if [ "$skip" = true ]; then
|
|
273
|
-
continue
|
|
274
|
-
fi
|
|
275
|
-
|
|
276
|
-
# Remove prefix, convert to lowercase, and replace underscores with dashes
|
|
277
|
-
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
|
278
|
-
|
|
279
|
-
# Boolean handling: true = flag only, false = skip entirely
|
|
280
|
-
if [ "$value" = "false" ]; then
|
|
281
|
-
continue
|
|
282
|
-
fi
|
|
283
|
-
|
|
284
|
-
SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
|
|
285
|
-
if [ -n "$value" ] && [ "$value" != "true" ]; then
|
|
286
|
-
SERVER_ARGS+=("$value")
|
|
287
|
-
fi
|
|
288
|
-
done
|
|
289
|
-
|
|
290
|
-
echo "-------------------------------------------------------------------"
|
|
291
|
-
<% if (modelServer === 'vllm') { %>
|
|
292
|
-
echo "vLLM engine args: [${SERVER_ARGS[@]}]"
|
|
293
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
294
|
-
echo "SGLang engine args: [${SERVER_ARGS[@]}]"
|
|
295
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
296
|
-
echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
|
|
297
|
-
<% } %>
|
|
298
|
-
echo "-------------------------------------------------------------------"
|
|
299
|
-
|
|
300
|
-
# Pass the collected arguments to the main entrypoint
|
|
301
|
-
<% if (modelServer === 'vllm') { %>
|
|
302
|
-
exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
|
|
303
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
304
|
-
exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
|
|
305
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
306
|
-
# TensorRT-LLM requires the model as a positional argument
|
|
307
|
-
# Syntax: trtllm-serve serve MODEL [OPTIONS]
|
|
308
|
-
if [ -z "$TRTLLM_MODEL" ]; then
|
|
309
|
-
echo "Error: TRTLLM_MODEL environment variable is not set"
|
|
310
|
-
exit 1
|
|
311
|
-
fi
|
|
312
|
-
exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
|
|
181
|
+
# --- Server-specific arg conversion and exec ---
|
|
182
|
+
<% if (['vllm', 'sglang', 'tensorrt-llm'].includes(modelServer)) { %>
|
|
183
|
+
<%- include('serve.d/' + modelServer) %>
|
|
313
184
|
<% } %>
|
|
314
185
|
<% } %>
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------
|
|
2
|
+
# LMI / DJL Server Configuration
|
|
3
|
+
# ---------------------------------------------------------------------------
|
|
4
|
+
# Config: /opt/ml/model/serving.properties
|
|
5
|
+
# Entrypoint: DJL Serving (built into base image)
|
|
6
|
+
# Port: 8080 (configured in serving.properties)
|
|
7
|
+
# ---------------------------------------------------------------------------
|
|
8
|
+
|
|
9
|
+
# LMI/DJL containers use serving.properties for configuration
|
|
10
|
+
if [ ! -f /opt/ml/model/serving.properties ]; then
|
|
11
|
+
echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
|
|
12
|
+
exit 1
|
|
13
|
+
fi
|
|
14
|
+
|
|
15
|
+
echo "Using configuration from /opt/ml/model/serving.properties"
|
|
16
|
+
cat /opt/ml/model/serving.properties
|
|
17
|
+
|
|
18
|
+
# DJL Serving is already configured in the base image entrypoint
|
|
19
|
+
exit 0
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------
|
|
2
|
+
# SGLang Server Configuration
|
|
3
|
+
# ---------------------------------------------------------------------------
|
|
4
|
+
# Env prefix: SGLANG_
|
|
5
|
+
# Entrypoint: python3 -m sglang.launch_server
|
|
6
|
+
# Port: 8080 (SageMaker requirement)
|
|
7
|
+
# ---------------------------------------------------------------------------
|
|
8
|
+
|
|
9
|
+
PREFIX="SGLANG_"
|
|
10
|
+
ARG_PREFIX="--"
|
|
11
|
+
|
|
12
|
+
EXCLUDE_VARS=()
|
|
13
|
+
|
|
14
|
+
# Declare and populate array of matching environment variables
|
|
15
|
+
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|
|
16
|
+
|
|
17
|
+
# Convert SGLANG_ env vars to CLI arguments
|
|
18
|
+
for var in "${env_vars[@]}"; do
|
|
19
|
+
IFS='=' read -r key value <<< "$var"
|
|
20
|
+
|
|
21
|
+
# Skip excluded variables
|
|
22
|
+
skip=false
|
|
23
|
+
for exclude in "${EXCLUDE_VARS[@]}"; do
|
|
24
|
+
if [ "$key" = "$exclude" ]; then
|
|
25
|
+
skip=true
|
|
26
|
+
break
|
|
27
|
+
fi
|
|
28
|
+
done
|
|
29
|
+
if [ "$skip" = true ]; then continue; fi
|
|
30
|
+
|
|
31
|
+
# Remove prefix, convert to lowercase, replace underscores with dashes
|
|
32
|
+
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
|
33
|
+
|
|
34
|
+
# Boolean handling: true = flag only, false = skip entirely
|
|
35
|
+
if [ "$value" = "false" ]; then continue; fi
|
|
36
|
+
|
|
37
|
+
SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
|
|
38
|
+
if [ -n "$value" ] && [ "$value" != "true" ]; then
|
|
39
|
+
SERVER_ARGS+=("$value")
|
|
40
|
+
fi
|
|
41
|
+
done
|
|
42
|
+
|
|
43
|
+
echo "-------------------------------------------------------------------"
|
|
44
|
+
echo "SGLang engine args: [${SERVER_ARGS[@]}]"
|
|
45
|
+
echo "-------------------------------------------------------------------"
|
|
46
|
+
|
|
47
|
+
exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------
|
|
2
|
+
# TensorRT-LLM Server Configuration
|
|
3
|
+
# ---------------------------------------------------------------------------
|
|
4
|
+
# Env prefix: TRTLLM_
|
|
5
|
+
# Entrypoint: trtllm-serve serve MODEL [OPTIONS]
|
|
6
|
+
# Port: 8081 (nginx proxies to 8080 for SageMaker)
|
|
7
|
+
# ---------------------------------------------------------------------------
|
|
8
|
+
|
|
9
|
+
PREFIX="TRTLLM_"
|
|
10
|
+
ARG_PREFIX="--"
|
|
11
|
+
|
|
12
|
+
# TRTLLM_MODEL is used as the positional argument, not a --flag
|
|
13
|
+
EXCLUDE_VARS=("TRTLLM_MODEL")
|
|
14
|
+
|
|
15
|
+
# Declare and populate array of matching environment variables
|
|
16
|
+
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|
|
17
|
+
|
|
18
|
+
# Convert TRTLLM_ env vars to CLI arguments
|
|
19
|
+
for var in "${env_vars[@]}"; do
|
|
20
|
+
IFS='=' read -r key value <<< "$var"
|
|
21
|
+
|
|
22
|
+
# Skip excluded variables
|
|
23
|
+
skip=false
|
|
24
|
+
for exclude in "${EXCLUDE_VARS[@]}"; do
|
|
25
|
+
if [ "$key" = "$exclude" ]; then
|
|
26
|
+
skip=true
|
|
27
|
+
break
|
|
28
|
+
fi
|
|
29
|
+
done
|
|
30
|
+
if [ "$skip" = true ]; then continue; fi
|
|
31
|
+
|
|
32
|
+
# Remove prefix, convert to lowercase, replace underscores with dashes
|
|
33
|
+
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
|
34
|
+
|
|
35
|
+
# Boolean handling: true = flag only, false = skip entirely
|
|
36
|
+
if [ "$value" = "false" ]; then continue; fi
|
|
37
|
+
|
|
38
|
+
SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
|
|
39
|
+
if [ -n "$value" ] && [ "$value" != "true" ]; then
|
|
40
|
+
SERVER_ARGS+=("$value")
|
|
41
|
+
fi
|
|
42
|
+
done
|
|
43
|
+
|
|
44
|
+
echo "-------------------------------------------------------------------"
|
|
45
|
+
echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
|
|
46
|
+
echo "-------------------------------------------------------------------"
|
|
47
|
+
|
|
48
|
+
# TensorRT-LLM requires the model as a positional argument
|
|
49
|
+
if [ -z "$TRTLLM_MODEL" ]; then
|
|
50
|
+
echo "Error: TRTLLM_MODEL environment variable is not set"
|
|
51
|
+
exit 1
|
|
52
|
+
fi
|
|
53
|
+
exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------
|
|
2
|
+
# vLLM Server Configuration
|
|
3
|
+
# ---------------------------------------------------------------------------
|
|
4
|
+
# Env prefix: VLLM_
|
|
5
|
+
# Entrypoint: python3 -m vllm.entrypoints.openai.api_server
|
|
6
|
+
# Port: 8080 (SageMaker requirement)
|
|
7
|
+
# ---------------------------------------------------------------------------
|
|
8
|
+
|
|
9
|
+
PREFIX="VLLM_"
|
|
10
|
+
ARG_PREFIX="--"
|
|
11
|
+
|
|
12
|
+
# Internal variables set by the base image — not CLI args
|
|
13
|
+
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
|
|
14
|
+
|
|
15
|
+
# Declare and populate array of matching environment variables
|
|
16
|
+
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|
|
17
|
+
|
|
18
|
+
# Convert VLLM_ env vars to CLI arguments
|
|
19
|
+
for var in "${env_vars[@]}"; do
|
|
20
|
+
IFS='=' read -r key value <<< "$var"
|
|
21
|
+
|
|
22
|
+
# Skip excluded variables
|
|
23
|
+
skip=false
|
|
24
|
+
for exclude in "${EXCLUDE_VARS[@]}"; do
|
|
25
|
+
if [ "$key" = "$exclude" ]; then
|
|
26
|
+
skip=true
|
|
27
|
+
break
|
|
28
|
+
fi
|
|
29
|
+
done
|
|
30
|
+
if [ "$skip" = true ]; then continue; fi
|
|
31
|
+
|
|
32
|
+
# Remove prefix, convert to lowercase, replace underscores with dashes
|
|
33
|
+
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
|
34
|
+
|
|
35
|
+
# Boolean handling: true = flag only, false = skip entirely
|
|
36
|
+
if [ "$value" = "false" ]; then continue; fi
|
|
37
|
+
|
|
38
|
+
SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
|
|
39
|
+
if [ -n "$value" ] && [ "$value" != "true" ]; then
|
|
40
|
+
SERVER_ARGS+=("$value")
|
|
41
|
+
fi
|
|
42
|
+
done
|
|
43
|
+
|
|
44
|
+
echo "-------------------------------------------------------------------"
|
|
45
|
+
echo "vLLM engine args: [${SERVER_ARGS[@]}]"
|
|
46
|
+
echo "-------------------------------------------------------------------"
|
|
47
|
+
|
|
48
|
+
exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Build the CreateTrainingJob JSON request for SageMaker.
|
|
7
|
+
|
|
8
|
+
This helper is called by do/train to construct the full API request body.
|
|
9
|
+
It handles conditional fields (spot training, metric definitions, environment,
|
|
10
|
+
tags) and writes the result to a JSON file for use with:
|
|
11
|
+
aws sagemaker create-training-job --cli-input-json file://path.json
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def parse_args():
|
|
20
|
+
"""Parse command-line arguments."""
|
|
21
|
+
parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
|
|
22
|
+
parser.add_argument('--job-name', required=True, help='Training job name')
|
|
23
|
+
parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
|
|
24
|
+
parser.add_argument('--image', required=True, help='Training container image URI')
|
|
25
|
+
parser.add_argument('--instance-type', required=True, help='Instance type')
|
|
26
|
+
parser.add_argument('--instance-count', required=True, help='Instance count')
|
|
27
|
+
parser.add_argument('--volume-size', required=True, help='Volume size in GB')
|
|
28
|
+
parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
|
|
29
|
+
parser.add_argument('--output-path', required=True, help='S3 URI for output')
|
|
30
|
+
parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
|
|
31
|
+
parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
|
|
32
|
+
parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
|
|
33
|
+
parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
|
|
34
|
+
parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
|
|
35
|
+
parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
|
|
36
|
+
parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
|
|
37
|
+
parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
|
|
38
|
+
parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
|
|
39
|
+
return parser.parse_args()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def build_request(args):
|
|
43
|
+
"""Construct the CreateTrainingJob request dictionary."""
|
|
44
|
+
# Parse JSON inputs
|
|
45
|
+
hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
|
|
46
|
+
metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
|
|
47
|
+
environment = json.loads(args.environment) if args.environment else {}
|
|
48
|
+
tags = json.loads(args.tags) if args.tags else {}
|
|
49
|
+
|
|
50
|
+
# Base request structure
|
|
51
|
+
request = {
|
|
52
|
+
'TrainingJobName': args.job_name,
|
|
53
|
+
'RoleArn': args.role_arn,
|
|
54
|
+
'AlgorithmSpecification': {
|
|
55
|
+
'TrainingImage': args.image,
|
|
56
|
+
'TrainingInputMode': 'File'
|
|
57
|
+
},
|
|
58
|
+
'InputDataConfig': [
|
|
59
|
+
{
|
|
60
|
+
'ChannelName': 'training',
|
|
61
|
+
'DataSource': {
|
|
62
|
+
'S3DataSource': {
|
|
63
|
+
'S3DataType': 'S3Prefix',
|
|
64
|
+
'S3Uri': args.dataset,
|
|
65
|
+
'S3DataDistributionType': 'FullyReplicated'
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
],
|
|
70
|
+
'OutputDataConfig': {
|
|
71
|
+
'S3OutputPath': args.output_path
|
|
72
|
+
},
|
|
73
|
+
'ResourceConfig': {
|
|
74
|
+
'InstanceType': args.instance_type,
|
|
75
|
+
'InstanceCount': int(args.instance_count),
|
|
76
|
+
'VolumeSizeInGB': int(args.volume_size)
|
|
77
|
+
},
|
|
78
|
+
'StoppingCondition': {
|
|
79
|
+
'MaxRuntimeInSeconds': int(args.max_runtime)
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# Hyperparameters — ensure all values are strings (SageMaker requirement)
|
|
84
|
+
if hyperparams:
|
|
85
|
+
request['HyperParameters'] = {
|
|
86
|
+
str(k): str(v) for k, v in hyperparams.items()
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Managed spot training
|
|
90
|
+
if args.enable_spot == 'true':
|
|
91
|
+
request['EnableManagedSpotTraining'] = True
|
|
92
|
+
request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
|
|
93
|
+
|
|
94
|
+
# Checkpoint configuration (for spot training resumption)
|
|
95
|
+
if args.checkpoint_path:
|
|
96
|
+
request['CheckpointConfig'] = {
|
|
97
|
+
'S3Uri': args.checkpoint_path
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Metric definitions (custom CloudWatch metrics)
|
|
101
|
+
if metric_definitions and metric_definitions != []:
|
|
102
|
+
request['AlgorithmSpecification']['MetricDefinitions'] = [
|
|
103
|
+
{'Name': m['name'], 'Regex': m['regex']}
|
|
104
|
+
for m in metric_definitions
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# Environment variables for the container
|
|
108
|
+
if environment and environment != {}:
|
|
109
|
+
request['Environment'] = environment
|
|
110
|
+
|
|
111
|
+
# Tags — convert from {key: value} map to [{Key: k, Value: v}] array
|
|
112
|
+
if tags and tags != {}:
|
|
113
|
+
request['Tags'] = [
|
|
114
|
+
{'Key': str(k), 'Value': str(v)}
|
|
115
|
+
for k, v in tags.items()
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
return request
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def main():
|
|
122
|
+
"""Main entry point."""
|
|
123
|
+
args = parse_args()
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
request = build_request(args)
|
|
127
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
128
|
+
print(f'❌ Failed to build request: {e}', file=sys.stderr)
|
|
129
|
+
sys.exit(1)
|
|
130
|
+
|
|
131
|
+
# Write the JSON request to the output file
|
|
132
|
+
try:
|
|
133
|
+
with open(args.output_file, 'w') as f:
|
|
134
|
+
json.dump(request, f, indent=2)
|
|
135
|
+
except IOError as e:
|
|
136
|
+
print(f'❌ Failed to write request file: {e}', file=sys.stderr)
|
|
137
|
+
sys.exit(1)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == '__main__':
|
|
141
|
+
main()
|