@aws/ml-container-creator 0.8.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/LICENSE-THIRD-PARTY +50760 -16218
  2. package/bin/cli.js +31 -137
  3. package/package.json +7 -2
  4. package/servers/lib/catalogs/instances.json +52 -1275
  5. package/servers/lib/catalogs/models.json +0 -132
  6. package/servers/lib/catalogs/popular-diffusors.json +1 -110
  7. package/src/app.js +29 -2
  8. package/src/lib/config-manager.js +17 -0
  9. package/src/lib/generated/cli-options.js +467 -0
  10. package/src/lib/generated/validation-rules.js +202 -0
  11. package/src/lib/mcp-client.js +16 -1
  12. package/src/lib/mcp-command-handler.js +10 -2
  13. package/src/lib/prompt-runner.js +16 -2
  14. package/src/lib/train-config-parser.js +136 -0
  15. package/src/lib/train-config-persistence.js +143 -0
  16. package/src/lib/train-config-validator.js +112 -0
  17. package/src/lib/train-feedback.js +46 -0
  18. package/src/lib/train-idempotency.js +97 -0
  19. package/src/lib/train-request-builder.js +120 -0
  20. package/templates/code/serve +5 -134
  21. package/templates/code/serve.d/lmi.ejs +19 -0
  22. package/templates/code/serve.d/sglang.ejs +47 -0
  23. package/templates/code/serve.d/tensorrt-llm.ejs +53 -0
  24. package/templates/code/serve.d/vllm.ejs +48 -0
  25. package/templates/do/.train_build_request.py +141 -0
  26. package/templates/do/.train_poll_parser.py +135 -0
  27. package/templates/do/.train_status_parser.py +187 -0
  28. package/templates/do/clean +1 -1387
  29. package/templates/do/clean.d/async-inference.ejs +508 -0
  30. package/templates/do/clean.d/batch-transform.ejs +512 -0
  31. package/templates/do/clean.d/hyperpod-eks.ejs +481 -0
  32. package/templates/do/clean.d/managed-inference.ejs +1043 -0
  33. package/templates/do/deploy +1 -1766
  34. package/templates/do/deploy.d/async-inference.ejs +501 -0
  35. package/templates/do/deploy.d/batch-transform.ejs +529 -0
  36. package/templates/do/deploy.d/hyperpod-eks.ejs +339 -0
  37. package/templates/do/deploy.d/managed-inference.ejs +726 -0
  38. package/templates/do/lib/feedback.sh +41 -0
  39. package/templates/do/train +786 -0
  40. package/templates/do/training/config.yaml +140 -0
  41. package/templates/do/training/train.py +463 -0
@@ -0,0 +1,120 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Train Request Builder
6
+ *
7
+ * JavaScript module that replicates the Python helper's (.train_build_request.py)
8
+ * logic for constructing a CreateTrainingJob JSON request from a parsed config.
9
+ *
10
+ * This module mirrors the behavior of the Python build_request() function,
11
+ * providing a testable implementation of the config-to-API mapping logic.
12
+ */
13
+
14
+ /**
15
+ * Build a CreateTrainingJob request from a parsed training config.
16
+ *
17
+ * Maps config fields to the SageMaker CreateTrainingJob API structure:
18
+ * - image → AlgorithmSpecification.TrainingImage
19
+ * - instance_type → ResourceConfig.InstanceType
20
+ * - instance_count → ResourceConfig.InstanceCount
21
+ * - dataset → InputDataConfig[0].DataSource.S3DataSource.S3Uri
22
+ * - output_path → OutputDataConfig.S3OutputPath
23
+ * - hyperparameters → HyperParameters (string key-value pairs)
24
+ * - max_runtime_seconds → StoppingCondition.MaxRuntimeInSeconds
25
+ * - enable_spot=true → EnableManagedSpotTraining = true
26
+ * - enable_spot=true → StoppingCondition.MaxWaitTimeInSeconds
27
+ * - checkpoint_path → CheckpointConfig.S3Uri
28
+ * - metric_definitions → AlgorithmSpecification.MetricDefinitions
29
+ * - environment → Environment
30
+ * - tags → Tags (converted from {k:v} to [{Key:k, Value:v}])
31
+ *
32
+ * @param {object} options - Build options
33
+ * @param {string} options.jobName - Training job name
34
+ * @param {string} options.roleArn - SageMaker execution role ARN
35
+ * @param {object} options.config - Parsed training config (from parseTrainingConfig)
36
+ * @returns {object} CreateTrainingJob request body
37
+ */
38
+ export function buildTrainingJobRequest({ jobName, roleArn, config }) {
39
+ const request = {
40
+ TrainingJobName: jobName,
41
+ RoleArn: roleArn,
42
+ AlgorithmSpecification: {
43
+ TrainingImage: config.image,
44
+ TrainingInputMode: 'File'
45
+ },
46
+ InputDataConfig: [
47
+ {
48
+ ChannelName: 'training',
49
+ DataSource: {
50
+ S3DataSource: {
51
+ S3DataType: 'S3Prefix',
52
+ S3Uri: config.dataset,
53
+ S3DataDistributionType: 'FullyReplicated'
54
+ }
55
+ }
56
+ }
57
+ ],
58
+ OutputDataConfig: {
59
+ S3OutputPath: config.output_path
60
+ },
61
+ ResourceConfig: {
62
+ InstanceType: config.instance_type,
63
+ InstanceCount: parseInt(config.instance_count, 10),
64
+ VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
65
+ },
66
+ StoppingCondition: {
67
+ MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
68
+ }
69
+ };
70
+
71
+ // Hyperparameters — ensure all values are strings (SageMaker requirement)
72
+ const hyperparams = config.hyperparameters || {};
73
+ if (Object.keys(hyperparams).length > 0) {
74
+ request.HyperParameters = {};
75
+ for (const [k, v] of Object.entries(hyperparams)) {
76
+ request.HyperParameters[String(k)] = String(v);
77
+ }
78
+ }
79
+
80
+ // Managed spot training
81
+ const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
82
+ if (enableSpot) {
83
+ request.EnableManagedSpotTraining = true;
84
+ request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
85
+ }
86
+
87
+ // Checkpoint configuration (for spot training resumption)
88
+ const checkpointPath = config.checkpoint_path || '';
89
+ if (checkpointPath) {
90
+ request.CheckpointConfig = {
91
+ S3Uri: checkpointPath
92
+ };
93
+ }
94
+
95
+ // Metric definitions (custom CloudWatch metrics)
96
+ const metricDefs = config.metric_definitions || [];
97
+ if (Array.isArray(metricDefs) && metricDefs.length > 0) {
98
+ request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
99
+ Name: m.name,
100
+ Regex: m.regex
101
+ }));
102
+ }
103
+
104
+ // Environment variables for the container
105
+ const environment = config.environment || {};
106
+ if (Object.keys(environment).length > 0) {
107
+ request.Environment = environment;
108
+ }
109
+
110
+ // Tags — convert from {key: value} map to [{Key: k, Value: v}] array
111
+ const tags = config.tags || {};
112
+ if (Object.keys(tags).length > 0) {
113
+ request.Tags = Object.entries(tags).map(([k, v]) => ({
114
+ Key: String(k),
115
+ Value: String(v)
116
+ }));
117
+ }
118
+
119
+ return request;
120
+ }
@@ -10,35 +10,10 @@ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
10
10
  # CUDA compatibility setup (required for newer SageMaker inference AMIs)
11
11
  source /usr/bin/cuda_compat.sh 2>/dev/null || true
12
12
 
13
- <% if (modelServer === 'vllm') { %>
14
- echo "Starting vLLM server"
15
- <% } else if (modelServer === 'sglang') { %>
16
- echo "Starting SGLang server"
17
- <% } else if (modelServer === 'tensorrt-llm') { %>
18
- echo "Starting TensorRT-LLM server"
19
- <% } else if (modelServer === 'lmi') { %>
20
- echo "Starting LMI (Large Model Inference) server"
21
- <% } else if (modelServer === 'djl') { %>
22
- echo "Starting DJL Serving server"
23
- <% } %>
13
+ echo "Starting <%= modelServer %> server"
24
14
 
25
15
  <% if (modelServer === 'lmi' || modelServer === 'djl') { %>
26
- # LMI/DJL containers use serving.properties for configuration
27
- # The configuration file should be at /opt/ml/model/serving.properties
28
- # DJL Serving will automatically start with this configuration
29
-
30
- if [ ! -f /opt/ml/model/serving.properties ]; then
31
- echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
32
- exit 1
33
- fi
34
-
35
- echo "Using configuration from /opt/ml/model/serving.properties"
36
- cat /opt/ml/model/serving.properties
37
-
38
- # DJL Serving is already configured in the base image
39
- # This script is not typically needed for LMI/DJL as they have their own entrypoint
40
- # But we provide it for consistency with other model servers
41
- exit 0
16
+ <%- include('serve.d/lmi') %>
42
17
  <% } else { %>
43
18
 
44
19
  <% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
@@ -60,7 +35,6 @@ download_model_from_s3() {
60
35
  mkdir -p "${dest_path}"
61
36
 
62
37
  if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
63
- # Tarball: download and extract
64
38
  if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
65
39
  echo "Error: Failed to download tarball from ${s3_uri}" >&2
66
40
  return 1
@@ -72,13 +46,11 @@ download_model_from_s3() {
72
46
  fi
73
47
  rm -f /tmp/model_archive.tar.gz
74
48
  elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
75
- # Directory prefix: sync
76
49
  if ! aws s3 sync "$s3_uri" "$dest_path"; then
77
50
  echo "Error: Failed to sync from ${s3_uri}" >&2
78
51
  return 1
79
52
  fi
80
53
  else
81
- # Single file: copy
82
54
  if ! aws s3 cp "$s3_uri" "$dest_path/"; then
83
55
  echo "Error: Failed to copy ${s3_uri}" >&2
84
56
  return 1
@@ -109,19 +81,16 @@ _MODEL_VAR="TRTLLM_MODEL"
109
81
  resolve_model() {
110
82
  case "$MODEL_SOURCE" in
111
83
  huggingface)
112
- # Pass model name directly — server fetches from HF Hub
113
84
  echo "${!_MODEL_VAR}"
114
85
  return
115
86
  ;;
116
87
  s3|registry)
117
- # Check for pre-mounted artifacts first
118
88
  if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
119
89
  echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
120
90
  echo "$LOCAL_MODEL_PATH"
121
91
  return
122
92
  fi
123
93
 
124
- # For registry:// models, resolve artifact URI at runtime via SageMaker API
125
94
  if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
126
95
  local model_uri="${!_MODEL_VAR}"
127
96
  local registry_prefix="registry://"
@@ -131,7 +100,6 @@ resolve_model() {
131
100
  local version="${registry_path#*/}"
132
101
  local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"
133
102
 
134
- # Get account ID for ARN construction
135
103
  local account_id
136
104
  account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
137
105
  echo "Error: Failed to get AWS account ID for model package ARN" >&2
@@ -151,38 +119,22 @@ resolve_model() {
151
119
  exit 1
152
120
  }
153
121
 
154
- # Try ModelDataUrl first, then S3DataSource.S3Uri, then description
155
122
  MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
156
123
  import sys, json, re
157
124
  try:
158
125
  pkg = json.load(sys.stdin)
159
126
  uri = ''
160
- # Check InferenceSpecification.Containers[0]
161
127
  containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
162
128
  if containers:
163
129
  c = containers[0]
164
130
  uri = c.get('ModelDataUrl', '')
165
131
  if not uri:
166
132
  uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
167
- # Fallback: extract S3 URI from ModelPackageDescription
168
133
  if not uri:
169
134
  desc = pkg.get('ModelPackageDescription', '')
170
135
  m = re.search(r's3://[^\s]+', desc)
171
136
  if m:
172
137
  uri = m.group(0)
173
- # Fallback: check ModelCard hyperparameters for model_artifacts_s3
174
- if not uri:
175
- try:
176
- card = pkg.get('ModelCard', {})
177
- content = card.get('ModelCardContent', '{}')
178
- card_data = json.loads(content) if isinstance(content, str) else content
179
- params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
180
- for p in params:
181
- if p.get('name') == 'model_artifacts_s3':
182
- uri = p.get('value', '')
183
- break
184
- except:
185
- pass
186
138
  print(uri)
187
139
  except:
188
140
  print('')
@@ -192,19 +144,15 @@ except:
192
144
  echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
193
145
  else
194
146
  echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
195
- echo " Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
196
- echo " Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
197
147
  exit 1
198
148
  fi
199
149
  fi
200
150
  fi
201
151
 
202
- # Need artifact URI for download
203
152
  if [ -z "$MODEL_ARTIFACT_URI" ]; then
204
153
  echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
205
154
  exit 1
206
155
  fi
207
- # Download from S3
208
156
  if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
209
157
  echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
210
158
  exit 1
@@ -212,7 +160,6 @@ except:
212
160
  echo "$LOCAL_MODEL_PATH"
213
161
  ;;
214
162
  *)
215
- # Unrecognized source — treat as huggingface
216
163
  echo "${!_MODEL_VAR}"
217
164
  return
218
165
  ;;
@@ -226,89 +173,13 @@ unset _MODEL_VAR _RESOLVED_MODEL
226
173
 
227
174
  # Initialize server arguments
228
175
  <% if (modelServer === 'tensorrt-llm') { %>
229
- # port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
230
176
  SERVER_ARGS=(--host 0.0.0.0 --port 8081)
231
177
  <% } else { %>
232
- # port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
233
178
  SERVER_ARGS=(--host 0.0.0.0 --port 8080)
234
179
  <% } %>
235
180
 
236
- # Define the prefix for environment variables to look for
237
- <% if (modelServer === 'vllm') { %>
238
- PREFIX="VLLM_"
239
- <% } else if (modelServer === 'sglang') { %>
240
- PREFIX="SGLANG_"
241
- <% } else if (modelServer === 'tensorrt-llm') { %>
242
- PREFIX="TRTLLM_"
243
- <% } %>
244
- ARG_PREFIX="--"
245
-
246
- # Define environment variables to exclude (internal variables set by base images)
247
- <% if (modelServer === 'vllm') { %>
248
- EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
249
- <% } else if (modelServer === 'sglang') { %>
250
- EXCLUDE_VARS=()
251
- <% } else if (modelServer === 'tensorrt-llm') { %>
252
- # Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
253
- EXCLUDE_VARS=("TRTLLM_MODEL")
254
- <% } %>
255
-
256
- # Declare and populate array of matching environment variables
257
- mapfile -t env_vars < <(env | grep "^${PREFIX}")
258
-
259
- # Loop through the array and convert to command-line arguments
260
- for var in "${env_vars[@]}"; do
261
- IFS='=' read -r key value <<< "$var"
262
-
263
- # Skip excluded variables
264
- skip=false
265
- for exclude in "${EXCLUDE_VARS[@]}"; do
266
- if [ "$key" = "$exclude" ]; then
267
- skip=true
268
- break
269
- fi
270
- done
271
-
272
- if [ "$skip" = true ]; then
273
- continue
274
- fi
275
-
276
- # Remove prefix, convert to lowercase, and replace underscores with dashes
277
- arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
278
-
279
- # Boolean handling: true = flag only, false = skip entirely
280
- if [ "$value" = "false" ]; then
281
- continue
282
- fi
283
-
284
- SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
285
- if [ -n "$value" ] && [ "$value" != "true" ]; then
286
- SERVER_ARGS+=("$value")
287
- fi
288
- done
289
-
290
- echo "-------------------------------------------------------------------"
291
- <% if (modelServer === 'vllm') { %>
292
- echo "vLLM engine args: [${SERVER_ARGS[@]}]"
293
- <% } else if (modelServer === 'sglang') { %>
294
- echo "SGLang engine args: [${SERVER_ARGS[@]}]"
295
- <% } else if (modelServer === 'tensorrt-llm') { %>
296
- echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
297
- <% } %>
298
- echo "-------------------------------------------------------------------"
299
-
300
- # Pass the collected arguments to the main entrypoint
301
- <% if (modelServer === 'vllm') { %>
302
- exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
303
- <% } else if (modelServer === 'sglang') { %>
304
- exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
305
- <% } else if (modelServer === 'tensorrt-llm') { %>
306
- # TensorRT-LLM requires the model as a positional argument
307
- # Syntax: trtllm-serve serve MODEL [OPTIONS]
308
- if [ -z "$TRTLLM_MODEL" ]; then
309
- echo "Error: TRTLLM_MODEL environment variable is not set"
310
- exit 1
311
- fi
312
- exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
181
+ # --- Server-specific arg conversion and exec ---
182
+ <% if (['vllm', 'sglang', 'tensorrt-llm'].includes(modelServer)) { %>
183
+ <%- include('serve.d/' + modelServer) %>
313
184
  <% } %>
314
185
  <% } %>
@@ -0,0 +1,19 @@
1
+ # ---------------------------------------------------------------------------
2
+ # LMI / DJL Server Configuration
3
+ # ---------------------------------------------------------------------------
4
+ # Config: /opt/ml/model/serving.properties
5
+ # Entrypoint: DJL Serving (built into base image)
6
+ # Port: 8080 (configured in serving.properties)
7
+ # ---------------------------------------------------------------------------
8
+
9
+ # LMI/DJL containers use serving.properties for configuration
10
+ if [ ! -f /opt/ml/model/serving.properties ]; then
11
+ echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
12
+ exit 1
13
+ fi
14
+
15
+ echo "Using configuration from /opt/ml/model/serving.properties"
16
+ cat /opt/ml/model/serving.properties
17
+
18
+ # DJL Serving is already configured in the base image entrypoint
19
+ exit 0
@@ -0,0 +1,47 @@
1
+ # ---------------------------------------------------------------------------
2
+ # SGLang Server Configuration
3
+ # ---------------------------------------------------------------------------
4
+ # Env prefix: SGLANG_
5
+ # Entrypoint: python3 -m sglang.launch_server
6
+ # Port: 8080 (SageMaker requirement)
7
+ # ---------------------------------------------------------------------------
8
+
9
+ PREFIX="SGLANG_"
10
+ ARG_PREFIX="--"
11
+
12
+ EXCLUDE_VARS=()
13
+
14
+ # Declare and populate array of matching environment variables
15
+ mapfile -t env_vars < <(env | grep "^${PREFIX}")
16
+
17
+ # Convert SGLANG_ env vars to CLI arguments
18
+ for var in "${env_vars[@]}"; do
19
+ IFS='=' read -r key value <<< "$var"
20
+
21
+ # Skip excluded variables
22
+ skip=false
23
+ for exclude in "${EXCLUDE_VARS[@]}"; do
24
+ if [ "$key" = "$exclude" ]; then
25
+ skip=true
26
+ break
27
+ fi
28
+ done
29
+ if [ "$skip" = true ]; then continue; fi
30
+
31
+ # Remove prefix, convert to lowercase, replace underscores with dashes
32
+ arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
33
+
34
+ # Boolean handling: true = flag only, false = skip entirely
35
+ if [ "$value" = "false" ]; then continue; fi
36
+
37
+ SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
38
+ if [ -n "$value" ] && [ "$value" != "true" ]; then
39
+ SERVER_ARGS+=("$value")
40
+ fi
41
+ done
42
+
43
+ echo "-------------------------------------------------------------------"
44
+ echo "SGLang engine args: [${SERVER_ARGS[@]}]"
45
+ echo "-------------------------------------------------------------------"
46
+
47
+ exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
@@ -0,0 +1,53 @@
1
+ # ---------------------------------------------------------------------------
2
+ # TensorRT-LLM Server Configuration
3
+ # ---------------------------------------------------------------------------
4
+ # Env prefix: TRTLLM_
5
+ # Entrypoint: trtllm-serve serve MODEL [OPTIONS]
6
+ # Port: 8081 (nginx proxies to 8080 for SageMaker)
7
+ # ---------------------------------------------------------------------------
8
+
9
+ PREFIX="TRTLLM_"
10
+ ARG_PREFIX="--"
11
+
12
+ # TRTLLM_MODEL is used as the positional argument, not a --flag
13
+ EXCLUDE_VARS=("TRTLLM_MODEL")
14
+
15
+ # Declare and populate array of matching environment variables
16
+ mapfile -t env_vars < <(env | grep "^${PREFIX}")
17
+
18
+ # Convert TRTLLM_ env vars to CLI arguments
19
+ for var in "${env_vars[@]}"; do
20
+ IFS='=' read -r key value <<< "$var"
21
+
22
+ # Skip excluded variables
23
+ skip=false
24
+ for exclude in "${EXCLUDE_VARS[@]}"; do
25
+ if [ "$key" = "$exclude" ]; then
26
+ skip=true
27
+ break
28
+ fi
29
+ done
30
+ if [ "$skip" = true ]; then continue; fi
31
+
32
+ # Remove prefix, convert to lowercase, replace underscores with dashes
33
+ arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
34
+
35
+ # Boolean handling: true = flag only, false = skip entirely
36
+ if [ "$value" = "false" ]; then continue; fi
37
+
38
+ SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
39
+ if [ -n "$value" ] && [ "$value" != "true" ]; then
40
+ SERVER_ARGS+=("$value")
41
+ fi
42
+ done
43
+
44
+ echo "-------------------------------------------------------------------"
45
+ echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
46
+ echo "-------------------------------------------------------------------"
47
+
48
+ # TensorRT-LLM requires the model as a positional argument
49
+ if [ -z "$TRTLLM_MODEL" ]; then
50
+ echo "Error: TRTLLM_MODEL environment variable is not set"
51
+ exit 1
52
+ fi
53
+ exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
@@ -0,0 +1,48 @@
1
+ # ---------------------------------------------------------------------------
2
+ # vLLM Server Configuration
3
+ # ---------------------------------------------------------------------------
4
+ # Env prefix: VLLM_
5
+ # Entrypoint: python3 -m vllm.entrypoints.openai.api_server
6
+ # Port: 8080 (SageMaker requirement)
7
+ # ---------------------------------------------------------------------------
8
+
9
+ PREFIX="VLLM_"
10
+ ARG_PREFIX="--"
11
+
12
+ # Internal variables set by the base image — not CLI args
13
+ EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
14
+
15
+ # Declare and populate array of matching environment variables
16
+ mapfile -t env_vars < <(env | grep "^${PREFIX}")
17
+
18
+ # Convert VLLM_ env vars to CLI arguments
19
+ for var in "${env_vars[@]}"; do
20
+ IFS='=' read -r key value <<< "$var"
21
+
22
+ # Skip excluded variables
23
+ skip=false
24
+ for exclude in "${EXCLUDE_VARS[@]}"; do
25
+ if [ "$key" = "$exclude" ]; then
26
+ skip=true
27
+ break
28
+ fi
29
+ done
30
+ if [ "$skip" = true ]; then continue; fi
31
+
32
+ # Remove prefix, convert to lowercase, replace underscores with dashes
33
+ arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
34
+
35
+ # Boolean handling: true = flag only, false = skip entirely
36
+ if [ "$value" = "false" ]; then continue; fi
37
+
38
+ SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
39
+ if [ -n "$value" ] && [ "$value" != "true" ]; then
40
+ SERVER_ARGS+=("$value")
41
+ fi
42
+ done
43
+
44
+ echo "-------------------------------------------------------------------"
45
+ echo "vLLM engine args: [${SERVER_ARGS[@]}]"
46
+ echo "-------------------------------------------------------------------"
47
+
48
+ exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
@@ -0,0 +1,141 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ Build the CreateTrainingJob JSON request for SageMaker.
7
+
8
+ This helper is called by do/train to construct the full API request body.
9
+ It handles conditional fields (spot training, metric definitions, environment,
10
+ tags) and writes the result to a JSON file for use with:
11
+ aws sagemaker create-training-job --cli-input-json file://path.json
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import sys
17
+
18
+
19
+ def parse_args():
20
+ """Parse command-line arguments."""
21
+ parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
22
+ parser.add_argument('--job-name', required=True, help='Training job name')
23
+ parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
24
+ parser.add_argument('--image', required=True, help='Training container image URI')
25
+ parser.add_argument('--instance-type', required=True, help='Instance type')
26
+ parser.add_argument('--instance-count', required=True, help='Instance count')
27
+ parser.add_argument('--volume-size', required=True, help='Volume size in GB')
28
+ parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
29
+ parser.add_argument('--output-path', required=True, help='S3 URI for output')
30
+ parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
31
+ parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
32
+ parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
33
+ parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
34
+ parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
35
+ parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
36
+ parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
37
+ parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
38
+ parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
39
+ return parser.parse_args()
40
+
41
+
42
+ def build_request(args):
43
+ """Construct the CreateTrainingJob request dictionary."""
44
+ # Parse JSON inputs
45
+ hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
46
+ metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
47
+ environment = json.loads(args.environment) if args.environment else {}
48
+ tags = json.loads(args.tags) if args.tags else {}
49
+
50
+ # Base request structure
51
+ request = {
52
+ 'TrainingJobName': args.job_name,
53
+ 'RoleArn': args.role_arn,
54
+ 'AlgorithmSpecification': {
55
+ 'TrainingImage': args.image,
56
+ 'TrainingInputMode': 'File'
57
+ },
58
+ 'InputDataConfig': [
59
+ {
60
+ 'ChannelName': 'training',
61
+ 'DataSource': {
62
+ 'S3DataSource': {
63
+ 'S3DataType': 'S3Prefix',
64
+ 'S3Uri': args.dataset,
65
+ 'S3DataDistributionType': 'FullyReplicated'
66
+ }
67
+ }
68
+ }
69
+ ],
70
+ 'OutputDataConfig': {
71
+ 'S3OutputPath': args.output_path
72
+ },
73
+ 'ResourceConfig': {
74
+ 'InstanceType': args.instance_type,
75
+ 'InstanceCount': int(args.instance_count),
76
+ 'VolumeSizeInGB': int(args.volume_size)
77
+ },
78
+ 'StoppingCondition': {
79
+ 'MaxRuntimeInSeconds': int(args.max_runtime)
80
+ }
81
+ }
82
+
83
+ # Hyperparameters — ensure all values are strings (SageMaker requirement)
84
+ if hyperparams:
85
+ request['HyperParameters'] = {
86
+ str(k): str(v) for k, v in hyperparams.items()
87
+ }
88
+
89
+ # Managed spot training
90
+ if args.enable_spot == 'true':
91
+ request['EnableManagedSpotTraining'] = True
92
+ request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
93
+
94
+ # Checkpoint configuration (for spot training resumption)
95
+ if args.checkpoint_path:
96
+ request['CheckpointConfig'] = {
97
+ 'S3Uri': args.checkpoint_path
98
+ }
99
+
100
+ # Metric definitions (custom CloudWatch metrics)
101
+ if metric_definitions and metric_definitions != []:
102
+ request['AlgorithmSpecification']['MetricDefinitions'] = [
103
+ {'Name': m['name'], 'Regex': m['regex']}
104
+ for m in metric_definitions
105
+ ]
106
+
107
+ # Environment variables for the container
108
+ if environment and environment != {}:
109
+ request['Environment'] = environment
110
+
111
+ # Tags — convert from {key: value} map to [{Key: k, Value: v}] array
112
+ if tags and tags != {}:
113
+ request['Tags'] = [
114
+ {'Key': str(k), 'Value': str(v)}
115
+ for k, v in tags.items()
116
+ ]
117
+
118
+ return request
119
+
120
+
121
+ def main():
122
+ """Main entry point."""
123
+ args = parse_args()
124
+
125
+ try:
126
+ request = build_request(args)
127
+ except (json.JSONDecodeError, ValueError) as e:
128
+ print(f'❌ Failed to build request: {e}', file=sys.stderr)
129
+ sys.exit(1)
130
+
131
+ # Write the JSON request to the output file
132
+ try:
133
+ with open(args.output_file, 'w') as f:
134
+ json.dump(request, f, indent=2)
135
+ except IOError as e:
136
+ print(f'❌ Failed to write request file: {e}', file=sys.stderr)
137
+ sys.exit(1)
138
+
139
+
140
+ if __name__ == '__main__':
141
+ main()