@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
@@ -0,0 +1,64 @@
1
+ version: 0.2
2
+
3
+ env:
4
+ variables:
5
+ AWS_DEFAULT_REGION: <%= awsRegion %>
6
+ AWS_ACCOUNT_ID: ""
7
+ ECR_REPOSITORY_NAME: "ml-container-creator"
8
+ PROJECT_NAME: "<%= projectName %>"
9
+ IMAGE_TAG: "latest"
10
+
11
+ phases:
12
+ pre_build:
13
+ commands:
14
+ - echo Logging in to Amazon ECR...
15
+ - AWS_ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text)
16
+ - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com
17
+ - REPOSITORY_URI=$AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPOSITORY_NAME
18
+ - IMAGE_TAG=${CODEBUILD_RESOLVED_SOURCE_VERSION:-latest}
19
+ - PROJECT_TAG="$PROJECT_NAME-$(date +%Y%m%d-%H%M%S)"
20
+ - echo Repository URI is $REPOSITORY_URI
21
+ - echo Project tag is $PROJECT_TAG
22
+ - echo Image tag is $IMAGE_TAG
23
+ on-failure: ABORT
24
+ build:
25
+ commands:
26
+ - echo Build started on `date`
27
+ - echo Building the Docker image for project $PROJECT_NAME...
28
+ <% if (typeof modelLoadStrategy !== 'undefined' && modelLoadStrategy === 'build-time' && typeof modelSource !== 'undefined' && modelSource && modelSource !== 'huggingface') { %>
29
+ - |
30
+ # Export IAM role credentials for docker build
31
+ eval $(aws configure export-credentials --format env 2>/dev/null || true)
32
+ docker build \
33
+ --build-arg AWS_ACCESS_KEY_ID="${AWS_ACCESS_KEY_ID}" \
34
+ --build-arg AWS_SECRET_ACCESS_KEY="${AWS_SECRET_ACCESS_KEY}" \
35
+ --build-arg AWS_SESSION_TOKEN="${AWS_SESSION_TOKEN}" \
36
+ --build-arg AWS_DEFAULT_REGION="${AWS_DEFAULT_REGION}" \
37
+ -t $REPOSITORY_URI:$PROJECT_TAG .
38
+ <% } else { %>
39
+ - docker build -t $REPOSITORY_URI:$PROJECT_TAG .
40
+ <% } %>
41
+ - docker tag $REPOSITORY_URI:$PROJECT_TAG $REPOSITORY_URI:$PROJECT_NAME-latest
42
+ - docker tag $REPOSITORY_URI:$PROJECT_TAG $REPOSITORY_URI:latest
43
+ - echo Build completed on `date`
44
+ on-failure: ABORT
45
+ post_build:
46
+ commands:
47
+ - echo Post-build started on `date`
48
+ - echo Pushing the Docker images for project $PROJECT_NAME...
49
+ - docker push $REPOSITORY_URI:$PROJECT_TAG || (echo "Failed to push project tag $PROJECT_TAG" && exit 1)
50
+ - docker push $REPOSITORY_URI:$PROJECT_NAME-latest || (echo "Failed to push project latest tag" && exit 1)
51
+ - docker push $REPOSITORY_URI:latest || (echo "Failed to push latest tag" && exit 1)
52
+ - echo Successfully pushed images to ECR repository $ECR_REPOSITORY_NAME
53
+ - echo "Available tags:"
54
+ - echo " - $PROJECT_TAG (timestamped build)"
55
+ - echo " - $PROJECT_NAME-latest (project latest)"
56
+ - echo " - latest (global latest)"
57
+ - echo Writing image definitions file...
58
+ - printf '[{"name":"%s","imageUri":"%s"}]' $PROJECT_NAME $REPOSITORY_URI:$PROJECT_TAG > imagedefinitions.json
59
+ - echo Post-build completed on `date`
60
+
61
+ artifacts:
62
+ files:
63
+ - imagedefinitions.json
64
+ name: <%= projectName %>-artifacts
@@ -0,0 +1 @@
1
+ <%= chatTemplate %>
@@ -0,0 +1,35 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Gunicorn configuration file
6
+ """
7
+ import os
8
+ import multiprocessing
9
+ from serve import load_model_for_worker
10
+
11
+ # Bind address
12
+ bind = f"0.0.0.0:8080"
13
+
14
+ # Worker processes - cap at 4 for memory efficiency
15
+ workers = min(multiprocessing.cpu_count(), 4)
16
+
17
+ # Worker type
18
+ worker_class = 'sync'
19
+
20
+ # Timeouts
21
+ timeout = 120
22
+ keepalive = 5
23
+
24
+ # Worker lifecycle management
25
+ max_requests = 500
26
+ max_requests_jitter = 100
27
+
28
+ # Logging
29
+ accesslog = '-' # Log to stdout
30
+ errorlog = '-' # Log to stderr
31
+ loglevel = 'info'
32
+
33
+ # Worker initialization hook - load model in each worker
34
+ def post_worker_init(worker):
35
+ load_model_for_worker()
@@ -0,0 +1,10 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ WSGI entry point for Gunicorn
6
+ """
7
+ from serve import create_app
8
+
9
+ # Create the Flask application
10
+ application = create_app()
@@ -0,0 +1,387 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ <% if (framework === 'sklearn') { %>
4
+ """
5
+ SKLearn model handler for SageMaker inference
6
+ """
7
+ import os
8
+ import json
9
+ import pickle
10
+ import joblib
11
+ import numpy as np
12
+ from typing import Any, Dict
13
+ import logging
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class ModelHandler:
19
+ """Handle SKLearn model loading and inference"""
20
+
21
+ def __init__(self, model_path: str):
22
+ self.model_path = model_path
23
+ self.model = None
24
+ self._loaded = False
25
+
26
+ def load_model(self):
27
+ """Load the SKLearn model"""
28
+ try:
29
+ model_files = [f for f in os.listdir(self.model_path) if f.endswith('<%= modelFormat %>')]
30
+
31
+ if not model_files:
32
+ logger.warning("No SKLearn model files found in model directory")
33
+ logger.warning("Server will start but /invocations will fail until a model is provided")
34
+ logger.warning("Mount a model directory with: MODEL_DIR=/path/to/model ./do/run")
35
+ return
36
+
37
+ model_file = os.path.join(self.model_path, model_files[0])
38
+ logger.info(f"Loading model from {model_file}")
39
+
40
+ # Load with joblib first, fallback to pickle
41
+ try:
42
+ self.model = joblib.load(model_file)
43
+ except:
44
+ with open(model_file, 'rb') as f:
45
+ self.model = pickle.load(f)
46
+
47
+ self._loaded = True
48
+ logger.info("SKLearn model loaded successfully")
49
+
50
+ except Exception as e:
51
+ logger.error(f"Error loading model: {str(e)}")
52
+ raise
53
+
54
+ def is_loaded(self) -> bool:
55
+ """Check if model is loaded"""
56
+ return self._loaded and self.model is not None
57
+
58
+ def preprocess(self, raw_data: Any) -> np.ndarray:
59
+ """Preprocess input data for SKLearn model"""
60
+ try:
61
+ if isinstance(raw_data, dict):
62
+ data = raw_data.get('instances', raw_data.get('data', raw_data))
63
+ else:
64
+ data = raw_data
65
+
66
+ if isinstance(data, str):
67
+ data = json.loads(data)
68
+
69
+ return np.array(data)
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error in preprocessing: {str(e)}")
73
+ raise ValueError(f"Invalid input data format: {str(e)}")
74
+
75
+ def postprocess(self, predictions: np.ndarray) -> Dict[str, Any]:
76
+ """Postprocess SKLearn model predictions"""
77
+ try:
78
+ if hasattr(predictions, 'tolist'):
79
+ predictions = predictions.tolist()
80
+
81
+ return {'predictions': predictions}
82
+
83
+ except Exception as e:
84
+ logger.error(f"Error in postprocessing: {str(e)}")
85
+ raise
86
+
87
+ def predict(self, input_data: Any) -> Dict[str, Any]:
88
+ """Run inference on input data"""
89
+ if not self.is_loaded():
90
+ raise RuntimeError("Model is not loaded")
91
+
92
+ try:
93
+ processed_input = self.preprocess(input_data)
94
+ predictions = self.model.predict(processed_input)
95
+ return self.postprocess(predictions)
96
+
97
+ except Exception as e:
98
+ logger.error(f"Error during inference: {str(e)}")
99
+ raise
100
+ <% } else if (framework === 'xgboost') { %>
101
+ """
102
+ XGBoost model handler for SageMaker inference
103
+ """
104
+ import os
105
+ import json
106
+ import xgboost as xgb
107
+ import numpy as np
108
+ from typing import Any, Dict
109
+ import logging
110
+
111
+ logging.basicConfig(level=logging.INFO)
112
+ logger = logging.getLogger(__name__)
113
+
114
+ class ModelHandler:
115
+ """Handle XGBoost model loading and inference"""
116
+
117
+ def __init__(self, model_path: str):
118
+ self.model_path = model_path
119
+ self.model = None
120
+ self._loaded = False
121
+
122
+ def load_model(self):
123
+ """Load the XGBoost model"""
124
+ try:
125
+ model_files = [f for f in os.listdir(self.model_path) if f.endswith('<%= modelFormat %>')]
126
+
127
+ if not model_files:
128
+ logger.warning("No XGBoost model files found in model directory")
129
+ logger.warning("Server will start but /invocations will fail until a model is provided")
130
+ logger.warning("Mount a model directory with: MODEL_DIR=/path/to/model ./do/run")
131
+ return
132
+
133
+ model_file = os.path.join(self.model_path, model_files[0])
134
+ logger.info(f"Loading model from {model_file}")
135
+
136
+ self.model = xgb.Booster()
137
+ self.model.load_model(model_file)
138
+
139
+ self._loaded = True
140
+ logger.info("XGBoost model loaded successfully")
141
+
142
+ except Exception as e:
143
+ logger.error(f"Error loading model: {str(e)}")
144
+ raise
145
+
146
+ def is_loaded(self) -> bool:
147
+ """Check if model is loaded"""
148
+ return self._loaded and self.model is not None
149
+
150
+ def preprocess(self, raw_data: Any) -> xgb.DMatrix:
151
+ """Preprocess input data for XGBoost model"""
152
+ try:
153
+ if isinstance(raw_data, dict):
154
+ data = raw_data.get('instances', raw_data.get('data', raw_data))
155
+ else:
156
+ data = raw_data
157
+
158
+ if isinstance(data, str):
159
+ data = json.loads(data)
160
+
161
+ return xgb.DMatrix(np.array(data))
162
+
163
+ except Exception as e:
164
+ logger.error(f"Error in preprocessing: {str(e)}")
165
+ raise ValueError(f"Invalid input data format: {str(e)}")
166
+
167
+ def postprocess(self, predictions: np.ndarray) -> Dict[str, Any]:
168
+ """Postprocess XGBoost model predictions"""
169
+ try:
170
+ if hasattr(predictions, 'tolist'):
171
+ predictions = predictions.tolist()
172
+
173
+ return {'predictions': predictions}
174
+
175
+ except Exception as e:
176
+ logger.error(f"Error in postprocessing: {str(e)}")
177
+ raise
178
+
179
+ def predict(self, input_data: Any) -> Dict[str, Any]:
180
+ """Run inference on input data"""
181
+ if not self.is_loaded():
182
+ raise RuntimeError("Model is not loaded")
183
+
184
+ try:
185
+ processed_input = self.preprocess(input_data)
186
+ predictions = self.model.predict(processed_input)
187
+ return self.postprocess(predictions)
188
+
189
+ except Exception as e:
190
+ logger.error(f"Error during inference: {str(e)}")
191
+ raise
192
+ <% } else if (framework === 'tensorflow') { %>
193
+ """
194
+ TensorFlow model handler for SageMaker inference
195
+ """
196
+ import os
197
+ import json
198
+ import tensorflow as tf
199
+ import numpy as np
200
+ from typing import Any, Dict
201
+ import logging
202
+
203
+ logging.basicConfig(level=logging.INFO)
204
+ logger = logging.getLogger(__name__)
205
+
206
+ class ModelHandler:
207
+ """Handle TensorFlow model loading and inference"""
208
+
209
+ def __init__(self, model_path: str):
210
+ self.model_path = model_path
211
+ self.model = None
212
+ self._loaded = False
213
+
214
+ def load_model(self):
215
+ """Load the TensorFlow model"""
216
+ try:
217
+ logger.info(f"Loading model from {self.model_path}")
218
+
219
+ model_files = [f for f in os.listdir(self.model_path) if f.endswith(('.keras', '.h5'))]
220
+
221
+ if not model_files:
222
+ try:
223
+ self.model = tf.saved_model.load(self.model_path)
224
+ self.model = self.model.signatures['serving_default']
225
+ except:
226
+ logger.warning("No TensorFlow model files found in model directory")
227
+ logger.warning("Server will start but /invocations will fail until a model is provided")
228
+ logger.warning("Mount a model directory with: MODEL_DIR=/path/to/model ./do/run")
229
+ return
230
+ else:
231
+ model_file = os.path.join(self.model_path, model_files[0])
232
+ logger.info(f"Loading model from {model_file}")
233
+ self.model = tf.keras.models.load_model(model_file, compile=False)
234
+
235
+ self._loaded = True
236
+ logger.info("TensorFlow model loaded successfully")
237
+
238
+ except Exception as e:
239
+ logger.error(f"Error loading model: {str(e)}")
240
+ raise
241
+
242
+ def is_loaded(self) -> bool:
243
+ """Check if model is loaded"""
244
+ return self._loaded and self.model is not None
245
+
246
+ def preprocess(self, raw_data: Any) -> np.ndarray:
247
+ """Preprocess input data for TensorFlow model"""
248
+ try:
249
+ if isinstance(raw_data, dict):
250
+ data = raw_data.get('instances', raw_data.get('data', raw_data))
251
+ else:
252
+ data = raw_data
253
+
254
+ if isinstance(data, str):
255
+ data = json.loads(data)
256
+
257
+ return np.array(data, dtype=np.float32)
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error in preprocessing: {str(e)}")
261
+ raise ValueError(f"Invalid input data format: {str(e)}")
262
+
263
+ def postprocess(self, predictions: np.ndarray) -> Dict[str, Any]:
264
+ """Postprocess TensorFlow model predictions"""
265
+ try:
266
+ if hasattr(predictions, 'numpy'):
267
+ predictions = predictions.numpy()
268
+
269
+ if hasattr(predictions, 'tolist'):
270
+ predictions = predictions.tolist()
271
+
272
+ return {'predictions': predictions}
273
+
274
+ except Exception as e:
275
+ logger.error(f"Error in postprocessing: {str(e)}")
276
+ raise
277
+
278
+ def predict(self, input_data: Any) -> Dict[str, Any]:
279
+ """Run inference on input data"""
280
+ if not self.is_loaded():
281
+ raise RuntimeError("Model is not loaded")
282
+
283
+ try:
284
+ processed_input = self.preprocess(input_data)
285
+
286
+ # Handle different model types
287
+ if hasattr(self.model, 'predict'):
288
+ # Keras model
289
+ predictions = self.model.predict(processed_input)
290
+ else:
291
+ # SavedModel signature
292
+ input_tensor = tf.constant(processed_input)
293
+ result = self.model(input_tensor)
294
+ predictions = list(result.values())[0]
295
+
296
+ return self.postprocess(predictions)
297
+
298
+ except Exception as e:
299
+ logger.error(f"Error during inference: {str(e)}")
300
+ raise
301
+ <% } else if (framework === 'transformers' && modelServer === 'sglang') { %>
302
+ """
303
+ SGLang model handler for SageMaker inference
304
+ """
305
+ import os
306
+ import json
307
+ from typing import Any, Dict, List
308
+ import logging
309
+ from sglang import Runtime
310
+
311
+ logging.basicConfig(level=logging.INFO)
312
+ logger = logging.getLogger(__name__)
313
+
314
+ class ModelHandler:
315
+ """Handle SGLang model loading and inference"""
316
+
317
+ def __init__(self, model_path: str):
318
+ self.model_path = model_path
319
+ self.runtime = None
320
+ self._loaded = False
321
+
322
+ def load_model(self):
323
+ """Initialize SGLang runtime"""
324
+ try:
325
+ model_id = '<%= modelName %>'
326
+ logger.info(f"Loading SGLang model: {model_id}")
327
+
328
+ self.runtime = Runtime(
329
+ model_path=model_id,
330
+ tokenizer_path=model_id,
331
+ device="cuda",
332
+ mem_fraction_static=0.8
333
+ )
334
+
335
+ self._loaded = True
336
+ logger.info("SGLang model loaded successfully")
337
+
338
+ except Exception as e:
339
+ logger.error(f"Error loading model: {str(e)}")
340
+ raise
341
+
342
+ def is_loaded(self) -> bool:
343
+ """Check if model is loaded"""
344
+ return self._loaded and self.runtime is not None
345
+
346
+ def preprocess(self, raw_data: Any) -> List[str]:
347
+ """Preprocess input data for SGLang model"""
348
+ try:
349
+ if isinstance(raw_data, dict):
350
+ data = raw_data.get('instances', raw_data.get('inputs', raw_data))
351
+ else:
352
+ data = raw_data
353
+
354
+ if isinstance(data, str):
355
+ return [data]
356
+ elif isinstance(data, list):
357
+ return data
358
+ else:
359
+ raise ValueError("Input must be string or list of strings")
360
+
361
+ except Exception as e:
362
+ logger.error(f"Error in preprocessing: {str(e)}")
363
+ raise ValueError(f"Invalid input data format: {str(e)}")
364
+
365
+ def postprocess(self, outputs: List[str]) -> Dict[str, Any]:
366
+ """Postprocess SGLang model outputs"""
367
+ try:
368
+ return {'predictions': outputs}
369
+
370
+ except Exception as e:
371
+ logger.error(f"Error in postprocessing: {str(e)}")
372
+ raise
373
+
374
+ def predict(self, input_data: Any) -> Dict[str, Any]:
375
+ """Run inference on input data"""
376
+ if not self.is_loaded():
377
+ raise RuntimeError("Model is not loaded")
378
+
379
+ try:
380
+ prompts = self.preprocess(input_data)
381
+ outputs = self.runtime.generate(prompts)
382
+ return self.postprocess(outputs)
383
+
384
+ except Exception as e:
385
+ logger.error(f"Error during inference: {str(e)}")
386
+ raise
387
+ <% } %>