@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
@@ -0,0 +1,24 @@
1
+ apiVersion: v1
2
+ kind: ConfigMap
3
+ metadata:
4
+ name: <%= projectName %>-config
5
+ namespace: <%= hyperPodNamespace %>
6
+ labels:
7
+ app: <%= projectName %>
8
+ managed-by: ml-container-creator
9
+ data:
10
+ DEPLOYMENT_CONFIG: "<%= deploymentConfig %>"
11
+ FRAMEWORK: "<%= framework %>"
12
+ MODEL_SERVER: "<%= modelServer %>"
13
+ <% if (framework === 'transformers') { %>
14
+ MODEL_NAME: "<%= modelName %>"
15
+ <% if (hfToken) { %>
16
+ HF_TOKEN: "<%= hfToken %>"
17
+ <% } %>
18
+ <% if (ngcApiKey) { %>
19
+ NGC_API_KEY: "<%= ngcApiKey %>"
20
+ <% } %>
21
+ <% } %>
22
+ <% if (modelFormat) { %>
23
+ MODEL_FORMAT: "<%= modelFormat %>"
24
+ <% } %>
@@ -0,0 +1,71 @@
1
+ apiVersion: apps/v1
2
+ kind: Deployment
3
+ metadata:
4
+ name: <%= projectName %>
5
+ namespace: <%= hyperPodNamespace %>
6
+ labels:
7
+ app: <%= projectName %>
8
+ framework: <%= framework %>
9
+ managed-by: ml-container-creator
10
+ spec:
11
+ replicas: <%= hyperPodReplicas %>
12
+ selector:
13
+ matchLabels:
14
+ app: <%= projectName %>
15
+ template:
16
+ metadata:
17
+ labels:
18
+ app: <%= projectName %>
19
+ framework: <%= framework %>
20
+ spec:
21
+ containers:
22
+ - name: <%= projectName %>
23
+ image: ${AWS_ACCOUNT_ID}.dkr.ecr.<%= awsRegion %>.amazonaws.com/ml-container-creator:<%= projectName %>-latest
24
+ ports:
25
+ - containerPort: 8080
26
+ protocol: TCP
27
+ envFrom:
28
+ - configMapRef:
29
+ name: <%= projectName %>-config
30
+ resources:
31
+ requests:
32
+ cpu: "4"
33
+ memory: "16Gi"
34
+ nvidia.com/gpu: "1"
35
+ limits:
36
+ cpu: "8"
37
+ memory: "32Gi"
38
+ nvidia.com/gpu: "1"
39
+ <% if (fsxVolumeHandle) { %>
40
+ volumeMounts:
41
+ - name: fsx-storage
42
+ mountPath: /opt/ml/model
43
+ <% } %>
44
+ readinessProbe:
45
+ httpGet:
46
+ path: /ping
47
+ port: 8080
48
+ initialDelaySeconds: 120
49
+ periodSeconds: 10
50
+ livenessProbe:
51
+ httpGet:
52
+ path: /ping
53
+ port: 8080
54
+ initialDelaySeconds: 900
55
+ periodSeconds: 30
56
+ failureThreshold: 5
57
+ nodeSelector:
58
+ node.kubernetes.io/instance-type: <%= instanceType %>
59
+ tolerations:
60
+ - key: "nvidia.com/gpu"
61
+ operator: "Exists"
62
+ effect: "NoSchedule"
63
+ - key: "sagemaker.amazonaws.com/hyperpod"
64
+ operator: "Exists"
65
+ effect: "NoSchedule"
66
+ <% if (fsxVolumeHandle) { %>
67
+ volumes:
68
+ - name: fsx-storage
69
+ persistentVolumeClaim:
70
+ claimName: <%= projectName %>-fsx-pvc
71
+ <% } %>
@@ -0,0 +1,42 @@
1
+ <% if (fsxVolumeHandle) { %>
2
+ apiVersion: v1
3
+ kind: PersistentVolumeClaim
4
+ metadata:
5
+ name: <%= projectName %>-fsx-pvc
6
+ namespace: <%= hyperPodNamespace %>
7
+ labels:
8
+ app: <%= projectName %>
9
+ managed-by: ml-container-creator
10
+ spec:
11
+ accessModes:
12
+ - ReadWriteMany
13
+ storageClassName: fsx-lustre
14
+ resources:
15
+ requests:
16
+ storage: 1200Gi
17
+ ---
18
+ apiVersion: v1
19
+ kind: PersistentVolume
20
+ metadata:
21
+ name: <%= projectName %>-fsx-pv
22
+ labels:
23
+ app: <%= projectName %>
24
+ managed-by: ml-container-creator
25
+ spec:
26
+ capacity:
27
+ storage: 1200Gi
28
+ volumeMode: Filesystem
29
+ accessModes:
30
+ - ReadWriteMany
31
+ persistentVolumeReclaimPolicy: Retain
32
+ storageClassName: fsx-lustre
33
+ csi:
34
+ driver: fsx.csi.aws.com
35
+ volumeHandle: <%= fsxVolumeHandle %>
36
+ volumeAttributes:
37
+ dnsname: <%= fsxVolumeHandle %>.fsx.<%= awsRegion %>.amazonaws.com
38
+ mountname: fsx
39
+ <% } else { %>
40
+ # PVC not generated - no FSx volume handle provided
41
+ # To enable FSx for Lustre storage, provide fsxVolumeHandle during generation
42
+ <% } %>
@@ -0,0 +1,17 @@
1
+ apiVersion: v1
2
+ kind: Service
3
+ metadata:
4
+ name: <%= projectName %>
5
+ namespace: <%= hyperPodNamespace %>
6
+ labels:
7
+ app: <%= projectName %>
8
+ managed-by: ml-container-creator
9
+ spec:
10
+ type: ClusterIP
11
+ ports:
12
+ - port: 8080
13
+ targetPort: 8080
14
+ protocol: TCP
15
+ name: http
16
+ selector:
17
+ app: <%= projectName %>
@@ -0,0 +1,74 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ worker_processes auto;
6
+ daemon off;
7
+ pid /tmp/nginx.pid;
8
+ error_log /dev/stderr info;
9
+
10
+ events {
11
+ worker_connections 1024;
12
+ }
13
+
14
+ http {
15
+ access_log /dev/stdout;
16
+ client_max_body_size 100M;
17
+
18
+ upstream vllm_omni {
19
+ server 127.0.0.1:8081;
20
+ }
21
+
22
+ server {
23
+ listen 8080;
24
+ server_name _;
25
+
26
+ # SageMaker health check endpoint
27
+ location /ping {
28
+ proxy_pass http://vllm_omni/health;
29
+ proxy_set_header Host $host;
30
+ proxy_set_header X-Real-IP $remote_addr;
31
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
32
+ proxy_connect_timeout 2s;
33
+ proxy_read_timeout 2s;
34
+ }
35
+
36
+ # SageMaker inference endpoint - map to OpenAI image generation
37
+ location /invocations {
38
+ proxy_pass http://vllm_omni/v1/images/generations;
39
+ proxy_set_header Host $host;
40
+ proxy_set_header X-Real-IP $remote_addr;
41
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
42
+ proxy_connect_timeout 300s;
43
+ proxy_read_timeout 300s;
44
+ proxy_send_timeout 300s;
45
+ }
46
+
47
+ # Pass through all OpenAI API endpoints for direct access
48
+ location /v1/ {
49
+ proxy_pass http://vllm_omni;
50
+ proxy_set_header Host $host;
51
+ proxy_set_header X-Real-IP $remote_addr;
52
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
53
+ proxy_connect_timeout 300s;
54
+ proxy_read_timeout 300s;
55
+ proxy_send_timeout 300s;
56
+ }
57
+
58
+ # Health endpoint passthrough
59
+ location /health {
60
+ proxy_pass http://vllm_omni/health;
61
+ proxy_set_header Host $host;
62
+ proxy_set_header X-Real-IP $remote_addr;
63
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
64
+ }
65
+
66
+ # Metrics endpoint passthrough
67
+ location /metrics {
68
+ proxy_pass http://vllm_omni/metrics;
69
+ proxy_set_header Host $host;
70
+ proxy_set_header X-Real-IP $remote_addr;
71
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
72
+ }
73
+ }
74
+ }
@@ -0,0 +1,47 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ worker_processes auto;
6
+ daemon off;
7
+ pid /tmp/nginx.pid;
8
+ error_log /dev/stderr info;
9
+
10
+ events {
11
+ worker_connections 1024;
12
+ }
13
+
14
+ http {
15
+ access_log /dev/stdout;
16
+ client_max_body_size 100M;
17
+
18
+ upstream backend {
19
+ server 127.0.0.1:8000;
20
+ }
21
+
22
+ server {
23
+ listen 8080;
24
+ server_name _;
25
+
26
+ # SageMaker health check endpoint
27
+ location /ping {
28
+ proxy_pass http://backend/ping;
29
+ proxy_set_header Host $host;
30
+ proxy_set_header X-Real-IP $remote_addr;
31
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
32
+ proxy_connect_timeout 2s;
33
+ proxy_read_timeout 2s;
34
+ }
35
+
36
+ # SageMaker inference endpoint
37
+ location /invocations {
38
+ proxy_pass http://backend/invocations;
39
+ proxy_set_header Host $host;
40
+ proxy_set_header X-Real-IP $remote_addr;
41
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
42
+ proxy_connect_timeout 300s;
43
+ proxy_read_timeout 300s;
44
+ proxy_send_timeout 300s;
45
+ }
46
+ }
47
+ }
@@ -0,0 +1,74 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ worker_processes auto;
6
+ daemon off;
7
+ pid /tmp/nginx.pid;
8
+ error_log /dev/stderr info;
9
+
10
+ events {
11
+ worker_connections 1024;
12
+ }
13
+
14
+ http {
15
+ access_log /dev/stdout;
16
+ client_max_body_size 100M;
17
+
18
+ upstream trtllm {
19
+ server 127.0.0.1:8081;
20
+ }
21
+
22
+ server {
23
+ listen 8080;
24
+ server_name _;
25
+
26
+ # SageMaker health check endpoint
27
+ location /ping {
28
+ proxy_pass http://trtllm/health;
29
+ proxy_set_header Host $host;
30
+ proxy_set_header X-Real-IP $remote_addr;
31
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
32
+ proxy_connect_timeout 2s;
33
+ proxy_read_timeout 2s;
34
+ }
35
+
36
+ # SageMaker inference endpoint - map to OpenAI chat completions
37
+ location /invocations {
38
+ proxy_pass http://trtllm/v1/chat/completions;
39
+ proxy_set_header Host $host;
40
+ proxy_set_header X-Real-IP $remote_addr;
41
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
42
+ proxy_connect_timeout 300s;
43
+ proxy_read_timeout 300s;
44
+ proxy_send_timeout 300s;
45
+ }
46
+
47
+ # Pass through all OpenAI API endpoints for direct access
48
+ location /v1/ {
49
+ proxy_pass http://trtllm;
50
+ proxy_set_header Host $host;
51
+ proxy_set_header X-Real-IP $remote_addr;
52
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
53
+ proxy_connect_timeout 300s;
54
+ proxy_read_timeout 300s;
55
+ proxy_send_timeout 300s;
56
+ }
57
+
58
+ # Health endpoint passthrough
59
+ location /health {
60
+ proxy_pass http://trtllm/health;
61
+ proxy_set_header Host $host;
62
+ proxy_set_header X-Real-IP $remote_addr;
63
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
64
+ }
65
+
66
+ # Metrics endpoint passthrough
67
+ location /metrics {
68
+ proxy_pass http://trtllm/metrics;
69
+ proxy_set_header Host $host;
70
+ proxy_set_header X-Real-IP $remote_addr;
71
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
72
+ }
73
+ }
74
+ }
@@ -0,0 +1,61 @@
1
+ <% if (framework === 'sklearn') { %># Scikit-learn dependencies
2
+ scikit-learn==1.7.1
3
+ joblib==1.4.2
4
+ <% } else if (framework === 'xgboost') { %># XGBoost dependencies
5
+ xgboost==2.1.3
6
+ <% } else if (framework === 'tensorflow') { %># Tensorflow dependencies
7
+ setuptools>=65.0.0
8
+ tensorflow==2.20.0
9
+ <% } else if (framework === 'transformers' && modelServer === 'sglang') { %># SGLang dependencies
10
+ sglang[all]==0.5.10
11
+ torch>=2.1.0
12
+ transformers>=4.36.0
13
+ accelerate>=0.25.0
14
+ httpx==0.27.0 # For async HTTP client if needed
15
+ fastapi==0.116.1
16
+ uvicorn==0.35.0
17
+ <% } %>
18
+ numpy==1.26.4
19
+
20
+ # Web framework for inference endpoint
21
+ <% if (modelServer === 'flask') { %>
22
+ flask==3.0.3
23
+ gunicorn==23.0.0
24
+ <% } else if (modelServer === 'fastapi') { %>
25
+ fastapi==0.116.1
26
+ uvicorn==0.35.0
27
+ <% } %>
28
+
29
+ # SageMaker dependencies
30
+ sagemaker-inference==1.10.1
31
+ sagemaker-training==5.1.0
32
+
33
+ # Optional utilities
34
+ psutil==6.0.0 # For system monitoring
35
+ prometheus-client==0.20.0 # For metrics (if needed)
36
+
37
+ <% if (includeSampleModel && architecture !== 'triton') { %>
38
+ # Sample model dependencies
39
+ ucimlrepo
40
+ pandas
41
+ <% } %>
42
+
43
+ <% if (includeSampleModel && architecture === 'triton') { %>
44
+ # Triton sample model training dependencies
45
+ ucimlrepo
46
+ pandas
47
+ numpy
48
+ <% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
49
+ xgboost
50
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
51
+ lightgbm
52
+ <% } else if (backend === 'onnxruntime') { %>
53
+ scikit-learn
54
+ skl2onnx
55
+ onnxruntime
56
+ <% } else if (backend === 'tensorflow') { %>
57
+ tensorflow
58
+ <% } else if (backend === 'python') { %>
59
+ scikit-learn
60
+ <% } %>
61
+ <% } %>
@@ -0,0 +1,123 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ <% if (architecture === 'triton') { %>
6
+ <% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
7
+ import xgboost as xgb
8
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
9
+ import lightgbm as lgb
10
+ <% } else if (backend === 'onnxruntime') { %>
11
+ import onnxruntime as ort
12
+ <% } else if (backend === 'tensorflow') { %>
13
+ import tensorflow as tf
14
+ <% } else if (backend === 'python' && modelFormat === 'pkl') { %>
15
+ import pickle
16
+ <% } else if (backend === 'python' && modelFormat === 'joblib') { %>
17
+ import joblib
18
+ <% } %>
19
+ <% } else { %>
20
+ <% if (framework === 'sklearn') { %>
21
+ <% if (modelFormat === 'joblib') { %>import joblib
22
+ <% } else if (modelFormat === 'pkl') { %>import pickle
23
+ <% } %>
24
+ <% } else if (framework === 'xgboost') { %>import xgboost as xgb
25
+ <% } else if (framework === 'tensorflow') { %>import tensorflow as tf
26
+ <% } %>
27
+ <% } %>
28
+
29
+ # Load the trained model
30
+ <% if (architecture === 'triton') { %>
31
+ <% if (backend === 'fil' && modelFormat === 'xgboost_json') { %>
32
+ model = xgb.Booster()
33
+ model.load_model('./abalone_model.json')
34
+ <% } else if (backend === 'fil' && modelFormat === 'xgboost_ubj') { %>
35
+ model = xgb.Booster()
36
+ model.load_model('./abalone_model.ubj')
37
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
38
+ model = lgb.Booster(model_file='./abalone_model.txt')
39
+ <% } else if (backend === 'onnxruntime') { %>
40
+ session = ort.InferenceSession('./abalone_model.onnx')
41
+ <% } else if (backend === 'tensorflow') { %>
42
+ model = tf.saved_model.load('./abalone_model.savedmodel')
43
+ <% } else if (backend === 'python' && modelFormat === 'pkl') { %>
44
+ with open('./abalone_model.pkl', 'rb') as f:
45
+ model = pickle.load(f)
46
+ <% } else if (backend === 'python' && modelFormat === 'joblib') { %>
47
+ model = joblib.load('./abalone_model.joblib')
48
+ <% } %>
49
+ <% } else { %>
50
+ <% if (framework === 'sklearn') { %>
51
+ <% if (modelFormat === 'joblib') { %>model = joblib.load('./abalone_model.joblib')
52
+ <% } else if (modelFormat === 'pkl') { -%>
53
+ with open('./abalone_model.pkl', 'rb') as f:
54
+ model = pickle.load(f)
55
+ <% } %>
56
+ <% } else if (framework === 'xgboost') { %>model = xgb.Booster()
57
+ <% if (modelFormat === 'json') { %>model.load_model('./abalone_model.json')
58
+ <% } else if (modelFormat === 'model') { %>model.load_model('./abalone_model.model')
59
+ <% } else if (modelFormat === 'ubj') { %>model.load_model('./abalone_model.ubj')
60
+ <% } %>
61
+ <% } else if (framework === 'tensorflow') { %>
62
+ <% if (modelFormat === 'keras') { %>model = tf.keras.models.load_model('./abalone_model.keras')
63
+ <% } else if (modelFormat === 'h5') { %>model = tf.keras.models.load_model('./abalone_model.h5')
64
+ <% } else if (modelFormat === 'SavedModel') { %>model = tf.saved_model.load('./abalone_model')
65
+ model = model.signatures['serving_default']
66
+ <% } %>
67
+ <% } %>
68
+ <% } %>
69
+
70
+ # Create synthetic input array for abalone prediction
71
+ # Features: [Sex, Length, Diameter, Height, Whole_weight, Shucked_weight, Viscera_weight, Shell_weight]
72
+ # Sex: 0=M, 1=F, 2=I (Infant)
73
+ synthetic_input = np.array([[
74
+ 1, # Sex: Female
75
+ 0.455, # Length
76
+ 0.365, # Diameter
77
+ 0.095, # Height
78
+ 0.514, # Whole weight
79
+ 0.2245, # Shucked weight
80
+ 0.101, # Viscera weight
81
+ 0.15 # Shell weight
82
+ ]])
83
+
84
+ # Make prediction
85
+ <% if (architecture === 'triton') { %>
86
+ <% if (backend === 'fil' && (modelFormat === 'xgboost_json' || modelFormat === 'xgboost_ubj')) { %>
87
+ feature_names = ['Sex', 'Length', 'Diameter', 'Height', 'Whole_weight', 'Shucked_weight', 'Viscera_weight', 'Shell_weight']
88
+ dtest = xgb.DMatrix(synthetic_input, feature_names=feature_names)
89
+ prediction = model.predict(dtest)
90
+ print(f"Predicted rings: {prediction[0]:.1f}")
91
+ <% } else if (backend === 'fil' && modelFormat === 'lightgbm_txt') { %>
92
+ prediction = model.predict(synthetic_input)
93
+ print(f"Predicted rings: {prediction[0]:.1f}")
94
+ <% } else if (backend === 'onnxruntime') { %>
95
+ input_name = session.get_inputs()[0].name
96
+ prediction = session.run(None, {input_name: synthetic_input.astype(np.float32)})
97
+ print(f"Predicted rings: {prediction[0][0]:.1f}")
98
+ <% } else if (backend === 'tensorflow') { %>
99
+ infer = model.signatures['serving_default']
100
+ input_tensor = tf.constant(synthetic_input, dtype=tf.float32)
101
+ prediction = infer(input_tensor)
102
+ output_key = list(prediction.keys())[0]
103
+ result = prediction[output_key].numpy()
104
+ print(f"Predicted rings: {result[0][0]:.1f}")
105
+ <% } else if (backend === 'python') { %>
106
+ prediction = model.predict(synthetic_input)
107
+ print(f"Predicted rings: {prediction[0]:.1f}")
108
+ <% } %>
109
+ <% } else { %>
110
+ <% if (framework === 'sklearn') { %>prediction = model.predict(synthetic_input)
111
+ print(f"Predicted rings: {prediction[0]:.1f}")
112
+ <% } else if (framework === 'xgboost') { %>feature_names = ['Sex', 'Length', 'Diameter', 'Height', 'Whole_weight', 'Shucked_weight', 'Viscera_weight', 'Shell_weight']
113
+ dtest = xgb.DMatrix(synthetic_input, feature_names=feature_names)
114
+ prediction = model.predict(dtest)
115
+ print(f"Predicted rings: {prediction[0]:.1f}")
116
+ <% } else if (framework === 'tensorflow') { %>
117
+ <% if (modelFormat === 'SavedModel') { %>prediction = model(synthetic_input)
118
+ output_key = list(prediction.keys())[0]
119
+ result = prediction[output_key].numpy()
120
+ print(f"Prediction: {result[0][0]:.2f} rings")<% } else { %>prediction = model.predict(synthetic_input)
121
+ print(f"Prediction: {prediction[0][0]:.2f} rings")<% } %>
122
+ <% } %>
123
+ <% } %>