@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import ssl
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
# Handle SSL certificate issues
|
|
9
|
+
try:
|
|
10
|
+
import certifi
|
|
11
|
+
ssl._create_default_https_context = ssl._create_unverified_context
|
|
12
|
+
except ImportError:
|
|
13
|
+
# If certifi is not available, disable SSL verification as fallback
|
|
14
|
+
ssl._create_default_https_context = ssl._create_unverified_context
|
|
15
|
+
|
|
16
|
+
<% if (architecture === 'triton') { %>
|
|
17
|
+
<% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
|
|
18
|
+
try:
|
|
19
|
+
import xgboost as xgb
|
|
20
|
+
except ImportError:
|
|
21
|
+
print("Error: xgboost is required. Install dependencies with: pip install -r requirements.txt")
|
|
22
|
+
raise
|
|
23
|
+
<% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
|
|
24
|
+
try:
|
|
25
|
+
import lightgbm as lgb
|
|
26
|
+
except ImportError:
|
|
27
|
+
print("Error: lightgbm is required. Install dependencies with: pip install -r requirements.txt")
|
|
28
|
+
raise
|
|
29
|
+
<% } else if (backend === 'onnxruntime') { %>
|
|
30
|
+
try:
|
|
31
|
+
from sklearn.ensemble import RandomForestRegressor
|
|
32
|
+
from sklearn.model_selection import train_test_split
|
|
33
|
+
from skl2onnx import convert_sklearn
|
|
34
|
+
from skl2onnx.common.data_types import FloatTensorType
|
|
35
|
+
import onnx
|
|
36
|
+
except ImportError:
|
|
37
|
+
print("Error: scikit-learn, skl2onnx, and onnx are required. Install dependencies with: pip install -r requirements.txt")
|
|
38
|
+
raise
|
|
39
|
+
<% } else if (backend === 'tensorflow') { %>
|
|
40
|
+
try:
|
|
41
|
+
import tensorflow as tf
|
|
42
|
+
except ImportError:
|
|
43
|
+
print("Error: tensorflow is required. Install dependencies with: pip install -r requirements.txt")
|
|
44
|
+
raise
|
|
45
|
+
<% } else if (backend === 'python') { %>
|
|
46
|
+
try:
|
|
47
|
+
from sklearn.ensemble import RandomForestRegressor
|
|
48
|
+
from sklearn.model_selection import train_test_split
|
|
49
|
+
import pickle
|
|
50
|
+
<% if (modelFormat === 'joblib') { %>
|
|
51
|
+
import joblib
|
|
52
|
+
<% } %>
|
|
53
|
+
except ImportError:
|
|
54
|
+
print("Error: scikit-learn is required. Install dependencies with: pip install -r requirements.txt")
|
|
55
|
+
raise
|
|
56
|
+
<% } %>
|
|
57
|
+
<% } else { %>
|
|
58
|
+
<% const effectiveFramework = engine || framework; %>
|
|
59
|
+
<% if (effectiveFramework === 'sklearn') { %>from sklearn.ensemble import RandomForestRegressor
|
|
60
|
+
from sklearn.model_selection import train_test_split
|
|
61
|
+
<% if (modelFormat === 'joblib') { %>import joblib
|
|
62
|
+
<% } else if (modelFormat === 'pkl') { %>import pickle
|
|
63
|
+
<% } %>
|
|
64
|
+
<% } else if (effectiveFramework === 'xgboost' || effectiveFramework === 'tensorflow') { %>
|
|
65
|
+
<% if (effectiveFramework === 'xgboost') { %>import xgboost as xgb<% } %>
|
|
66
|
+
<% if (effectiveFramework === 'tensorflow') { %>import tensorflow as tf<% } %>
|
|
67
|
+
|
|
68
|
+
def train_test_split(X, y, test_size=0.2, random_state=None):
|
|
69
|
+
if random_state is not None:
|
|
70
|
+
np.random.seed(random_state)
|
|
71
|
+
|
|
72
|
+
n_samples = len(X)
|
|
73
|
+
n_test = int(n_samples * test_size)
|
|
74
|
+
|
|
75
|
+
indices = np.random.permutation(n_samples)
|
|
76
|
+
test_indices = indices[:n_test]
|
|
77
|
+
train_indices = indices[n_test:]
|
|
78
|
+
|
|
79
|
+
return X.iloc[train_indices], X.iloc[test_indices], y[train_indices], y[test_indices]
|
|
80
|
+
<% } %>
|
|
81
|
+
<% } %>
|
|
82
|
+
|
|
83
|
+
from ucimlrepo import fetch_ucirepo
|
|
84
|
+
<% if (architecture === 'triton' && (backend === 'fil' || backend === 'tensorflow')) { %>
|
|
85
|
+
|
|
86
|
+
def train_test_split(X, y, test_size=0.2, random_state=None):
|
|
87
|
+
if random_state is not None:
|
|
88
|
+
np.random.seed(random_state)
|
|
89
|
+
|
|
90
|
+
n_samples = len(X)
|
|
91
|
+
n_test = int(n_samples * test_size)
|
|
92
|
+
|
|
93
|
+
indices = np.random.permutation(n_samples)
|
|
94
|
+
test_indices = indices[:n_test]
|
|
95
|
+
train_indices = indices[n_test:]
|
|
96
|
+
|
|
97
|
+
return X.iloc[train_indices], X.iloc[test_indices], y[train_indices], y[test_indices]
|
|
98
|
+
<% } %>
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
abalone = fetch_ucirepo(id=1)
|
|
102
|
+
X = abalone.data.features.copy()
|
|
103
|
+
y = abalone.data.targets.values.ravel()
|
|
104
|
+
except Exception as e:
|
|
105
|
+
print(f"Warning: Could not download Abalone dataset from UCI repository: {e}")
|
|
106
|
+
print("Creating synthetic data for demonstration...")
|
|
107
|
+
|
|
108
|
+
# Create synthetic abalone-like data
|
|
109
|
+
np.random.seed(42)
|
|
110
|
+
n_samples = 4177 # Same as original dataset
|
|
111
|
+
|
|
112
|
+
# Create synthetic features similar to abalone dataset
|
|
113
|
+
# Features: Sex, Length, Diameter, Height, Whole weight, Shucked weight, Viscera weight, Shell weight
|
|
114
|
+
X = np.random.rand(n_samples, 8)
|
|
115
|
+
X[:, 0] = np.random.choice([0, 1, 2], n_samples) # Sex (M=0, F=1, I=2)
|
|
116
|
+
X[:, 1:] = X[:, 1:] * np.array([0.815, 0.650, 0.265, 2.826, 1.488, 0.760, 1.005]) # Scale to realistic ranges
|
|
117
|
+
|
|
118
|
+
# Create synthetic target (rings/age)
|
|
119
|
+
y = (X[:, 1] * 10 + X[:, 4] * 5 + np.random.normal(0, 2, n_samples)).astype(int)
|
|
120
|
+
y = np.clip(y, 1, 29) # Clip to realistic range
|
|
121
|
+
|
|
122
|
+
# Convert to DataFrame-like structure for compatibility
|
|
123
|
+
import pandas as pd
|
|
124
|
+
feature_names = ['Sex', 'Length', 'Diameter', 'Height', 'Whole_weight', 'Shucked_weight', 'Viscera_weight', 'Shell_weight']
|
|
125
|
+
X = pd.DataFrame(X, columns=feature_names)
|
|
126
|
+
|
|
127
|
+
# Encode Sex column if it's not already numeric
|
|
128
|
+
if hasattr(X, 'dtypes') and X['Sex'].dtype == 'object':
|
|
129
|
+
# Encode Sex column (M=0, F=1, I=2)
|
|
130
|
+
sex_map = {'M': 0, 'F': 1, 'I': 2}
|
|
131
|
+
X['Sex'] = X['Sex'].map(sex_map)
|
|
132
|
+
|
|
133
|
+
# Split data
|
|
134
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
135
|
+
|
|
136
|
+
<% if (architecture === 'triton') { %>
|
|
137
|
+
<% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
|
|
138
|
+
# Train XGBoost model
|
|
139
|
+
model = xgb.XGBRegressor(n_estimators=100, random_state=42)
|
|
140
|
+
model.fit(X_train, y_train)
|
|
141
|
+
|
|
142
|
+
# Evaluate
|
|
143
|
+
test_score = model.score(X_test, y_test)
|
|
144
|
+
print(f"Model trained. Test score: {test_score:.3f}")
|
|
145
|
+
<% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
|
|
146
|
+
# Train LightGBM model
|
|
147
|
+
model = lgb.LGBMRegressor(n_estimators=100, random_state=42, verbose=-1)
|
|
148
|
+
model.fit(X_train, y_train)
|
|
149
|
+
|
|
150
|
+
# Evaluate
|
|
151
|
+
test_score = model.score(X_test, y_test)
|
|
152
|
+
print(f"Model trained. Test score: {test_score:.3f}")
|
|
153
|
+
<% } else if (backend === 'onnxruntime') { %>
|
|
154
|
+
# Train sklearn model
|
|
155
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
|
156
|
+
model.fit(X_train, y_train)
|
|
157
|
+
|
|
158
|
+
print(f"Model trained. Test score: {model.score(X_test, y_test):.3f}")
|
|
159
|
+
|
|
160
|
+
# Convert to ONNX format
|
|
161
|
+
initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))]
|
|
162
|
+
onnx_model = convert_sklearn(model, 'abalone_model', initial_types=initial_type)
|
|
163
|
+
<% } else if (backend === 'tensorflow') { %>
|
|
164
|
+
# Train TensorFlow model
|
|
165
|
+
model = tf.keras.Sequential([
|
|
166
|
+
tf.keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
|
|
167
|
+
tf.keras.layers.Dense(32, activation='relu'),
|
|
168
|
+
tf.keras.layers.Dense(1)
|
|
169
|
+
])
|
|
170
|
+
|
|
171
|
+
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
|
|
172
|
+
|
|
173
|
+
# Train model
|
|
174
|
+
model.fit(X_train, y_train, epochs=50, batch_size=32, validation_split=0.2, verbose=0)
|
|
175
|
+
|
|
176
|
+
# Calculate test score
|
|
177
|
+
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
|
|
178
|
+
print(f"Model trained. Test MAE: {test_mae:.3f}")
|
|
179
|
+
<% } else if (backend === 'python') { %>
|
|
180
|
+
# Train sklearn model
|
|
181
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
|
182
|
+
model.fit(X_train, y_train)
|
|
183
|
+
|
|
184
|
+
print(f"Model trained. Test score: {model.score(X_test, y_test):.3f}")
|
|
185
|
+
<% } %>
|
|
186
|
+
<% } else { %>
|
|
187
|
+
<% const effectiveFramework = engine || framework; %>
|
|
188
|
+
<% if (effectiveFramework === 'sklearn') { %># Train sklearn model
|
|
189
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
|
190
|
+
model.fit(X_train, y_train)
|
|
191
|
+
|
|
192
|
+
print(f"Model trained and saved. Test score: {model.score(X_test, y_test):.3f}")
|
|
193
|
+
<% } else if (effectiveFramework === 'xgboost') { %># Train XGBoost model
|
|
194
|
+
model = xgb.XGBRegressor(n_estimators=100, random_state=42)
|
|
195
|
+
model.fit(X_train, y_train)
|
|
196
|
+
|
|
197
|
+
# Evaluate
|
|
198
|
+
test_score = model.score(X_test, y_test)
|
|
199
|
+
print(f"Model trained. Test score: {test_score:.3f}")
|
|
200
|
+
<% } else if (effectiveFramework === 'tensorflow') { %># Train TensorFlow model
|
|
201
|
+
# Create TensorFlow model
|
|
202
|
+
model = tf.keras.Sequential([
|
|
203
|
+
tf.keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
|
|
204
|
+
tf.keras.layers.Dense(32, activation='relu'),
|
|
205
|
+
tf.keras.layers.Dense(1)
|
|
206
|
+
])
|
|
207
|
+
|
|
208
|
+
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
|
|
209
|
+
|
|
210
|
+
# Train model
|
|
211
|
+
model.fit(X_train, y_train, epochs=50, batch_size=32, validation_split=0.2, verbose=0)
|
|
212
|
+
|
|
213
|
+
# Calculate test score
|
|
214
|
+
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
|
|
215
|
+
print(f"Model trained and saved. Test MAE: {test_mae:.3f}")
|
|
216
|
+
<% } %>
|
|
217
|
+
<% } %>
|
|
218
|
+
|
|
219
|
+
# Save model
|
|
220
|
+
# Get the directory where this script is located
|
|
221
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
222
|
+
|
|
223
|
+
<% if (architecture === 'triton') { %>
|
|
224
|
+
<% if (backend === 'fil' && modelFormat === 'xgboost_json') { %>model.save_model(os.path.join(script_dir, 'abalone_model.json'))
|
|
225
|
+
<% } else if (backend === 'fil' && modelFormat === 'xgboost_ubj') { %>model.save_model(os.path.join(script_dir, 'abalone_model.ubj'))
|
|
226
|
+
<% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>model.booster_.save_model(os.path.join(script_dir, 'abalone_model.txt'))
|
|
227
|
+
<% } else if (backend === 'onnxruntime') { %>onnx.save_model(onnx_model, os.path.join(script_dir, 'abalone_model.onnx'))
|
|
228
|
+
<% } else if (backend === 'tensorflow') { %>model.export(os.path.join(script_dir, 'abalone_model.savedmodel'))
|
|
229
|
+
<% } else if (backend === 'python' && modelFormat === 'pkl') { %>
|
|
230
|
+
with open(os.path.join(script_dir, 'abalone_model.pkl'), 'wb') as f:
|
|
231
|
+
pickle.dump(model, f)
|
|
232
|
+
<% } else if (backend === 'python' && modelFormat === 'joblib') { %>joblib.dump(model, os.path.join(script_dir, 'abalone_model.joblib'))
|
|
233
|
+
<% } else if (backend === 'python') { %>
|
|
234
|
+
# Custom format: defaulting to pickle serialization
|
|
235
|
+
with open(os.path.join(script_dir, 'abalone_model.pkl'), 'wb') as f:
|
|
236
|
+
pickle.dump(model, f)
|
|
237
|
+
<% } %>
|
|
238
|
+
<% } else { %>
|
|
239
|
+
<% if (modelFormat === 'joblib') { %>joblib.dump(model, os.path.join(script_dir, 'abalone_model.joblib'))
|
|
240
|
+
<% } else if (modelFormat === 'pkl') { -%>
|
|
241
|
+
with open(os.path.join(script_dir, 'abalone_model.pkl'), 'wb') as f:
|
|
242
|
+
pickle.dump(model, f)
|
|
243
|
+
<% } else if (modelFormat === 'json') { %>model.save_model(os.path.join(script_dir, 'abalone_model.json'))
|
|
244
|
+
<% } else if (modelFormat === 'model') { %>model.save_model(os.path.join(script_dir, 'abalone_model.model'))
|
|
245
|
+
<% } else if (modelFormat === 'ubj') { %>model.save_model(os.path.join(script_dir, 'abalone_model.ubj'))
|
|
246
|
+
<% } else if (modelFormat === 'h5') { %>model.save(os.path.join(script_dir, 'abalone_model.h5'))
|
|
247
|
+
<% } else if (modelFormat === 'keras') { %>model.save(os.path.join(script_dir, 'abalone_model.keras'))
|
|
248
|
+
<% } else if (modelFormat === 'SavedModel') { %>model.export(os.path.join(script_dir, 'abalone_model'))
|
|
249
|
+
<% } %>
|
|
250
|
+
<% } %>
|
|
251
|
+
|
|
252
|
+
print("Model saved.")
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
# Exit on any error
|
|
6
|
+
set -e
|
|
7
|
+
|
|
8
|
+
<% if (framework !== 'transformers') { %>
|
|
9
|
+
|
|
10
|
+
# Check if endpoint name is provided
|
|
11
|
+
if [ $# -ne 1 ]; then
|
|
12
|
+
echo "Usage: $0 <endpoint-name>"
|
|
13
|
+
echo "Example: $0 <%= framework %>-endpoint-1234567890"
|
|
14
|
+
exit 1
|
|
15
|
+
fi
|
|
16
|
+
|
|
17
|
+
ENDPOINT_NAME=$1
|
|
18
|
+
AWS_REGION="us-east-1"
|
|
19
|
+
|
|
20
|
+
<% } else { %>
|
|
21
|
+
if [ $# -ne 2 ]; then
|
|
22
|
+
echo "Usage: $0 <endpoint-name> <model-id>"
|
|
23
|
+
echo "Example: $0 <%= framework %>-endpoint-1234567890"
|
|
24
|
+
exit 1
|
|
25
|
+
fi
|
|
26
|
+
|
|
27
|
+
ENDPOINT_NAME=$1
|
|
28
|
+
MODEL_ID=$2
|
|
29
|
+
AWS_REGION="us-east-1"
|
|
30
|
+
|
|
31
|
+
<% } %>
|
|
32
|
+
|
|
33
|
+
echo "Testing SageMaker endpoint: ${ENDPOINT_NAME}"
|
|
34
|
+
|
|
35
|
+
echo "Checking endpoint status..."
|
|
36
|
+
aws sagemaker describe-endpoint --endpoint-name ${ENDPOINT_NAME} --region ${AWS_REGION} --query 'EndpointStatus' --output text
|
|
37
|
+
|
|
38
|
+
echo "Testing inference endpoint..."
|
|
39
|
+
|
|
40
|
+
<% if (framework !== 'transformers') { %>
|
|
41
|
+
echo '{"instances": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}' > input.json
|
|
42
|
+
<% } else {%>
|
|
43
|
+
|
|
44
|
+
cat > input.json << EOF
|
|
45
|
+
{
|
|
46
|
+
"model": "${MODEL_ID}",
|
|
47
|
+
"messages": [
|
|
48
|
+
{
|
|
49
|
+
"role": "user",
|
|
50
|
+
"content": "Hello, how are you?"
|
|
51
|
+
}
|
|
52
|
+
],
|
|
53
|
+
"max_tokens": 100,
|
|
54
|
+
"temperature": 0.7
|
|
55
|
+
}
|
|
56
|
+
EOF
|
|
57
|
+
|
|
58
|
+
<% } %>
|
|
59
|
+
|
|
60
|
+
aws sagemaker-runtime invoke-endpoint \
|
|
61
|
+
--endpoint-name ${ENDPOINT_NAME} \
|
|
62
|
+
--region ${AWS_REGION} \
|
|
63
|
+
--content-type 'application/json' \
|
|
64
|
+
--body fileb://input.json \
|
|
65
|
+
response.json
|
|
66
|
+
|
|
67
|
+
echo "Response:"
|
|
68
|
+
if command -v jq &> /dev/null; then
|
|
69
|
+
# Decode base64 if response is encoded
|
|
70
|
+
jq -r '.Body // .' response.json 2>/dev/null || cat response.json
|
|
71
|
+
else
|
|
72
|
+
cat response.json
|
|
73
|
+
fi
|
|
74
|
+
echo
|
|
75
|
+
|
|
76
|
+
echo "Cleaning up files..."
|
|
77
|
+
rm -f response.json input.json
|
|
78
|
+
|
|
79
|
+
echo "Test complete!"
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
<% if (framework !== 'transformers') { %>
|
|
6
|
+
# Exit on any error
|
|
7
|
+
set -e
|
|
8
|
+
|
|
9
|
+
IMAGE_NAME="<%= projectName %>"
|
|
10
|
+
CONTAINER_NAME="<%= framework %>-test"
|
|
11
|
+
PORT=8080
|
|
12
|
+
|
|
13
|
+
echo "Building Docker image..."
|
|
14
|
+
docker build -t ${IMAGE_NAME} .
|
|
15
|
+
|
|
16
|
+
echo "Stopping any existing container..."
|
|
17
|
+
docker stop ${CONTAINER_NAME} 2>/dev/null || true
|
|
18
|
+
docker rm ${CONTAINER_NAME} 2>/dev/null || true
|
|
19
|
+
|
|
20
|
+
echo "Starting container on port ${PORT}..."
|
|
21
|
+
docker run -d --name ${CONTAINER_NAME} -p ${PORT}:8080 ${IMAGE_NAME}
|
|
22
|
+
|
|
23
|
+
echo "Waiting for container to start..."
|
|
24
|
+
sleep 10
|
|
25
|
+
|
|
26
|
+
echo "Testing health check endpoint..."
|
|
27
|
+
curl -f http://localhost:${PORT}/ping || echo "Health check failed"
|
|
28
|
+
|
|
29
|
+
echo -e "\nTesting inference endpoint..."
|
|
30
|
+
curl -X POST http://localhost:${PORT}/invocations \
|
|
31
|
+
-H "Content-Type: application/json" \
|
|
32
|
+
-d '{"instances": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}' || echo "Inference failed"
|
|
33
|
+
|
|
34
|
+
echo -e "\nContainer logs:"
|
|
35
|
+
docker logs ${CONTAINER_NAME}
|
|
36
|
+
|
|
37
|
+
echo -e "\nCleaning up..."
|
|
38
|
+
docker stop ${CONTAINER_NAME}
|
|
39
|
+
docker rm ${CONTAINER_NAME}
|
|
40
|
+
|
|
41
|
+
echo "Test complete!"
|
|
42
|
+
<% } else
|
|
43
|
+
{%><%if (modelServer !== 'vllm') { %>
|
|
44
|
+
# Exit on any error
|
|
45
|
+
set -e
|
|
46
|
+
|
|
47
|
+
IMAGE_NAME="<%= projectName %>"
|
|
48
|
+
CONTAINER_NAME="<%= framework %>-test"
|
|
49
|
+
PORT=8080
|
|
50
|
+
|
|
51
|
+
echo "Building Docker image..."
|
|
52
|
+
docker build \
|
|
53
|
+
--build-arg MODEL=<%= modelName %> \
|
|
54
|
+
--build-arg MODEL_NAME=<%= projectName %>.<%= modelName %> \
|
|
55
|
+
--platform=linux/amd64 \
|
|
56
|
+
-t ${IMAGE_NAME} \
|
|
57
|
+
.
|
|
58
|
+
|
|
59
|
+
echo "Stopping any existing container..."
|
|
60
|
+
docker stop ${CONTAINER_NAME} 2>/dev/null || true
|
|
61
|
+
docker rm ${CONTAINER_NAME} 2>/dev/null || true
|
|
62
|
+
|
|
63
|
+
echo "Starting container on port ${PORT}..."
|
|
64
|
+
docker run -d --name ${CONTAINER_NAME} -p ${PORT}:8080 ${IMAGE_NAME}
|
|
65
|
+
|
|
66
|
+
echo "Waiting for container to start..."
|
|
67
|
+
sleep 10
|
|
68
|
+
|
|
69
|
+
echo "Testing health check endpoint..."
|
|
70
|
+
curl -f http://localhost:${PORT}/ping || echo "Health check failed"
|
|
71
|
+
|
|
72
|
+
echo -e "\nContainer logs:"
|
|
73
|
+
docker logs ${CONTAINER_NAME}
|
|
74
|
+
|
|
75
|
+
echo -e "\nCleaning up..."
|
|
76
|
+
docker stop ${CONTAINER_NAME}
|
|
77
|
+
docker rm ${CONTAINER_NAME}
|
|
78
|
+
|
|
79
|
+
echo "Test complete!"
|
|
80
|
+
<% } %><% } %>
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
<% if (framework === 'sglang') { %>
|
|
5
|
+
#!/usr/bin/env python3
|
|
6
|
+
"""
|
|
7
|
+
Local testing script for SGLang models
|
|
8
|
+
|
|
9
|
+
This script allows you to test your SGLang model locally before containerizing.
|
|
10
|
+
Unlike serve.py (which runs a production HTTP server), this is a CLI tool
|
|
11
|
+
for development and debugging.
|
|
12
|
+
|
|
13
|
+
Usage examples:
|
|
14
|
+
# Test with text input
|
|
15
|
+
python test_model_handler.py --input-data '"Hello, how are you?"'
|
|
16
|
+
|
|
17
|
+
# Test with SageMaker format
|
|
18
|
+
python test_model_handler.py --input-data '{"instances": ["Hello, world!", "How are you?"]}'
|
|
19
|
+
|
|
20
|
+
# Custom model
|
|
21
|
+
python test_model_handler.py --model-id microsoft/DialoGPT-small --input-data '"Hello!"'
|
|
22
|
+
|
|
23
|
+
This is NOT used in production - serve.py handles containerized inference.
|
|
24
|
+
"""
|
|
25
|
+
import json
|
|
26
|
+
import argparse
|
|
27
|
+
import sys
|
|
28
|
+
import os
|
|
29
|
+
import asyncio
|
|
30
|
+
from sglang import Runtime
|
|
31
|
+
|
|
32
|
+
def usage():
|
|
33
|
+
"""Print usage examples and exit"""
|
|
34
|
+
print("\nSGLANG Model Handler Test Tool")
|
|
35
|
+
print("=" * 40)
|
|
36
|
+
print("\nUsage examples:")
|
|
37
|
+
print(" # Basic test with text input:")
|
|
38
|
+
print(' python test_model_handler.py --input-data \'"Hello, how are you?"\'')
|
|
39
|
+
print("\n # SageMaker format:")
|
|
40
|
+
print(' python test_model_handler.py --input-data \'{"instances": ["Hello!", "How are you?"]}\'')
|
|
41
|
+
print("\n # Custom model:")
|
|
42
|
+
print(' python test_model_handler.py --model-id microsoft/DialoGPT-small --input-data \'"Hello!"\'')
|
|
43
|
+
print("\n # Show this help:")
|
|
44
|
+
print(" python test_model_handler.py --help")
|
|
45
|
+
print("\nNote: This is for local testing only. Production uses serve.py in containers.\n")
|
|
46
|
+
sys.exit(0)
|
|
47
|
+
|
|
48
|
+
async def main():
|
|
49
|
+
parser = argparse.ArgumentParser(
|
|
50
|
+
description='Local CLI tool for testing SGLang model inference',
|
|
51
|
+
epilog='Use --usage for detailed examples'
|
|
52
|
+
)
|
|
53
|
+
parser.add_argument('--model-id', type=str, default='<%= model || "microsoft/DialoGPT-medium" %>',
|
|
54
|
+
help='Model ID to load (default: <%= model || "microsoft/DialoGPT-medium" %>)')
|
|
55
|
+
parser.add_argument('--input-data', type=str,
|
|
56
|
+
help='Input data as JSON string')
|
|
57
|
+
parser.add_argument('--usage', action='store_true',
|
|
58
|
+
help='Show detailed usage examples')
|
|
59
|
+
|
|
60
|
+
args = parser.parse_args()
|
|
61
|
+
|
|
62
|
+
if args.usage:
|
|
63
|
+
usage()
|
|
64
|
+
|
|
65
|
+
if not args.input_data:
|
|
66
|
+
print("Error: --input-data is required")
|
|
67
|
+
print("Use --usage for examples or --help for options")
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
print(f"Loading SGLang model: {args.model_id}")
|
|
71
|
+
runtime = Runtime(
|
|
72
|
+
model_path=args.model_id,
|
|
73
|
+
tokenizer_path=args.model_id,
|
|
74
|
+
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
|
|
75
|
+
mem_fraction_static=0.8
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
input_data = json.loads(args.input_data)
|
|
80
|
+
except json.JSONDecodeError:
|
|
81
|
+
input_data = args.input_data
|
|
82
|
+
|
|
83
|
+
# Extract prompts
|
|
84
|
+
if isinstance(input_data, dict):
|
|
85
|
+
prompts = input_data.get('instances', input_data.get('inputs', [input_data]))
|
|
86
|
+
else:
|
|
87
|
+
prompts = [input_data]
|
|
88
|
+
|
|
89
|
+
print("Running inference...")
|
|
90
|
+
outputs = runtime.generate(prompts)
|
|
91
|
+
|
|
92
|
+
result = {'predictions': outputs}
|
|
93
|
+
print("\nResult:")
|
|
94
|
+
print(json.dumps(result, indent=2))
|
|
95
|
+
|
|
96
|
+
if __name__ == '__main__':
|
|
97
|
+
asyncio.run(main())
|
|
98
|
+
<% } else { %>
|
|
99
|
+
#!/usr/bin/env python3
|
|
100
|
+
"""
|
|
101
|
+
Local testing script for <%= framework %> models
|
|
102
|
+
|
|
103
|
+
This script allows you to test your model locally before containerizing.
|
|
104
|
+
Unlike serve.py (which runs a production HTTP server), this is a CLI tool
|
|
105
|
+
for development and debugging.
|
|
106
|
+
|
|
107
|
+
Usage examples:
|
|
108
|
+
# Test with array input
|
|
109
|
+
python test_model_handler.py --input-data '[[1,2,3,4]]'
|
|
110
|
+
|
|
111
|
+
# Test with SageMaker format
|
|
112
|
+
python test_model_handler.py --input-data '{"instances": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}'
|
|
113
|
+
|
|
114
|
+
# Custom model path
|
|
115
|
+
python test_model_handler.py --model-path ./ --input-data '[[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]'
|
|
116
|
+
|
|
117
|
+
This is NOT used in production - serve.py handles containerized inference.
|
|
118
|
+
"""
|
|
119
|
+
import json
|
|
120
|
+
import argparse
|
|
121
|
+
import sys
|
|
122
|
+
import os
|
|
123
|
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'code'))
|
|
124
|
+
from model_handler import ModelHandler
|
|
125
|
+
|
|
126
|
+
def usage():
|
|
127
|
+
"""Print usage examples and exit"""
|
|
128
|
+
print("\n<%= framework.toUpperCase() %> Model Handler Test Tool")
|
|
129
|
+
print("=" * 40)
|
|
130
|
+
print("\nUsage examples:")
|
|
131
|
+
print(" # Basic test with array input:")
|
|
132
|
+
print(" python test_model_handler.py --input-data '[[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]'")
|
|
133
|
+
print("\n # SageMaker format:")
|
|
134
|
+
print(" python test_model_handler.py --input-data '{\"instances\": [[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]}'")
|
|
135
|
+
print("\n # Custom model path:")
|
|
136
|
+
print(" python test_model_handler.py --model-path ../sample_model --input-data '[[1, 0.455, 0.365, 0.095, 0.514, 0.2245, 0.101, 0.15]]'")
|
|
137
|
+
print("\n # Show this help:")
|
|
138
|
+
print(" python test_model_handler.py --help")
|
|
139
|
+
print("\nNote: This is for local testing only. Production uses serve.py in containers.\n")
|
|
140
|
+
sys.exit(0)
|
|
141
|
+
|
|
142
|
+
def main():
|
|
143
|
+
parser = argparse.ArgumentParser(
|
|
144
|
+
description='Local CLI tool for testing <%= framework %> model inference',
|
|
145
|
+
epilog='Use --usage for detailed examples'
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument('--model-path', type=str, default='sample_model',
|
|
148
|
+
help='Path to model directory (default: sample_model)')
|
|
149
|
+
parser.add_argument('--input-data', type=str,
|
|
150
|
+
help='Input data as application/json string')
|
|
151
|
+
parser.add_argument('--usage', action='store_true',
|
|
152
|
+
help='Show detailed usage examples')
|
|
153
|
+
|
|
154
|
+
args = parser.parse_args()
|
|
155
|
+
|
|
156
|
+
if args.usage:
|
|
157
|
+
usage()
|
|
158
|
+
|
|
159
|
+
if not args.input_data:
|
|
160
|
+
print("Error: --input-data is required")
|
|
161
|
+
print("Use --usage for examples or --help for options")
|
|
162
|
+
sys.exit(1)
|
|
163
|
+
|
|
164
|
+
print(f"Loading model from: {args.model_path}")
|
|
165
|
+
handler = ModelHandler(args.model_path)
|
|
166
|
+
handler.load_model()
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
input_data = json.loads(args.input_data)
|
|
170
|
+
except json.JSONDecodeError:
|
|
171
|
+
input_data = args.input_data
|
|
172
|
+
|
|
173
|
+
print("Running inference...")
|
|
174
|
+
result = handler.predict(input_data)
|
|
175
|
+
print("\nResult:")
|
|
176
|
+
print(json.dumps(result, indent=2))
|
|
177
|
+
|
|
178
|
+
if __name__ == '__main__':
|
|
179
|
+
main()
|
|
180
|
+
<% } %>
|