@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
<% if (modelServer === 'vllm') { %>
|
|
6
|
+
echo "Starting vLLM server"
|
|
7
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
8
|
+
echo "Starting SGLang server"
|
|
9
|
+
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
10
|
+
echo "Starting TensorRT-LLM server"
|
|
11
|
+
<% } else if (modelServer === 'lmi') { %>
|
|
12
|
+
echo "Starting LMI (Large Model Inference) server"
|
|
13
|
+
<% } else if (modelServer === 'djl') { %>
|
|
14
|
+
echo "Starting DJL Serving server"
|
|
15
|
+
<% } %>
|
|
16
|
+
|
|
17
|
+
<% if (modelServer === 'lmi' || modelServer === 'djl') { %>
|
|
18
|
+
# LMI/DJL containers use serving.properties for configuration
|
|
19
|
+
# The configuration file should be at /opt/ml/model/serving.properties
|
|
20
|
+
# DJL Serving will automatically start with this configuration
|
|
21
|
+
|
|
22
|
+
if [ ! -f /opt/ml/model/serving.properties ]; then
|
|
23
|
+
echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
|
|
24
|
+
exit 1
|
|
25
|
+
fi
|
|
26
|
+
|
|
27
|
+
echo "Using configuration from /opt/ml/model/serving.properties"
|
|
28
|
+
cat /opt/ml/model/serving.properties
|
|
29
|
+
|
|
30
|
+
# DJL Serving is already configured in the base image
|
|
31
|
+
# This script is not typically needed for LMI/DJL as they have their own entrypoint
|
|
32
|
+
# But we provide it for consistency with other model servers
|
|
33
|
+
exit 0
|
|
34
|
+
<% } else { %>
|
|
35
|
+
|
|
36
|
+
<% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
# download_model_from_s3 — Download model artifacts from S3 to a local path
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
download_model_from_s3() {
|
|
41
|
+
local s3_uri="$1"
|
|
42
|
+
local dest_path="$2"
|
|
43
|
+
local start_time
|
|
44
|
+
start_time=$(date +%s)
|
|
45
|
+
|
|
46
|
+
if [ -z "$s3_uri" ] || [ -z "$dest_path" ]; then
|
|
47
|
+
echo "Error: download_model_from_s3 requires S3 URI and destination path" >&2
|
|
48
|
+
return 1
|
|
49
|
+
fi
|
|
50
|
+
|
|
51
|
+
echo "Downloading model from ${s3_uri} to ${dest_path}..." >&2
|
|
52
|
+
mkdir -p "${dest_path}"
|
|
53
|
+
|
|
54
|
+
if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
|
|
55
|
+
# Tarball: download and extract
|
|
56
|
+
if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
|
|
57
|
+
echo "Error: Failed to download tarball from ${s3_uri}" >&2
|
|
58
|
+
return 1
|
|
59
|
+
fi
|
|
60
|
+
if ! tar -xzf /tmp/model_archive.tar.gz -C "$dest_path"; then
|
|
61
|
+
echo "Error: Failed to extract tarball from ${s3_uri}" >&2
|
|
62
|
+
rm -f /tmp/model_archive.tar.gz
|
|
63
|
+
return 1
|
|
64
|
+
fi
|
|
65
|
+
rm -f /tmp/model_archive.tar.gz
|
|
66
|
+
elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
|
|
67
|
+
# Directory prefix: sync
|
|
68
|
+
if ! aws s3 sync "$s3_uri" "$dest_path"; then
|
|
69
|
+
echo "Error: Failed to sync from ${s3_uri}" >&2
|
|
70
|
+
return 1
|
|
71
|
+
fi
|
|
72
|
+
else
|
|
73
|
+
# Single file: copy
|
|
74
|
+
if ! aws s3 cp "$s3_uri" "$dest_path/"; then
|
|
75
|
+
echo "Error: Failed to copy ${s3_uri}" >&2
|
|
76
|
+
return 1
|
|
77
|
+
fi
|
|
78
|
+
fi
|
|
79
|
+
|
|
80
|
+
local duration
|
|
81
|
+
duration=$(( $(date +%s) - start_time ))
|
|
82
|
+
echo "Download complete: ${s3_uri} → ${dest_path} (${duration}s)" >&2
|
|
83
|
+
}
|
|
84
|
+
<% } %>
|
|
85
|
+
|
|
86
|
+
# ---------------------------------------------------------------------------
|
|
87
|
+
# Model Loading Adapter — resolve model based on MODEL_SOURCE env var
|
|
88
|
+
# ---------------------------------------------------------------------------
|
|
89
|
+
MODEL_SOURCE="${MODEL_SOURCE:-huggingface}"
|
|
90
|
+
MODEL_ARTIFACT_URI="${MODEL_ARTIFACT_URI:-}"
|
|
91
|
+
LOCAL_MODEL_PATH="/opt/ml/model"
|
|
92
|
+
|
|
93
|
+
<% if (modelServer === 'vllm') { %>
|
|
94
|
+
_MODEL_VAR="VLLM_MODEL"
|
|
95
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
96
|
+
_MODEL_VAR="SGLANG_MODEL_PATH"
|
|
97
|
+
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
98
|
+
_MODEL_VAR="TRTLLM_MODEL"
|
|
99
|
+
<% } %>
|
|
100
|
+
|
|
101
|
+
resolve_model() {
|
|
102
|
+
case "$MODEL_SOURCE" in
|
|
103
|
+
huggingface)
|
|
104
|
+
# Pass model name directly — server fetches from HF Hub
|
|
105
|
+
echo "${!_MODEL_VAR}"
|
|
106
|
+
return
|
|
107
|
+
;;
|
|
108
|
+
s3|jumpstart|jumpstart-hub|registry)
|
|
109
|
+
# Check for pre-mounted artifacts first
|
|
110
|
+
if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
|
|
111
|
+
echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
|
|
112
|
+
echo "$LOCAL_MODEL_PATH"
|
|
113
|
+
return
|
|
114
|
+
fi
|
|
115
|
+
|
|
116
|
+
# For registry:// models, resolve artifact URI at runtime via SageMaker API
|
|
117
|
+
if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
|
|
118
|
+
local model_uri="${!_MODEL_VAR}"
|
|
119
|
+
local registry_prefix="registry://"
|
|
120
|
+
if [[ "$model_uri" == "${registry_prefix}"* ]]; then
|
|
121
|
+
local registry_path="${model_uri#${registry_prefix}}"
|
|
122
|
+
local group_name="${registry_path%%/*}"
|
|
123
|
+
local version="${registry_path#*/}"
|
|
124
|
+
local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"
|
|
125
|
+
|
|
126
|
+
# Get account ID for ARN construction
|
|
127
|
+
local account_id
|
|
128
|
+
account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
|
|
129
|
+
echo "Error: Failed to get AWS account ID for model package ARN" >&2
|
|
130
|
+
exit 1
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
local package_arn="arn:aws:sagemaker:${region}:${account_id}:model-package/${group_name}/${version}"
|
|
134
|
+
echo "Resolving ${model_uri} via SageMaker DescribeModelPackage..." >&2
|
|
135
|
+
echo " ARN: ${package_arn}" >&2
|
|
136
|
+
|
|
137
|
+
local describe_output
|
|
138
|
+
describe_output=$(aws sagemaker describe-model-package \
|
|
139
|
+
--model-package-name "$package_arn" \
|
|
140
|
+
--region "$region" \
|
|
141
|
+
--output json 2>/dev/null) || {
|
|
142
|
+
echo "Error: Failed to describe model package: ${package_arn}" >&2
|
|
143
|
+
exit 1
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Try ModelDataUrl first, then S3DataSource.S3Uri, then description
|
|
147
|
+
MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
|
|
148
|
+
import sys, json, re
|
|
149
|
+
try:
|
|
150
|
+
pkg = json.load(sys.stdin)
|
|
151
|
+
uri = ''
|
|
152
|
+
# Check InferenceSpecification.Containers[0]
|
|
153
|
+
containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
|
|
154
|
+
if containers:
|
|
155
|
+
c = containers[0]
|
|
156
|
+
uri = c.get('ModelDataUrl', '')
|
|
157
|
+
if not uri:
|
|
158
|
+
uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
|
|
159
|
+
# Fallback: extract S3 URI from ModelPackageDescription
|
|
160
|
+
if not uri:
|
|
161
|
+
desc = pkg.get('ModelPackageDescription', '')
|
|
162
|
+
m = re.search(r's3://[^\s]+', desc)
|
|
163
|
+
if m:
|
|
164
|
+
uri = m.group(0)
|
|
165
|
+
# Fallback: check ModelCard hyperparameters for model_artifacts_s3
|
|
166
|
+
if not uri:
|
|
167
|
+
try:
|
|
168
|
+
card = pkg.get('ModelCard', {})
|
|
169
|
+
content = card.get('ModelCardContent', '{}')
|
|
170
|
+
card_data = json.loads(content) if isinstance(content, str) else content
|
|
171
|
+
params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
|
|
172
|
+
for p in params:
|
|
173
|
+
if p.get('name') == 'model_artifacts_s3':
|
|
174
|
+
uri = p.get('value', '')
|
|
175
|
+
break
|
|
176
|
+
except:
|
|
177
|
+
pass
|
|
178
|
+
print(uri)
|
|
179
|
+
except:
|
|
180
|
+
print('')
|
|
181
|
+
" 2>/dev/null)
|
|
182
|
+
|
|
183
|
+
if [ -n "$MODEL_ARTIFACT_URI" ] && [ "$MODEL_ARTIFACT_URI" != "None" ]; then
|
|
184
|
+
echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
|
|
185
|
+
else
|
|
186
|
+
echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
|
|
187
|
+
echo " Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
|
|
188
|
+
echo " Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
|
|
189
|
+
exit 1
|
|
190
|
+
fi
|
|
191
|
+
fi
|
|
192
|
+
fi
|
|
193
|
+
|
|
194
|
+
# Need artifact URI for download
|
|
195
|
+
if [ -z "$MODEL_ARTIFACT_URI" ]; then
|
|
196
|
+
echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
|
|
197
|
+
exit 1
|
|
198
|
+
fi
|
|
199
|
+
# Download from S3
|
|
200
|
+
if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
|
|
201
|
+
echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
|
|
202
|
+
exit 1
|
|
203
|
+
fi
|
|
204
|
+
echo "$LOCAL_MODEL_PATH"
|
|
205
|
+
;;
|
|
206
|
+
*)
|
|
207
|
+
# Unrecognized source — treat as huggingface
|
|
208
|
+
echo "${!_MODEL_VAR}"
|
|
209
|
+
return
|
|
210
|
+
;;
|
|
211
|
+
esac
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
_RESOLVED_MODEL=$(resolve_model) || exit 1
|
|
215
|
+
export "${_MODEL_VAR}=${_RESOLVED_MODEL}"
|
|
216
|
+
echo "Resolved ${_MODEL_VAR}=${_RESOLVED_MODEL} (source: ${MODEL_SOURCE})"
|
|
217
|
+
unset _MODEL_VAR _RESOLVED_MODEL
|
|
218
|
+
|
|
219
|
+
# Initialize server arguments
|
|
220
|
+
<% if (modelServer === 'tensorrt-llm') { %>
|
|
221
|
+
# port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
|
|
222
|
+
SERVER_ARGS=(--host 0.0.0.0 --port 8081)
|
|
223
|
+
<% } else { %>
|
|
224
|
+
# port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
|
|
225
|
+
SERVER_ARGS=(--host 0.0.0.0 --port 8080)
|
|
226
|
+
<% } %>
|
|
227
|
+
|
|
228
|
+
# Define the prefix for environment variables to look for
|
|
229
|
+
<% if (modelServer === 'vllm') { %>
|
|
230
|
+
PREFIX="VLLM_"
|
|
231
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
232
|
+
PREFIX="SGLANG_"
|
|
233
|
+
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
234
|
+
PREFIX="TRTLLM_"
|
|
235
|
+
<% } %>
|
|
236
|
+
ARG_PREFIX="--"
|
|
237
|
+
|
|
238
|
+
# Define environment variables to exclude (internal variables set by base images)
|
|
239
|
+
<% if (modelServer === 'vllm') { %>
|
|
240
|
+
EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
|
|
241
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
242
|
+
EXCLUDE_VARS=()
|
|
243
|
+
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
244
|
+
# Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
|
|
245
|
+
EXCLUDE_VARS=("TRTLLM_MODEL")
|
|
246
|
+
<% } %>
|
|
247
|
+
|
|
248
|
+
# Declare and populate array of matching environment variables
|
|
249
|
+
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|
|
250
|
+
|
|
251
|
+
# Loop through the array and convert to command-line arguments
|
|
252
|
+
for var in "${env_vars[@]}"; do
|
|
253
|
+
IFS='=' read -r key value <<< "$var"
|
|
254
|
+
|
|
255
|
+
# Skip excluded variables
|
|
256
|
+
skip=false
|
|
257
|
+
for exclude in "${EXCLUDE_VARS[@]}"; do
|
|
258
|
+
if [ "$key" = "$exclude" ]; then
|
|
259
|
+
skip=true
|
|
260
|
+
break
|
|
261
|
+
fi
|
|
262
|
+
done
|
|
263
|
+
|
|
264
|
+
if [ "$skip" = true ]; then
|
|
265
|
+
continue
|
|
266
|
+
fi
|
|
267
|
+
|
|
268
|
+
# Remove prefix, convert to lowercase, and replace underscores with dashes
|
|
269
|
+
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
|
270
|
+
SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
|
|
271
|
+
if [ -n "$value" ]; then
|
|
272
|
+
SERVER_ARGS+=("$value")
|
|
273
|
+
fi
|
|
274
|
+
done
|
|
275
|
+
|
|
276
|
+
echo "-------------------------------------------------------------------"
|
|
277
|
+
<% if (modelServer === 'vllm') { %>
|
|
278
|
+
echo "vLLM engine args: [${SERVER_ARGS[@]}]"
|
|
279
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
280
|
+
echo "SGLang engine args: [${SERVER_ARGS[@]}]"
|
|
281
|
+
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
282
|
+
echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
|
|
283
|
+
<% } %>
|
|
284
|
+
echo "-------------------------------------------------------------------"
|
|
285
|
+
|
|
286
|
+
# Pass the collected arguments to the main entrypoint
|
|
287
|
+
<% if (modelServer === 'vllm') { %>
|
|
288
|
+
exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
|
|
289
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
290
|
+
exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
|
|
291
|
+
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
292
|
+
# TensorRT-LLM requires the model as a positional argument
|
|
293
|
+
# Syntax: trtllm-serve serve MODEL [OPTIONS]
|
|
294
|
+
if [ -z "$TRTLLM_MODEL" ]; then
|
|
295
|
+
echo "Error: TRTLLM_MODEL environment variable is not set"
|
|
296
|
+
exit 1
|
|
297
|
+
fi
|
|
298
|
+
exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
|
|
299
|
+
<% } %>
|
|
300
|
+
<% } %>
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
"""
|
|
5
|
+
SageMaker inference server script
|
|
6
|
+
"""
|
|
7
|
+
import os
|
|
8
|
+
import logging
|
|
9
|
+
<% if (modelServer === 'flask') { %>
|
|
10
|
+
from flask import Flask, request, jsonify
|
|
11
|
+
<% } else if (modelServer === 'fastapi') { %>
|
|
12
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
13
|
+
from fastapi.responses import JSONResponse
|
|
14
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
15
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
16
|
+
from fastapi.responses import JSONResponse
|
|
17
|
+
from fastapi.responses import JSONResponse
|
|
18
|
+
import asyncio
|
|
19
|
+
from sglang import Runtime
|
|
20
|
+
<% } %>
|
|
21
|
+
from model_handler import ModelHandler
|
|
22
|
+
|
|
23
|
+
# Configure logging
|
|
24
|
+
logging.basicConfig(level=logging.INFO,
|
|
25
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
model_handler = None
|
|
28
|
+
|
|
29
|
+
<% if (modelServer === 'flask') { %>
|
|
30
|
+
app = Flask(__name__)
|
|
31
|
+
|
|
32
|
+
def create_app():
|
|
33
|
+
app = Flask(__name__)
|
|
34
|
+
|
|
35
|
+
@app.route('/ping', methods=['GET'])
|
|
36
|
+
def ping():
|
|
37
|
+
"""Health check endpoint"""
|
|
38
|
+
if model_handler and model_handler.is_loaded():
|
|
39
|
+
return jsonify({'status': 'healthy'})
|
|
40
|
+
return jsonify({'status': 'model not loaded'}), 503
|
|
41
|
+
|
|
42
|
+
@app.route('/invocations', methods=['POST'])
|
|
43
|
+
def invocations():
|
|
44
|
+
"""Main inference endpoint"""
|
|
45
|
+
if not model_handler or not model_handler.is_loaded():
|
|
46
|
+
return jsonify({'error': 'Model not loaded'}), 503
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
data = request.get_json() if request.is_json else request.data
|
|
50
|
+
result = model_handler.predict(data)
|
|
51
|
+
return jsonify(result)
|
|
52
|
+
except ValueError as e:
|
|
53
|
+
return jsonify({'error': f'Invalid input: {str(e)}'}), 400
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logger.exception("Error during inference")
|
|
56
|
+
return jsonify({'error': str(e)}), 500
|
|
57
|
+
|
|
58
|
+
return app
|
|
59
|
+
|
|
60
|
+
def load_model_for_worker():
|
|
61
|
+
"""Load the model when the server starts"""
|
|
62
|
+
global model_handler
|
|
63
|
+
model_path = "/opt/ml/model"
|
|
64
|
+
logger.info(f"Loading model from {model_path}")
|
|
65
|
+
model_handler = ModelHandler(model_path)
|
|
66
|
+
model_handler.load_model()
|
|
67
|
+
logger.info("Model loaded successfully")
|
|
68
|
+
<% } else if (modelServer === 'fastapi') { %>
|
|
69
|
+
app = FastAPI()
|
|
70
|
+
|
|
71
|
+
@app.on_event("startup")
|
|
72
|
+
async def startup_event():
|
|
73
|
+
"""Load the model when the server starts"""
|
|
74
|
+
global model_handler
|
|
75
|
+
model_path = "/opt/ml/model"
|
|
76
|
+
logger.info(f"Loading model from {model_path}")
|
|
77
|
+
model_handler = ModelHandler(model_path)
|
|
78
|
+
model_handler.load_model()
|
|
79
|
+
logger.info("Model loaded successfully")
|
|
80
|
+
|
|
81
|
+
@app.get('/ping')
|
|
82
|
+
async def ping():
|
|
83
|
+
"""Health check endpoint"""
|
|
84
|
+
if model_handler and model_handler.is_loaded():
|
|
85
|
+
return {'status': 'healthy'}
|
|
86
|
+
raise HTTPException(status_code=503, detail={'status': 'model not loaded'})
|
|
87
|
+
|
|
88
|
+
@app.post('/invocations')
|
|
89
|
+
async def invocations(request: Request):
|
|
90
|
+
"""Main inference endpoint"""
|
|
91
|
+
if not model_handler or not model_handler.is_loaded():
|
|
92
|
+
raise HTTPException(status_code=503, detail={'error': 'Model not loaded'})
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
content_type = request.headers.get('content-type', '')
|
|
96
|
+
if 'application/json' in content_type:
|
|
97
|
+
data = await request.json()
|
|
98
|
+
else:
|
|
99
|
+
data = await request.body()
|
|
100
|
+
|
|
101
|
+
result = model_handler.predict(data)
|
|
102
|
+
return result
|
|
103
|
+
except ValueError as e:
|
|
104
|
+
raise HTTPException(status_code=400, detail={'error': f'Invalid input: {str(e)}'})
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.exception("Error during inference")
|
|
107
|
+
raise HTTPException(status_code=500, detail={'error': str(e)})
|
|
108
|
+
|
|
109
|
+
<% } else if (modelServer === 'sglang') { %>
|
|
110
|
+
app = FastAPI()
|
|
111
|
+
sglang_runtime = None
|
|
112
|
+
|
|
113
|
+
@app.on_event("startup")
|
|
114
|
+
async def startup_event():
|
|
115
|
+
"""Initialize SGLang runtime when the server starts"""
|
|
116
|
+
global sglang_runtime
|
|
117
|
+
model_id = "<%= modelName %>"
|
|
118
|
+
logger.info(f"Initializing SGLang runtime with model: {model_id}")
|
|
119
|
+
|
|
120
|
+
sglang_runtime = Runtime(
|
|
121
|
+
model_path=model_id,
|
|
122
|
+
tokenizer_path=model_id,
|
|
123
|
+
device="cuda",
|
|
124
|
+
mem_fraction_static=0.8
|
|
125
|
+
)
|
|
126
|
+
logger.info("SGLang runtime initialized successfully")
|
|
127
|
+
|
|
128
|
+
@app.get('/ping')
|
|
129
|
+
async def ping():
|
|
130
|
+
"""Health check endpoint"""
|
|
131
|
+
if sglang_runtime:
|
|
132
|
+
return {'status': 'healthy'}
|
|
133
|
+
raise HTTPException(status_code=503, detail={'status': 'runtime not loaded'})
|
|
134
|
+
|
|
135
|
+
@app.post('/invocations')
|
|
136
|
+
async def invocations(request: Request):
|
|
137
|
+
"""Main inference endpoint"""
|
|
138
|
+
if not sglang_runtime:
|
|
139
|
+
raise HTTPException(status_code=503, detail={'error': 'Runtime not loaded'})
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
content_type = request.headers.get('content-type', '')
|
|
143
|
+
if 'application/json' in content_type:
|
|
144
|
+
data = await request.json()
|
|
145
|
+
else:
|
|
146
|
+
data = await request.body()
|
|
147
|
+
|
|
148
|
+
# Extract prompts from SageMaker format
|
|
149
|
+
if isinstance(data, dict):
|
|
150
|
+
prompts = data.get('instances', data.get('inputs', [data]))
|
|
151
|
+
else:
|
|
152
|
+
prompts = [data]
|
|
153
|
+
|
|
154
|
+
# Generate responses
|
|
155
|
+
outputs = sglang_runtime.generate(prompts)
|
|
156
|
+
return {'predictions': outputs}
|
|
157
|
+
|
|
158
|
+
except ValueError as e:
|
|
159
|
+
raise HTTPException(status_code=400, detail={'error': f'Invalid input: {str(e)}'})
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.exception("Error during inference")
|
|
162
|
+
raise HTTPException(status_code=500, detail={'error': str(e)})
|
|
163
|
+
<% } %>
|
|
164
|
+
|
|
165
|
+
if __name__ == '__main__':
|
|
166
|
+
<% if (modelServer === 'flask') { %>
|
|
167
|
+
app = create_app()
|
|
168
|
+
load_model_for_worker() # Load model for development server
|
|
169
|
+
port = int(os.environ.get("SAGEMAKER_BIND_TO_PORT", 8080))
|
|
170
|
+
app.run(host='0.0.0.0', port=port, debug=False)
|
|
171
|
+
<% } else if (modelServer === 'fastapi' || modelServer === 'sglang') { %>
|
|
172
|
+
import uvicorn
|
|
173
|
+
port = int(os.environ.get("SAGEMAKER_BIND_TO_PORT", 8080))
|
|
174
|
+
uvicorn.run(app, host='0.0.0.0', port=port)
|
|
175
|
+
<% } %>
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
# LMI/DJL Serving Configuration
|
|
5
|
+
# Documentation: https://docs.djl.ai/master/docs/serving/serving/docs/lmi/deployment_guide/index.html
|
|
6
|
+
|
|
7
|
+
<% if (modelServer === 'lmi') { %>
|
|
8
|
+
# LMI Container Configuration
|
|
9
|
+
# LMI provides optimized inference with multiple backend options
|
|
10
|
+
|
|
11
|
+
# Model Configuration
|
|
12
|
+
<% if (modelSource === 'huggingface' || !modelSource) { %>
|
|
13
|
+
option.model_id=<%= modelName %>
|
|
14
|
+
<% } else if (artifactUri) { %>
|
|
15
|
+
option.model_id=<%= artifactUri %>
|
|
16
|
+
<% } else { %>
|
|
17
|
+
# Model will be loaded from /opt/ml/model at runtime
|
|
18
|
+
# (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
|
|
19
|
+
# option.model_id=/opt/ml/model
|
|
20
|
+
<% } %>
|
|
21
|
+
|
|
22
|
+
# Backend Selection
|
|
23
|
+
# Options: vllm, lmi-dist (DeepSpeed), tensorrt-llm, transformers-neuronx
|
|
24
|
+
# Leave unset to let LMI auto-select the best backend for your model
|
|
25
|
+
# option.rolling_batch=vllm
|
|
26
|
+
|
|
27
|
+
# Tensor Parallel Degree
|
|
28
|
+
# Set to number of GPUs for multi-GPU inference
|
|
29
|
+
# option.tensor_parallel_degree=1
|
|
30
|
+
|
|
31
|
+
# Data Type
|
|
32
|
+
# Options: fp16, bf16, fp32, int8
|
|
33
|
+
# option.dtype=fp16
|
|
34
|
+
|
|
35
|
+
# Max Rolling Batch Size
|
|
36
|
+
# Maximum number of concurrent requests
|
|
37
|
+
# option.max_rolling_batch_size=32
|
|
38
|
+
|
|
39
|
+
<% if (hfToken && (!modelSource || modelSource === 'huggingface')) { %>
|
|
40
|
+
# HuggingFace Authentication
|
|
41
|
+
option.hf_token=<%= hfToken %>
|
|
42
|
+
<% } %>
|
|
43
|
+
|
|
44
|
+
<% if (chatTemplate) { %>
|
|
45
|
+
# Chat Template
|
|
46
|
+
# Custom chat template for formatting messages
|
|
47
|
+
option.chat_template=<%= chatTemplate %>
|
|
48
|
+
<% } %>
|
|
49
|
+
|
|
50
|
+
# Performance Tuning
|
|
51
|
+
# Uncomment and adjust based on your needs:
|
|
52
|
+
# option.max_model_len=4096
|
|
53
|
+
# option.gpu_memory_utilization=0.9
|
|
54
|
+
# option.enable_chunked_prefill=true
|
|
55
|
+
|
|
56
|
+
<% } else if (modelServer === 'djl') { %>
|
|
57
|
+
# DJL Serving Configuration
|
|
58
|
+
# DJL provides flexible model serving with multiple framework support
|
|
59
|
+
|
|
60
|
+
# Model Configuration
|
|
61
|
+
<% if (modelSource === 'huggingface' || !modelSource) { %>
|
|
62
|
+
option.model_id=<%= modelName %>
|
|
63
|
+
<% } else if (artifactUri) { %>
|
|
64
|
+
option.model_id=<%= artifactUri %>
|
|
65
|
+
<% } else { %>
|
|
66
|
+
# Model will be loaded from /opt/ml/model at runtime
|
|
67
|
+
# (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
|
|
68
|
+
# option.model_id=/opt/ml/model
|
|
69
|
+
<% } %>
|
|
70
|
+
|
|
71
|
+
# Engine Selection
|
|
72
|
+
# Options: Python, PyTorch, TensorFlow, MXNet
|
|
73
|
+
engine=Python
|
|
74
|
+
|
|
75
|
+
# Batch Size
|
|
76
|
+
# Maximum batch size for inference
|
|
77
|
+
# batch_size=1
|
|
78
|
+
|
|
79
|
+
# Job Queue Size
|
|
80
|
+
# Maximum number of queued requests
|
|
81
|
+
# job_queue_size=100
|
|
82
|
+
|
|
83
|
+
<% if (hfToken && (!modelSource || modelSource === 'huggingface')) { %>
|
|
84
|
+
# HuggingFace Authentication
|
|
85
|
+
option.hf_token=<%= hfToken %>
|
|
86
|
+
<% } %>
|
|
87
|
+
|
|
88
|
+
<% if (chatTemplate) { %>
|
|
89
|
+
# Chat Template
|
|
90
|
+
option.chat_template=<%= chatTemplate %>
|
|
91
|
+
<% } %>
|
|
92
|
+
|
|
93
|
+
# GPU Configuration
|
|
94
|
+
# option.tensor_parallel_degree=1
|
|
95
|
+
# option.device_map=auto
|
|
96
|
+
|
|
97
|
+
<% } %>
|
|
98
|
+
|
|
99
|
+
# Additional Environment-Specific Configuration
|
|
100
|
+
<% if (orderedEnvVars && orderedEnvVars.length > 0) { %>
|
|
101
|
+
# Custom environment variables from configuration
|
|
102
|
+
<% orderedEnvVars.forEach(({ key, value }) => { %>
|
|
103
|
+
# <%= key %>=<%= value %>
|
|
104
|
+
<% }); %>
|
|
105
|
+
<% } %>
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
"""
|
|
5
|
+
Script to start the SageMaker inference server
|
|
6
|
+
"""
|
|
7
|
+
import signal
|
|
8
|
+
import subprocess
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def signal_handler(signum, frame):
|
|
13
|
+
"""Handle shutdown signals"""
|
|
14
|
+
print(f"Received signal {signum}, shutting down gracefully...")
|
|
15
|
+
sys.exit(0)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if __name__ == '__main__':
|
|
19
|
+
# Register signal handlers
|
|
20
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
21
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
22
|
+
|
|
23
|
+
<% if (modelServer === 'flask') { %>
|
|
24
|
+
print("Starting SageMaker inference server with Gunicorn")
|
|
25
|
+
subprocess.run([
|
|
26
|
+
'gunicorn',
|
|
27
|
+
'--config', '/opt/ml/code/gunicorn_config.py',
|
|
28
|
+
'wsgi:application'
|
|
29
|
+
])
|
|
30
|
+
<% } else if (modelServer === 'fastapi' || modelServer === 'sglang') { %>
|
|
31
|
+
print("Starting SageMaker inference server with Uvicorn")
|
|
32
|
+
subprocess.run([
|
|
33
|
+
'uvicorn',
|
|
34
|
+
'serve:app',
|
|
35
|
+
'--host', '0.0.0.0',
|
|
36
|
+
'--port', '8080',
|
|
37
|
+
'--workers', '4'
|
|
38
|
+
])
|
|
39
|
+
<% } %>
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
set -e
|
|
6
|
+
|
|
7
|
+
echo "Starting TensorRT-LLM server on port 8081..."
|
|
8
|
+
/usr/bin/serve_trtllm &
|
|
9
|
+
TRTLLM_PID=$!
|
|
10
|
+
|
|
11
|
+
# Wait for TensorRT-LLM to be ready
|
|
12
|
+
echo "Waiting for TensorRT-LLM server to start..."
|
|
13
|
+
for i in {1..300}; do
|
|
14
|
+
if curl -s http://localhost:8081/health > /dev/null 2>&1; then
|
|
15
|
+
echo "TensorRT-LLM server is ready!"
|
|
16
|
+
break
|
|
17
|
+
fi
|
|
18
|
+
if [ $i -eq 300 ]; then
|
|
19
|
+
echo "ERROR: TensorRT-LLM server failed to start within 300 seconds"
|
|
20
|
+
exit 1
|
|
21
|
+
fi
|
|
22
|
+
sleep 1
|
|
23
|
+
done
|
|
24
|
+
|
|
25
|
+
echo "Starting nginx reverse proxy on port 8080..."
|
|
26
|
+
nginx -c /etc/nginx/nginx.conf &
|
|
27
|
+
NGINX_PID=$!
|
|
28
|
+
|
|
29
|
+
# Wait for either process to exit (this keeps the container running)
|
|
30
|
+
wait -n $TRTLLM_PID $NGINX_PID
|
|
31
|
+
|
|
32
|
+
# If we get here, one process exited - this is an error condition
|
|
33
|
+
EXIT_CODE=$?
|
|
34
|
+
echo "ERROR: A critical process exited unexpectedly (exit code: $EXIT_CODE)"
|
|
35
|
+
|
|
36
|
+
# Kill any remaining processes
|
|
37
|
+
kill $TRTLLM_PID $NGINX_PID 2>/dev/null || true
|
|
38
|
+
|
|
39
|
+
exit $EXIT_CODE
|