@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
@@ -0,0 +1,128 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # NVIDIA Triton Inference Server Dockerfile
5
+ # Backend: <%= backend %>
6
+
7
+ <% if (comments && comments.acceleratorInfo) { %>
8
+ <%= comments.acceleratorInfo %>
9
+ <% } %>
10
+
11
+ <% if (comments && comments.validationInfo) { %>
12
+ <%= comments.validationInfo %>
13
+ <% } %>
14
+
15
+ # Triton Inference Server base image from NVIDIA NGC
16
+ # Public image - no NGC authentication required
17
+ ARG BASE_IMAGE=<%= baseImage || 'nvcr.io/nvidia/tritonserver:24.08-py3' %>
18
+ FROM ${BASE_IMAGE}
19
+
20
+ # Set a docker label to name this project, postpended with the build time
21
+ LABEL project.name="<%= projectName %>-<%= buildTimestamp %>" \
22
+ project.base-name="<%= projectName %>" \
23
+ project.build-time="<%= buildTimestamp %>"
24
+
25
+ # Set a docker label to advertise multi-model support on the container
26
+ LABEL com.amazonaws.sagemaker.capabilities.multi-models=true
27
+ # Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present
28
+ LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
29
+
30
+ # Set working directory
31
+ WORKDIR /opt/ml
32
+
33
+ <% if (backend === 'vllm' || backend === 'tensorrtllm') { %>
34
+ # HuggingFace model configuration for LLM backends
35
+ ENV HF_MODEL_ID="<%= modelName %>"
36
+ <% if (hfToken) { %>
37
+ # Set HuggingFace authentication token for gated models
38
+ ENV HF_TOKEN="<%= hfToken %>"
39
+ <% } %>
40
+ <% } %>
41
+
42
+ <% if (backend === 'python') { %>
43
+ # Install Python backend dependencies
44
+ COPY triton/requirements.txt /tmp/triton_requirements.txt
45
+ RUN pip install --no-cache-dir -r /tmp/triton_requirements.txt && \
46
+ rm /tmp/triton_requirements.txt
47
+ <% } %>
48
+
49
+ # Set up model repository directory structure
50
+ # Triton expects models at: /opt/ml/model/model_repository/<model-name>/<version>/
51
+ RUN mkdir -p /opt/ml/model/model_repository/<%= modelName || 'model' %>/1
52
+
53
+ # Set permissions for model repository
54
+ RUN chmod -R 755 /opt/ml/model/model_repository
55
+
56
+ # Copy Triton model configuration
57
+ COPY triton/config.pbtxt /opt/ml/model/model_repository/<%= modelName || 'model' %>/config.pbtxt
58
+
59
+ <% if (backend === 'python') { %>
60
+ # Copy Python backend model implementation
61
+ COPY triton/model.py /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.py
62
+ <% } %>
63
+
64
+ <% if (includeSampleModel) { %>
65
+ # Copy sample model artifact
66
+ <% if (backend === 'fil') { %>
67
+ <% if (modelFormat === 'xgboost_json') { %>
68
+ COPY sample_model/abalone_model.json /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/xgboost.json
69
+ <% } else if (modelFormat === 'xgboost_ubj') { %>
70
+ COPY sample_model/abalone_model.ubj /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/xgboost.ubj
71
+ <% } else if (modelFormat === 'lightgbm_txt') { %>
72
+ COPY sample_model/abalone_model.txt /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.txt
73
+ <% } %>
74
+ <% } else if (backend === 'onnxruntime') { %>
75
+ COPY sample_model/abalone_model.onnx /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.onnx
76
+ <% } else if (backend === 'tensorflow') { %>
77
+ COPY sample_model/abalone_model.savedmodel /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.savedmodel/
78
+ <% } else if (backend === 'pytorch') { %>
79
+ COPY sample_model/abalone_model.pt /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.pt
80
+ <% } else if (backend === 'python') { %>
81
+ <% if (modelFormat === 'pkl') { %>
82
+ COPY sample_model/abalone_model.pkl /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.pkl
83
+ <% } else if (modelFormat === 'joblib') { %>
84
+ COPY sample_model/abalone_model.joblib /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/model.joblib
85
+ <% } %>
86
+ <% } %>
87
+ # Also copy training script for reference
88
+ COPY sample_model/ /opt/ml/sample_model/
89
+ <% } else { %>
90
+ # Model artifacts should be placed in:
91
+ # /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/
92
+ # COPY your_model_files /opt/ml/model/model_repository/<%= modelName || 'model' %>/1/
93
+ <% } %>
94
+
95
+ <% if (comments && comments.envVarExplanations && Object.keys(comments.envVarExplanations).length > 0) { %>
96
+ # Environment Variables Configuration
97
+ <% for (const [category, comment] of Object.entries(comments.envVarExplanations)) { %>
98
+ <%= comment %>
99
+ <% } %>
100
+ <% } %>
101
+
102
+ # Triton environment variables
103
+ ENV TRITON_MODEL_REPOSITORY=/opt/ml/model/model_repository
104
+
105
+ <% if (orderedEnvVars && orderedEnvVars.length > 0) { %>
106
+ # Additional environment variables from configuration
107
+ <% orderedEnvVars.forEach(({ key, value }) => { %>
108
+ ENV <%= key %>=<%= value %>
109
+ <% }); %>
110
+ <% } %>
111
+
112
+ # Expose port 8080 for SageMaker compatibility
113
+ # Triton default ports: 8000 (HTTP), 8001 (gRPC), 8002 (metrics)
114
+ # SageMaker requires port 8080
115
+ EXPOSE 8080
116
+
117
+ <% if (comments && comments.troubleshooting) { %>
118
+ <%= comments.troubleshooting %>
119
+ <% } %>
120
+
121
+ # Start Triton Inference Server
122
+ # --http-port=8080: SageMaker requires port 8080
123
+ # --model-repository: Path to model repository
124
+ # --strict-model-config=false: Allow Triton to auto-complete config for some backends
125
+ ENTRYPOINT ["tritonserver", \
126
+ "--http-port=8080", \
127
+ "--model-repository=/opt/ml/model/model_repository", \
128
+ "--strict-model-config=false"]
@@ -0,0 +1,163 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Triton Model Configuration
5
+ # Backend: <%= backend %>
6
+ # Model: <%= modelName || 'model' %>
7
+
8
+ name: "<%= modelName || 'model' %>"
9
+ backend: "<%= backend %>"
10
+
11
+ <% if (backend === 'vllm' || backend === 'tensorrtllm') { %>
12
+ # LLM backends (vllm, tensorrtllm) auto-configure most settings
13
+ # Minimal configuration is sufficient
14
+ max_batch_size: 0
15
+
16
+ <% if (backend === 'vllm') { %>
17
+ # vLLM backend parameters
18
+ parameters: {
19
+ key: "model"
20
+ value: { string_value: "<%= modelName %>" }
21
+ }
22
+ parameters: {
23
+ key: "gpu_memory_utilization"
24
+ value: { string_value: "0.9" }
25
+ }
26
+ <% } else if (backend === 'tensorrtllm') { %>
27
+ # TensorRT-LLM backend parameters
28
+ parameters: {
29
+ key: "model"
30
+ value: { string_value: "<%= modelName %>" }
31
+ }
32
+ <% } %>
33
+ <% } else { %>
34
+ # Maximum batch size (0 = batching disabled)
35
+ max_batch_size: 8
36
+
37
+ <% if (backend === 'fil') { %>
38
+ # FIL (Forest Inference Library) backend for tree-based models
39
+ # Supports XGBoost, LightGBM, and scikit-learn random forests
40
+ input [
41
+ {
42
+ name: "input__0"
43
+ data_type: TYPE_FP32
44
+ dims: [ -1 ] # Dynamic feature dimension
45
+ }
46
+ ]
47
+ output [
48
+ {
49
+ name: "output__0"
50
+ data_type: TYPE_FP32
51
+ dims: [ 1 ] # Single prediction value
52
+ }
53
+ ]
54
+
55
+ # FIL-specific parameters
56
+ parameters: {
57
+ key: "model_type"
58
+ <% if (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj') { %>
59
+ value: { string_value: "xgboost_json" }
60
+ <% } else if (modelFormat === 'lightgbm_txt') { %>
61
+ value: { string_value: "lightgbm" }
62
+ <% } %>
63
+ }
64
+ parameters: {
65
+ key: "output_class"
66
+ value: { string_value: "false" }
67
+ }
68
+ <% } else if (backend === 'onnxruntime') { %>
69
+ # ONNX Runtime backend
70
+ # Input/output shapes depend on your specific model
71
+ input [
72
+ {
73
+ name: "input"
74
+ data_type: TYPE_FP32
75
+ dims: [ -1 ] # Dynamic input dimension
76
+ }
77
+ ]
78
+ output [
79
+ {
80
+ name: "output"
81
+ data_type: TYPE_FP32
82
+ dims: [ -1 ] # Dynamic output dimension
83
+ }
84
+ ]
85
+ <% } else if (backend === 'tensorflow') { %>
86
+ # TensorFlow SavedModel backend
87
+ # Input/output shapes depend on your specific model
88
+ input [
89
+ {
90
+ name: "input"
91
+ data_type: TYPE_FP32
92
+ dims: [ -1 ] # Dynamic input dimension
93
+ }
94
+ ]
95
+ output [
96
+ {
97
+ name: "output"
98
+ data_type: TYPE_FP32
99
+ dims: [ -1 ] # Dynamic output dimension
100
+ }
101
+ ]
102
+
103
+ # TensorFlow-specific parameters
104
+ parameters: {
105
+ key: "TF_INTER_OP_PARALLELISM"
106
+ value: { string_value: "0" }
107
+ }
108
+ parameters: {
109
+ key: "TF_INTRA_OP_PARALLELISM"
110
+ value: { string_value: "0" }
111
+ }
112
+ <% } else if (backend === 'pytorch') { %>
113
+ # PyTorch TorchScript backend
114
+ # Input/output shapes depend on your specific model
115
+ input [
116
+ {
117
+ name: "INPUT__0"
118
+ data_type: TYPE_FP32
119
+ dims: [ -1 ] # Dynamic input dimension
120
+ }
121
+ ]
122
+ output [
123
+ {
124
+ name: "OUTPUT__0"
125
+ data_type: TYPE_FP32
126
+ dims: [ -1 ] # Dynamic output dimension
127
+ }
128
+ ]
129
+ <% } else if (backend === 'python') { %>
130
+ # Python backend
131
+ # Custom model implementation in model.py
132
+ input [
133
+ {
134
+ name: "INPUT"
135
+ data_type: TYPE_FP32
136
+ dims: [ -1 ] # Dynamic input dimension
137
+ }
138
+ ]
139
+ output [
140
+ {
141
+ name: "OUTPUT"
142
+ data_type: TYPE_FP32
143
+ dims: [ -1 ] # Dynamic output dimension
144
+ }
145
+ ]
146
+ <% } %>
147
+
148
+ # Instance group configuration
149
+ # Specifies how many instances of the model to run and on which device
150
+ instance_group [
151
+ {
152
+ count: 1
153
+ kind: KIND_AUTO # Automatically select CPU or GPU based on availability
154
+ }
155
+ ]
156
+
157
+ # Dynamic batching configuration
158
+ # Enables automatic request batching for improved throughput
159
+ dynamic_batching {
160
+ preferred_batch_size: [ 4, 8 ]
161
+ max_queue_delay_microseconds: 100
162
+ }
163
+ <% } %>
@@ -0,0 +1,130 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Triton Python Backend Model Implementation
6
+
7
+ This module implements the TritonPythonModel interface for serving
8
+ custom Python models via NVIDIA Triton Inference Server.
9
+
10
+ Backend: python
11
+ Model: <%= modelName || 'model' %>
12
+ """
13
+
14
+ import json
15
+ import os
16
+
17
+ import numpy as np
18
+ import triton_python_backend_utils as pb_utils
19
+
20
+ <% if (modelFormat === 'pkl') { %>
21
+ import pickle
22
+ <% } else if (modelFormat === 'joblib') { %>
23
+ import joblib
24
+ <% } %>
25
+
26
+
27
+ class TritonPythonModel:
28
+ """Triton Python backend model implementation.
29
+
30
+ This class implements the required interface for Triton's Python backend:
31
+ - initialize(): Called once when the model is loaded
32
+ - execute(): Called for each inference request batch
33
+ - finalize(): Called once when the model is unloaded
34
+ """
35
+
36
+ def initialize(self, args):
37
+ """Initialize the model.
38
+
39
+ Called once when the model is loaded by Triton. Use this method to
40
+ load model artifacts and set up any required resources.
41
+
42
+ Args:
43
+ args: Dictionary containing model configuration:
44
+ - model_config: JSON string of the model configuration
45
+ - model_instance_kind: Device type (CPU/GPU)
46
+ - model_instance_device_id: Device ID
47
+ - model_repository: Path to the model repository
48
+ - model_version: Model version being loaded
49
+ - model_name: Name of the model
50
+ """
51
+ self.model_config = json.loads(args['model_config'])
52
+ model_repository = args['model_repository']
53
+ model_version = args['model_version']
54
+
55
+ # Construct path to model artifact
56
+ model_dir = os.path.join(model_repository, model_version)
57
+
58
+ <% if (modelFormat === 'pkl') { %>
59
+ # Load pickle model
60
+ model_path = os.path.join(model_dir, 'model.pkl')
61
+ with open(model_path, 'rb') as f:
62
+ self.model = pickle.load(f)
63
+ <% } else if (modelFormat === 'joblib') { %>
64
+ # Load joblib model
65
+ model_path = os.path.join(model_dir, 'model.joblib')
66
+ self.model = joblib.load(model_path)
67
+ <% } else { %>
68
+ # Custom model loading
69
+ # TODO: Implement your model loading logic here
70
+ # model_path = os.path.join(model_dir, 'your_model_file')
71
+ self.model = None
72
+ <% } %>
73
+
74
+ # Get output configuration from model config
75
+ output_config = pb_utils.get_output_config_by_name(
76
+ self.model_config, 'OUTPUT'
77
+ )
78
+ self.output_dtype = pb_utils.triton_string_to_numpy(
79
+ output_config['data_type']
80
+ )
81
+
82
+ def execute(self, requests):
83
+ """Handle inference requests.
84
+
85
+ Called for each batch of inference requests. Processes input tensors
86
+ and returns output tensors.
87
+
88
+ Args:
89
+ requests: List of pb_utils.InferenceRequest objects
90
+
91
+ Returns:
92
+ List of pb_utils.InferenceResponse objects
93
+ """
94
+ responses = []
95
+
96
+ for request in requests:
97
+ # Get input tensor
98
+ input_tensor = pb_utils.get_input_tensor_by_name(request, 'INPUT')
99
+ input_data = input_tensor.as_numpy()
100
+
101
+ <% if (modelFormat === 'pkl' || modelFormat === 'joblib') { %>
102
+ # Run prediction
103
+ predictions = self.model.predict(input_data)
104
+ output_data = np.array(predictions, dtype=self.output_dtype)
105
+ <% } else { %>
106
+ # Custom inference logic
107
+ # TODO: Implement your inference logic here
108
+ output_data = np.zeros(
109
+ (input_data.shape[0], 1), dtype=self.output_dtype
110
+ )
111
+ <% } %>
112
+
113
+ # Create output tensor
114
+ output_tensor = pb_utils.Tensor('OUTPUT', output_data)
115
+
116
+ # Create inference response
117
+ inference_response = pb_utils.InferenceResponse(
118
+ output_tensors=[output_tensor]
119
+ )
120
+ responses.append(inference_response)
121
+
122
+ return responses
123
+
124
+ def finalize(self):
125
+ """Clean up resources.
126
+
127
+ Called once when the model is being unloaded. Use this method to
128
+ release any resources held by the model.
129
+ """
130
+ self.model = None
@@ -0,0 +1,11 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Python dependencies for Triton Python backend
5
+ numpy>=1.24.0
6
+ <% if (modelFormat === 'pkl') { %>
7
+ scikit-learn>=1.3.0
8
+ <% } else if (modelFormat === 'joblib') { %>
9
+ scikit-learn>=1.3.0
10
+ joblib>=1.3.0
11
+ <% } %>