@aws/ml-container-creator 0.7.1 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/bin/cli.js +1 -1
- package/infra/ci-harness/buildspec.yml +4 -0
- package/package.json +3 -1
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/model-servers.json +80 -0
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/servers/model-picker/index.js +27 -16
- package/src/app.js +113 -23
- package/src/lib/cli-handler.js +1 -1
- package/src/lib/config-manager.js +39 -2
- package/src/lib/cross-cutting-checker.js +146 -33
- package/src/lib/deployment-config-resolver.js +10 -4
- package/src/lib/e2e-bootstrap.js +227 -0
- package/src/lib/e2e-catalog-validator.js +103 -0
- package/src/lib/e2e-quota-validator.js +135 -0
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +306 -24
- package/src/lib/prompts.js +9 -3
- package/src/lib/template-manager.js +10 -4
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/src/lib/tune-catalog-validator.js +5 -5
- package/templates/code/serve +2 -2
- package/templates/code/serving.properties +2 -2
- package/templates/diffusors/serve +3 -3
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/.tune_helper.py +2 -2
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/register +8 -2
- package/templates/do/test +5 -5
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
- package/templates/do/tune +2 -2
- package/templates/marketplace/config +118 -0
- package/templates/marketplace/deploy +890 -0
- package/templates/marketplace/test +453 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* E2E Quota Validator
|
|
6
|
+
*
|
|
7
|
+
* Validates that the AWS account has sufficient service quotas for the
|
|
8
|
+
* instance types required by a given tier in the e2e catalog.
|
|
9
|
+
*
|
|
10
|
+
* Requirements: 3.3, 3.4
|
|
11
|
+
*/
|
|
12
|
+
|
|
13
|
+
import { ServiceQuotasClient, GetServiceQuotaCommand } from '@aws-sdk/client-service-quotas';
|
|
14
|
+
import { filterByTier } from './e2e-catalog-validator.js';
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Instance type to Service Quotas quota code mapping.
|
|
18
|
+
* SageMaker real-time endpoint instance quotas follow a naming pattern.
|
|
19
|
+
* This map covers the instance types used in the e2e catalog.
|
|
20
|
+
*/
|
|
21
|
+
const INSTANCE_QUOTA_CODES = {
|
|
22
|
+
'ml.g6e.xlarge': 'L-2D6591FA',
|
|
23
|
+
'ml.g6e.2xlarge': 'L-2D6591FA',
|
|
24
|
+
'ml.g6e.4xlarge': 'L-2D6591FA',
|
|
25
|
+
'ml.g6e.12xlarge': 'L-2D6591FA',
|
|
26
|
+
'ml.g5.xlarge': 'L-0100B498',
|
|
27
|
+
'ml.g5.2xlarge': 'L-0100B498',
|
|
28
|
+
'ml.m5.xlarge': 'L-ABB2FAC3',
|
|
29
|
+
'ml.p5.48xlarge': 'L-E89A212B'
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
const SAGEMAKER_SERVICE_CODE = 'sagemaker';
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* Parse the instance type from a CLI args string.
|
|
36
|
+
*
|
|
37
|
+
* Looks for `--instance-type=<value>` or `--instance-type <value>` patterns.
|
|
38
|
+
*
|
|
39
|
+
* @param {string} args - The CLI args string
|
|
40
|
+
* @returns {string|null} The instance type value, or null if not found
|
|
41
|
+
*/
|
|
42
|
+
export function parseInstanceType(args) {
|
|
43
|
+
if (!args || typeof args !== 'string') {
|
|
44
|
+
return null;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// Match --instance-type=value or --instance-type value
|
|
48
|
+
const equalMatch = args.match(/--instance-type=(\S+)/);
|
|
49
|
+
if (equalMatch) {
|
|
50
|
+
return equalMatch[1];
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const spaceMatch = args.match(/--instance-type\s+(\S+)/);
|
|
54
|
+
if (spaceMatch) {
|
|
55
|
+
return spaceMatch[1];
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
return null;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/**
|
|
62
|
+
* Sum instance counts per type for a given tier in the catalog.
|
|
63
|
+
*
|
|
64
|
+
* @param {string} tier - The tier to filter by
|
|
65
|
+
* @param {Object} catalog - The catalog object
|
|
66
|
+
* @returns {Map<string, number>} Map of instance type to required count
|
|
67
|
+
*/
|
|
68
|
+
export function sumInstanceRequirements(tier, catalog) {
|
|
69
|
+
const configs = filterByTier(catalog, tier);
|
|
70
|
+
const counts = new Map();
|
|
71
|
+
|
|
72
|
+
for (const config of configs) {
|
|
73
|
+
const instanceType = parseInstanceType(config.args);
|
|
74
|
+
if (instanceType) {
|
|
75
|
+
counts.set(instanceType, (counts.get(instanceType) || 0) + 1);
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return counts;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Validate that the AWS account has sufficient quotas for the instance types
|
|
84
|
+
* required by a given tier.
|
|
85
|
+
*
|
|
86
|
+
* @param {string} tier - The tier to validate quotas for
|
|
87
|
+
* @param {Object} catalog - The catalog object
|
|
88
|
+
* @param {string} region - The AWS region to check quotas in
|
|
89
|
+
* @param {Object} [options] - Optional configuration
|
|
90
|
+
* @param {Object} [options.client] - Pre-configured ServiceQuotasClient (for testing)
|
|
91
|
+
* @returns {Promise<Array<{instanceType: string, required: number, available: number, sufficient: boolean}>>}
|
|
92
|
+
*/
|
|
93
|
+
export async function validateQuotas(tier, catalog, region, options = {}) {
|
|
94
|
+
const instanceRequirements = sumInstanceRequirements(tier, catalog);
|
|
95
|
+
const results = [];
|
|
96
|
+
|
|
97
|
+
if (instanceRequirements.size === 0) {
|
|
98
|
+
return results;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
const client = options.client || new ServiceQuotasClient({ region });
|
|
102
|
+
|
|
103
|
+
for (const [instanceType, required] of instanceRequirements) {
|
|
104
|
+
const quotaCode = INSTANCE_QUOTA_CODES[instanceType];
|
|
105
|
+
let available = 0;
|
|
106
|
+
|
|
107
|
+
if (quotaCode) {
|
|
108
|
+
try {
|
|
109
|
+
const command = new GetServiceQuotaCommand({
|
|
110
|
+
ServiceCode: SAGEMAKER_SERVICE_CODE,
|
|
111
|
+
QuotaCode: quotaCode
|
|
112
|
+
});
|
|
113
|
+
const response = await client.send(command);
|
|
114
|
+
available = response.Quota?.Value ?? 0;
|
|
115
|
+
} catch (err) {
|
|
116
|
+
// If we can't fetch the quota, assume 0 and warn
|
|
117
|
+
console.warn(`⚠️ Could not fetch quota for ${instanceType}: ${err.message}`);
|
|
118
|
+
available = 0;
|
|
119
|
+
}
|
|
120
|
+
} else {
|
|
121
|
+
console.warn(`⚠️ No quota code mapping for ${instanceType}, skipping quota check`);
|
|
122
|
+
available = 0;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
const sufficient = available >= required;
|
|
126
|
+
|
|
127
|
+
if (!sufficient) {
|
|
128
|
+
console.warn(`⚠️ ${instanceType} quota is ${available}, need ${required} for ${tier} tier`);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
results.push({ instanceType, required, available, sufficient });
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
return results;
|
|
135
|
+
}
|
package/src/lib/mcp-client.js
CHANGED
|
@@ -14,6 +14,12 @@
|
|
|
14
14
|
|
|
15
15
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
|
16
16
|
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
|
17
|
+
import path from 'path';
|
|
18
|
+
import { fileURLToPath } from 'url';
|
|
19
|
+
|
|
20
|
+
const __mcp_filename = fileURLToPath(import.meta.url);
|
|
21
|
+
const __mcp_dirname = path.dirname(__mcp_filename);
|
|
22
|
+
const PACKAGE_ROOT = path.resolve(__mcp_dirname, '../..');
|
|
17
23
|
|
|
18
24
|
const DEFAULT_TOOL_NAME = 'get_ml_config';
|
|
19
25
|
const DEFAULT_LIMIT = 10;
|
|
@@ -96,6 +102,15 @@ class McpClient {
|
|
|
96
102
|
async _executeQuery() {
|
|
97
103
|
const { command, args = [], env } = this.serverConfig;
|
|
98
104
|
|
|
105
|
+
// Resolve relative paths in args against the package root
|
|
106
|
+
const resolvedArgs = args.map(arg => {
|
|
107
|
+
if (arg && !path.isAbsolute(arg) && !arg.startsWith('-')) {
|
|
108
|
+
const resolved = path.resolve(PACKAGE_ROOT, arg);
|
|
109
|
+
return resolved;
|
|
110
|
+
}
|
|
111
|
+
return arg;
|
|
112
|
+
});
|
|
113
|
+
|
|
99
114
|
// Build environment: merge process.env with server-specific env
|
|
100
115
|
// When --smart flag is active, inject BEDROCK_SMART=true for this run
|
|
101
116
|
// Discover mode is now default; inject DISCOVER_MODE=false only when explicitly disabled
|
|
@@ -108,7 +123,7 @@ class McpClient {
|
|
|
108
123
|
// Create stdio transport — spawns the server process
|
|
109
124
|
this._transport = new StdioClientTransport({
|
|
110
125
|
command,
|
|
111
|
-
args,
|
|
126
|
+
args: resolvedArgs,
|
|
112
127
|
env: spawnEnv,
|
|
113
128
|
stderr: 'pipe'
|
|
114
129
|
});
|
|
@@ -91,8 +91,12 @@ export default class McpCommandHandler {
|
|
|
91
91
|
const installed = await this._installBundledDependencies(resolved.serverDir, name);
|
|
92
92
|
if (!installed) return;
|
|
93
93
|
|
|
94
|
+
// Store path relative to package root for portability
|
|
95
|
+
const packageRoot = path.resolve(__dirname, '../..');
|
|
96
|
+
const relativePath = path.relative(packageRoot, resolved.entryPoint);
|
|
97
|
+
|
|
94
98
|
command = 'node';
|
|
95
|
-
commandArgs = [
|
|
99
|
+
commandArgs = [relativePath];
|
|
96
100
|
} else {
|
|
97
101
|
// Find the '--' separator to split name from command
|
|
98
102
|
const separatorIndex = positionalArgs.indexOf('--');
|
|
@@ -195,9 +199,13 @@ export default class McpCommandHandler {
|
|
|
195
199
|
const installed = await this._installBundledDependencies(resolved.serverDir, server.name);
|
|
196
200
|
if (!installed) continue;
|
|
197
201
|
|
|
202
|
+
// Store path relative to package root for portability across machines
|
|
203
|
+
const packageRoot = path.resolve(__dirname, '../..');
|
|
204
|
+
const relativePath = path.relative(packageRoot, resolved.entryPoint);
|
|
205
|
+
|
|
198
206
|
config.mcpServers[server.name] = {
|
|
199
207
|
command: 'node',
|
|
200
|
-
args: [
|
|
208
|
+
args: [relativePath]
|
|
201
209
|
};
|
|
202
210
|
added++;
|
|
203
211
|
}
|
package/src/lib/prompt-runner.js
CHANGED
|
@@ -50,6 +50,20 @@ const __pr_filename = fileURLToPath(import.meta.url);
|
|
|
50
50
|
const __pr_dirname = path.dirname(__pr_filename);
|
|
51
51
|
const GENERATOR_ROOT = path.resolve(__pr_dirname, '..', '..');
|
|
52
52
|
|
|
53
|
+
/**
|
|
54
|
+
* Resolve MCP server args — converts relative paths to absolute using GENERATOR_ROOT.
|
|
55
|
+
* @param {string[]} args - The args array from mcp.json serverConfig
|
|
56
|
+
* @returns {string[]} Args with relative paths resolved
|
|
57
|
+
*/
|
|
58
|
+
function resolveMcpArgs(args) {
|
|
59
|
+
return (args || []).map(arg => {
|
|
60
|
+
if (arg && !path.isAbsolute(arg) && !arg.startsWith('-')) {
|
|
61
|
+
return path.resolve(GENERATOR_ROOT, arg);
|
|
62
|
+
}
|
|
63
|
+
return arg;
|
|
64
|
+
});
|
|
65
|
+
}
|
|
66
|
+
|
|
53
67
|
export default class PromptRunner {
|
|
54
68
|
constructor({ configManager, options, registryConfigManager, baseConfig, promptFn }) {
|
|
55
69
|
this.configManager = configManager;
|
|
@@ -111,6 +125,14 @@ export default class PromptRunner {
|
|
|
111
125
|
framework: framework || deploymentConfigAnswers.framework,
|
|
112
126
|
modelServer: modelServer || deploymentConfigAnswers.modelServer
|
|
113
127
|
};
|
|
128
|
+
|
|
129
|
+
// ──────────────────────────────────────────────────────────────────────
|
|
130
|
+
// Marketplace fast-path: skip all container-related prompts
|
|
131
|
+
// Requirements: 2.3, 2.4, 2.5
|
|
132
|
+
// ──────────────────────────────────────────────────────────────────────
|
|
133
|
+
if (frameworkAnswers.architecture === 'marketplace') {
|
|
134
|
+
return this._runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp);
|
|
135
|
+
}
|
|
114
136
|
|
|
115
137
|
// Engine prompt for http architecture
|
|
116
138
|
const engineAnswers = await this._runPhase(enginePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
|
|
@@ -596,13 +618,27 @@ export default class PromptRunner {
|
|
|
596
618
|
// Infer modelSource from model name prefix if not set by MCP
|
|
597
619
|
const modelName = combinedAnswers.customModelName || combinedAnswers.modelName;
|
|
598
620
|
if (!combinedAnswers.modelSource && modelName) {
|
|
599
|
-
|
|
621
|
+
// Reject deprecated JumpStart prefixes with migration message
|
|
622
|
+
if (modelName.startsWith('jumpstart://') || modelName.startsWith('jumpstart-hub://')) {
|
|
623
|
+
const bareId = modelName.replace(/^jumpstart(-hub)?:\/\//, '');
|
|
624
|
+
console.error(`\n ⚠️ JumpStart is no longer supported. Use the HuggingFace model ID directly: ${bareId}`);
|
|
625
|
+
console.error(' JumpStart model sources have been removed. Use one of:');
|
|
626
|
+
console.error(' • HuggingFace model ID (e.g., meta-llama/Llama-2-7b-hf)');
|
|
627
|
+
console.error(' • s3://bucket/path/model.tar.gz');
|
|
628
|
+
console.error(' • registry://model-package-name');
|
|
629
|
+
console.error(' • marketplace://arn:aws:sagemaker:...\n');
|
|
630
|
+
process.exit(1);
|
|
631
|
+
}
|
|
632
|
+
if (modelName.startsWith('marketplace://')) {
|
|
633
|
+
// marketplace://arn:aws:sagemaker:... → set architecture to marketplace and store ARN
|
|
634
|
+
const arn = modelName.replace(/^marketplace:\/\//, '');
|
|
635
|
+
combinedAnswers.modelPackageArn = arn;
|
|
636
|
+
combinedAnswers.architecture = 'marketplace';
|
|
637
|
+
combinedAnswers.deploymentConfig = 'marketplace';
|
|
638
|
+
combinedAnswers.modelSource = undefined;
|
|
639
|
+
} else if (modelName.startsWith('s3://')) {
|
|
600
640
|
combinedAnswers.modelSource = 's3';
|
|
601
641
|
combinedAnswers.artifactUri = modelName;
|
|
602
|
-
} else if (modelName.startsWith('jumpstart://')) {
|
|
603
|
-
combinedAnswers.modelSource = 'jumpstart';
|
|
604
|
-
} else if (modelName.startsWith('jumpstart-hub://')) {
|
|
605
|
-
combinedAnswers.modelSource = 'jumpstart-hub';
|
|
606
642
|
} else if (modelName.startsWith('registry://')) {
|
|
607
643
|
combinedAnswers.modelSource = 'registry';
|
|
608
644
|
}
|
|
@@ -613,7 +649,7 @@ export default class PromptRunner {
|
|
|
613
649
|
combinedAnswers.artifactUri = modelName;
|
|
614
650
|
}
|
|
615
651
|
}
|
|
616
|
-
const downloadSources = ['
|
|
652
|
+
const downloadSources = ['s3'];
|
|
617
653
|
if (downloadSources.includes(combinedAnswers.modelSource) && !combinedAnswers.artifactUri) {
|
|
618
654
|
console.log(`\n ⚠️ Model source is '${combinedAnswers.modelSource}' but no artifact URI was resolved.`);
|
|
619
655
|
console.log(' The model-picker could not determine the download location.');
|
|
@@ -638,18 +674,7 @@ export default class PromptRunner {
|
|
|
638
674
|
}
|
|
639
675
|
}
|
|
640
676
|
|
|
641
|
-
|
|
642
|
-
// HubAccessConfig on CreateModel, which is not yet supported by the generator.
|
|
643
|
-
if (combinedAnswers.modelSource === 'jumpstart-hub') {
|
|
644
|
-
console.log('\n ⚠️ JumpStart Private Hub models are not yet fully supported.');
|
|
645
|
-
console.log(' Private hub artifacts live in AWS-managed S3 buckets that require');
|
|
646
|
-
console.log(' SageMaker\'s HubAccessConfig mechanism for access.');
|
|
647
|
-
console.log(' The generated project will not be able to download model artifacts at runtime.');
|
|
648
|
-
console.log(' This feature is tracked for a future release.\n');
|
|
649
|
-
console.log(' Falling back to HuggingFace source.\n');
|
|
650
|
-
combinedAnswers.modelSource = 'huggingface';
|
|
651
|
-
delete combinedAnswers.artifactUri;
|
|
652
|
-
}
|
|
677
|
+
|
|
653
678
|
|
|
654
679
|
// Apply auto-set model format for Triton backends with single format
|
|
655
680
|
// Requirements: 3.3, 3.4, 3.5
|
|
@@ -731,6 +756,265 @@ export default class PromptRunner {
|
|
|
731
756
|
return combinedAnswers;
|
|
732
757
|
}
|
|
733
758
|
|
|
759
|
+
/**
|
|
760
|
+
* Marketplace-specific prompt flow.
|
|
761
|
+
* Skips all container-related prompts (framework, model server, base image, CUDA version)
|
|
762
|
+
* and prompts only for: model package ARN, instance type, deployment target, region.
|
|
763
|
+
*
|
|
764
|
+
* Requirements: 2.3, 2.4, 2.5
|
|
765
|
+
* @private
|
|
766
|
+
*/
|
|
767
|
+
async _runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp) {
|
|
768
|
+
console.log('\n🏪 Marketplace Model Package Configuration');
|
|
769
|
+
|
|
770
|
+
// Query marketplace-picker MCP server for subscription discovery
|
|
771
|
+
// Requirements: 2.4, 6.1, 6.2
|
|
772
|
+
let mcpSubscriptions = [];
|
|
773
|
+
const cm = this.configManager;
|
|
774
|
+
if (cm && cm.getMcpServerNames && cm.getMcpServerNames().includes('marketplace-picker')) {
|
|
775
|
+
try {
|
|
776
|
+
console.log(' 🔍 Querying marketplace-picker for subscriptions...');
|
|
777
|
+
const result = await cm.queryMcpServer('marketplace-picker', {
|
|
778
|
+
region: explicitConfig.awsRegion || existingConfig.awsRegion || process.env.AWS_REGION || 'us-east-1'
|
|
779
|
+
});
|
|
780
|
+
if (result && result.metadata?.subscriptions?.length > 0) {
|
|
781
|
+
mcpSubscriptions = result.metadata.subscriptions;
|
|
782
|
+
console.log(` ✅ Found ${mcpSubscriptions.length} Marketplace subscription(s)`);
|
|
783
|
+
} else {
|
|
784
|
+
console.log(' ℹ️ No Marketplace subscriptions found — enter ARN manually');
|
|
785
|
+
}
|
|
786
|
+
} catch (err) {
|
|
787
|
+
console.log(` ⚠️ marketplace-picker unavailable: ${err.message}`);
|
|
788
|
+
console.log(' Falling back to manual ARN entry');
|
|
789
|
+
}
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
// Marketplace-specific prompts: model package ARN
|
|
793
|
+
const marketplacePrompts = [
|
|
794
|
+
{
|
|
795
|
+
type: mcpSubscriptions.length > 0 ? 'list' : 'input',
|
|
796
|
+
name: 'modelPackageArn',
|
|
797
|
+
message: mcpSubscriptions.length > 0
|
|
798
|
+
? 'Select a Marketplace model package:'
|
|
799
|
+
: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
|
|
800
|
+
...(mcpSubscriptions.length > 0 ? {
|
|
801
|
+
choices: [
|
|
802
|
+
...mcpSubscriptions.map(sub => ({
|
|
803
|
+
name: `${sub.modelName} (${sub.vendor}) — ${sub.arn}`,
|
|
804
|
+
value: sub.arn,
|
|
805
|
+
short: sub.modelName
|
|
806
|
+
})),
|
|
807
|
+
{ type: 'separator', separator: '──────────────' },
|
|
808
|
+
{ name: 'Enter ARN manually...', value: '__manual__', short: 'manual' }
|
|
809
|
+
]
|
|
810
|
+
} : {
|
|
811
|
+
validate: (input) => {
|
|
812
|
+
if (!input || input.trim() === '') {
|
|
813
|
+
return 'Model package ARN is required';
|
|
814
|
+
}
|
|
815
|
+
const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
|
|
816
|
+
if (!arnPattern.test(input.trim())) {
|
|
817
|
+
return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
|
|
818
|
+
}
|
|
819
|
+
return true;
|
|
820
|
+
}
|
|
821
|
+
})
|
|
822
|
+
},
|
|
823
|
+
{
|
|
824
|
+
type: 'input',
|
|
825
|
+
name: 'modelPackageArnManual',
|
|
826
|
+
message: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
|
|
827
|
+
when: (answers) => answers.modelPackageArn === '__manual__',
|
|
828
|
+
validate: (input) => {
|
|
829
|
+
if (!input || input.trim() === '') {
|
|
830
|
+
return 'Model package ARN is required';
|
|
831
|
+
}
|
|
832
|
+
const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
|
|
833
|
+
if (!arnPattern.test(input.trim())) {
|
|
834
|
+
return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
|
|
835
|
+
}
|
|
836
|
+
return true;
|
|
837
|
+
}
|
|
838
|
+
}
|
|
839
|
+
];
|
|
840
|
+
const marketplaceAnswers = await this._runPhase(marketplacePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
|
|
841
|
+
|
|
842
|
+
// Handle manual ARN entry fallback
|
|
843
|
+
if (marketplaceAnswers.modelPackageArn === '__manual__' && marketplaceAnswers.modelPackageArnManual) {
|
|
844
|
+
marketplaceAnswers.modelPackageArn = marketplaceAnswers.modelPackageArnManual;
|
|
845
|
+
delete marketplaceAnswers.modelPackageArnManual;
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
// Infrastructure prompts: region, deployment target, instance type
|
|
849
|
+
console.log('\n💪 Infrastructure & Deployment');
|
|
850
|
+
const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion;
|
|
851
|
+
const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {};
|
|
852
|
+
|
|
853
|
+
// Marketplace deployment targets (no HyperPod — vendor controls the container)
|
|
854
|
+
const marketplaceInfraPrompts = [
|
|
855
|
+
{
|
|
856
|
+
type: 'list',
|
|
857
|
+
name: 'awsRegion',
|
|
858
|
+
message: 'Target AWS region?',
|
|
859
|
+
choices: (answers) => {
|
|
860
|
+
const bootstrapReg = answers._bootstrapRegion;
|
|
861
|
+
const choices = ['us-east-1'];
|
|
862
|
+
if (bootstrapReg && bootstrapReg !== 'us-east-1') {
|
|
863
|
+
choices.unshift({ name: `${bootstrapReg} (from bootstrap profile)`, value: bootstrapReg });
|
|
864
|
+
}
|
|
865
|
+
choices.push({ name: 'Custom...', value: 'custom' });
|
|
866
|
+
return choices;
|
|
867
|
+
},
|
|
868
|
+
default: (answers) => answers._bootstrapRegion || 'us-east-1'
|
|
869
|
+
},
|
|
870
|
+
{
|
|
871
|
+
type: 'input',
|
|
872
|
+
name: 'customAwsRegion',
|
|
873
|
+
message: 'Enter AWS region (e.g., us-west-2, eu-west-1):',
|
|
874
|
+
when: answers => answers.awsRegion === 'custom'
|
|
875
|
+
},
|
|
876
|
+
{
|
|
877
|
+
type: 'list',
|
|
878
|
+
name: 'deploymentTarget',
|
|
879
|
+
message: 'Deployment target?',
|
|
880
|
+
choices: [
|
|
881
|
+
{ name: 'SageMaker Real-Time Inference', value: 'realtime-inference' },
|
|
882
|
+
{ name: 'SageMaker Async Inference', value: 'async-inference' },
|
|
883
|
+
{ name: 'SageMaker Batch Transform', value: 'batch-transform' }
|
|
884
|
+
],
|
|
885
|
+
default: 'realtime-inference'
|
|
886
|
+
},
|
|
887
|
+
{
|
|
888
|
+
type: 'list',
|
|
889
|
+
name: 'instanceType',
|
|
890
|
+
message: 'Instance type for deployment?',
|
|
891
|
+
choices: [
|
|
892
|
+
{ name: 'ml.g5.xlarge (1 GPU, 24GB)', value: 'ml.g5.xlarge' },
|
|
893
|
+
{ name: 'ml.g5.2xlarge (1 GPU, 24GB)', value: 'ml.g5.2xlarge' },
|
|
894
|
+
{ name: 'ml.g5.4xlarge (1 GPU, 24GB)', value: 'ml.g5.4xlarge' },
|
|
895
|
+
{ name: 'ml.g5.12xlarge (4 GPUs, 96GB)', value: 'ml.g5.12xlarge' },
|
|
896
|
+
{ name: 'ml.p3.2xlarge (1 GPU, 16GB V100)', value: 'ml.p3.2xlarge' },
|
|
897
|
+
{ name: 'ml.m5.xlarge (CPU, 16GB)', value: 'ml.m5.xlarge' },
|
|
898
|
+
{ name: 'Custom...', value: 'custom' }
|
|
899
|
+
],
|
|
900
|
+
default: 'ml.g5.xlarge'
|
|
901
|
+
},
|
|
902
|
+
{
|
|
903
|
+
type: 'input',
|
|
904
|
+
name: 'customInstanceType',
|
|
905
|
+
message: 'Enter instance type (e.g., ml.g5.xlarge):',
|
|
906
|
+
validate: (input) => {
|
|
907
|
+
if (!input || input.trim() === '') {
|
|
908
|
+
return 'Instance type is required';
|
|
909
|
+
}
|
|
910
|
+
if (!input.startsWith('ml.')) {
|
|
911
|
+
return 'Instance type must start with "ml." (e.g., ml.g5.xlarge)';
|
|
912
|
+
}
|
|
913
|
+
return true;
|
|
914
|
+
},
|
|
915
|
+
when: answers => answers.instanceType === 'custom'
|
|
916
|
+
}
|
|
917
|
+
];
|
|
918
|
+
const infraAnswers = await this._runPhase(marketplaceInfraPrompts, { ...frameworkAnswers, ...regionPreviousAnswers }, explicitConfig, existingConfig);
|
|
919
|
+
|
|
920
|
+
// Async-specific prompts (only when deploymentTarget === 'async-inference')
|
|
921
|
+
let asyncAnswers = {};
|
|
922
|
+
if (infraAnswers.deploymentTarget === 'async-inference') {
|
|
923
|
+
asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
// Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
|
|
927
|
+
let batchTransformAnswers = {};
|
|
928
|
+
if (infraAnswers.deploymentTarget === 'batch-transform') {
|
|
929
|
+
batchTransformAnswers = await this._runPhase(
|
|
930
|
+
infraBatchTransformPrompts,
|
|
931
|
+
{ ...infraAnswers },
|
|
932
|
+
explicitConfig,
|
|
933
|
+
existingConfig
|
|
934
|
+
);
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
// Role ARN prompt (always needed for marketplace deploy)
|
|
938
|
+
const rolePrompts = [
|
|
939
|
+
{
|
|
940
|
+
type: 'input',
|
|
941
|
+
name: 'awsRoleArn',
|
|
942
|
+
message: 'AWS IAM Role ARN for SageMaker execution (optional)?',
|
|
943
|
+
validate: (input) => {
|
|
944
|
+
if (!input || input.trim() === '') {
|
|
945
|
+
return true;
|
|
946
|
+
}
|
|
947
|
+
const arnPattern = /^arn:aws:iam::\d{12}:role\/[\w+=,.@-]+$/;
|
|
948
|
+
if (!arnPattern.test(input)) {
|
|
949
|
+
return 'Invalid ARN format. Expected: arn:aws:iam::123456789012:role/RoleName';
|
|
950
|
+
}
|
|
951
|
+
return true;
|
|
952
|
+
}
|
|
953
|
+
}
|
|
954
|
+
];
|
|
955
|
+
const roleAnswers = await this._runPhase(rolePrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
956
|
+
|
|
957
|
+
// Project name + destination
|
|
958
|
+
console.log('\n📋 Project Configuration');
|
|
959
|
+
const allTechnicalAnswers = {
|
|
960
|
+
...frameworkAnswers,
|
|
961
|
+
...marketplaceAnswers,
|
|
962
|
+
...infraAnswers,
|
|
963
|
+
...asyncAnswers,
|
|
964
|
+
...batchTransformAnswers,
|
|
965
|
+
...roleAnswers
|
|
966
|
+
};
|
|
967
|
+
const projectAnswers = await this._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
|
|
968
|
+
const destinationAnswers = await this._runPhase(destinationPrompts,
|
|
969
|
+
{ ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
|
|
970
|
+
|
|
971
|
+
// Combine all marketplace answers
|
|
972
|
+
const combinedAnswers = {
|
|
973
|
+
...frameworkAnswers,
|
|
974
|
+
...marketplaceAnswers,
|
|
975
|
+
...infraAnswers,
|
|
976
|
+
...asyncAnswers,
|
|
977
|
+
...batchTransformAnswers,
|
|
978
|
+
...roleAnswers,
|
|
979
|
+
...projectAnswers,
|
|
980
|
+
...destinationAnswers,
|
|
981
|
+
buildTimestamp
|
|
982
|
+
};
|
|
983
|
+
|
|
984
|
+
// Handle custom instance type
|
|
985
|
+
if (combinedAnswers.customInstanceType) {
|
|
986
|
+
combinedAnswers.instanceType = combinedAnswers.customInstanceType;
|
|
987
|
+
delete combinedAnswers.customInstanceType;
|
|
988
|
+
}
|
|
989
|
+
|
|
990
|
+
// Handle custom AWS region
|
|
991
|
+
if (combinedAnswers.customAwsRegion) {
|
|
992
|
+
combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
|
|
993
|
+
delete combinedAnswers.customAwsRegion;
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
// Map awsRoleArn to roleArn for templates
|
|
997
|
+
if (combinedAnswers.awsRoleArn) {
|
|
998
|
+
combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
|
|
999
|
+
delete combinedAnswers.awsRoleArn;
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
// Ensure CLI-provided values are in combinedAnswers
|
|
1003
|
+
if (explicitConfig.modelPackageArn && !combinedAnswers.modelPackageArn) {
|
|
1004
|
+
combinedAnswers.modelPackageArn = explicitConfig.modelPackageArn;
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1007
|
+
// Handle marketplace:// prefix from --model-name CLI option
|
|
1008
|
+
const modelName = explicitConfig.modelName || combinedAnswers.modelName;
|
|
1009
|
+
if (modelName && modelName.startsWith('marketplace://')) {
|
|
1010
|
+
const arn = modelName.replace(/^marketplace:\/\//, '');
|
|
1011
|
+
combinedAnswers.modelPackageArn = arn;
|
|
1012
|
+
delete combinedAnswers.modelName;
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
return combinedAnswers;
|
|
1016
|
+
}
|
|
1017
|
+
|
|
734
1018
|
/**
|
|
735
1019
|
* Checks if a parameter is promptable according to the parameter matrix
|
|
736
1020
|
* @param {string} parameterName - Name of the parameter
|
|
@@ -1114,7 +1398,7 @@ export default class PromptRunner {
|
|
|
1114
1398
|
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
|
1115
1399
|
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
|
|
1116
1400
|
|
|
1117
|
-
const serverArgs = [...(serverConfig.args
|
|
1401
|
+
const serverArgs = [...resolveMcpArgs(serverConfig.args)];
|
|
1118
1402
|
if (!discover && !serverArgs.includes('--no-discover')) {
|
|
1119
1403
|
serverArgs.push('--no-discover');
|
|
1120
1404
|
}
|
|
@@ -1669,7 +1953,7 @@ export default class PromptRunner {
|
|
|
1669
1953
|
|
|
1670
1954
|
const transport = new StdioClientTransport({
|
|
1671
1955
|
command: serverConfig.command,
|
|
1672
|
-
args: serverConfig.args
|
|
1956
|
+
args: resolveMcpArgs(serverConfig.args),
|
|
1673
1957
|
env: { ...process.env, ...(serverConfig.env || {}) },
|
|
1674
1958
|
stderr: 'pipe'
|
|
1675
1959
|
});
|
|
@@ -1746,9 +2030,7 @@ export default class PromptRunner {
|
|
|
1746
2030
|
const registryConfigManager = this.registryConfigManager;
|
|
1747
2031
|
if (registryConfigManager) {
|
|
1748
2032
|
// Only try HuggingFace API for bare model IDs (not prefixed URIs)
|
|
1749
|
-
const isNonHfUri = modelId.startsWith('
|
|
1750
|
-
modelId.startsWith('jumpstart-hub://') ||
|
|
1751
|
-
modelId.startsWith('s3://') ||
|
|
2033
|
+
const isNonHfUri = modelId.startsWith('s3://') ||
|
|
1752
2034
|
modelId.startsWith('registry://');
|
|
1753
2035
|
|
|
1754
2036
|
if (!isNonHfUri) {
|
|
@@ -1773,7 +2055,7 @@ export default class PromptRunner {
|
|
|
1773
2055
|
console.log(' ⚠️ HuggingFace API unavailable');
|
|
1774
2056
|
}
|
|
1775
2057
|
} else {
|
|
1776
|
-
// Non-HF URI (
|
|
2058
|
+
// Non-HF URI (s3://, registry://, etc.) — skip HF lookup silently
|
|
1777
2059
|
// The summary at the end of this function will report "No additional model information"
|
|
1778
2060
|
}
|
|
1779
2061
|
|
package/src/lib/prompts.js
CHANGED
|
@@ -232,6 +232,12 @@ const deploymentConfigPrompts = [
|
|
|
232
232
|
name: 'Diffusors with vLLM Omni',
|
|
233
233
|
value: 'diffusors-vllm-omni',
|
|
234
234
|
short: 'diffusors-vllm-omni'
|
|
235
|
+
},
|
|
236
|
+
{ type: 'separator', separator: '── AWS Marketplace ──' },
|
|
237
|
+
{
|
|
238
|
+
name: 'Marketplace Model Package',
|
|
239
|
+
value: 'marketplace',
|
|
240
|
+
short: 'marketplace'
|
|
235
241
|
}
|
|
236
242
|
]
|
|
237
243
|
}
|
|
@@ -469,9 +475,9 @@ const modelFormatPrompts = [
|
|
|
469
475
|
if (!input || input.trim() === '') {
|
|
470
476
|
return 'Model name is required';
|
|
471
477
|
}
|
|
472
|
-
// Basic validation - must contain a slash (org/model,
|
|
478
|
+
// Basic validation - must contain a slash (org/model, s3://path, etc.)
|
|
473
479
|
if (!input.includes('/')) {
|
|
474
|
-
return 'Please use the full model path (e.g., microsoft/DialoGPT-medium,
|
|
480
|
+
return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, s3://bucket/model, registry://my-package)';
|
|
475
481
|
}
|
|
476
482
|
return true;
|
|
477
483
|
},
|
|
@@ -583,7 +589,7 @@ const hfTokenPrompts = [
|
|
|
583
589
|
}
|
|
584
590
|
|
|
585
591
|
// Skip HF token prompt for non-HuggingFace model sources
|
|
586
|
-
// (S3,
|
|
592
|
+
// (S3, Registry models don't need HF auth)
|
|
587
593
|
const modelSource = answers.modelSource;
|
|
588
594
|
if (modelSource && modelSource !== 'huggingface') {
|
|
589
595
|
return false;
|
|
@@ -50,7 +50,7 @@ export default class TemplateManager {
|
|
|
50
50
|
*/
|
|
51
51
|
validate() {
|
|
52
52
|
const supportedOptions = {
|
|
53
|
-
//
|
|
53
|
+
// 16 canonical deployment-config values (2 http, 5 transformers, 7 triton, 1 diffusors, 1 marketplace)
|
|
54
54
|
deploymentConfigs: [
|
|
55
55
|
// HTTP architecture (2)
|
|
56
56
|
'http-flask', 'http-fastapi',
|
|
@@ -61,7 +61,9 @@ export default class TemplateManager {
|
|
|
61
61
|
'triton-fil', 'triton-onnxruntime', 'triton-tensorflow',
|
|
62
62
|
'triton-pytorch', 'triton-vllm', 'triton-tensorrtllm', 'triton-python',
|
|
63
63
|
// Diffusors architecture (1)
|
|
64
|
-
'diffusors-vllm-omni'
|
|
64
|
+
'diffusors-vllm-omni',
|
|
65
|
+
// Marketplace architecture (1)
|
|
66
|
+
'marketplace'
|
|
65
67
|
],
|
|
66
68
|
buildTargets: ['codebuild'],
|
|
67
69
|
deploymentTargets: ['realtime-inference', 'async-inference', 'batch-transform', 'hyperpod-eks'],
|
|
@@ -82,7 +84,7 @@ export default class TemplateManager {
|
|
|
82
84
|
this._validateGpuRequirement();
|
|
83
85
|
} else {
|
|
84
86
|
// Fallback: validate architecture and backend separately (new canonical format)
|
|
85
|
-
const architectures = ['http', 'transformers', 'triton', 'diffusors'];
|
|
87
|
+
const architectures = ['http', 'transformers', 'triton', 'diffusors', 'marketplace'];
|
|
86
88
|
const backends = [
|
|
87
89
|
// http backends
|
|
88
90
|
'flask', 'fastapi',
|
|
@@ -95,7 +97,11 @@ export default class TemplateManager {
|
|
|
95
97
|
];
|
|
96
98
|
|
|
97
99
|
this._validateChoice('architecture', architectures);
|
|
98
|
-
|
|
100
|
+
|
|
101
|
+
// Marketplace has no backend — skip backend validation
|
|
102
|
+
if (this.answers.architecture !== 'marketplace') {
|
|
103
|
+
this._validateChoice('backend', backends);
|
|
104
|
+
}
|
|
99
105
|
|
|
100
106
|
// Validate tensorrt-llm is only used with transformers architecture
|
|
101
107
|
if (this.answers.backend === 'tensorrt-llm' && this.answers.architecture !== 'transformers') {
|