@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
# NVIDIA Triton Inference Server Dockerfile
|
|
5
|
+
# Backend: <%= backend %>
|
|
6
|
+
|
|
7
|
+
<% if (comments && comments.acceleratorInfo) { %>
|
|
8
|
+
<%= comments.acceleratorInfo %>
|
|
9
|
+
<% } %>
|
|
10
|
+
|
|
11
|
+
<% if (comments && comments.validationInfo) { %>
|
|
12
|
+
<%= comments.validationInfo %>
|
|
13
|
+
<% } %>
|
|
14
|
+
|
|
15
|
+
# Triton Inference Server base image from NVIDIA NGC
|
|
16
|
+
# Public image - no NGC authentication required
|
|
17
|
+
ARG BASE_IMAGE=<%= baseImage || 'nvcr.io/nvidia/tritonserver:24.08-py3' %>
|
|
18
|
+
FROM ${BASE_IMAGE}
|
|
19
|
+
|
|
20
|
+
# Set a docker label to name this project, postpended with the build time
|
|
21
|
+
LABEL project.name="<%= projectName %>-<%= buildTimestamp %>" \
|
|
22
|
+
project.base-name="<%= projectName %>" \
|
|
23
|
+
project.build-time="<%= buildTimestamp %>"
|
|
24
|
+
|
|
25
|
+
# Set a docker label to advertise multi-model support on the container
|
|
26
|
+
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true
|
|
27
|
+
# Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present
|
|
28
|
+
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
|
|
29
|
+
|
|
30
|
+
# Set working directory
|
|
31
|
+
WORKDIR /opt/ml
|
|
32
|
+
|
|
33
|
+
<% if (backend === 'vllm' || backend === 'tensorrtllm') { %>
|
|
34
|
+
# HuggingFace model configuration for LLM backends
|
|
35
|
+
ENV HF_MODEL_ID="<%= modelName %>"
|
|
36
|
+
<% if (hfToken) { %>
|
|
37
|
+
# Set HuggingFace authentication token for gated models
|
|
38
|
+
ENV HF_TOKEN="<%= hfToken %>"
|
|
39
|
+
<% } %>
|
|
40
|
+
<% } %>
|
|
41
|
+
|
|
42
|
+
<% if (backend === 'python') { %>
|
|
43
|
+
# Install Python backend dependencies
|
|
44
|
+
COPY triton/requirements.txt /tmp/triton_requirements.txt
|
|
45
|
+
RUN pip install --no-cache-dir -r /tmp/triton_requirements.txt && \
|
|
46
|
+
rm /tmp/triton_requirements.txt
|
|
47
|
+
<% } %>
|
|
48
|
+
|
|
49
|
+
# Set up model repository directory structure
|
|
50
|
+
# Triton expects models at: /opt/ml/model/model_repository/<model-name>/<version>/
|
|
51
|
+
RUN mkdir -p /opt/ml/model/model_repository/<%= modelName || 'model' %>/1
|
|
52
|
+
|
|
53
|
+
# Set permissions for model repository
|
|
54
|
+
RUN chmod -R 755 /opt/ml/model/model_repository
|
|
55
|
+
|
|
56
|
+
# Copy Triton model configuration
|
|
57
|
+
COPY triton/config.pbtxt /opt/ml/model/model_repository/<%= modelName || 'model' %>/config.pbtxt
|
|
58
|
+
|
|
59
|
+
<% if (backend === 'python') { %>
|
|
60
|
+
# Copy Python backend model implementation
|
|
61
|
+
COPY triton/model.py /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.py
|
|
62
|
+
<% } %>
|
|
63
|
+
|
|
64
|
+
<% if (includeSampleModel) { %>
|
|
65
|
+
# Copy sample model artifact
|
|
66
|
+
<% if (backend === 'fil') { %>
|
|
67
|
+
<% if (modelFormat === 'xgboost_json') { %>
|
|
68
|
+
COPY sample_model/abalone_model.json /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/xgboost.json
|
|
69
|
+
<% } else if (modelFormat === 'xgboost_ubj') { %>
|
|
70
|
+
COPY sample_model/abalone_model.ubj /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/xgboost.ubj
|
|
71
|
+
<% } else if (modelFormat === 'lightgbm_txt') { %>
|
|
72
|
+
COPY sample_model/abalone_model.txt /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.txt
|
|
73
|
+
<% } %>
|
|
74
|
+
<% } else if (backend === 'onnxruntime') { %>
|
|
75
|
+
COPY sample_model/abalone_model.onnx /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.onnx
|
|
76
|
+
<% } else if (backend === 'tensorflow') { %>
|
|
77
|
+
COPY sample_model/abalone_model.savedmodel /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.savedmodel/
|
|
78
|
+
<% } else if (backend === 'pytorch') { %>
|
|
79
|
+
COPY sample_model/abalone_model.pt /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.pt
|
|
80
|
+
<% } else if (backend === 'python') { %>
|
|
81
|
+
<% if (modelFormat === 'pkl') { %>
|
|
82
|
+
COPY sample_model/abalone_model.pkl /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.pkl
|
|
83
|
+
<% } else if (modelFormat === 'joblib') { %>
|
|
84
|
+
COPY sample_model/abalone_model.joblib /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.joblib
|
|
85
|
+
<% } %>
|
|
86
|
+
<% } %>
|
|
87
|
+
# Also copy training script for reference
|
|
88
|
+
COPY sample_model/ /opt/ml/sample_model/
|
|
89
|
+
<% } else { %>
|
|
90
|
+
# Model artifacts should be placed in:
|
|
91
|
+
# /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/
|
|
92
|
+
# COPY your_model_files /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/
|
|
93
|
+
<% } %>
|
|
94
|
+
|
|
95
|
+
<% if (comments && comments.envVarExplanations && Object.keys(comments.envVarExplanations).length > 0) { %>
|
|
96
|
+
# Environment Variables Configuration
|
|
97
|
+
<% for (const [category, comment] of Object.entries(comments.envVarExplanations)) { %>
|
|
98
|
+
<%= comment %>
|
|
99
|
+
<% } %>
|
|
100
|
+
<% } %>
|
|
101
|
+
|
|
102
|
+
# Triton environment variables
|
|
103
|
+
ENV TRITON_MODEL_REPOSITORY=/opt/ml/model/model_repository
|
|
104
|
+
|
|
105
|
+
<% if (orderedEnvVars && orderedEnvVars.length > 0) { %>
|
|
106
|
+
# Additional environment variables from configuration
|
|
107
|
+
<% orderedEnvVars.forEach(({ key, value }) => { %>
|
|
108
|
+
ENV <%= key %>=<%= value %>
|
|
109
|
+
<% }); %>
|
|
110
|
+
<% } %>
|
|
111
|
+
|
|
112
|
+
# Expose port 8080 for SageMaker compatibility
|
|
113
|
+
# Triton default ports: 8000 (HTTP), 8001 (gRPC), 8002 (metrics)
|
|
114
|
+
# SageMaker requires port 8080
|
|
115
|
+
EXPOSE 8080
|
|
116
|
+
|
|
117
|
+
<% if (comments && comments.troubleshooting) { %>
|
|
118
|
+
<%= comments.troubleshooting %>
|
|
119
|
+
<% } %>
|
|
120
|
+
|
|
121
|
+
# Start Triton Inference Server
|
|
122
|
+
# --http-port=8080: SageMaker requires port 8080
|
|
123
|
+
# --model-repository: Path to model repository
|
|
124
|
+
# --strict-model-config=false: Allow Triton to auto-complete config for some backends
|
|
125
|
+
ENTRYPOINT ["tritonserver", \
|
|
126
|
+
"--http-port=8080", \
|
|
127
|
+
"--model-repository=/opt/ml/model/model_repository", \
|
|
128
|
+
"--strict-model-config=false"]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
# Triton Model Configuration
|
|
5
|
+
# Backend: <%= backend %>
|
|
6
|
+
# Model: <%= modelName || 'model' %>
|
|
7
|
+
|
|
8
|
+
name: "<%= modelName || 'model' %>"
|
|
9
|
+
backend: "<%= backend %>"
|
|
10
|
+
|
|
11
|
+
<% if (backend === 'vllm' || backend === 'tensorrtllm') { %>
|
|
12
|
+
# LLM backends (vllm, tensorrtllm) auto-configure most settings
|
|
13
|
+
# Minimal configuration is sufficient
|
|
14
|
+
max_batch_size: 0
|
|
15
|
+
|
|
16
|
+
<% if (backend === 'vllm') { %>
|
|
17
|
+
# vLLM backend parameters
|
|
18
|
+
parameters: {
|
|
19
|
+
key: "model"
|
|
20
|
+
value: { string_value: "<%= modelName %>" }
|
|
21
|
+
}
|
|
22
|
+
parameters: {
|
|
23
|
+
key: "gpu_memory_utilization"
|
|
24
|
+
value: { string_value: "0.9" }
|
|
25
|
+
}
|
|
26
|
+
<% } else if (backend === 'tensorrtllm') { %>
|
|
27
|
+
# TensorRT-LLM backend parameters
|
|
28
|
+
parameters: {
|
|
29
|
+
key: "model"
|
|
30
|
+
value: { string_value: "<%= modelName %>" }
|
|
31
|
+
}
|
|
32
|
+
<% } %>
|
|
33
|
+
<% } else { %>
|
|
34
|
+
# Maximum batch size (0 = batching disabled)
|
|
35
|
+
max_batch_size: 8
|
|
36
|
+
|
|
37
|
+
<% if (backend === 'fil') { %>
|
|
38
|
+
# FIL (Forest Inference Library) backend for tree-based models
|
|
39
|
+
# Supports XGBoost, LightGBM, and scikit-learn random forests
|
|
40
|
+
input [
|
|
41
|
+
{
|
|
42
|
+
name: "input__0"
|
|
43
|
+
data_type: TYPE_FP32
|
|
44
|
+
dims: [ -1 ] # Dynamic feature dimension
|
|
45
|
+
}
|
|
46
|
+
]
|
|
47
|
+
output [
|
|
48
|
+
{
|
|
49
|
+
name: "output__0"
|
|
50
|
+
data_type: TYPE_FP32
|
|
51
|
+
dims: [ 1 ] # Single prediction value
|
|
52
|
+
}
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
# FIL-specific parameters
|
|
56
|
+
parameters: {
|
|
57
|
+
key: "model_type"
|
|
58
|
+
<% if (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj') { %>
|
|
59
|
+
value: { string_value: "xgboost_json" }
|
|
60
|
+
<% } else if (modelFormat === 'lightgbm_txt') { %>
|
|
61
|
+
value: { string_value: "lightgbm" }
|
|
62
|
+
<% } %>
|
|
63
|
+
}
|
|
64
|
+
parameters: {
|
|
65
|
+
key: "output_class"
|
|
66
|
+
value: { string_value: "false" }
|
|
67
|
+
}
|
|
68
|
+
<% } else if (backend === 'onnxruntime') { %>
|
|
69
|
+
# ONNX Runtime backend
|
|
70
|
+
# Input/output shapes depend on your specific model
|
|
71
|
+
input [
|
|
72
|
+
{
|
|
73
|
+
name: "input"
|
|
74
|
+
data_type: TYPE_FP32
|
|
75
|
+
dims: [ -1 ] # Dynamic input dimension
|
|
76
|
+
}
|
|
77
|
+
]
|
|
78
|
+
output [
|
|
79
|
+
{
|
|
80
|
+
name: "output"
|
|
81
|
+
data_type: TYPE_FP32
|
|
82
|
+
dims: [ -1 ] # Dynamic output dimension
|
|
83
|
+
}
|
|
84
|
+
]
|
|
85
|
+
<% } else if (backend === 'tensorflow') { %>
|
|
86
|
+
# TensorFlow SavedModel backend
|
|
87
|
+
# Input/output shapes depend on your specific model
|
|
88
|
+
input [
|
|
89
|
+
{
|
|
90
|
+
name: "input"
|
|
91
|
+
data_type: TYPE_FP32
|
|
92
|
+
dims: [ -1 ] # Dynamic input dimension
|
|
93
|
+
}
|
|
94
|
+
]
|
|
95
|
+
output [
|
|
96
|
+
{
|
|
97
|
+
name: "output"
|
|
98
|
+
data_type: TYPE_FP32
|
|
99
|
+
dims: [ -1 ] # Dynamic output dimension
|
|
100
|
+
}
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# TensorFlow-specific parameters
|
|
104
|
+
parameters: {
|
|
105
|
+
key: "TF_INTER_OP_PARALLELISM"
|
|
106
|
+
value: { string_value: "0" }
|
|
107
|
+
}
|
|
108
|
+
parameters: {
|
|
109
|
+
key: "TF_INTRA_OP_PARALLELISM"
|
|
110
|
+
value: { string_value: "0" }
|
|
111
|
+
}
|
|
112
|
+
<% } else if (backend === 'pytorch') { %>
|
|
113
|
+
# PyTorch TorchScript backend
|
|
114
|
+
# Input/output shapes depend on your specific model
|
|
115
|
+
input [
|
|
116
|
+
{
|
|
117
|
+
name: "INPUT__0"
|
|
118
|
+
data_type: TYPE_FP32
|
|
119
|
+
dims: [ -1 ] # Dynamic input dimension
|
|
120
|
+
}
|
|
121
|
+
]
|
|
122
|
+
output [
|
|
123
|
+
{
|
|
124
|
+
name: "OUTPUT__0"
|
|
125
|
+
data_type: TYPE_FP32
|
|
126
|
+
dims: [ -1 ] # Dynamic output dimension
|
|
127
|
+
}
|
|
128
|
+
]
|
|
129
|
+
<% } else if (backend === 'python') { %>
|
|
130
|
+
# Python backend
|
|
131
|
+
# Custom model implementation in model.py
|
|
132
|
+
input [
|
|
133
|
+
{
|
|
134
|
+
name: "INPUT"
|
|
135
|
+
data_type: TYPE_FP32
|
|
136
|
+
dims: [ -1 ] # Dynamic input dimension
|
|
137
|
+
}
|
|
138
|
+
]
|
|
139
|
+
output [
|
|
140
|
+
{
|
|
141
|
+
name: "OUTPUT"
|
|
142
|
+
data_type: TYPE_FP32
|
|
143
|
+
dims: [ -1 ] # Dynamic output dimension
|
|
144
|
+
}
|
|
145
|
+
]
|
|
146
|
+
<% } %>
|
|
147
|
+
|
|
148
|
+
# Instance group configuration
|
|
149
|
+
# Specifies how many instances of the model to run and on which device
|
|
150
|
+
instance_group [
|
|
151
|
+
{
|
|
152
|
+
count: 1
|
|
153
|
+
kind: KIND_AUTO # Automatically select CPU or GPU based on availability
|
|
154
|
+
}
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
# Dynamic batching configuration
|
|
158
|
+
# Enables automatic request batching for improved throughput
|
|
159
|
+
dynamic_batching {
|
|
160
|
+
preferred_batch_size: [ 4, 8 ]
|
|
161
|
+
max_queue_delay_microseconds: 100
|
|
162
|
+
}
|
|
163
|
+
<% } %>
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Triton Python Backend Model Implementation
|
|
6
|
+
|
|
7
|
+
This module implements the TritonPythonModel interface for serving
|
|
8
|
+
custom Python models via NVIDIA Triton Inference Server.
|
|
9
|
+
|
|
10
|
+
Backend: python
|
|
11
|
+
Model: <%= modelName || 'model' %>
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import triton_python_backend_utils as pb_utils
|
|
19
|
+
|
|
20
|
+
<% if (modelFormat === 'pkl') { %>
|
|
21
|
+
import pickle
|
|
22
|
+
<% } else if (modelFormat === 'joblib') { %>
|
|
23
|
+
import joblib
|
|
24
|
+
<% } %>
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TritonPythonModel:
|
|
28
|
+
"""Triton Python backend model implementation.
|
|
29
|
+
|
|
30
|
+
This class implements the required interface for Triton's Python backend:
|
|
31
|
+
- initialize(): Called once when the model is loaded
|
|
32
|
+
- execute(): Called for each inference request batch
|
|
33
|
+
- finalize(): Called once when the model is unloaded
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def initialize(self, args):
|
|
37
|
+
"""Initialize the model.
|
|
38
|
+
|
|
39
|
+
Called once when the model is loaded by Triton. Use this method to
|
|
40
|
+
load model artifacts and set up any required resources.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
args: Dictionary containing model configuration:
|
|
44
|
+
- model_config: JSON string of the model configuration
|
|
45
|
+
- model_instance_kind: Device type (CPU/GPU)
|
|
46
|
+
- model_instance_device_id: Device ID
|
|
47
|
+
- model_repository: Path to the model repository
|
|
48
|
+
- model_version: Model version being loaded
|
|
49
|
+
- model_name: Name of the model
|
|
50
|
+
"""
|
|
51
|
+
self.model_config = json.loads(args['model_config'])
|
|
52
|
+
model_repository = args['model_repository']
|
|
53
|
+
model_version = args['model_version']
|
|
54
|
+
|
|
55
|
+
# Construct path to model artifact
|
|
56
|
+
model_dir = os.path.join(model_repository, model_version)
|
|
57
|
+
|
|
58
|
+
<% if (modelFormat === 'pkl') { %>
|
|
59
|
+
# Load pickle model
|
|
60
|
+
model_path = os.path.join(model_dir, 'model.pkl')
|
|
61
|
+
with open(model_path, 'rb') as f:
|
|
62
|
+
self.model = pickle.load(f)
|
|
63
|
+
<% } else if (modelFormat === 'joblib') { %>
|
|
64
|
+
# Load joblib model
|
|
65
|
+
model_path = os.path.join(model_dir, 'model.joblib')
|
|
66
|
+
self.model = joblib.load(model_path)
|
|
67
|
+
<% } else { %>
|
|
68
|
+
# Custom model loading
|
|
69
|
+
# TODO: Implement your model loading logic here
|
|
70
|
+
# model_path = os.path.join(model_dir, 'your_model_file')
|
|
71
|
+
self.model = None
|
|
72
|
+
<% } %>
|
|
73
|
+
|
|
74
|
+
# Get output configuration from model config
|
|
75
|
+
output_config = pb_utils.get_output_config_by_name(
|
|
76
|
+
self.model_config, 'OUTPUT'
|
|
77
|
+
)
|
|
78
|
+
self.output_dtype = pb_utils.triton_string_to_numpy(
|
|
79
|
+
output_config['data_type']
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def execute(self, requests):
|
|
83
|
+
"""Handle inference requests.
|
|
84
|
+
|
|
85
|
+
Called for each batch of inference requests. Processes input tensors
|
|
86
|
+
and returns output tensors.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
requests: List of pb_utils.InferenceRequest objects
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of pb_utils.InferenceResponse objects
|
|
93
|
+
"""
|
|
94
|
+
responses = []
|
|
95
|
+
|
|
96
|
+
for request in requests:
|
|
97
|
+
# Get input tensor
|
|
98
|
+
input_tensor = pb_utils.get_input_tensor_by_name(request, 'INPUT')
|
|
99
|
+
input_data = input_tensor.as_numpy()
|
|
100
|
+
|
|
101
|
+
<% if (modelFormat === 'pkl' || modelFormat === 'joblib') { %>
|
|
102
|
+
# Run prediction
|
|
103
|
+
predictions = self.model.predict(input_data)
|
|
104
|
+
output_data = np.array(predictions, dtype=self.output_dtype)
|
|
105
|
+
<% } else { %>
|
|
106
|
+
# Custom inference logic
|
|
107
|
+
# TODO: Implement your inference logic here
|
|
108
|
+
output_data = np.zeros(
|
|
109
|
+
(input_data.shape[0], 1), dtype=self.output_dtype
|
|
110
|
+
)
|
|
111
|
+
<% } %>
|
|
112
|
+
|
|
113
|
+
# Create output tensor
|
|
114
|
+
output_tensor = pb_utils.Tensor('OUTPUT', output_data)
|
|
115
|
+
|
|
116
|
+
# Create inference response
|
|
117
|
+
inference_response = pb_utils.InferenceResponse(
|
|
118
|
+
output_tensors=[output_tensor]
|
|
119
|
+
)
|
|
120
|
+
responses.append(inference_response)
|
|
121
|
+
|
|
122
|
+
return responses
|
|
123
|
+
|
|
124
|
+
def finalize(self):
|
|
125
|
+
"""Clean up resources.
|
|
126
|
+
|
|
127
|
+
Called once when the model is being unloaded. Use this method to
|
|
128
|
+
release any resources held by the model.
|
|
129
|
+
"""
|
|
130
|
+
self.model = None
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
# Python dependencies for Triton Python backend
|
|
5
|
+
numpy>=1.24.0
|
|
6
|
+
<% if (modelFormat === 'pkl') { %>
|
|
7
|
+
scikit-learn>=1.3.0
|
|
8
|
+
<% } else if (modelFormat === 'joblib') { %>
|
|
9
|
+
scikit-learn>=1.3.0
|
|
10
|
+
joblib>=1.3.0
|
|
11
|
+
<% } %>
|