@aws/ml-container-creator 0.9.0 ā 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +31 -137
- package/config/parameter-schema-v2.json +2065 -0
- package/package.json +6 -3
- package/servers/lib/catalogs/jumpstart-public.json +101 -16
- package/servers/lib/catalogs/models.json +182 -26
- package/src/app.js +6 -389
- package/src/lib/bootstrap-command-handler.js +75 -1078
- package/src/lib/bootstrap-profile-manager.js +634 -0
- package/src/lib/bootstrap-provisioners.js +421 -0
- package/src/lib/config-loader.js +405 -0
- package/src/lib/config-manager.js +59 -1668
- package/src/lib/config-mcp-client.js +118 -0
- package/src/lib/config-validator.js +634 -0
- package/src/lib/cuda-resolver.js +140 -0
- package/src/lib/e2e-catalog-validator.js +251 -3
- package/src/lib/e2e-ci-recorder.js +103 -0
- package/src/lib/generated/cli-options.js +471 -0
- package/src/lib/generated/parameter-matrix.js +671 -0
- package/src/lib/generated/validation-rules.js +202 -0
- package/src/lib/marketplace-flow.js +276 -0
- package/src/lib/mcp-query-runner.js +768 -0
- package/src/lib/parameter-schema-validator.js +62 -18
- package/src/lib/prompt-runner.js +41 -1504
- package/src/lib/prompts/feature-prompts.js +172 -0
- package/src/lib/prompts/index.js +48 -0
- package/src/lib/prompts/infrastructure-prompts.js +690 -0
- package/src/lib/prompts/model-prompts.js +552 -0
- package/src/lib/prompts/project-prompts.js +70 -0
- package/src/lib/prompts.js +2 -1446
- package/src/lib/registry-command-handler.js +135 -3
- package/src/lib/secrets-prompt-runner.js +251 -0
- package/src/lib/template-variable-resolver.js +398 -0
- package/templates/code/serve +5 -134
- package/templates/code/serve.d/lmi.ejs +19 -0
- package/templates/code/serve.d/sglang.ejs +47 -0
- package/templates/code/serve.d/tensorrt-llm.ejs +53 -0
- package/templates/code/serve.d/vllm.ejs +48 -0
- package/templates/do/clean +1 -1387
- package/templates/do/clean.d/async-inference.ejs +508 -0
- package/templates/do/clean.d/batch-transform.ejs +512 -0
- package/templates/do/clean.d/hyperpod-eks.ejs +481 -0
- package/templates/do/clean.d/managed-inference.ejs +1043 -0
- package/templates/do/deploy +1 -1766
- package/templates/do/deploy.d/async-inference.ejs +501 -0
- package/templates/do/deploy.d/batch-transform.ejs +529 -0
- package/templates/do/deploy.d/hyperpod-eks.ejs +339 -0
- package/templates/do/deploy.d/managed-inference.ejs +726 -0
- package/config/parameter-schema.json +0 -88
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
// AUTO-GENERATED by scripts/codegen-validator.js ā DO NOT EDIT
|
|
2
|
+
// Source: config/parameter-schema-v2.json
|
|
3
|
+
// Generated: 2026-05-23T12:02:19.548Z
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Validation rules derived from parameter-schema-v2.json.
|
|
7
|
+
* Each key maps to a function that returns null (valid) or an error string.
|
|
8
|
+
*/
|
|
9
|
+
export const validationRules = {
|
|
10
|
+
"projectName": (value) => {
|
|
11
|
+
if (value.length < 2) return `projectName must be at least 2 characters`;
|
|
12
|
+
if (value.length > 63) return `projectName must be at most 63 characters`;
|
|
13
|
+
if (!new RegExp("^[a-z0-9][a-z0-9-]*[a-z0-9]$").test(value)) return `projectName does not match required pattern`;
|
|
14
|
+
return null;
|
|
15
|
+
},
|
|
16
|
+
"deploymentConfig": (value) => {
|
|
17
|
+
if (!["http-flask","http-fastapi","transformers-vllm","transformers-sglang","transformers-tensorrt-llm","transformers-lmi","transformers-djl","triton-fil","triton-onnxruntime","triton-tensorflow","triton-pytorch","triton-vllm","triton-tensorrtllm","triton-python","diffusors-vllm-omni","marketplace"].includes(value)) return `Invalid value "${value}" for deploymentConfig. Valid: http-flask, http-fastapi, transformers-vllm, transformers-sglang, transformers-tensorrt-llm, transformers-lmi, transformers-djl, triton-fil, triton-onnxruntime, triton-tensorflow, triton-pytorch, triton-vllm, triton-tensorrtllm, triton-python, diffusors-vllm-omni, marketplace`;
|
|
18
|
+
return null;
|
|
19
|
+
},
|
|
20
|
+
"modelName": (value) => {
|
|
21
|
+
if (value.length < 1) return `modelName must be at least 1 characters`;
|
|
22
|
+
return null;
|
|
23
|
+
},
|
|
24
|
+
"deploymentTarget": (value) => {
|
|
25
|
+
if (!["managed-inference","realtime-inference","async-inference","batch-transform","hyperpod-eks"].includes(value)) return `Invalid value "${value}" for deploymentTarget. Valid: managed-inference, realtime-inference, async-inference, batch-transform, hyperpod-eks`;
|
|
26
|
+
return null;
|
|
27
|
+
},
|
|
28
|
+
"instanceType": (value) => {
|
|
29
|
+
if (!new RegExp("^ml\\.[a-z0-9]+\\.[a-z0-9]+$").test(value)) return `instanceType does not match required pattern`;
|
|
30
|
+
return null;
|
|
31
|
+
},
|
|
32
|
+
"icGpuCount": (value) => {
|
|
33
|
+
if (value < 0) return `icGpuCount must be >= 0, got ${value}`;
|
|
34
|
+
if (value > 8) return `icGpuCount must be <= 8, got ${value}`;
|
|
35
|
+
return null;
|
|
36
|
+
},
|
|
37
|
+
"icCopyCount": (value) => {
|
|
38
|
+
if (value < 0) return `icCopyCount must be >= 0, got ${value}`;
|
|
39
|
+
if (value > 100) return `icCopyCount must be <= 100, got ${value}`;
|
|
40
|
+
return null;
|
|
41
|
+
},
|
|
42
|
+
"icMemorySize": (value) => {
|
|
43
|
+
if (value < 128) return `icMemorySize must be >= 128, got ${value}`;
|
|
44
|
+
if (value > 3145728) return `icMemorySize must be <= 3145728, got ${value}`;
|
|
45
|
+
return null;
|
|
46
|
+
},
|
|
47
|
+
"maxLoras": (value) => {
|
|
48
|
+
if (value < 1) return `maxLoras must be >= 1, got ${value}`;
|
|
49
|
+
if (value > 256) return `maxLoras must be <= 256, got ${value}`;
|
|
50
|
+
return null;
|
|
51
|
+
},
|
|
52
|
+
"maxLoraRank": (value) => {
|
|
53
|
+
if (value < 8) return `maxLoraRank must be >= 8, got ${value}`;
|
|
54
|
+
if (value > 512) return `maxLoraRank must be <= 512, got ${value}`;
|
|
55
|
+
return null;
|
|
56
|
+
},
|
|
57
|
+
"benchmarkConcurrency": (value) => {
|
|
58
|
+
if (value < 1) return `benchmarkConcurrency must be >= 1, got ${value}`;
|
|
59
|
+
if (value > 1000) return `benchmarkConcurrency must be <= 1000, got ${value}`;
|
|
60
|
+
return null;
|
|
61
|
+
},
|
|
62
|
+
"benchmarkInputTokens": (value) => {
|
|
63
|
+
if (value < 1) return `benchmarkInputTokens must be >= 1, got ${value}`;
|
|
64
|
+
if (value > 128000) return `benchmarkInputTokens must be <= 128000, got ${value}`;
|
|
65
|
+
return null;
|
|
66
|
+
},
|
|
67
|
+
"benchmarkOutputTokens": (value) => {
|
|
68
|
+
if (value < 1) return `benchmarkOutputTokens must be >= 1, got ${value}`;
|
|
69
|
+
if (value > 128000) return `benchmarkOutputTokens must be <= 128000, got ${value}`;
|
|
70
|
+
return null;
|
|
71
|
+
},
|
|
72
|
+
"benchmarkRequestCount": (value) => {
|
|
73
|
+
if (value < 1) return `benchmarkRequestCount must be >= 1, got ${value}`;
|
|
74
|
+
return null;
|
|
75
|
+
},
|
|
76
|
+
"benchmarkS3OutputPath": (value) => {
|
|
77
|
+
if (!new RegExp("^s3://").test(value)) return `benchmarkS3OutputPath does not match required pattern`;
|
|
78
|
+
return null;
|
|
79
|
+
},
|
|
80
|
+
"framework": (value) => {
|
|
81
|
+
if (!["sklearn","xgboost","tensorflow","transformers"].includes(value)) return `Invalid value "${value}" for framework. Valid: sklearn, xgboost, tensorflow, transformers`;
|
|
82
|
+
return null;
|
|
83
|
+
},
|
|
84
|
+
"modelFormat": (value) => {
|
|
85
|
+
if (!["pkl","joblib","json","model","ubj","keras","h5","SavedModel"].includes(value)) return `Invalid value "${value}" for modelFormat. Valid: pkl, joblib, json, model, ubj, keras, h5, SavedModel`;
|
|
86
|
+
return null;
|
|
87
|
+
},
|
|
88
|
+
"modelServer": (value) => {
|
|
89
|
+
if (!["flask","fastapi","vllm","sglang"].includes(value)) return `Invalid value "${value}" for modelServer. Valid: flask, fastapi, vllm, sglang`;
|
|
90
|
+
return null;
|
|
91
|
+
},
|
|
92
|
+
"region": (value) => {
|
|
93
|
+
if (!new RegExp("^[a-z]{2}-[a-z]+-\\d+$").test(value)) return `region does not match required pattern`;
|
|
94
|
+
return null;
|
|
95
|
+
},
|
|
96
|
+
"roleArn": (value) => {
|
|
97
|
+
if (!new RegExp("^arn:aws:iam::").test(value)) return `roleArn does not match required pattern`;
|
|
98
|
+
return null;
|
|
99
|
+
},
|
|
100
|
+
"buildTarget": (value) => {
|
|
101
|
+
if (!["codebuild"].includes(value)) return `Invalid value "${value}" for buildTarget. Valid: codebuild`;
|
|
102
|
+
return null;
|
|
103
|
+
},
|
|
104
|
+
"codebuildComputeType": (value) => {
|
|
105
|
+
if (!["SMALL","MEDIUM","LARGE","BUILD_GENERAL1_SMALL","BUILD_GENERAL1_MEDIUM","BUILD_GENERAL1_LARGE","BUILD_GENERAL1_2XLARGE"].includes(value)) return `Invalid value "${value}" for codebuildComputeType. Valid: SMALL, MEDIUM, LARGE, BUILD_GENERAL1_SMALL, BUILD_GENERAL1_MEDIUM, BUILD_GENERAL1_LARGE, BUILD_GENERAL1_2XLARGE`;
|
|
106
|
+
return null;
|
|
107
|
+
},
|
|
108
|
+
"hfTokenArn": (value) => {
|
|
109
|
+
if (!new RegExp("^arn:aws:secretsmanager:").test(value)) return `hfTokenArn does not match required pattern`;
|
|
110
|
+
return null;
|
|
111
|
+
},
|
|
112
|
+
"ngcTokenArn": (value) => {
|
|
113
|
+
if (!new RegExp("^arn:aws:secretsmanager:").test(value)) return `ngcTokenArn does not match required pattern`;
|
|
114
|
+
return null;
|
|
115
|
+
},
|
|
116
|
+
"endpointInitialInstanceCount": (value) => {
|
|
117
|
+
if (value < 1) return `endpointInitialInstanceCount must be >= 1, got ${value}`;
|
|
118
|
+
if (value > 100) return `endpointInitialInstanceCount must be <= 100, got ${value}`;
|
|
119
|
+
return null;
|
|
120
|
+
},
|
|
121
|
+
"endpointDataCapturePercent": (value) => {
|
|
122
|
+
if (value < 0) return `endpointDataCapturePercent must be >= 0, got ${value}`;
|
|
123
|
+
if (value > 100) return `endpointDataCapturePercent must be <= 100, got ${value}`;
|
|
124
|
+
return null;
|
|
125
|
+
},
|
|
126
|
+
"endpointVariantName": (value) => {
|
|
127
|
+
if (!new RegExp("^[a-zA-Z0-9]([\\w-]{0,62}[a-zA-Z0-9])?$").test(value)) return `endpointVariantName does not match required pattern`;
|
|
128
|
+
return null;
|
|
129
|
+
},
|
|
130
|
+
"endpointVolumeSize": (value) => {
|
|
131
|
+
if (value < 1) return `endpointVolumeSize must be >= 1, got ${value}`;
|
|
132
|
+
if (value > 16384) return `endpointVolumeSize must be <= 16384, got ${value}`;
|
|
133
|
+
return null;
|
|
134
|
+
},
|
|
135
|
+
"icCpuCount": (value) => {
|
|
136
|
+
if (value < 0.25) return `icCpuCount must be >= 0.25, got ${value}`;
|
|
137
|
+
if (value > 768) return `icCpuCount must be <= 768, got ${value}`;
|
|
138
|
+
return null;
|
|
139
|
+
},
|
|
140
|
+
"icModelWeight": (value) => {
|
|
141
|
+
if (value < 0) return `icModelWeight must be >= 0, got ${value}`;
|
|
142
|
+
if (value > 1) return `icModelWeight must be <= 1, got ${value}`;
|
|
143
|
+
return null;
|
|
144
|
+
},
|
|
145
|
+
"asyncS3OutputPath": (value) => {
|
|
146
|
+
if (!new RegExp("^s3://").test(value)) return `asyncS3OutputPath does not match required pattern`;
|
|
147
|
+
return null;
|
|
148
|
+
},
|
|
149
|
+
"asyncSnsSuccessTopic": (value) => {
|
|
150
|
+
if (!new RegExp("^arn:aws:sns:").test(value)) return `asyncSnsSuccessTopic does not match required pattern`;
|
|
151
|
+
return null;
|
|
152
|
+
},
|
|
153
|
+
"asyncSnsErrorTopic": (value) => {
|
|
154
|
+
if (!new RegExp("^arn:aws:sns:").test(value)) return `asyncSnsErrorTopic does not match required pattern`;
|
|
155
|
+
return null;
|
|
156
|
+
},
|
|
157
|
+
"asyncMaxConcurrent": (value) => {
|
|
158
|
+
if (value < 1) return `asyncMaxConcurrent must be >= 1, got ${value}`;
|
|
159
|
+
if (value > 100) return `asyncMaxConcurrent must be <= 100, got ${value}`;
|
|
160
|
+
return null;
|
|
161
|
+
},
|
|
162
|
+
"batchInputPath": (value) => {
|
|
163
|
+
if (!new RegExp("^s3://").test(value)) return `batchInputPath does not match required pattern`;
|
|
164
|
+
return null;
|
|
165
|
+
},
|
|
166
|
+
"batchOutputPath": (value) => {
|
|
167
|
+
if (!new RegExp("^s3://").test(value)) return `batchOutputPath does not match required pattern`;
|
|
168
|
+
return null;
|
|
169
|
+
},
|
|
170
|
+
"batchInstanceCount": (value) => {
|
|
171
|
+
if (value < 1) return `batchInstanceCount must be >= 1, got ${value}`;
|
|
172
|
+
if (value > 100) return `batchInstanceCount must be <= 100, got ${value}`;
|
|
173
|
+
return null;
|
|
174
|
+
},
|
|
175
|
+
"batchSplitType": (value) => {
|
|
176
|
+
if (!["Line","RecordIO","None"].includes(value)) return `Invalid value "${value}" for batchSplitType. Valid: Line, RecordIO, None`;
|
|
177
|
+
return null;
|
|
178
|
+
},
|
|
179
|
+
"batchStrategy": (value) => {
|
|
180
|
+
if (!["MultiRecord","SingleRecord"].includes(value)) return `Invalid value "${value}" for batchStrategy. Valid: MultiRecord, SingleRecord`;
|
|
181
|
+
return null;
|
|
182
|
+
},
|
|
183
|
+
"batchJoinSource": (value) => {
|
|
184
|
+
if (!["Input","None"].includes(value)) return `Invalid value "${value}" for batchJoinSource. Valid: Input, None`;
|
|
185
|
+
return null;
|
|
186
|
+
},
|
|
187
|
+
"batchMaxConcurrent": (value) => {
|
|
188
|
+
if (value < 1) return `batchMaxConcurrent must be >= 1, got ${value}`;
|
|
189
|
+
return null;
|
|
190
|
+
},
|
|
191
|
+
"batchMaxPayload": (value) => {
|
|
192
|
+
if (value < 0) return `batchMaxPayload must be >= 0, got ${value}`;
|
|
193
|
+
if (value > 100) return `batchMaxPayload must be <= 100, got ${value}`;
|
|
194
|
+
return null;
|
|
195
|
+
},
|
|
196
|
+
"hyperpodReplicas": (value) => {
|
|
197
|
+
if (value < 1) return `hyperpodReplicas must be >= 1, got ${value}`;
|
|
198
|
+
return null;
|
|
199
|
+
},
|
|
200
|
+
};
|
|
201
|
+
|
|
202
|
+
// 43 parameters have validation rules
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Marketplace Flow - Handles the marketplace-specific prompt flow.
|
|
6
|
+
* Uses delegation pattern: receives parent PromptRunner reference to access shared state.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
import {
|
|
10
|
+
infraAsyncPrompts,
|
|
11
|
+
infraBatchTransformPrompts,
|
|
12
|
+
projectPrompts,
|
|
13
|
+
destinationPrompts
|
|
14
|
+
} from './prompts/index.js';
|
|
15
|
+
|
|
16
|
+
export default class MarketplaceFlow {
|
|
17
|
+
constructor(runner) {
|
|
18
|
+
this.runner = runner;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
* Marketplace-specific prompt flow.
|
|
23
|
+
* Skips all container-related prompts and prompts only for:
|
|
24
|
+
* model package ARN, instance type, deployment target, region.
|
|
25
|
+
*
|
|
26
|
+
* Requirements: 2.3, 2.4, 2.5
|
|
27
|
+
*/
|
|
28
|
+
async _runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp) {
|
|
29
|
+
console.log('\nšŖ Marketplace Model Package Configuration');
|
|
30
|
+
|
|
31
|
+
// Query marketplace-picker MCP server for subscription discovery
|
|
32
|
+
let mcpSubscriptions = [];
|
|
33
|
+
const cm = this.runner.configManager;
|
|
34
|
+
if (cm && cm.getMcpServerNames && cm.getMcpServerNames().includes('marketplace-picker')) {
|
|
35
|
+
try {
|
|
36
|
+
console.log(' š Querying marketplace-picker for subscriptions...');
|
|
37
|
+
const result = await cm.queryMcpServer('marketplace-picker', {
|
|
38
|
+
region: explicitConfig.awsRegion || existingConfig.awsRegion || process.env.AWS_REGION || 'us-east-1'
|
|
39
|
+
});
|
|
40
|
+
if (result && result.metadata?.subscriptions?.length > 0) {
|
|
41
|
+
mcpSubscriptions = result.metadata.subscriptions;
|
|
42
|
+
console.log(` ā
Found ${mcpSubscriptions.length} Marketplace subscription(s)`);
|
|
43
|
+
} else {
|
|
44
|
+
console.log(' ā¹ļø No Marketplace subscriptions found ā enter ARN manually');
|
|
45
|
+
}
|
|
46
|
+
} catch (err) {
|
|
47
|
+
console.log(` ā ļø marketplace-picker unavailable: ${err.message}`);
|
|
48
|
+
console.log(' Falling back to manual ARN entry');
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
// Marketplace-specific prompts: model package ARN
|
|
53
|
+
const marketplacePrompts = [
|
|
54
|
+
{
|
|
55
|
+
type: mcpSubscriptions.length > 0 ? 'list' : 'input',
|
|
56
|
+
name: 'modelPackageArn',
|
|
57
|
+
message: mcpSubscriptions.length > 0
|
|
58
|
+
? 'Select a Marketplace model package:'
|
|
59
|
+
: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
|
|
60
|
+
...(mcpSubscriptions.length > 0 ? {
|
|
61
|
+
choices: [
|
|
62
|
+
...mcpSubscriptions.map(sub => ({
|
|
63
|
+
name: `${sub.modelName} (${sub.vendor}) ā ${sub.arn}`,
|
|
64
|
+
value: sub.arn,
|
|
65
|
+
short: sub.modelName
|
|
66
|
+
})),
|
|
67
|
+
{ type: 'separator', separator: 'āāāāāāāāāāāāāā' },
|
|
68
|
+
{ name: 'Enter ARN manually...', value: '__manual__', short: 'manual' }
|
|
69
|
+
]
|
|
70
|
+
} : {
|
|
71
|
+
validate: (input) => {
|
|
72
|
+
if (!input || input.trim() === '') {
|
|
73
|
+
return 'Model package ARN is required';
|
|
74
|
+
}
|
|
75
|
+
const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
|
|
76
|
+
if (!arnPattern.test(input.trim())) {
|
|
77
|
+
return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
|
|
78
|
+
}
|
|
79
|
+
return true;
|
|
80
|
+
}
|
|
81
|
+
})
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
type: 'input',
|
|
85
|
+
name: 'modelPackageArnManual',
|
|
86
|
+
message: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
|
|
87
|
+
when: (answers) => answers.modelPackageArn === '__manual__',
|
|
88
|
+
validate: (input) => {
|
|
89
|
+
if (!input || input.trim() === '') {
|
|
90
|
+
return 'Model package ARN is required';
|
|
91
|
+
}
|
|
92
|
+
const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
|
|
93
|
+
if (!arnPattern.test(input.trim())) {
|
|
94
|
+
return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
|
|
95
|
+
}
|
|
96
|
+
return true;
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
];
|
|
100
|
+
const marketplaceAnswers = await this.runner._runPhase(marketplacePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
|
|
101
|
+
|
|
102
|
+
// Handle manual ARN entry fallback
|
|
103
|
+
if (marketplaceAnswers.modelPackageArn === '__manual__' && marketplaceAnswers.modelPackageArnManual) {
|
|
104
|
+
marketplaceAnswers.modelPackageArn = marketplaceAnswers.modelPackageArnManual;
|
|
105
|
+
delete marketplaceAnswers.modelPackageArnManual;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Infrastructure prompts: region, deployment target, instance type
|
|
109
|
+
console.log('\nšŖ Infrastructure & Deployment');
|
|
110
|
+
const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion;
|
|
111
|
+
const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {};
|
|
112
|
+
|
|
113
|
+
const marketplaceInfraPrompts = [
|
|
114
|
+
{
|
|
115
|
+
type: 'list',
|
|
116
|
+
name: 'awsRegion',
|
|
117
|
+
message: 'Target AWS region?',
|
|
118
|
+
choices: (answers) => {
|
|
119
|
+
const bootstrapReg = answers._bootstrapRegion;
|
|
120
|
+
const choices = ['us-east-1'];
|
|
121
|
+
if (bootstrapReg && bootstrapReg !== 'us-east-1') {
|
|
122
|
+
choices.unshift({ name: `${bootstrapReg} (from bootstrap profile)`, value: bootstrapReg });
|
|
123
|
+
}
|
|
124
|
+
choices.push({ name: 'Custom...', value: 'custom' });
|
|
125
|
+
return choices;
|
|
126
|
+
},
|
|
127
|
+
default: (answers) => answers._bootstrapRegion || 'us-east-1'
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
type: 'input',
|
|
131
|
+
name: 'customAwsRegion',
|
|
132
|
+
message: 'Enter AWS region (e.g., us-west-2, eu-west-1):',
|
|
133
|
+
when: answers => answers.awsRegion === 'custom'
|
|
134
|
+
},
|
|
135
|
+
{
|
|
136
|
+
type: 'list',
|
|
137
|
+
name: 'deploymentTarget',
|
|
138
|
+
message: 'Deployment target?',
|
|
139
|
+
choices: [
|
|
140
|
+
{ name: 'SageMaker Real-Time Inference', value: 'realtime-inference' },
|
|
141
|
+
{ name: 'SageMaker Async Inference', value: 'async-inference' },
|
|
142
|
+
{ name: 'SageMaker Batch Transform', value: 'batch-transform' }
|
|
143
|
+
],
|
|
144
|
+
default: 'realtime-inference'
|
|
145
|
+
},
|
|
146
|
+
{
|
|
147
|
+
type: 'list',
|
|
148
|
+
name: 'instanceType',
|
|
149
|
+
message: 'Instance type for deployment?',
|
|
150
|
+
choices: [
|
|
151
|
+
{ name: 'ml.g5.xlarge (1 GPU, 24GB)', value: 'ml.g5.xlarge' },
|
|
152
|
+
{ name: 'ml.g5.2xlarge (1 GPU, 24GB)', value: 'ml.g5.2xlarge' },
|
|
153
|
+
{ name: 'ml.g5.4xlarge (1 GPU, 24GB)', value: 'ml.g5.4xlarge' },
|
|
154
|
+
{ name: 'ml.g5.12xlarge (4 GPUs, 96GB)', value: 'ml.g5.12xlarge' },
|
|
155
|
+
{ name: 'ml.p3.2xlarge (1 GPU, 16GB V100)', value: 'ml.p3.2xlarge' },
|
|
156
|
+
{ name: 'ml.m5.xlarge (CPU, 16GB)', value: 'ml.m5.xlarge' },
|
|
157
|
+
{ name: 'Custom...', value: 'custom' }
|
|
158
|
+
],
|
|
159
|
+
default: 'ml.g5.xlarge'
|
|
160
|
+
},
|
|
161
|
+
{
|
|
162
|
+
type: 'input',
|
|
163
|
+
name: 'customInstanceType',
|
|
164
|
+
message: 'Enter instance type (e.g., ml.g5.xlarge):',
|
|
165
|
+
validate: (input) => {
|
|
166
|
+
if (!input || input.trim() === '') {
|
|
167
|
+
return 'Instance type is required';
|
|
168
|
+
}
|
|
169
|
+
if (!input.startsWith('ml.')) {
|
|
170
|
+
return 'Instance type must start with "ml." (e.g., ml.g5.xlarge)';
|
|
171
|
+
}
|
|
172
|
+
return true;
|
|
173
|
+
},
|
|
174
|
+
when: answers => answers.instanceType === 'custom'
|
|
175
|
+
}
|
|
176
|
+
];
|
|
177
|
+
const infraAnswers = await this.runner._runPhase(marketplaceInfraPrompts, { ...frameworkAnswers, ...regionPreviousAnswers }, explicitConfig, existingConfig);
|
|
178
|
+
|
|
179
|
+
// Async-specific prompts
|
|
180
|
+
let asyncAnswers = {};
|
|
181
|
+
if (infraAnswers.deploymentTarget === 'async-inference') {
|
|
182
|
+
asyncAnswers = await this.runner._runPhase(infraAsyncPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
// Batch transform-specific prompts
|
|
186
|
+
let batchTransformAnswers = {};
|
|
187
|
+
if (infraAnswers.deploymentTarget === 'batch-transform') {
|
|
188
|
+
batchTransformAnswers = await this.runner._runPhase(
|
|
189
|
+
infraBatchTransformPrompts,
|
|
190
|
+
{ ...infraAnswers },
|
|
191
|
+
explicitConfig,
|
|
192
|
+
existingConfig
|
|
193
|
+
);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// Role ARN prompt
|
|
197
|
+
const rolePrompts = [
|
|
198
|
+
{
|
|
199
|
+
type: 'input',
|
|
200
|
+
name: 'awsRoleArn',
|
|
201
|
+
message: 'AWS IAM Role ARN for SageMaker execution (optional)?',
|
|
202
|
+
validate: (input) => {
|
|
203
|
+
if (!input || input.trim() === '') {
|
|
204
|
+
return true;
|
|
205
|
+
}
|
|
206
|
+
const arnPattern = /^arn:aws:iam::\d{12}:role\/[\w+=,.@-]+$/;
|
|
207
|
+
if (!arnPattern.test(input)) {
|
|
208
|
+
return 'Invalid ARN format. Expected: arn:aws:iam::123456789012:role/RoleName';
|
|
209
|
+
}
|
|
210
|
+
return true;
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
];
|
|
214
|
+
const roleAnswers = await this.runner._runPhase(rolePrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
215
|
+
|
|
216
|
+
// Project name + destination
|
|
217
|
+
console.log('\nš Project Configuration');
|
|
218
|
+
const allTechnicalAnswers = {
|
|
219
|
+
...frameworkAnswers,
|
|
220
|
+
...marketplaceAnswers,
|
|
221
|
+
...infraAnswers,
|
|
222
|
+
...asyncAnswers,
|
|
223
|
+
...batchTransformAnswers,
|
|
224
|
+
...roleAnswers
|
|
225
|
+
};
|
|
226
|
+
const projectAnswers = await this.runner._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
|
|
227
|
+
const destinationAnswers = await this.runner._runPhase(destinationPrompts,
|
|
228
|
+
{ ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
|
|
229
|
+
|
|
230
|
+
// Combine all marketplace answers
|
|
231
|
+
const combinedAnswers = {
|
|
232
|
+
...frameworkAnswers,
|
|
233
|
+
...marketplaceAnswers,
|
|
234
|
+
...infraAnswers,
|
|
235
|
+
...asyncAnswers,
|
|
236
|
+
...batchTransformAnswers,
|
|
237
|
+
...roleAnswers,
|
|
238
|
+
...projectAnswers,
|
|
239
|
+
...destinationAnswers,
|
|
240
|
+
buildTimestamp
|
|
241
|
+
};
|
|
242
|
+
|
|
243
|
+
// Handle custom instance type
|
|
244
|
+
if (combinedAnswers.customInstanceType) {
|
|
245
|
+
combinedAnswers.instanceType = combinedAnswers.customInstanceType;
|
|
246
|
+
delete combinedAnswers.customInstanceType;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// Handle custom AWS region
|
|
250
|
+
if (combinedAnswers.customAwsRegion) {
|
|
251
|
+
combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
|
|
252
|
+
delete combinedAnswers.customAwsRegion;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// Map awsRoleArn to roleArn for templates
|
|
256
|
+
if (combinedAnswers.awsRoleArn) {
|
|
257
|
+
combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
|
|
258
|
+
delete combinedAnswers.awsRoleArn;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
// Ensure CLI-provided values are in combinedAnswers
|
|
262
|
+
if (explicitConfig.modelPackageArn && !combinedAnswers.modelPackageArn) {
|
|
263
|
+
combinedAnswers.modelPackageArn = explicitConfig.modelPackageArn;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// Handle marketplace:// prefix from --model-name CLI option
|
|
267
|
+
const modelName = explicitConfig.modelName || combinedAnswers.modelName;
|
|
268
|
+
if (modelName && modelName.startsWith('marketplace://')) {
|
|
269
|
+
const arn = modelName.replace(/^marketplace:\/\//, '');
|
|
270
|
+
combinedAnswers.modelPackageArn = arn;
|
|
271
|
+
delete combinedAnswers.modelName;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
return combinedAnswers;
|
|
275
|
+
}
|
|
276
|
+
}
|