@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
@@ -0,0 +1,300 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ <% if (modelServer === 'vllm') { %>
6
+ echo "Starting vLLM server"
7
+ <% } else if (modelServer === 'sglang') { %>
8
+ echo "Starting SGLang server"
9
+ <% } else if (modelServer === 'tensorrt-llm') { %>
10
+ echo "Starting TensorRT-LLM server"
11
+ <% } else if (modelServer === 'lmi') { %>
12
+ echo "Starting LMI (Large Model Inference) server"
13
+ <% } else if (modelServer === 'djl') { %>
14
+ echo "Starting DJL Serving server"
15
+ <% } %>
16
+
17
+ <% if (modelServer === 'lmi' || modelServer === 'djl') { %>
18
+ # LMI/DJL containers use serving.properties for configuration
19
+ # The configuration file should be at /opt/ml/model/serving.properties
20
+ # DJL Serving will automatically start with this configuration
21
+
22
+ if [ ! -f /opt/ml/model/serving.properties ]; then
23
+ echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
24
+ exit 1
25
+ fi
26
+
27
+ echo "Using configuration from /opt/ml/model/serving.properties"
28
+ cat /opt/ml/model/serving.properties
29
+
30
+ # DJL Serving is already configured in the base image
31
+ # This script is not typically needed for LMI/DJL as they have their own entrypoint
32
+ # But we provide it for consistency with other model servers
33
+ exit 0
34
+ <% } else { %>
35
+
36
+ <% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
37
+ # ---------------------------------------------------------------------------
38
+ # download_model_from_s3 — Download model artifacts from S3 to a local path
39
+ # ---------------------------------------------------------------------------
40
+ download_model_from_s3() {
41
+ local s3_uri="$1"
42
+ local dest_path="$2"
43
+ local start_time
44
+ start_time=$(date +%s)
45
+
46
+ if [ -z "$s3_uri" ] || [ -z "$dest_path" ]; then
47
+ echo "Error: download_model_from_s3 requires S3 URI and destination path" >&2
48
+ return 1
49
+ fi
50
+
51
+ echo "Downloading model from ${s3_uri} to ${dest_path}..." >&2
52
+ mkdir -p "${dest_path}"
53
+
54
+ if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
55
+ # Tarball: download and extract
56
+ if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
57
+ echo "Error: Failed to download tarball from ${s3_uri}" >&2
58
+ return 1
59
+ fi
60
+ if ! tar -xzf /tmp/model_archive.tar.gz -C "$dest_path"; then
61
+ echo "Error: Failed to extract tarball from ${s3_uri}" >&2
62
+ rm -f /tmp/model_archive.tar.gz
63
+ return 1
64
+ fi
65
+ rm -f /tmp/model_archive.tar.gz
66
+ elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
67
+ # Directory prefix: sync
68
+ if ! aws s3 sync "$s3_uri" "$dest_path"; then
69
+ echo "Error: Failed to sync from ${s3_uri}" >&2
70
+ return 1
71
+ fi
72
+ else
73
+ # Single file: copy
74
+ if ! aws s3 cp "$s3_uri" "$dest_path/"; then
75
+ echo "Error: Failed to copy ${s3_uri}" >&2
76
+ return 1
77
+ fi
78
+ fi
79
+
80
+ local duration
81
+ duration=$(( $(date +%s) - start_time ))
82
+ echo "Download complete: ${s3_uri} → ${dest_path} (${duration}s)" >&2
83
+ }
84
+ <% } %>
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Model Loading Adapter — resolve model based on MODEL_SOURCE env var
88
+ # ---------------------------------------------------------------------------
89
+ MODEL_SOURCE="${MODEL_SOURCE:-huggingface}"
90
+ MODEL_ARTIFACT_URI="${MODEL_ARTIFACT_URI:-}"
91
+ LOCAL_MODEL_PATH="/opt/ml/model"
92
+
93
+ <% if (modelServer === 'vllm') { %>
94
+ _MODEL_VAR="VLLM_MODEL"
95
+ <% } else if (modelServer === 'sglang') { %>
96
+ _MODEL_VAR="SGLANG_MODEL_PATH"
97
+ <% } else if (modelServer === 'tensorrt-llm') { %>
98
+ _MODEL_VAR="TRTLLM_MODEL"
99
+ <% } %>
100
+
101
+ resolve_model() {
102
+ case "$MODEL_SOURCE" in
103
+ huggingface)
104
+ # Pass model name directly — server fetches from HF Hub
105
+ echo "${!_MODEL_VAR}"
106
+ return
107
+ ;;
108
+ s3|jumpstart|jumpstart-hub|registry)
109
+ # Check for pre-mounted artifacts first
110
+ if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
111
+ echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
112
+ echo "$LOCAL_MODEL_PATH"
113
+ return
114
+ fi
115
+
116
+ # For registry:// models, resolve artifact URI at runtime via SageMaker API
117
+ if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
118
+ local model_uri="${!_MODEL_VAR}"
119
+ local registry_prefix="registry://"
120
+ if [[ "$model_uri" == "${registry_prefix}"* ]]; then
121
+ local registry_path="${model_uri#${registry_prefix}}"
122
+ local group_name="${registry_path%%/*}"
123
+ local version="${registry_path#*/}"
124
+ local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"
125
+
126
+ # Get account ID for ARN construction
127
+ local account_id
128
+ account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
129
+ echo "Error: Failed to get AWS account ID for model package ARN" >&2
130
+ exit 1
131
+ }
132
+
133
+ local package_arn="arn:aws:sagemaker:${region}:${account_id}:model-package/${group_name}/${version}"
134
+ echo "Resolving ${model_uri} via SageMaker DescribeModelPackage..." >&2
135
+ echo " ARN: ${package_arn}" >&2
136
+
137
+ local describe_output
138
+ describe_output=$(aws sagemaker describe-model-package \
139
+ --model-package-name "$package_arn" \
140
+ --region "$region" \
141
+ --output json 2>/dev/null) || {
142
+ echo "Error: Failed to describe model package: ${package_arn}" >&2
143
+ exit 1
144
+ }
145
+
146
+ # Try ModelDataUrl first, then S3DataSource.S3Uri, then description
147
+ MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
148
+ import sys, json, re
149
+ try:
150
+ pkg = json.load(sys.stdin)
151
+ uri = ''
152
+ # Check InferenceSpecification.Containers[0]
153
+ containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
154
+ if containers:
155
+ c = containers[0]
156
+ uri = c.get('ModelDataUrl', '')
157
+ if not uri:
158
+ uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
159
+ # Fallback: extract S3 URI from ModelPackageDescription
160
+ if not uri:
161
+ desc = pkg.get('ModelPackageDescription', '')
162
+ m = re.search(r's3://[^\s]+', desc)
163
+ if m:
164
+ uri = m.group(0)
165
+ # Fallback: check ModelCard hyperparameters for model_artifacts_s3
166
+ if not uri:
167
+ try:
168
+ card = pkg.get('ModelCard', {})
169
+ content = card.get('ModelCardContent', '{}')
170
+ card_data = json.loads(content) if isinstance(content, str) else content
171
+ params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
172
+ for p in params:
173
+ if p.get('name') == 'model_artifacts_s3':
174
+ uri = p.get('value', '')
175
+ break
176
+ except:
177
+ pass
178
+ print(uri)
179
+ except:
180
+ print('')
181
+ " 2>/dev/null)
182
+
183
+ if [ -n "$MODEL_ARTIFACT_URI" ] && [ "$MODEL_ARTIFACT_URI" != "None" ]; then
184
+ echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
185
+ else
186
+ echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
187
+ echo " Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
188
+ echo " Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
189
+ exit 1
190
+ fi
191
+ fi
192
+ fi
193
+
194
+ # Need artifact URI for download
195
+ if [ -z "$MODEL_ARTIFACT_URI" ]; then
196
+ echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
197
+ exit 1
198
+ fi
199
+ # Download from S3
200
+ if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
201
+ echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
202
+ exit 1
203
+ fi
204
+ echo "$LOCAL_MODEL_PATH"
205
+ ;;
206
+ *)
207
+ # Unrecognized source — treat as huggingface
208
+ echo "${!_MODEL_VAR}"
209
+ return
210
+ ;;
211
+ esac
212
+ }
213
+
214
+ _RESOLVED_MODEL=$(resolve_model) || exit 1
215
+ export "${_MODEL_VAR}=${_RESOLVED_MODEL}"
216
+ echo "Resolved ${_MODEL_VAR}=${_RESOLVED_MODEL} (source: ${MODEL_SOURCE})"
217
+ unset _MODEL_VAR _RESOLVED_MODEL
218
+
219
+ # Initialize server arguments
220
+ <% if (modelServer === 'tensorrt-llm') { %>
221
+ # port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
222
+ SERVER_ARGS=(--host 0.0.0.0 --port 8081)
223
+ <% } else { %>
224
+ # port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
225
+ SERVER_ARGS=(--host 0.0.0.0 --port 8080)
226
+ <% } %>
227
+
228
+ # Define the prefix for environment variables to look for
229
+ <% if (modelServer === 'vllm') { %>
230
+ PREFIX="VLLM_"
231
+ <% } else if (modelServer === 'sglang') { %>
232
+ PREFIX="SGLANG_"
233
+ <% } else if (modelServer === 'tensorrt-llm') { %>
234
+ PREFIX="TRTLLM_"
235
+ <% } %>
236
+ ARG_PREFIX="--"
237
+
238
+ # Define environment variables to exclude (internal variables set by base images)
239
+ <% if (modelServer === 'vllm') { %>
240
+ EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
241
+ <% } else if (modelServer === 'sglang') { %>
242
+ EXCLUDE_VARS=()
243
+ <% } else if (modelServer === 'tensorrt-llm') { %>
244
+ # Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
245
+ EXCLUDE_VARS=("TRTLLM_MODEL")
246
+ <% } %>
247
+
248
+ # Declare and populate array of matching environment variables
249
+ mapfile -t env_vars < <(env | grep "^${PREFIX}")
250
+
251
+ # Loop through the array and convert to command-line arguments
252
+ for var in "${env_vars[@]}"; do
253
+ IFS='=' read -r key value <<< "$var"
254
+
255
+ # Skip excluded variables
256
+ skip=false
257
+ for exclude in "${EXCLUDE_VARS[@]}"; do
258
+ if [ "$key" = "$exclude" ]; then
259
+ skip=true
260
+ break
261
+ fi
262
+ done
263
+
264
+ if [ "$skip" = true ]; then
265
+ continue
266
+ fi
267
+
268
+ # Remove prefix, convert to lowercase, and replace underscores with dashes
269
+ arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
270
+ SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
271
+ if [ -n "$value" ]; then
272
+ SERVER_ARGS+=("$value")
273
+ fi
274
+ done
275
+
276
+ echo "-------------------------------------------------------------------"
277
+ <% if (modelServer === 'vllm') { %>
278
+ echo "vLLM engine args: [${SERVER_ARGS[@]}]"
279
+ <% } else if (modelServer === 'sglang') { %>
280
+ echo "SGLang engine args: [${SERVER_ARGS[@]}]"
281
+ <% } else if (modelServer === 'tensorrt-llm') { %>
282
+ echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
283
+ <% } %>
284
+ echo "-------------------------------------------------------------------"
285
+
286
+ # Pass the collected arguments to the main entrypoint
287
+ <% if (modelServer === 'vllm') { %>
288
+ exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
289
+ <% } else if (modelServer === 'sglang') { %>
290
+ exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
291
+ <% } else if (modelServer === 'tensorrt-llm') { %>
292
+ # TensorRT-LLM requires the model as a positional argument
293
+ # Syntax: trtllm-serve serve MODEL [OPTIONS]
294
+ if [ -z "$TRTLLM_MODEL" ]; then
295
+ echo "Error: TRTLLM_MODEL environment variable is not set"
296
+ exit 1
297
+ fi
298
+ exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
299
+ <% } %>
300
+ <% } %>
@@ -0,0 +1,175 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ """
5
+ SageMaker inference server script
6
+ """
7
+ import os
8
+ import logging
9
+ <% if (modelServer === 'flask') { %>
10
+ from flask import Flask, request, jsonify
11
+ <% } else if (modelServer === 'fastapi') { %>
12
+ from fastapi import FastAPI, HTTPException, Request
13
+ from fastapi.responses import JSONResponse
14
+ <% } else if (modelServer === 'sglang') { %>
15
+ from fastapi import FastAPI, HTTPException, Request
16
+ from fastapi.responses import JSONResponse
17
+ from fastapi.responses import JSONResponse
18
+ import asyncio
19
+ from sglang import Runtime
20
+ <% } %>
21
+ from model_handler import ModelHandler
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
26
+ logger = logging.getLogger(__name__)
27
+ model_handler = None
28
+
29
+ <% if (modelServer === 'flask') { %>
30
+ app = Flask(__name__)
31
+
32
+ def create_app():
33
+ app = Flask(__name__)
34
+
35
+ @app.route('/ping', methods=['GET'])
36
+ def ping():
37
+ """Health check endpoint"""
38
+ if model_handler and model_handler.is_loaded():
39
+ return jsonify({'status': 'healthy'})
40
+ return jsonify({'status': 'model not loaded'}), 503
41
+
42
+ @app.route('/invocations', methods=['POST'])
43
+ def invocations():
44
+ """Main inference endpoint"""
45
+ if not model_handler or not model_handler.is_loaded():
46
+ return jsonify({'error': 'Model not loaded'}), 503
47
+
48
+ try:
49
+ data = request.get_json() if request.is_json else request.data
50
+ result = model_handler.predict(data)
51
+ return jsonify(result)
52
+ except ValueError as e:
53
+ return jsonify({'error': f'Invalid input: {str(e)}'}), 400
54
+ except Exception as e:
55
+ logger.exception("Error during inference")
56
+ return jsonify({'error': str(e)}), 500
57
+
58
+ return app
59
+
60
+ def load_model_for_worker():
61
+ """Load the model when the server starts"""
62
+ global model_handler
63
+ model_path = "/opt/ml/model"
64
+ logger.info(f"Loading model from {model_path}")
65
+ model_handler = ModelHandler(model_path)
66
+ model_handler.load_model()
67
+ logger.info("Model loaded successfully")
68
+ <% } else if (modelServer === 'fastapi') { %>
69
+ app = FastAPI()
70
+
71
+ @app.on_event("startup")
72
+ async def startup_event():
73
+ """Load the model when the server starts"""
74
+ global model_handler
75
+ model_path = "/opt/ml/model"
76
+ logger.info(f"Loading model from {model_path}")
77
+ model_handler = ModelHandler(model_path)
78
+ model_handler.load_model()
79
+ logger.info("Model loaded successfully")
80
+
81
+ @app.get('/ping')
82
+ async def ping():
83
+ """Health check endpoint"""
84
+ if model_handler and model_handler.is_loaded():
85
+ return {'status': 'healthy'}
86
+ raise HTTPException(status_code=503, detail={'status': 'model not loaded'})
87
+
88
+ @app.post('/invocations')
89
+ async def invocations(request: Request):
90
+ """Main inference endpoint"""
91
+ if not model_handler or not model_handler.is_loaded():
92
+ raise HTTPException(status_code=503, detail={'error': 'Model not loaded'})
93
+
94
+ try:
95
+ content_type = request.headers.get('content-type', '')
96
+ if 'application/json' in content_type:
97
+ data = await request.json()
98
+ else:
99
+ data = await request.body()
100
+
101
+ result = model_handler.predict(data)
102
+ return result
103
+ except ValueError as e:
104
+ raise HTTPException(status_code=400, detail={'error': f'Invalid input: {str(e)}'})
105
+ except Exception as e:
106
+ logger.exception("Error during inference")
107
+ raise HTTPException(status_code=500, detail={'error': str(e)})
108
+
109
+ <% } else if (modelServer === 'sglang') { %>
110
+ app = FastAPI()
111
+ sglang_runtime = None
112
+
113
+ @app.on_event("startup")
114
+ async def startup_event():
115
+ """Initialize SGLang runtime when the server starts"""
116
+ global sglang_runtime
117
+ model_id = "<%= modelName %>"
118
+ logger.info(f"Initializing SGLang runtime with model: {model_id}")
119
+
120
+ sglang_runtime = Runtime(
121
+ model_path=model_id,
122
+ tokenizer_path=model_id,
123
+ device="cuda",
124
+ mem_fraction_static=0.8
125
+ )
126
+ logger.info("SGLang runtime initialized successfully")
127
+
128
+ @app.get('/ping')
129
+ async def ping():
130
+ """Health check endpoint"""
131
+ if sglang_runtime:
132
+ return {'status': 'healthy'}
133
+ raise HTTPException(status_code=503, detail={'status': 'runtime not loaded'})
134
+
135
+ @app.post('/invocations')
136
+ async def invocations(request: Request):
137
+ """Main inference endpoint"""
138
+ if not sglang_runtime:
139
+ raise HTTPException(status_code=503, detail={'error': 'Runtime not loaded'})
140
+
141
+ try:
142
+ content_type = request.headers.get('content-type', '')
143
+ if 'application/json' in content_type:
144
+ data = await request.json()
145
+ else:
146
+ data = await request.body()
147
+
148
+ # Extract prompts from SageMaker format
149
+ if isinstance(data, dict):
150
+ prompts = data.get('instances', data.get('inputs', [data]))
151
+ else:
152
+ prompts = [data]
153
+
154
+ # Generate responses
155
+ outputs = sglang_runtime.generate(prompts)
156
+ return {'predictions': outputs}
157
+
158
+ except ValueError as e:
159
+ raise HTTPException(status_code=400, detail={'error': f'Invalid input: {str(e)}'})
160
+ except Exception as e:
161
+ logger.exception("Error during inference")
162
+ raise HTTPException(status_code=500, detail={'error': str(e)})
163
+ <% } %>
164
+
165
+ if __name__ == '__main__':
166
+ <% if (modelServer === 'flask') { %>
167
+ app = create_app()
168
+ load_model_for_worker() # Load model for development server
169
+ port = int(os.environ.get("SAGEMAKER_BIND_TO_PORT", 8080))
170
+ app.run(host='0.0.0.0', port=port, debug=False)
171
+ <% } else if (modelServer === 'fastapi' || modelServer === 'sglang') { %>
172
+ import uvicorn
173
+ port = int(os.environ.get("SAGEMAKER_BIND_TO_PORT", 8080))
174
+ uvicorn.run(app, host='0.0.0.0', port=port)
175
+ <% } %>
@@ -0,0 +1,105 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # LMI/DJL Serving Configuration
5
+ # Documentation: https://docs.djl.ai/master/docs/serving/serving/docs/lmi/deployment_guide/index.html
6
+
7
+ <% if (modelServer === 'lmi') { %>
8
+ # LMI Container Configuration
9
+ # LMI provides optimized inference with multiple backend options
10
+
11
+ # Model Configuration
12
+ <% if (modelSource === 'huggingface' || !modelSource) { %>
13
+ option.model_id=<%= modelName %>
14
+ <% } else if (artifactUri) { %>
15
+ option.model_id=<%= artifactUri %>
16
+ <% } else { %>
17
+ # Model will be loaded from /opt/ml/model at runtime
18
+ # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
19
+ # option.model_id=/opt/ml/model
20
+ <% } %>
21
+
22
+ # Backend Selection
23
+ # Options: vllm, lmi-dist (DeepSpeed), tensorrt-llm, transformers-neuronx
24
+ # Leave unset to let LMI auto-select the best backend for your model
25
+ # option.rolling_batch=vllm
26
+
27
+ # Tensor Parallel Degree
28
+ # Set to number of GPUs for multi-GPU inference
29
+ # option.tensor_parallel_degree=1
30
+
31
+ # Data Type
32
+ # Options: fp16, bf16, fp32, int8
33
+ # option.dtype=fp16
34
+
35
+ # Max Rolling Batch Size
36
+ # Maximum number of concurrent requests
37
+ # option.max_rolling_batch_size=32
38
+
39
+ <% if (hfToken && (!modelSource || modelSource === 'huggingface')) { %>
40
+ # HuggingFace Authentication
41
+ option.hf_token=<%= hfToken %>
42
+ <% } %>
43
+
44
+ <% if (chatTemplate) { %>
45
+ # Chat Template
46
+ # Custom chat template for formatting messages
47
+ option.chat_template=<%= chatTemplate %>
48
+ <% } %>
49
+
50
+ # Performance Tuning
51
+ # Uncomment and adjust based on your needs:
52
+ # option.max_model_len=4096
53
+ # option.gpu_memory_utilization=0.9
54
+ # option.enable_chunked_prefill=true
55
+
56
+ <% } else if (modelServer === 'djl') { %>
57
+ # DJL Serving Configuration
58
+ # DJL provides flexible model serving with multiple framework support
59
+
60
+ # Model Configuration
61
+ <% if (modelSource === 'huggingface' || !modelSource) { %>
62
+ option.model_id=<%= modelName %>
63
+ <% } else if (artifactUri) { %>
64
+ option.model_id=<%= artifactUri %>
65
+ <% } else { %>
66
+ # Model will be loaded from /opt/ml/model at runtime
67
+ # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
68
+ # option.model_id=/opt/ml/model
69
+ <% } %>
70
+
71
+ # Engine Selection
72
+ # Options: Python, PyTorch, TensorFlow, MXNet
73
+ engine=Python
74
+
75
+ # Batch Size
76
+ # Maximum batch size for inference
77
+ # batch_size=1
78
+
79
+ # Job Queue Size
80
+ # Maximum number of queued requests
81
+ # job_queue_size=100
82
+
83
+ <% if (hfToken && (!modelSource || modelSource === 'huggingface')) { %>
84
+ # HuggingFace Authentication
85
+ option.hf_token=<%= hfToken %>
86
+ <% } %>
87
+
88
+ <% if (chatTemplate) { %>
89
+ # Chat Template
90
+ option.chat_template=<%= chatTemplate %>
91
+ <% } %>
92
+
93
+ # GPU Configuration
94
+ # option.tensor_parallel_degree=1
95
+ # option.device_map=auto
96
+
97
+ <% } %>
98
+
99
+ # Additional Environment-Specific Configuration
100
+ <% if (orderedEnvVars && orderedEnvVars.length > 0) { %>
101
+ # Custom environment variables from configuration
102
+ <% orderedEnvVars.forEach(({ key, value }) => { %>
103
+ # <%= key %>=<%= value %>
104
+ <% }); %>
105
+ <% } %>
@@ -0,0 +1,39 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ """
5
+ Script to start the SageMaker inference server
6
+ """
7
+ import signal
8
+ import subprocess
9
+ import sys
10
+
11
+
12
+ def signal_handler(signum, frame):
13
+ """Handle shutdown signals"""
14
+ print(f"Received signal {signum}, shutting down gracefully...")
15
+ sys.exit(0)
16
+
17
+
18
+ if __name__ == '__main__':
19
+ # Register signal handlers
20
+ signal.signal(signal.SIGTERM, signal_handler)
21
+ signal.signal(signal.SIGINT, signal_handler)
22
+
23
+ <% if (modelServer === 'flask') { %>
24
+ print("Starting SageMaker inference server with Gunicorn")
25
+ subprocess.run([
26
+ 'gunicorn',
27
+ '--config', '/opt/ml/code/gunicorn_config.py',
28
+ 'wsgi:application'
29
+ ])
30
+ <% } else if (modelServer === 'fastapi' || modelServer === 'sglang') { %>
31
+ print("Starting SageMaker inference server with Uvicorn")
32
+ subprocess.run([
33
+ 'uvicorn',
34
+ 'serve:app',
35
+ '--host', '0.0.0.0',
36
+ '--port', '8080',
37
+ '--workers', '4'
38
+ ])
39
+ <% } %>
@@ -0,0 +1,39 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ set -e
6
+
7
+ echo "Starting TensorRT-LLM server on port 8081..."
8
+ /usr/bin/serve_trtllm &
9
+ TRTLLM_PID=$!
10
+
11
+ # Wait for TensorRT-LLM to be ready
12
+ echo "Waiting for TensorRT-LLM server to start..."
13
+ for i in {1..300}; do
14
+ if curl -s http://localhost:8081/health > /dev/null 2>&1; then
15
+ echo "TensorRT-LLM server is ready!"
16
+ break
17
+ fi
18
+ if [ $i -eq 300 ]; then
19
+ echo "ERROR: TensorRT-LLM server failed to start within 300 seconds"
20
+ exit 1
21
+ fi
22
+ sleep 1
23
+ done
24
+
25
+ echo "Starting nginx reverse proxy on port 8080..."
26
+ nginx -c /etc/nginx/nginx.conf &
27
+ NGINX_PID=$!
28
+
29
+ # Wait for either process to exit (this keeps the container running)
30
+ wait -n $TRTLLM_PID $NGINX_PID
31
+
32
+ # If we get here, one process exited - this is an error condition
33
+ EXIT_CODE=$?
34
+ echo "ERROR: A critical process exited unexpectedly (exit code: $EXIT_CODE)"
35
+
36
+ # Kill any remaining processes
37
+ kill $TRTLLM_PID $NGINX_PID 2>/dev/null || true
38
+
39
+ exit $EXIT_CODE