@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
version: 0.2
|
|
2
|
+
|
|
3
|
+
env:
|
|
4
|
+
variables:
|
|
5
|
+
AWS_DEFAULT_REGION: <%= awsRegion %>
|
|
6
|
+
AWS_ACCOUNT_ID: ""
|
|
7
|
+
ECR_REPOSITORY_NAME: "ml-container-creator"
|
|
8
|
+
PROJECT_NAME: "<%= projectName %>"
|
|
9
|
+
IMAGE_TAG: "latest"
|
|
10
|
+
|
|
11
|
+
phases:
|
|
12
|
+
pre_build:
|
|
13
|
+
commands:
|
|
14
|
+
- echo Logging in to Amazon ECR...
|
|
15
|
+
- AWS_ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text)
|
|
16
|
+
- aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com
|
|
17
|
+
- REPOSITORY_URI=$AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPOSITORY_NAME
|
|
18
|
+
- IMAGE_TAG=${CODEBUILD_RESOLVED_SOURCE_VERSION:-latest}
|
|
19
|
+
- PROJECT_TAG="$PROJECT_NAME-$(date +%Y%m%d-%H%M%S)"
|
|
20
|
+
- echo Repository URI is $REPOSITORY_URI
|
|
21
|
+
- echo Project tag is $PROJECT_TAG
|
|
22
|
+
- echo Image tag is $IMAGE_TAG
|
|
23
|
+
on-failure: ABORT
|
|
24
|
+
build:
|
|
25
|
+
commands:
|
|
26
|
+
- echo Build started on `date`
|
|
27
|
+
- echo Building the Docker image for project $PROJECT_NAME...
|
|
28
|
+
<% if (typeof modelLoadStrategy !== 'undefined' && modelLoadStrategy === 'build-time' && typeof modelSource !== 'undefined' && modelSource && modelSource !== 'huggingface') { %>
|
|
29
|
+
- |
|
|
30
|
+
# Export IAM role credentials for docker build
|
|
31
|
+
eval $(aws configure export-credentials --format env 2>/dev/null || true)
|
|
32
|
+
docker build \
|
|
33
|
+
--build-arg AWS_ACCESS_KEY_ID="${AWS_ACCESS_KEY_ID}" \
|
|
34
|
+
--build-arg AWS_SECRET_ACCESS_KEY="${AWS_SECRET_ACCESS_KEY}" \
|
|
35
|
+
--build-arg AWS_SESSION_TOKEN="${AWS_SESSION_TOKEN}" \
|
|
36
|
+
--build-arg AWS_DEFAULT_REGION="${AWS_DEFAULT_REGION}" \
|
|
37
|
+
-t $REPOSITORY_URI:$PROJECT_TAG .
|
|
38
|
+
<% } else { %>
|
|
39
|
+
- docker build -t $REPOSITORY_URI:$PROJECT_TAG .
|
|
40
|
+
<% } %>
|
|
41
|
+
- docker tag $REPOSITORY_URI:$PROJECT_TAG $REPOSITORY_URI:$PROJECT_NAME-latest
|
|
42
|
+
- docker tag $REPOSITORY_URI:$PROJECT_TAG $REPOSITORY_URI:latest
|
|
43
|
+
- echo Build completed on `date`
|
|
44
|
+
on-failure: ABORT
|
|
45
|
+
post_build:
|
|
46
|
+
commands:
|
|
47
|
+
- echo Post-build started on `date`
|
|
48
|
+
- echo Pushing the Docker images for project $PROJECT_NAME...
|
|
49
|
+
- docker push $REPOSITORY_URI:$PROJECT_TAG || (echo "Failed to push project tag $PROJECT_TAG" && exit 1)
|
|
50
|
+
- docker push $REPOSITORY_URI:$PROJECT_NAME-latest || (echo "Failed to push project latest tag" && exit 1)
|
|
51
|
+
- docker push $REPOSITORY_URI:latest || (echo "Failed to push latest tag" && exit 1)
|
|
52
|
+
- echo Successfully pushed images to ECR repository $ECR_REPOSITORY_NAME
|
|
53
|
+
- echo "Available tags:"
|
|
54
|
+
- echo " - $PROJECT_TAG (timestamped build)"
|
|
55
|
+
- echo " - $PROJECT_NAME-latest (project latest)"
|
|
56
|
+
- echo " - latest (global latest)"
|
|
57
|
+
- echo Writing image definitions file...
|
|
58
|
+
- printf '[{"name":"%s","imageUri":"%s"}]' $PROJECT_NAME $REPOSITORY_URI:$PROJECT_TAG > imagedefinitions.json
|
|
59
|
+
- echo Post-build completed on `date`
|
|
60
|
+
|
|
61
|
+
artifacts:
|
|
62
|
+
files:
|
|
63
|
+
- imagedefinitions.json
|
|
64
|
+
name: <%= projectName %>-artifacts
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
<%= chatTemplate %>
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Gunicorn configuration file
|
|
6
|
+
"""
|
|
7
|
+
import os
|
|
8
|
+
import multiprocessing
|
|
9
|
+
from serve import load_model_for_worker
|
|
10
|
+
|
|
11
|
+
# Bind address
|
|
12
|
+
bind = f"0.0.0.0:8080"
|
|
13
|
+
|
|
14
|
+
# Worker processes - cap at 4 for memory efficiency
|
|
15
|
+
workers = min(multiprocessing.cpu_count(), 4)
|
|
16
|
+
|
|
17
|
+
# Worker type
|
|
18
|
+
worker_class = 'sync'
|
|
19
|
+
|
|
20
|
+
# Timeouts
|
|
21
|
+
timeout = 120
|
|
22
|
+
keepalive = 5
|
|
23
|
+
|
|
24
|
+
# Worker lifecycle management
|
|
25
|
+
max_requests = 500
|
|
26
|
+
max_requests_jitter = 100
|
|
27
|
+
|
|
28
|
+
# Logging
|
|
29
|
+
accesslog = '-' # Log to stdout
|
|
30
|
+
errorlog = '-' # Log to stderr
|
|
31
|
+
loglevel = 'info'
|
|
32
|
+
|
|
33
|
+
# Worker initialization hook - load model in each worker
|
|
34
|
+
def post_worker_init(worker):
|
|
35
|
+
load_model_for_worker()
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
<% if (framework === 'sklearn') { %>
|
|
4
|
+
"""
|
|
5
|
+
SKLearn model handler for SageMaker inference
|
|
6
|
+
"""
|
|
7
|
+
import os
|
|
8
|
+
import json
|
|
9
|
+
import pickle
|
|
10
|
+
import joblib
|
|
11
|
+
import numpy as np
|
|
12
|
+
from typing import Any, Dict
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
class ModelHandler:
|
|
19
|
+
"""Handle SKLearn model loading and inference"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model_path: str):
|
|
22
|
+
self.model_path = model_path
|
|
23
|
+
self.model = None
|
|
24
|
+
self._loaded = False
|
|
25
|
+
|
|
26
|
+
def load_model(self):
|
|
27
|
+
"""Load the SKLearn model"""
|
|
28
|
+
try:
|
|
29
|
+
model_files = [f for f in os.listdir(self.model_path) if f.endswith('<%= modelFormat %>')]
|
|
30
|
+
|
|
31
|
+
if not model_files:
|
|
32
|
+
logger.warning("No SKLearn model files found in model directory")
|
|
33
|
+
logger.warning("Server will start but /invocations will fail until a model is provided")
|
|
34
|
+
logger.warning("Mount a model directory with: MODEL_DIR=/path/to/model ./do/run")
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
model_file = os.path.join(self.model_path, model_files[0])
|
|
38
|
+
logger.info(f"Loading model from {model_file}")
|
|
39
|
+
|
|
40
|
+
# Load with joblib first, fallback to pickle
|
|
41
|
+
try:
|
|
42
|
+
self.model = joblib.load(model_file)
|
|
43
|
+
except:
|
|
44
|
+
with open(model_file, 'rb') as f:
|
|
45
|
+
self.model = pickle.load(f)
|
|
46
|
+
|
|
47
|
+
self._loaded = True
|
|
48
|
+
logger.info("SKLearn model loaded successfully")
|
|
49
|
+
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f"Error loading model: {str(e)}")
|
|
52
|
+
raise
|
|
53
|
+
|
|
54
|
+
def is_loaded(self) -> bool:
|
|
55
|
+
"""Check if model is loaded"""
|
|
56
|
+
return self._loaded and self.model is not None
|
|
57
|
+
|
|
58
|
+
def preprocess(self, raw_data: Any) -> np.ndarray:
|
|
59
|
+
"""Preprocess input data for SKLearn model"""
|
|
60
|
+
try:
|
|
61
|
+
if isinstance(raw_data, dict):
|
|
62
|
+
data = raw_data.get('instances', raw_data.get('data', raw_data))
|
|
63
|
+
else:
|
|
64
|
+
data = raw_data
|
|
65
|
+
|
|
66
|
+
if isinstance(data, str):
|
|
67
|
+
data = json.loads(data)
|
|
68
|
+
|
|
69
|
+
return np.array(data)
|
|
70
|
+
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f"Error in preprocessing: {str(e)}")
|
|
73
|
+
raise ValueError(f"Invalid input data format: {str(e)}")
|
|
74
|
+
|
|
75
|
+
def postprocess(self, predictions: np.ndarray) -> Dict[str, Any]:
|
|
76
|
+
"""Postprocess SKLearn model predictions"""
|
|
77
|
+
try:
|
|
78
|
+
if hasattr(predictions, 'tolist'):
|
|
79
|
+
predictions = predictions.tolist()
|
|
80
|
+
|
|
81
|
+
return {'predictions': predictions}
|
|
82
|
+
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.error(f"Error in postprocessing: {str(e)}")
|
|
85
|
+
raise
|
|
86
|
+
|
|
87
|
+
def predict(self, input_data: Any) -> Dict[str, Any]:
|
|
88
|
+
"""Run inference on input data"""
|
|
89
|
+
if not self.is_loaded():
|
|
90
|
+
raise RuntimeError("Model is not loaded")
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
processed_input = self.preprocess(input_data)
|
|
94
|
+
predictions = self.model.predict(processed_input)
|
|
95
|
+
return self.postprocess(predictions)
|
|
96
|
+
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.error(f"Error during inference: {str(e)}")
|
|
99
|
+
raise
|
|
100
|
+
<% } else if (framework === 'xgboost') { %>
|
|
101
|
+
"""
|
|
102
|
+
XGBoost model handler for SageMaker inference
|
|
103
|
+
"""
|
|
104
|
+
import os
|
|
105
|
+
import json
|
|
106
|
+
import xgboost as xgb
|
|
107
|
+
import numpy as np
|
|
108
|
+
from typing import Any, Dict
|
|
109
|
+
import logging
|
|
110
|
+
|
|
111
|
+
logging.basicConfig(level=logging.INFO)
|
|
112
|
+
logger = logging.getLogger(__name__)
|
|
113
|
+
|
|
114
|
+
class ModelHandler:
|
|
115
|
+
"""Handle XGBoost model loading and inference"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, model_path: str):
|
|
118
|
+
self.model_path = model_path
|
|
119
|
+
self.model = None
|
|
120
|
+
self._loaded = False
|
|
121
|
+
|
|
122
|
+
def load_model(self):
|
|
123
|
+
"""Load the XGBoost model"""
|
|
124
|
+
try:
|
|
125
|
+
model_files = [f for f in os.listdir(self.model_path) if f.endswith('<%= modelFormat %>')]
|
|
126
|
+
|
|
127
|
+
if not model_files:
|
|
128
|
+
logger.warning("No XGBoost model files found in model directory")
|
|
129
|
+
logger.warning("Server will start but /invocations will fail until a model is provided")
|
|
130
|
+
logger.warning("Mount a model directory with: MODEL_DIR=/path/to/model ./do/run")
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
model_file = os.path.join(self.model_path, model_files[0])
|
|
134
|
+
logger.info(f"Loading model from {model_file}")
|
|
135
|
+
|
|
136
|
+
self.model = xgb.Booster()
|
|
137
|
+
self.model.load_model(model_file)
|
|
138
|
+
|
|
139
|
+
self._loaded = True
|
|
140
|
+
logger.info("XGBoost model loaded successfully")
|
|
141
|
+
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"Error loading model: {str(e)}")
|
|
144
|
+
raise
|
|
145
|
+
|
|
146
|
+
def is_loaded(self) -> bool:
|
|
147
|
+
"""Check if model is loaded"""
|
|
148
|
+
return self._loaded and self.model is not None
|
|
149
|
+
|
|
150
|
+
def preprocess(self, raw_data: Any) -> xgb.DMatrix:
|
|
151
|
+
"""Preprocess input data for XGBoost model"""
|
|
152
|
+
try:
|
|
153
|
+
if isinstance(raw_data, dict):
|
|
154
|
+
data = raw_data.get('instances', raw_data.get('data', raw_data))
|
|
155
|
+
else:
|
|
156
|
+
data = raw_data
|
|
157
|
+
|
|
158
|
+
if isinstance(data, str):
|
|
159
|
+
data = json.loads(data)
|
|
160
|
+
|
|
161
|
+
return xgb.DMatrix(np.array(data))
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error(f"Error in preprocessing: {str(e)}")
|
|
165
|
+
raise ValueError(f"Invalid input data format: {str(e)}")
|
|
166
|
+
|
|
167
|
+
def postprocess(self, predictions: np.ndarray) -> Dict[str, Any]:
|
|
168
|
+
"""Postprocess XGBoost model predictions"""
|
|
169
|
+
try:
|
|
170
|
+
if hasattr(predictions, 'tolist'):
|
|
171
|
+
predictions = predictions.tolist()
|
|
172
|
+
|
|
173
|
+
return {'predictions': predictions}
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Error in postprocessing: {str(e)}")
|
|
177
|
+
raise
|
|
178
|
+
|
|
179
|
+
def predict(self, input_data: Any) -> Dict[str, Any]:
|
|
180
|
+
"""Run inference on input data"""
|
|
181
|
+
if not self.is_loaded():
|
|
182
|
+
raise RuntimeError("Model is not loaded")
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
processed_input = self.preprocess(input_data)
|
|
186
|
+
predictions = self.model.predict(processed_input)
|
|
187
|
+
return self.postprocess(predictions)
|
|
188
|
+
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.error(f"Error during inference: {str(e)}")
|
|
191
|
+
raise
|
|
192
|
+
<% } else if (framework === 'tensorflow') { %>
|
|
193
|
+
"""
|
|
194
|
+
TensorFlow model handler for SageMaker inference
|
|
195
|
+
"""
|
|
196
|
+
import os
|
|
197
|
+
import json
|
|
198
|
+
import tensorflow as tf
|
|
199
|
+
import numpy as np
|
|
200
|
+
from typing import Any, Dict
|
|
201
|
+
import logging
|
|
202
|
+
|
|
203
|
+
logging.basicConfig(level=logging.INFO)
|
|
204
|
+
logger = logging.getLogger(__name__)
|
|
205
|
+
|
|
206
|
+
class ModelHandler:
|
|
207
|
+
"""Handle TensorFlow model loading and inference"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, model_path: str):
|
|
210
|
+
self.model_path = model_path
|
|
211
|
+
self.model = None
|
|
212
|
+
self._loaded = False
|
|
213
|
+
|
|
214
|
+
def load_model(self):
|
|
215
|
+
"""Load the TensorFlow model"""
|
|
216
|
+
try:
|
|
217
|
+
logger.info(f"Loading model from {self.model_path}")
|
|
218
|
+
|
|
219
|
+
model_files = [f for f in os.listdir(self.model_path) if f.endswith(('.keras', '.h5'))]
|
|
220
|
+
|
|
221
|
+
if not model_files:
|
|
222
|
+
try:
|
|
223
|
+
self.model = tf.saved_model.load(self.model_path)
|
|
224
|
+
self.model = self.model.signatures['serving_default']
|
|
225
|
+
except:
|
|
226
|
+
logger.warning("No TensorFlow model files found in model directory")
|
|
227
|
+
logger.warning("Server will start but /invocations will fail until a model is provided")
|
|
228
|
+
logger.warning("Mount a model directory with: MODEL_DIR=/path/to/model ./do/run")
|
|
229
|
+
return
|
|
230
|
+
else:
|
|
231
|
+
model_file = os.path.join(self.model_path, model_files[0])
|
|
232
|
+
logger.info(f"Loading model from {model_file}")
|
|
233
|
+
self.model = tf.keras.models.load_model(model_file, compile=False)
|
|
234
|
+
|
|
235
|
+
self._loaded = True
|
|
236
|
+
logger.info("TensorFlow model loaded successfully")
|
|
237
|
+
|
|
238
|
+
except Exception as e:
|
|
239
|
+
logger.error(f"Error loading model: {str(e)}")
|
|
240
|
+
raise
|
|
241
|
+
|
|
242
|
+
def is_loaded(self) -> bool:
|
|
243
|
+
"""Check if model is loaded"""
|
|
244
|
+
return self._loaded and self.model is not None
|
|
245
|
+
|
|
246
|
+
def preprocess(self, raw_data: Any) -> np.ndarray:
|
|
247
|
+
"""Preprocess input data for TensorFlow model"""
|
|
248
|
+
try:
|
|
249
|
+
if isinstance(raw_data, dict):
|
|
250
|
+
data = raw_data.get('instances', raw_data.get('data', raw_data))
|
|
251
|
+
else:
|
|
252
|
+
data = raw_data
|
|
253
|
+
|
|
254
|
+
if isinstance(data, str):
|
|
255
|
+
data = json.loads(data)
|
|
256
|
+
|
|
257
|
+
return np.array(data, dtype=np.float32)
|
|
258
|
+
|
|
259
|
+
except Exception as e:
|
|
260
|
+
logger.error(f"Error in preprocessing: {str(e)}")
|
|
261
|
+
raise ValueError(f"Invalid input data format: {str(e)}")
|
|
262
|
+
|
|
263
|
+
def postprocess(self, predictions: np.ndarray) -> Dict[str, Any]:
|
|
264
|
+
"""Postprocess TensorFlow model predictions"""
|
|
265
|
+
try:
|
|
266
|
+
if hasattr(predictions, 'numpy'):
|
|
267
|
+
predictions = predictions.numpy()
|
|
268
|
+
|
|
269
|
+
if hasattr(predictions, 'tolist'):
|
|
270
|
+
predictions = predictions.tolist()
|
|
271
|
+
|
|
272
|
+
return {'predictions': predictions}
|
|
273
|
+
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.error(f"Error in postprocessing: {str(e)}")
|
|
276
|
+
raise
|
|
277
|
+
|
|
278
|
+
def predict(self, input_data: Any) -> Dict[str, Any]:
|
|
279
|
+
"""Run inference on input data"""
|
|
280
|
+
if not self.is_loaded():
|
|
281
|
+
raise RuntimeError("Model is not loaded")
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
processed_input = self.preprocess(input_data)
|
|
285
|
+
|
|
286
|
+
# Handle different model types
|
|
287
|
+
if hasattr(self.model, 'predict'):
|
|
288
|
+
# Keras model
|
|
289
|
+
predictions = self.model.predict(processed_input)
|
|
290
|
+
else:
|
|
291
|
+
# SavedModel signature
|
|
292
|
+
input_tensor = tf.constant(processed_input)
|
|
293
|
+
result = self.model(input_tensor)
|
|
294
|
+
predictions = list(result.values())[0]
|
|
295
|
+
|
|
296
|
+
return self.postprocess(predictions)
|
|
297
|
+
|
|
298
|
+
except Exception as e:
|
|
299
|
+
logger.error(f"Error during inference: {str(e)}")
|
|
300
|
+
raise
|
|
301
|
+
<% } else if (framework === 'transformers' && modelServer === 'sglang') { %>
|
|
302
|
+
"""
|
|
303
|
+
SGLang model handler for SageMaker inference
|
|
304
|
+
"""
|
|
305
|
+
import os
|
|
306
|
+
import json
|
|
307
|
+
from typing import Any, Dict, List
|
|
308
|
+
import logging
|
|
309
|
+
from sglang import Runtime
|
|
310
|
+
|
|
311
|
+
logging.basicConfig(level=logging.INFO)
|
|
312
|
+
logger = logging.getLogger(__name__)
|
|
313
|
+
|
|
314
|
+
class ModelHandler:
|
|
315
|
+
"""Handle SGLang model loading and inference"""
|
|
316
|
+
|
|
317
|
+
def __init__(self, model_path: str):
|
|
318
|
+
self.model_path = model_path
|
|
319
|
+
self.runtime = None
|
|
320
|
+
self._loaded = False
|
|
321
|
+
|
|
322
|
+
def load_model(self):
|
|
323
|
+
"""Initialize SGLang runtime"""
|
|
324
|
+
try:
|
|
325
|
+
model_id = '<%= modelName %>'
|
|
326
|
+
logger.info(f"Loading SGLang model: {model_id}")
|
|
327
|
+
|
|
328
|
+
self.runtime = Runtime(
|
|
329
|
+
model_path=model_id,
|
|
330
|
+
tokenizer_path=model_id,
|
|
331
|
+
device="cuda",
|
|
332
|
+
mem_fraction_static=0.8
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
self._loaded = True
|
|
336
|
+
logger.info("SGLang model loaded successfully")
|
|
337
|
+
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logger.error(f"Error loading model: {str(e)}")
|
|
340
|
+
raise
|
|
341
|
+
|
|
342
|
+
def is_loaded(self) -> bool:
|
|
343
|
+
"""Check if model is loaded"""
|
|
344
|
+
return self._loaded and self.runtime is not None
|
|
345
|
+
|
|
346
|
+
def preprocess(self, raw_data: Any) -> List[str]:
|
|
347
|
+
"""Preprocess input data for SGLang model"""
|
|
348
|
+
try:
|
|
349
|
+
if isinstance(raw_data, dict):
|
|
350
|
+
data = raw_data.get('instances', raw_data.get('inputs', raw_data))
|
|
351
|
+
else:
|
|
352
|
+
data = raw_data
|
|
353
|
+
|
|
354
|
+
if isinstance(data, str):
|
|
355
|
+
return [data]
|
|
356
|
+
elif isinstance(data, list):
|
|
357
|
+
return data
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError("Input must be string or list of strings")
|
|
360
|
+
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.error(f"Error in preprocessing: {str(e)}")
|
|
363
|
+
raise ValueError(f"Invalid input data format: {str(e)}")
|
|
364
|
+
|
|
365
|
+
def postprocess(self, outputs: List[str]) -> Dict[str, Any]:
|
|
366
|
+
"""Postprocess SGLang model outputs"""
|
|
367
|
+
try:
|
|
368
|
+
return {'predictions': outputs}
|
|
369
|
+
|
|
370
|
+
except Exception as e:
|
|
371
|
+
logger.error(f"Error in postprocessing: {str(e)}")
|
|
372
|
+
raise
|
|
373
|
+
|
|
374
|
+
def predict(self, input_data: Any) -> Dict[str, Any]:
|
|
375
|
+
"""Run inference on input data"""
|
|
376
|
+
if not self.is_loaded():
|
|
377
|
+
raise RuntimeError("Model is not loaded")
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
prompts = self.preprocess(input_data)
|
|
381
|
+
outputs = self.runtime.generate(prompts)
|
|
382
|
+
return self.postprocess(outputs)
|
|
383
|
+
|
|
384
|
+
except Exception as e:
|
|
385
|
+
logger.error(f"Error during inference: {str(e)}")
|
|
386
|
+
raise
|
|
387
|
+
<% } %>
|