@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
@@ -0,0 +1,252 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ import ssl
6
+ import numpy as np
7
+
8
+ # Handle SSL certificate issues
9
+ try:
10
+ import certifi
11
+ ssl._create_default_https_context = ssl._create_unverified_context
12
+ except ImportError:
13
+ # If certifi is not available, disable SSL verification as fallback
14
+ ssl._create_default_https_context = ssl._create_unverified_context
15
+
16
+ <% if (architecture === 'triton') { %>
17
+ <% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
18
+ try:
19
+ import xgboost as xgb
20
+ except ImportError:
21
+ print("Error: xgboost is required. Install dependencies with: pip install -r requirements.txt")
22
+ raise
23
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
24
+ try:
25
+ import lightgbm as lgb
26
+ except ImportError:
27
+ print("Error: lightgbm is required. Install dependencies with: pip install -r requirements.txt")
28
+ raise
29
+ <% } else if (backend === 'onnxruntime') { %>
30
+ try:
31
+ from sklearn.ensemble import RandomForestRegressor
32
+ from sklearn.model_selection import train_test_split
33
+ from skl2onnx import convert_sklearn
34
+ from skl2onnx.common.data_types import FloatTensorType
35
+ import onnx
36
+ except ImportError:
37
+ print("Error: scikit-learn, skl2onnx, and onnx are required. Install dependencies with: pip install -r requirements.txt")
38
+ raise
39
+ <% } else if (backend === 'tensorflow') { %>
40
+ try:
41
+ import tensorflow as tf
42
+ except ImportError:
43
+ print("Error: tensorflow is required. Install dependencies with: pip install -r requirements.txt")
44
+ raise
45
+ <% } else if (backend === 'python') { %>
46
+ try:
47
+ from sklearn.ensemble import RandomForestRegressor
48
+ from sklearn.model_selection import train_test_split
49
+ import pickle
50
+ <% if (modelFormat === 'joblib') { %>
51
+ import joblib
52
+ <% } %>
53
+ except ImportError:
54
+ print("Error: scikit-learn is required. Install dependencies with: pip install -r requirements.txt")
55
+ raise
56
+ <% } %>
57
+ <% } else { %>
58
+ <% const effectiveFramework = engine || framework; %>
59
+ <% if (effectiveFramework === 'sklearn') { %>from sklearn.ensemble import RandomForestRegressor
60
+ from sklearn.model_selection import train_test_split
61
+ <% if (modelFormat === 'joblib') { %>import joblib
62
+ <% } else if (modelFormat === 'pkl') { %>import pickle
63
+ <% } %>
64
+ <% } else if (effectiveFramework === 'xgboost' || effectiveFramework === 'tensorflow') { %>
65
+ <% if (effectiveFramework === 'xgboost') { %>import xgboost as xgb<% } %>
66
+ <% if (effectiveFramework === 'tensorflow') { %>import tensorflow as tf<% } %>
67
+
68
+ def train_test_split(X, y, test_size=0.2, random_state=None):
69
+ if random_state is not None:
70
+ np.random.seed(random_state)
71
+
72
+ n_samples = len(X)
73
+ n_test = int(n_samples * test_size)
74
+
75
+ indices = np.random.permutation(n_samples)
76
+ test_indices = indices[:n_test]
77
+ train_indices = indices[n_test:]
78
+
79
+ return X.iloc[train_indices], X.iloc[test_indices], y[train_indices], y[test_indices]
80
+ <% } %>
81
+ <% } %>
82
+
83
+ from ucimlrepo import fetch_ucirepo
84
+ <% if (architecture === 'triton' && (backend === 'fil' || backend === 'tensorflow')) { %>
85
+
86
+ def train_test_split(X, y, test_size=0.2, random_state=None):
87
+ if random_state is not None:
88
+ np.random.seed(random_state)
89
+
90
+ n_samples = len(X)
91
+ n_test = int(n_samples * test_size)
92
+
93
+ indices = np.random.permutation(n_samples)
94
+ test_indices = indices[:n_test]
95
+ train_indices = indices[n_test:]
96
+
97
+ return X.iloc[train_indices], X.iloc[test_indices], y[train_indices], y[test_indices]
98
+ <% } %>
99
+
100
+ try:
101
+ abalone = fetch_ucirepo(id=1)
102
+ X = abalone.data.features.copy()
103
+ y = abalone.data.targets.values.ravel()
104
+ except Exception as e:
105
+ print(f"Warning: Could not download Abalone dataset from UCI repository: {e}")
106
+ print("Creating synthetic data for demonstration...")
107
+
108
+ # Create synthetic abalone-like data
109
+ np.random.seed(42)
110
+ n_samples = 4177 # Same as original dataset
111
+
112
+ # Create synthetic features similar to abalone dataset
113
+ # Features: Sex, Length, Diameter, Height, Whole weight, Shucked weight, Viscera weight, Shell weight
114
+ X = np.random.rand(n_samples, 8)
115
+ X[:, 0] = np.random.choice([0, 1, 2], n_samples) # Sex (M=0, F=1, I=2)
116
+ X[:, 1:] = X[:, 1:] * np.array([0.815, 0.650, 0.265, 2.826, 1.488, 0.760, 1.005]) # Scale to realistic ranges
117
+
118
+ # Create synthetic target (rings/age)
119
+ y = (X[:, 1] * 10 + X[:, 4] * 5 + np.random.normal(0, 2, n_samples)).astype(int)
120
+ y = np.clip(y, 1, 29) # Clip to realistic range
121
+
122
+ # Convert to DataFrame-like structure for compatibility
123
+ import pandas as pd
124
+ feature_names = ['Sex', 'Length', 'Diameter', 'Height', 'Whole_weight', 'Shucked_weight', 'Viscera_weight', 'Shell_weight']
125
+ X = pd.DataFrame(X, columns=feature_names)
126
+
127
+ # Encode Sex column if it's not already numeric
128
+ if hasattr(X, 'dtypes') and X['Sex'].dtype == 'object':
129
+ # Encode Sex column (M=0, F=1, I=2)
130
+ sex_map = {'M': 0, 'F': 1, 'I': 2}
131
+ X['Sex'] = X['Sex'].map(sex_map)
132
+
133
+ # Split data
134
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
135
+
136
+ <% if (architecture === 'triton') { %>
137
+ <% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
138
+ # Train XGBoost model
139
+ model = xgb.XGBRegressor(n_estimators=100, random_state=42)
140
+ model.fit(X_train, y_train)
141
+
142
+ # Evaluate
143
+ test_score = model.score(X_test, y_test)
144
+ print(f"Model trained. Test score: {test_score:.3f}")
145
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
146
+ # Train LightGBM model
147
+ model = lgb.LGBMRegressor(n_estimators=100, random_state=42, verbose=-1)
148
+ model.fit(X_train, y_train)
149
+
150
+ # Evaluate
151
+ test_score = model.score(X_test, y_test)
152
+ print(f"Model trained. Test score: {test_score:.3f}")
153
+ <% } else if (backend === 'onnxruntime') { %>
154
+ # Train sklearn model
155
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
156
+ model.fit(X_train, y_train)
157
+
158
+ print(f"Model trained. Test score: {model.score(X_test, y_test):.3f}")
159
+
160
+ # Convert to ONNX format
161
+ initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))]
162
+ onnx_model = convert_sklearn(model, 'abalone_model', initial_types=initial_type)
163
+ <% } else if (backend === 'tensorflow') { %>
164
+ # Train TensorFlow model
165
+ model = tf.keras.Sequential([
166
+ tf.keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
167
+ tf.keras.layers.Dense(32, activation='relu'),
168
+ tf.keras.layers.Dense(1)
169
+ ])
170
+
171
+ model.compile(optimizer='adam', loss='mse', metrics=['mae'])
172
+
173
+ # Train model
174
+ model.fit(X_train, y_train, epochs=50, batch_size=32, validation_split=0.2, verbose=0)
175
+
176
+ # Calculate test score
177
+ test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
178
+ print(f"Model trained. Test MAE: {test_mae:.3f}")
179
+ <% } else if (backend === 'python') { %>
180
+ # Train sklearn model
181
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
182
+ model.fit(X_train, y_train)
183
+
184
+ print(f"Model trained. Test score: {model.score(X_test, y_test):.3f}")
185
+ <% } %>
186
+ <% } else { %>
187
+ <% const effectiveFramework = engine || framework; %>
188
+ <% if (effectiveFramework === 'sklearn') { %># Train sklearn model
189
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
190
+ model.fit(X_train, y_train)
191
+
192
+ print(f"Model trained and saved. Test score: {model.score(X_test, y_test):.3f}")
193
+ <% } else if (effectiveFramework === 'xgboost') { %># Train XGBoost model
194
+ model = xgb.XGBRegressor(n_estimators=100, random_state=42)
195
+ model.fit(X_train, y_train)
196
+
197
+ # Evaluate
198
+ test_score = model.score(X_test, y_test)
199
+ print(f"Model trained. Test score: {test_score:.3f}")
200
+ <% } else if (effectiveFramework === 'tensorflow') { %># Train TensorFlow model
201
+ # Create TensorFlow model
202
+ model = tf.keras.Sequential([
203
+ tf.keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
204
+ tf.keras.layers.Dense(32, activation='relu'),
205
+ tf.keras.layers.Dense(1)
206
+ ])
207
+
208
+ model.compile(optimizer='adam', loss='mse', metrics=['mae'])
209
+
210
+ # Train model
211
+ model.fit(X_train, y_train, epochs=50, batch_size=32, validation_split=0.2, verbose=0)
212
+
213
+ # Calculate test score
214
+ test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
215
+ print(f"Model trained and saved. Test MAE: {test_mae:.3f}")
216
+ <% } %>
217
+ <% } %>
218
+
219
+ # Save model
220
+ # Get the directory where this script is located
221
+ script_dir = os.path.dirname(os.path.abspath(__file__))
222
+
223
+ <% if (architecture === 'triton') { %>
224
+ <% if (backend === 'fil' && modelFormat === 'xgboost_json') { %>model.save_model(os.path.join(script_dir, 'abalone_model.json'))
225
+ <% } else if (backend === 'fil' && modelFormat === 'xgboost_ubj') { %>model.save_model(os.path.join(script_dir, 'abalone_model.ubj'))
226
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>model.booster_.save_model(os.path.join(script_dir, 'abalone_model.txt'))
227
+ <% } else if (backend === 'onnxruntime') { %>onnx.save_model(onnx_model, os.path.join(script_dir, 'abalone_model.onnx'))
228
+ <% } else if (backend === 'tensorflow') { %>model.export(os.path.join(script_dir, 'abalone_model.savedmodel'))
229
+ <% } else if (backend === 'python' && modelFormat === 'pkl') { %>
230
+ with open(os.path.join(script_dir, 'abalone_model.pkl'), 'wb') as f:
231
+ pickle.dump(model, f)
232
+ <% } else if (backend === 'python' && modelFormat === 'joblib') { %>joblib.dump(model, os.path.join(script_dir, 'abalone_model.joblib'))
233
+ <% } else if (backend === 'python') { %>
234
+ # Custom format: defaulting to pickle serialization
235
+ with open(os.path.join(script_dir, 'abalone_model.pkl'), 'wb') as f:
236
+ pickle.dump(model, f)
237
+ <% } %>
238
+ <% } else { %>
239
+ <% if (modelFormat === 'joblib') { %>joblib.dump(model, os.path.join(script_dir, 'abalone_model.joblib'))
240
+ <% } else if (modelFormat === 'pkl') { -%>
241
+ with open(os.path.join(script_dir, 'abalone_model.pkl'), 'wb') as f:
242
+ pickle.dump(model, f)
243
+ <% } else if (modelFormat === 'json') { %>model.save_model(os.path.join(script_dir, 'abalone_model.json'))
244
+ <% } else if (modelFormat === 'model') { %>model.save_model(os.path.join(script_dir, 'abalone_model.model'))
245
+ <% } else if (modelFormat === 'ubj') { %>model.save_model(os.path.join(script_dir, 'abalone_model.ubj'))
246
+ <% } else if (modelFormat === 'h5') { %>model.save(os.path.join(script_dir, 'abalone_model.h5'))
247
+ <% } else if (modelFormat === 'keras') { %>model.save(os.path.join(script_dir, 'abalone_model.keras'))
248
+ <% } else if (modelFormat === 'SavedModel') { %>model.export(os.path.join(script_dir, 'abalone_model'))
249
+ <% } %>
250
+ <% } %>
251
+
252
+ print("Model saved.")
@@ -0,0 +1,79 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # Exit on any error
6
+ set -e
7
+
8
+ <% if (framework !== 'transformers') { %>
9
+
10
+ # Check if endpoint name is provided
11
+ if [ $# -ne 1 ]; then
12
+ echo "Usage: $0 <endpoint-name>"
13
+ echo "Example: $0 <%= framework %>-endpoint-1234567890"
14
+ exit 1
15
+ fi
16
+
17
+ ENDPOINT_NAME=$1
18
+ AWS_REGION="us-east-1"
19
+
20
+ <% } else { %>
21
+ if [ $# -ne 2 ]; then
22
+ echo "Usage: $0 <endpoint-name> <model-id>"
23
+ echo "Example: $0 <%= framework %>-endpoint-1234567890"
24
+ exit 1
25
+ fi
26
+
27
+ ENDPOINT_NAME=$1
28
+ MODEL_ID=$2
29
+ AWS_REGION="us-east-1"
30
+
31
+ <% } %>
32
+
33
+ echo "Testing SageMaker endpoint: ${ENDPOINT_NAME}"
34
+
35
+ echo "Checking endpoint status..."
36
+ aws sagemaker describe-endpoint --endpoint-name ${ENDPOINT_NAME} --region ${AWS_REGION} --query 'EndpointStatus' --output text
37
+
38
+ echo "Testing inference endpoint..."
39
+
40
+ <% if (framework !== 'transformers') { %>
41
+ echo '{"instances": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}' > input.json
42
+ <% } else {%>
43
+
44
+ cat > input.json << EOF
45
+ {
46
+ "model": "${MODEL_ID}",
47
+ "messages": [
48
+ {
49
+ "role": "user",
50
+ "content": "Hello, how are you?"
51
+ }
52
+ ],
53
+ "max_tokens": 100,
54
+ "temperature": 0.7
55
+ }
56
+ EOF
57
+
58
+ <% } %>
59
+
60
+ aws sagemaker-runtime invoke-endpoint \
61
+ --endpoint-name ${ENDPOINT_NAME} \
62
+ --region ${AWS_REGION} \
63
+ --content-type 'application/json' \
64
+ --body fileb://input.json \
65
+ response.json
66
+
67
+ echo "Response:"
68
+ if command -v jq &> /dev/null; then
69
+ # Decode base64 if response is encoded
70
+ jq -r '.Body // .' response.json 2>/dev/null || cat response.json
71
+ else
72
+ cat response.json
73
+ fi
74
+ echo
75
+
76
+ echo "Cleaning up files..."
77
+ rm -f response.json input.json
78
+
79
+ echo "Test complete!"
@@ -0,0 +1,80 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ <% if (framework !== 'transformers') { %>
6
+ # Exit on any error
7
+ set -e
8
+
9
+ IMAGE_NAME="<%= projectName %>"
10
+ CONTAINER_NAME="<%= framework %>-test"
11
+ PORT=8080
12
+
13
+ echo "Building Docker image..."
14
+ docker build -t ${IMAGE_NAME} .
15
+
16
+ echo "Stopping any existing container..."
17
+ docker stop ${CONTAINER_NAME} 2>/dev/null || true
18
+ docker rm ${CONTAINER_NAME} 2>/dev/null || true
19
+
20
+ echo "Starting container on port ${PORT}..."
21
+ docker run -d --name ${CONTAINER_NAME} -p ${PORT}:8080 ${IMAGE_NAME}
22
+
23
+ echo "Waiting for container to start..."
24
+ sleep 10
25
+
26
+ echo "Testing health check endpoint..."
27
+ curl -f http://localhost:${PORT}/ping || echo "Health check failed"
28
+
29
+ echo -e "\nTesting inference endpoint..."
30
+ curl -X POST http://localhost:${PORT}/invocations \
31
+ -H "Content-Type: application/json" \
32
+ -d '{"instances": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}' || echo "Inference failed"
33
+
34
+ echo -e "\nContainer logs:"
35
+ docker logs ${CONTAINER_NAME}
36
+
37
+ echo -e "\nCleaning up..."
38
+ docker stop ${CONTAINER_NAME}
39
+ docker rm ${CONTAINER_NAME}
40
+
41
+ echo "Test complete!"
42
+ <% } else
43
+ {%><%if (modelServer !== 'vllm') { %>
44
+ # Exit on any error
45
+ set -e
46
+
47
+ IMAGE_NAME="<%= projectName %>"
48
+ CONTAINER_NAME="<%= framework %>-test"
49
+ PORT=8080
50
+
51
+ echo "Building Docker image..."
52
+ docker build \
53
+ --build-arg MODEL=<%= modelName %> \
54
+ --build-arg MODEL_NAME=<%= projectName %>.<%= modelName %> \
55
+ --platform=linux/amd64 \
56
+ -t ${IMAGE_NAME} \
57
+ .
58
+
59
+ echo "Stopping any existing container..."
60
+ docker stop ${CONTAINER_NAME} 2>/dev/null || true
61
+ docker rm ${CONTAINER_NAME} 2>/dev/null || true
62
+
63
+ echo "Starting container on port ${PORT}..."
64
+ docker run -d --name ${CONTAINER_NAME} -p ${PORT}:8080 ${IMAGE_NAME}
65
+
66
+ echo "Waiting for container to start..."
67
+ sleep 10
68
+
69
+ echo "Testing health check endpoint..."
70
+ curl -f http://localhost:${PORT}/ping || echo "Health check failed"
71
+
72
+ echo -e "\nContainer logs:"
73
+ docker logs ${CONTAINER_NAME}
74
+
75
+ echo -e "\nCleaning up..."
76
+ docker stop ${CONTAINER_NAME}
77
+ docker rm ${CONTAINER_NAME}
78
+
79
+ echo "Test complete!"
80
+ <% } %><% } %>
@@ -0,0 +1,180 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ <% if (framework === 'sglang') { %>
5
+ #!/usr/bin/env python3
6
+ """
7
+ Local testing script for SGLang models
8
+
9
+ This script allows you to test your SGLang model locally before containerizing.
10
+ Unlike serve.py (which runs a production HTTP server), this is a CLI tool
11
+ for development and debugging.
12
+
13
+ Usage examples:
14
+ # Test with text input
15
+ python test_model_handler.py --input-data '"Hello, how are you?"'
16
+
17
+ # Test with SageMaker format
18
+ python test_model_handler.py --input-data '{"instances": ["Hello, world!", "How are you?"]}'
19
+
20
+ # Custom model
21
+ python test_model_handler.py --model-id microsoft/DialoGPT-small --input-data '"Hello!"'
22
+
23
+ This is NOT used in production - serve.py handles containerized inference.
24
+ """
25
+ import json
26
+ import argparse
27
+ import sys
28
+ import os
29
+ import asyncio
30
+ from sglang import Runtime
31
+
32
+ def usage():
33
+ """Print usage examples and exit"""
34
+ print("\nSGLANG Model Handler Test Tool")
35
+ print("=" * 40)
36
+ print("\nUsage examples:")
37
+ print(" # Basic test with text input:")
38
+ print(' python test_model_handler.py --input-data \'"Hello, how are you?"\'')
39
+ print("\n # SageMaker format:")
40
+ print(' python test_model_handler.py --input-data \'{"instances": ["Hello!", "How are you?"]}\'')
41
+ print("\n # Custom model:")
42
+ print(' python test_model_handler.py --model-id microsoft/DialoGPT-small --input-data \'"Hello!"\'')
43
+ print("\n # Show this help:")
44
+ print(" python test_model_handler.py --help")
45
+ print("\nNote: This is for local testing only. Production uses serve.py in containers.\n")
46
+ sys.exit(0)
47
+
48
+ async def main():
49
+ parser = argparse.ArgumentParser(
50
+ description='Local CLI tool for testing SGLang model inference',
51
+ epilog='Use --usage for detailed examples'
52
+ )
53
+ parser.add_argument('--model-id', type=str, default='<%= model || "microsoft/DialoGPT-medium" %>',
54
+ help='Model ID to load (default: <%= model || "microsoft/DialoGPT-medium" %>)')
55
+ parser.add_argument('--input-data', type=str,
56
+ help='Input data as JSON string')
57
+ parser.add_argument('--usage', action='store_true',
58
+ help='Show detailed usage examples')
59
+
60
+ args = parser.parse_args()
61
+
62
+ if args.usage:
63
+ usage()
64
+
65
+ if not args.input_data:
66
+ print("Error: --input-data is required")
67
+ print("Use --usage for examples or --help for options")
68
+ sys.exit(1)
69
+
70
+ print(f"Loading SGLang model: {args.model_id}")
71
+ runtime = Runtime(
72
+ model_path=args.model_id,
73
+ tokenizer_path=args.model_id,
74
+ device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
75
+ mem_fraction_static=0.8
76
+ )
77
+
78
+ try:
79
+ input_data = json.loads(args.input_data)
80
+ except json.JSONDecodeError:
81
+ input_data = args.input_data
82
+
83
+ # Extract prompts
84
+ if isinstance(input_data, dict):
85
+ prompts = input_data.get('instances', input_data.get('inputs', [input_data]))
86
+ else:
87
+ prompts = [input_data]
88
+
89
+ print("Running inference...")
90
+ outputs = runtime.generate(prompts)
91
+
92
+ result = {'predictions': outputs}
93
+ print("\nResult:")
94
+ print(json.dumps(result, indent=2))
95
+
96
+ if __name__ == '__main__':
97
+ asyncio.run(main())
98
+ <% } else { %>
99
+ #!/usr/bin/env python3
100
+ """
101
+ Local testing script for <%= framework %> models
102
+
103
+ This script allows you to test your model locally before containerizing.
104
+ Unlike serve.py (which runs a production HTTP server), this is a CLI tool
105
+ for development and debugging.
106
+
107
+ Usage examples:
108
+ # Test with array input
109
+ python test_model_handler.py --input-data '[[1,2,3,4]]'
110
+
111
+ # Test with SageMaker format
112
+ python test_model_handler.py --input-data '{"instances": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}'
113
+
114
+ # Custom model path
115
+ python test_model_handler.py --model-path ./ --input-data '[[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]'
116
+
117
+ This is NOT used in production - serve.py handles containerized inference.
118
+ """
119
+ import json
120
+ import argparse
121
+ import sys
122
+ import os
123
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'code'))
124
+ from model_handler import ModelHandler
125
+
126
+ def usage():
127
+ """Print usage examples and exit"""
128
+ print("\n<%= framework.toUpperCase() %> Model Handler Test Tool")
129
+ print("=" * 40)
130
+ print("\nUsage examples:")
131
+ print(" # Basic test with array input:")
132
+ print(" python test_model_handler.py --input-data '[[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]'")
133
+ print("\n # SageMaker format:")
134
+ print(" python test_model_handler.py --input-data '{\"instances\": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}'")
135
+ print("\n # Custom model path:")
136
+ print(" python test_model_handler.py --model-path ../sample_model --input-data '[[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]'")
137
+ print("\n # Show this help:")
138
+ print(" python test_model_handler.py --help")
139
+ print("\nNote: This is for local testing only. Production uses serve.py in containers.\n")
140
+ sys.exit(0)
141
+
142
+ def main():
143
+ parser = argparse.ArgumentParser(
144
+ description='Local CLI tool for testing <%= framework %> model inference',
145
+ epilog='Use --usage for detailed examples'
146
+ )
147
+ parser.add_argument('--model-path', type=str, default='sample_model',
148
+ help='Path to model directory (default: sample_model)')
149
+ parser.add_argument('--input-data', type=str,
150
+ help='Input data as application/json string')
151
+ parser.add_argument('--usage', action='store_true',
152
+ help='Show detailed usage examples')
153
+
154
+ args = parser.parse_args()
155
+
156
+ if args.usage:
157
+ usage()
158
+
159
+ if not args.input_data:
160
+ print("Error: --input-data is required")
161
+ print("Use --usage for examples or --help for options")
162
+ sys.exit(1)
163
+
164
+ print(f"Loading model from: {args.model_path}")
165
+ handler = ModelHandler(args.model_path)
166
+ handler.load_model()
167
+
168
+ try:
169
+ input_data = json.loads(args.input_data)
170
+ except json.JSONDecodeError:
171
+ input_data = args.input_data
172
+
173
+ print("Running inference...")
174
+ result = handler.predict(input_data)
175
+ print("\nResult:")
176
+ print(json.dumps(result, indent=2))
177
+
178
+ if __name__ == '__main__':
179
+ main()
180
+ <% } %>