@aws/ml-container-creator 0.7.1 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +1 -1
- package/infra/ci-harness/buildspec.yml +4 -0
- package/package.json +1 -1
- package/servers/lib/catalogs/model-servers.json +80 -0
- package/servers/model-picker/index.js +27 -16
- package/src/app.js +89 -21
- package/src/lib/cli-handler.js +1 -1
- package/src/lib/config-manager.js +39 -2
- package/src/lib/cross-cutting-checker.js +146 -33
- package/src/lib/deployment-config-resolver.js +10 -4
- package/src/lib/e2e-bootstrap.js +227 -0
- package/src/lib/e2e-catalog-validator.js +103 -0
- package/src/lib/e2e-quota-validator.js +135 -0
- package/src/lib/prompt-runner.js +290 -22
- package/src/lib/prompts.js +9 -3
- package/src/lib/template-manager.js +10 -4
- package/src/lib/tune-catalog-validator.js +5 -5
- package/templates/code/serve +2 -2
- package/templates/code/serving.properties +2 -2
- package/templates/diffusors/serve +3 -3
- package/templates/do/.tune_helper.py +2 -2
- package/templates/do/register +8 -2
- package/templates/do/test +5 -5
- package/templates/do/tune +2 -2
- package/templates/marketplace/config +118 -0
- package/templates/marketplace/deploy +890 -0
- package/templates/marketplace/test +453 -0
package/src/lib/prompt-runner.js
CHANGED
|
@@ -111,6 +111,14 @@ export default class PromptRunner {
|
|
|
111
111
|
framework: framework || deploymentConfigAnswers.framework,
|
|
112
112
|
modelServer: modelServer || deploymentConfigAnswers.modelServer
|
|
113
113
|
};
|
|
114
|
+
|
|
115
|
+
// ──────────────────────────────────────────────────────────────────────
|
|
116
|
+
// Marketplace fast-path: skip all container-related prompts
|
|
117
|
+
// Requirements: 2.3, 2.4, 2.5
|
|
118
|
+
// ──────────────────────────────────────────────────────────────────────
|
|
119
|
+
if (frameworkAnswers.architecture === 'marketplace') {
|
|
120
|
+
return this._runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp);
|
|
121
|
+
}
|
|
114
122
|
|
|
115
123
|
// Engine prompt for http architecture
|
|
116
124
|
const engineAnswers = await this._runPhase(enginePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
|
|
@@ -596,13 +604,27 @@ export default class PromptRunner {
|
|
|
596
604
|
// Infer modelSource from model name prefix if not set by MCP
|
|
597
605
|
const modelName = combinedAnswers.customModelName || combinedAnswers.modelName;
|
|
598
606
|
if (!combinedAnswers.modelSource && modelName) {
|
|
599
|
-
|
|
607
|
+
// Reject deprecated JumpStart prefixes with migration message
|
|
608
|
+
if (modelName.startsWith('jumpstart://') || modelName.startsWith('jumpstart-hub://')) {
|
|
609
|
+
const bareId = modelName.replace(/^jumpstart(-hub)?:\/\//, '');
|
|
610
|
+
console.error(`\n ⚠️ JumpStart is no longer supported. Use the HuggingFace model ID directly: ${bareId}`);
|
|
611
|
+
console.error(' JumpStart model sources have been removed. Use one of:');
|
|
612
|
+
console.error(' • HuggingFace model ID (e.g., meta-llama/Llama-2-7b-hf)');
|
|
613
|
+
console.error(' • s3://bucket/path/model.tar.gz');
|
|
614
|
+
console.error(' • registry://model-package-name');
|
|
615
|
+
console.error(' • marketplace://arn:aws:sagemaker:...\n');
|
|
616
|
+
process.exit(1);
|
|
617
|
+
}
|
|
618
|
+
if (modelName.startsWith('marketplace://')) {
|
|
619
|
+
// marketplace://arn:aws:sagemaker:... → set architecture to marketplace and store ARN
|
|
620
|
+
const arn = modelName.replace(/^marketplace:\/\//, '');
|
|
621
|
+
combinedAnswers.modelPackageArn = arn;
|
|
622
|
+
combinedAnswers.architecture = 'marketplace';
|
|
623
|
+
combinedAnswers.deploymentConfig = 'marketplace';
|
|
624
|
+
combinedAnswers.modelSource = undefined;
|
|
625
|
+
} else if (modelName.startsWith('s3://')) {
|
|
600
626
|
combinedAnswers.modelSource = 's3';
|
|
601
627
|
combinedAnswers.artifactUri = modelName;
|
|
602
|
-
} else if (modelName.startsWith('jumpstart://')) {
|
|
603
|
-
combinedAnswers.modelSource = 'jumpstart';
|
|
604
|
-
} else if (modelName.startsWith('jumpstart-hub://')) {
|
|
605
|
-
combinedAnswers.modelSource = 'jumpstart-hub';
|
|
606
628
|
} else if (modelName.startsWith('registry://')) {
|
|
607
629
|
combinedAnswers.modelSource = 'registry';
|
|
608
630
|
}
|
|
@@ -613,7 +635,7 @@ export default class PromptRunner {
|
|
|
613
635
|
combinedAnswers.artifactUri = modelName;
|
|
614
636
|
}
|
|
615
637
|
}
|
|
616
|
-
const downloadSources = ['
|
|
638
|
+
const downloadSources = ['s3'];
|
|
617
639
|
if (downloadSources.includes(combinedAnswers.modelSource) && !combinedAnswers.artifactUri) {
|
|
618
640
|
console.log(`\n ⚠️ Model source is '${combinedAnswers.modelSource}' but no artifact URI was resolved.`);
|
|
619
641
|
console.log(' The model-picker could not determine the download location.');
|
|
@@ -638,18 +660,7 @@ export default class PromptRunner {
|
|
|
638
660
|
}
|
|
639
661
|
}
|
|
640
662
|
|
|
641
|
-
|
|
642
|
-
// HubAccessConfig on CreateModel, which is not yet supported by the generator.
|
|
643
|
-
if (combinedAnswers.modelSource === 'jumpstart-hub') {
|
|
644
|
-
console.log('\n ⚠️ JumpStart Private Hub models are not yet fully supported.');
|
|
645
|
-
console.log(' Private hub artifacts live in AWS-managed S3 buckets that require');
|
|
646
|
-
console.log(' SageMaker\'s HubAccessConfig mechanism for access.');
|
|
647
|
-
console.log(' The generated project will not be able to download model artifacts at runtime.');
|
|
648
|
-
console.log(' This feature is tracked for a future release.\n');
|
|
649
|
-
console.log(' Falling back to HuggingFace source.\n');
|
|
650
|
-
combinedAnswers.modelSource = 'huggingface';
|
|
651
|
-
delete combinedAnswers.artifactUri;
|
|
652
|
-
}
|
|
663
|
+
|
|
653
664
|
|
|
654
665
|
// Apply auto-set model format for Triton backends with single format
|
|
655
666
|
// Requirements: 3.3, 3.4, 3.5
|
|
@@ -731,6 +742,265 @@ export default class PromptRunner {
|
|
|
731
742
|
return combinedAnswers;
|
|
732
743
|
}
|
|
733
744
|
|
|
745
|
+
/**
|
|
746
|
+
* Marketplace-specific prompt flow.
|
|
747
|
+
* Skips all container-related prompts (framework, model server, base image, CUDA version)
|
|
748
|
+
* and prompts only for: model package ARN, instance type, deployment target, region.
|
|
749
|
+
*
|
|
750
|
+
* Requirements: 2.3, 2.4, 2.5
|
|
751
|
+
* @private
|
|
752
|
+
*/
|
|
753
|
+
async _runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp) {
|
|
754
|
+
console.log('\n🏪 Marketplace Model Package Configuration');
|
|
755
|
+
|
|
756
|
+
// Query marketplace-picker MCP server for subscription discovery
|
|
757
|
+
// Requirements: 2.4, 6.1, 6.2
|
|
758
|
+
let mcpSubscriptions = [];
|
|
759
|
+
const cm = this.configManager;
|
|
760
|
+
if (cm && cm.getMcpServerNames && cm.getMcpServerNames().includes('marketplace-picker')) {
|
|
761
|
+
try {
|
|
762
|
+
console.log(' 🔍 Querying marketplace-picker for subscriptions...');
|
|
763
|
+
const result = await cm.queryMcpServer('marketplace-picker', {
|
|
764
|
+
region: explicitConfig.awsRegion || existingConfig.awsRegion || process.env.AWS_REGION || 'us-east-1'
|
|
765
|
+
});
|
|
766
|
+
if (result && result.metadata?.subscriptions?.length > 0) {
|
|
767
|
+
mcpSubscriptions = result.metadata.subscriptions;
|
|
768
|
+
console.log(` ✅ Found ${mcpSubscriptions.length} Marketplace subscription(s)`);
|
|
769
|
+
} else {
|
|
770
|
+
console.log(' ℹ️ No Marketplace subscriptions found — enter ARN manually');
|
|
771
|
+
}
|
|
772
|
+
} catch (err) {
|
|
773
|
+
console.log(` ⚠️ marketplace-picker unavailable: ${err.message}`);
|
|
774
|
+
console.log(' Falling back to manual ARN entry');
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
// Marketplace-specific prompts: model package ARN
|
|
779
|
+
const marketplacePrompts = [
|
|
780
|
+
{
|
|
781
|
+
type: mcpSubscriptions.length > 0 ? 'list' : 'input',
|
|
782
|
+
name: 'modelPackageArn',
|
|
783
|
+
message: mcpSubscriptions.length > 0
|
|
784
|
+
? 'Select a Marketplace model package:'
|
|
785
|
+
: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
|
|
786
|
+
...(mcpSubscriptions.length > 0 ? {
|
|
787
|
+
choices: [
|
|
788
|
+
...mcpSubscriptions.map(sub => ({
|
|
789
|
+
name: `${sub.modelName} (${sub.vendor}) — ${sub.arn}`,
|
|
790
|
+
value: sub.arn,
|
|
791
|
+
short: sub.modelName
|
|
792
|
+
})),
|
|
793
|
+
{ type: 'separator', separator: '──────────────' },
|
|
794
|
+
{ name: 'Enter ARN manually...', value: '__manual__', short: 'manual' }
|
|
795
|
+
]
|
|
796
|
+
} : {
|
|
797
|
+
validate: (input) => {
|
|
798
|
+
if (!input || input.trim() === '') {
|
|
799
|
+
return 'Model package ARN is required';
|
|
800
|
+
}
|
|
801
|
+
const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
|
|
802
|
+
if (!arnPattern.test(input.trim())) {
|
|
803
|
+
return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
|
|
804
|
+
}
|
|
805
|
+
return true;
|
|
806
|
+
}
|
|
807
|
+
})
|
|
808
|
+
},
|
|
809
|
+
{
|
|
810
|
+
type: 'input',
|
|
811
|
+
name: 'modelPackageArnManual',
|
|
812
|
+
message: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
|
|
813
|
+
when: (answers) => answers.modelPackageArn === '__manual__',
|
|
814
|
+
validate: (input) => {
|
|
815
|
+
if (!input || input.trim() === '') {
|
|
816
|
+
return 'Model package ARN is required';
|
|
817
|
+
}
|
|
818
|
+
const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
|
|
819
|
+
if (!arnPattern.test(input.trim())) {
|
|
820
|
+
return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
|
|
821
|
+
}
|
|
822
|
+
return true;
|
|
823
|
+
}
|
|
824
|
+
}
|
|
825
|
+
];
|
|
826
|
+
const marketplaceAnswers = await this._runPhase(marketplacePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
|
|
827
|
+
|
|
828
|
+
// Handle manual ARN entry fallback
|
|
829
|
+
if (marketplaceAnswers.modelPackageArn === '__manual__' && marketplaceAnswers.modelPackageArnManual) {
|
|
830
|
+
marketplaceAnswers.modelPackageArn = marketplaceAnswers.modelPackageArnManual;
|
|
831
|
+
delete marketplaceAnswers.modelPackageArnManual;
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
// Infrastructure prompts: region, deployment target, instance type
|
|
835
|
+
console.log('\n💪 Infrastructure & Deployment');
|
|
836
|
+
const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion;
|
|
837
|
+
const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {};
|
|
838
|
+
|
|
839
|
+
// Marketplace deployment targets (no HyperPod — vendor controls the container)
|
|
840
|
+
const marketplaceInfraPrompts = [
|
|
841
|
+
{
|
|
842
|
+
type: 'list',
|
|
843
|
+
name: 'awsRegion',
|
|
844
|
+
message: 'Target AWS region?',
|
|
845
|
+
choices: (answers) => {
|
|
846
|
+
const bootstrapReg = answers._bootstrapRegion;
|
|
847
|
+
const choices = ['us-east-1'];
|
|
848
|
+
if (bootstrapReg && bootstrapReg !== 'us-east-1') {
|
|
849
|
+
choices.unshift({ name: `${bootstrapReg} (from bootstrap profile)`, value: bootstrapReg });
|
|
850
|
+
}
|
|
851
|
+
choices.push({ name: 'Custom...', value: 'custom' });
|
|
852
|
+
return choices;
|
|
853
|
+
},
|
|
854
|
+
default: (answers) => answers._bootstrapRegion || 'us-east-1'
|
|
855
|
+
},
|
|
856
|
+
{
|
|
857
|
+
type: 'input',
|
|
858
|
+
name: 'customAwsRegion',
|
|
859
|
+
message: 'Enter AWS region (e.g., us-west-2, eu-west-1):',
|
|
860
|
+
when: answers => answers.awsRegion === 'custom'
|
|
861
|
+
},
|
|
862
|
+
{
|
|
863
|
+
type: 'list',
|
|
864
|
+
name: 'deploymentTarget',
|
|
865
|
+
message: 'Deployment target?',
|
|
866
|
+
choices: [
|
|
867
|
+
{ name: 'SageMaker Real-Time Inference', value: 'realtime-inference' },
|
|
868
|
+
{ name: 'SageMaker Async Inference', value: 'async-inference' },
|
|
869
|
+
{ name: 'SageMaker Batch Transform', value: 'batch-transform' }
|
|
870
|
+
],
|
|
871
|
+
default: 'realtime-inference'
|
|
872
|
+
},
|
|
873
|
+
{
|
|
874
|
+
type: 'list',
|
|
875
|
+
name: 'instanceType',
|
|
876
|
+
message: 'Instance type for deployment?',
|
|
877
|
+
choices: [
|
|
878
|
+
{ name: 'ml.g5.xlarge (1 GPU, 24GB)', value: 'ml.g5.xlarge' },
|
|
879
|
+
{ name: 'ml.g5.2xlarge (1 GPU, 24GB)', value: 'ml.g5.2xlarge' },
|
|
880
|
+
{ name: 'ml.g5.4xlarge (1 GPU, 24GB)', value: 'ml.g5.4xlarge' },
|
|
881
|
+
{ name: 'ml.g5.12xlarge (4 GPUs, 96GB)', value: 'ml.g5.12xlarge' },
|
|
882
|
+
{ name: 'ml.p3.2xlarge (1 GPU, 16GB V100)', value: 'ml.p3.2xlarge' },
|
|
883
|
+
{ name: 'ml.m5.xlarge (CPU, 16GB)', value: 'ml.m5.xlarge' },
|
|
884
|
+
{ name: 'Custom...', value: 'custom' }
|
|
885
|
+
],
|
|
886
|
+
default: 'ml.g5.xlarge'
|
|
887
|
+
},
|
|
888
|
+
{
|
|
889
|
+
type: 'input',
|
|
890
|
+
name: 'customInstanceType',
|
|
891
|
+
message: 'Enter instance type (e.g., ml.g5.xlarge):',
|
|
892
|
+
validate: (input) => {
|
|
893
|
+
if (!input || input.trim() === '') {
|
|
894
|
+
return 'Instance type is required';
|
|
895
|
+
}
|
|
896
|
+
if (!input.startsWith('ml.')) {
|
|
897
|
+
return 'Instance type must start with "ml." (e.g., ml.g5.xlarge)';
|
|
898
|
+
}
|
|
899
|
+
return true;
|
|
900
|
+
},
|
|
901
|
+
when: answers => answers.instanceType === 'custom'
|
|
902
|
+
}
|
|
903
|
+
];
|
|
904
|
+
const infraAnswers = await this._runPhase(marketplaceInfraPrompts, { ...frameworkAnswers, ...regionPreviousAnswers }, explicitConfig, existingConfig);
|
|
905
|
+
|
|
906
|
+
// Async-specific prompts (only when deploymentTarget === 'async-inference')
|
|
907
|
+
let asyncAnswers = {};
|
|
908
|
+
if (infraAnswers.deploymentTarget === 'async-inference') {
|
|
909
|
+
asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
// Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
|
|
913
|
+
let batchTransformAnswers = {};
|
|
914
|
+
if (infraAnswers.deploymentTarget === 'batch-transform') {
|
|
915
|
+
batchTransformAnswers = await this._runPhase(
|
|
916
|
+
infraBatchTransformPrompts,
|
|
917
|
+
{ ...infraAnswers },
|
|
918
|
+
explicitConfig,
|
|
919
|
+
existingConfig
|
|
920
|
+
);
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
// Role ARN prompt (always needed for marketplace deploy)
|
|
924
|
+
const rolePrompts = [
|
|
925
|
+
{
|
|
926
|
+
type: 'input',
|
|
927
|
+
name: 'awsRoleArn',
|
|
928
|
+
message: 'AWS IAM Role ARN for SageMaker execution (optional)?',
|
|
929
|
+
validate: (input) => {
|
|
930
|
+
if (!input || input.trim() === '') {
|
|
931
|
+
return true;
|
|
932
|
+
}
|
|
933
|
+
const arnPattern = /^arn:aws:iam::\d{12}:role\/[\w+=,.@-]+$/;
|
|
934
|
+
if (!arnPattern.test(input)) {
|
|
935
|
+
return 'Invalid ARN format. Expected: arn:aws:iam::123456789012:role/RoleName';
|
|
936
|
+
}
|
|
937
|
+
return true;
|
|
938
|
+
}
|
|
939
|
+
}
|
|
940
|
+
];
|
|
941
|
+
const roleAnswers = await this._runPhase(rolePrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
942
|
+
|
|
943
|
+
// Project name + destination
|
|
944
|
+
console.log('\n📋 Project Configuration');
|
|
945
|
+
const allTechnicalAnswers = {
|
|
946
|
+
...frameworkAnswers,
|
|
947
|
+
...marketplaceAnswers,
|
|
948
|
+
...infraAnswers,
|
|
949
|
+
...asyncAnswers,
|
|
950
|
+
...batchTransformAnswers,
|
|
951
|
+
...roleAnswers
|
|
952
|
+
};
|
|
953
|
+
const projectAnswers = await this._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
|
|
954
|
+
const destinationAnswers = await this._runPhase(destinationPrompts,
|
|
955
|
+
{ ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
|
|
956
|
+
|
|
957
|
+
// Combine all marketplace answers
|
|
958
|
+
const combinedAnswers = {
|
|
959
|
+
...frameworkAnswers,
|
|
960
|
+
...marketplaceAnswers,
|
|
961
|
+
...infraAnswers,
|
|
962
|
+
...asyncAnswers,
|
|
963
|
+
...batchTransformAnswers,
|
|
964
|
+
...roleAnswers,
|
|
965
|
+
...projectAnswers,
|
|
966
|
+
...destinationAnswers,
|
|
967
|
+
buildTimestamp
|
|
968
|
+
};
|
|
969
|
+
|
|
970
|
+
// Handle custom instance type
|
|
971
|
+
if (combinedAnswers.customInstanceType) {
|
|
972
|
+
combinedAnswers.instanceType = combinedAnswers.customInstanceType;
|
|
973
|
+
delete combinedAnswers.customInstanceType;
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
// Handle custom AWS region
|
|
977
|
+
if (combinedAnswers.customAwsRegion) {
|
|
978
|
+
combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
|
|
979
|
+
delete combinedAnswers.customAwsRegion;
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
// Map awsRoleArn to roleArn for templates
|
|
983
|
+
if (combinedAnswers.awsRoleArn) {
|
|
984
|
+
combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
|
|
985
|
+
delete combinedAnswers.awsRoleArn;
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
// Ensure CLI-provided values are in combinedAnswers
|
|
989
|
+
if (explicitConfig.modelPackageArn && !combinedAnswers.modelPackageArn) {
|
|
990
|
+
combinedAnswers.modelPackageArn = explicitConfig.modelPackageArn;
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
// Handle marketplace:// prefix from --model-name CLI option
|
|
994
|
+
const modelName = explicitConfig.modelName || combinedAnswers.modelName;
|
|
995
|
+
if (modelName && modelName.startsWith('marketplace://')) {
|
|
996
|
+
const arn = modelName.replace(/^marketplace:\/\//, '');
|
|
997
|
+
combinedAnswers.modelPackageArn = arn;
|
|
998
|
+
delete combinedAnswers.modelName;
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
return combinedAnswers;
|
|
1002
|
+
}
|
|
1003
|
+
|
|
734
1004
|
/**
|
|
735
1005
|
* Checks if a parameter is promptable according to the parameter matrix
|
|
736
1006
|
* @param {string} parameterName - Name of the parameter
|
|
@@ -1746,9 +2016,7 @@ export default class PromptRunner {
|
|
|
1746
2016
|
const registryConfigManager = this.registryConfigManager;
|
|
1747
2017
|
if (registryConfigManager) {
|
|
1748
2018
|
// Only try HuggingFace API for bare model IDs (not prefixed URIs)
|
|
1749
|
-
const isNonHfUri = modelId.startsWith('
|
|
1750
|
-
modelId.startsWith('jumpstart-hub://') ||
|
|
1751
|
-
modelId.startsWith('s3://') ||
|
|
2019
|
+
const isNonHfUri = modelId.startsWith('s3://') ||
|
|
1752
2020
|
modelId.startsWith('registry://');
|
|
1753
2021
|
|
|
1754
2022
|
if (!isNonHfUri) {
|
|
@@ -1773,7 +2041,7 @@ export default class PromptRunner {
|
|
|
1773
2041
|
console.log(' ⚠️ HuggingFace API unavailable');
|
|
1774
2042
|
}
|
|
1775
2043
|
} else {
|
|
1776
|
-
// Non-HF URI (
|
|
2044
|
+
// Non-HF URI (s3://, registry://, etc.) — skip HF lookup silently
|
|
1777
2045
|
// The summary at the end of this function will report "No additional model information"
|
|
1778
2046
|
}
|
|
1779
2047
|
|
package/src/lib/prompts.js
CHANGED
|
@@ -232,6 +232,12 @@ const deploymentConfigPrompts = [
|
|
|
232
232
|
name: 'Diffusors with vLLM Omni',
|
|
233
233
|
value: 'diffusors-vllm-omni',
|
|
234
234
|
short: 'diffusors-vllm-omni'
|
|
235
|
+
},
|
|
236
|
+
{ type: 'separator', separator: '── AWS Marketplace ──' },
|
|
237
|
+
{
|
|
238
|
+
name: 'Marketplace Model Package',
|
|
239
|
+
value: 'marketplace',
|
|
240
|
+
short: 'marketplace'
|
|
235
241
|
}
|
|
236
242
|
]
|
|
237
243
|
}
|
|
@@ -469,9 +475,9 @@ const modelFormatPrompts = [
|
|
|
469
475
|
if (!input || input.trim() === '') {
|
|
470
476
|
return 'Model name is required';
|
|
471
477
|
}
|
|
472
|
-
// Basic validation - must contain a slash (org/model,
|
|
478
|
+
// Basic validation - must contain a slash (org/model, s3://path, etc.)
|
|
473
479
|
if (!input.includes('/')) {
|
|
474
|
-
return 'Please use the full model path (e.g., microsoft/DialoGPT-medium,
|
|
480
|
+
return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, s3://bucket/model, registry://my-package)';
|
|
475
481
|
}
|
|
476
482
|
return true;
|
|
477
483
|
},
|
|
@@ -583,7 +589,7 @@ const hfTokenPrompts = [
|
|
|
583
589
|
}
|
|
584
590
|
|
|
585
591
|
// Skip HF token prompt for non-HuggingFace model sources
|
|
586
|
-
// (S3,
|
|
592
|
+
// (S3, Registry models don't need HF auth)
|
|
587
593
|
const modelSource = answers.modelSource;
|
|
588
594
|
if (modelSource && modelSource !== 'huggingface') {
|
|
589
595
|
return false;
|
|
@@ -50,7 +50,7 @@ export default class TemplateManager {
|
|
|
50
50
|
*/
|
|
51
51
|
validate() {
|
|
52
52
|
const supportedOptions = {
|
|
53
|
-
//
|
|
53
|
+
// 16 canonical deployment-config values (2 http, 5 transformers, 7 triton, 1 diffusors, 1 marketplace)
|
|
54
54
|
deploymentConfigs: [
|
|
55
55
|
// HTTP architecture (2)
|
|
56
56
|
'http-flask', 'http-fastapi',
|
|
@@ -61,7 +61,9 @@ export default class TemplateManager {
|
|
|
61
61
|
'triton-fil', 'triton-onnxruntime', 'triton-tensorflow',
|
|
62
62
|
'triton-pytorch', 'triton-vllm', 'triton-tensorrtllm', 'triton-python',
|
|
63
63
|
// Diffusors architecture (1)
|
|
64
|
-
'diffusors-vllm-omni'
|
|
64
|
+
'diffusors-vllm-omni',
|
|
65
|
+
// Marketplace architecture (1)
|
|
66
|
+
'marketplace'
|
|
65
67
|
],
|
|
66
68
|
buildTargets: ['codebuild'],
|
|
67
69
|
deploymentTargets: ['realtime-inference', 'async-inference', 'batch-transform', 'hyperpod-eks'],
|
|
@@ -82,7 +84,7 @@ export default class TemplateManager {
|
|
|
82
84
|
this._validateGpuRequirement();
|
|
83
85
|
} else {
|
|
84
86
|
// Fallback: validate architecture and backend separately (new canonical format)
|
|
85
|
-
const architectures = ['http', 'transformers', 'triton', 'diffusors'];
|
|
87
|
+
const architectures = ['http', 'transformers', 'triton', 'diffusors', 'marketplace'];
|
|
86
88
|
const backends = [
|
|
87
89
|
// http backends
|
|
88
90
|
'flask', 'fastapi',
|
|
@@ -95,7 +97,11 @@ export default class TemplateManager {
|
|
|
95
97
|
];
|
|
96
98
|
|
|
97
99
|
this._validateChoice('architecture', architectures);
|
|
98
|
-
|
|
100
|
+
|
|
101
|
+
// Marketplace has no backend — skip backend validation
|
|
102
|
+
if (this.answers.architecture !== 'marketplace') {
|
|
103
|
+
this._validateChoice('backend', backends);
|
|
104
|
+
}
|
|
99
105
|
|
|
100
106
|
// Validate tensorrt-llm is only used with transformers architecture
|
|
101
107
|
if (this.answers.backend === 'tensorrt-llm' && this.answers.architecture !== 'transformers') {
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
|
|
14
14
|
/**
|
|
15
15
|
* Look up a model entry in the catalog by model ID.
|
|
16
|
-
* @param {string} modelId - The
|
|
16
|
+
* @param {string} modelId - The model ID to look up
|
|
17
17
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
18
18
|
* @returns {Object|null} The catalog entry for the model, or null if not found
|
|
19
19
|
*/
|
|
@@ -29,7 +29,7 @@ export function lookupModel(modelId, catalog) {
|
|
|
29
29
|
|
|
30
30
|
/**
|
|
31
31
|
* Check whether a model ID is present in the Supported Model Catalog.
|
|
32
|
-
* @param {string} modelId - The
|
|
32
|
+
* @param {string} modelId - The model ID to check
|
|
33
33
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
34
34
|
* @returns {boolean} True if the model is in the catalog
|
|
35
35
|
*/
|
|
@@ -41,7 +41,7 @@ export function isTuneSupported(modelId, catalog) {
|
|
|
41
41
|
* Validate that a model ID exists in the catalog.
|
|
42
42
|
* Returns a descriptive error when the model is not supported, including
|
|
43
43
|
* the model name, supported families, and a reference to `do/train`.
|
|
44
|
-
* @param {string} modelId - The
|
|
44
|
+
* @param {string} modelId - The model ID to validate
|
|
45
45
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
46
46
|
* @returns {{ valid: boolean, error?: string }}
|
|
47
47
|
*/
|
|
@@ -65,7 +65,7 @@ export function validateModel(modelId, catalog) {
|
|
|
65
65
|
* Validate that a technique is supported for the given model.
|
|
66
66
|
* Returns a descriptive error listing the supported techniques when
|
|
67
67
|
* the requested technique is not available.
|
|
68
|
-
* @param {string} modelId - The
|
|
68
|
+
* @param {string} modelId - The model ID
|
|
69
69
|
* @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
|
|
70
70
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
71
71
|
* @returns {{ valid: boolean, error?: string }}
|
|
@@ -92,7 +92,7 @@ export function validateTechnique(modelId, technique, catalog) {
|
|
|
92
92
|
* Validate that a training type is supported for the given model and technique.
|
|
93
93
|
* Returns a descriptive error listing the supported training types when
|
|
94
94
|
* the requested type is not available.
|
|
95
|
-
* @param {string} modelId - The
|
|
95
|
+
* @param {string} modelId - The model ID
|
|
96
96
|
* @param {string} technique - The technique (e.g., 'sft', 'dpo')
|
|
97
97
|
* @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
|
|
98
98
|
* @param {Object} catalog - The tune catalog object with a `models` map
|
package/templates/code/serve
CHANGED
|
@@ -113,7 +113,7 @@ resolve_model() {
|
|
|
113
113
|
echo "${!_MODEL_VAR}"
|
|
114
114
|
return
|
|
115
115
|
;;
|
|
116
|
-
s3|
|
|
116
|
+
s3|registry)
|
|
117
117
|
# Check for pre-mounted artifacts first
|
|
118
118
|
if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
|
|
119
119
|
echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
|
|
@@ -245,7 +245,7 @@ ARG_PREFIX="--"
|
|
|
245
245
|
|
|
246
246
|
# Define environment variables to exclude (internal variables set by base images)
|
|
247
247
|
<% if (modelServer === 'vllm') { %>
|
|
248
|
-
EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
|
|
248
|
+
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
|
|
249
249
|
<% } else if (modelServer === 'sglang') { %>
|
|
250
250
|
EXCLUDE_VARS=()
|
|
251
251
|
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
@@ -15,7 +15,7 @@ option.model_id=<%= modelName %>
|
|
|
15
15
|
option.model_id=<%= artifactUri %>
|
|
16
16
|
<% } else { %>
|
|
17
17
|
# Model will be loaded from /opt/ml/model at runtime
|
|
18
|
-
# (
|
|
18
|
+
# (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
|
|
19
19
|
# option.model_id=/opt/ml/model
|
|
20
20
|
<% } %>
|
|
21
21
|
|
|
@@ -71,7 +71,7 @@ option.model_id=<%= modelName %>
|
|
|
71
71
|
option.model_id=<%= artifactUri %>
|
|
72
72
|
<% } else { %>
|
|
73
73
|
# Model will be loaded from /opt/ml/model at runtime
|
|
74
|
-
# (
|
|
74
|
+
# (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
|
|
75
75
|
# option.model_id=/opt/ml/model
|
|
76
76
|
<% } %>
|
|
77
77
|
|
|
@@ -9,10 +9,10 @@ echo "Starting vLLM-Omni server (diffusion model serving)"
|
|
|
9
9
|
|
|
10
10
|
# Resolve model URI prefixes that engines cannot handle natively.
|
|
11
11
|
# The generator's model-picker may store provider-specific URIs
|
|
12
|
-
# (e.g.
|
|
13
|
-
#
|
|
12
|
+
# (e.g. registry://my-model-group/1) as the model identifier.
|
|
13
|
+
# vLLM expects a HuggingFace repo ID or local path.
|
|
14
14
|
_RAW_MODEL="${VLLM_MODEL:-}"
|
|
15
|
-
if [[ "$_RAW_MODEL" ==
|
|
15
|
+
if [[ "$_RAW_MODEL" == registry://* ]]; then
|
|
16
16
|
if [ -d /opt/ml/model ] && [ "$(ls -A /opt/ml/model 2>/dev/null)" ]; then
|
|
17
17
|
echo "Resolved VLLM_MODEL='${_RAW_MODEL}' → /opt/ml/model (local artifacts found)"
|
|
18
18
|
export VLLM_MODEL="/opt/ml/model"
|
|
@@ -176,7 +176,7 @@ def cmd_submit(args):
|
|
|
176
176
|
)
|
|
177
177
|
elif "ValidationException" in error_msg and "license" in error_msg.lower():
|
|
178
178
|
_error_exit(
|
|
179
|
-
f"Model license not accepted. Accept the license
|
|
179
|
+
f"Model license not accepted. Accept the model license before "
|
|
180
180
|
f"using this model for customization. Details: {error_msg}"
|
|
181
181
|
)
|
|
182
182
|
else:
|
|
@@ -660,7 +660,7 @@ def main():
|
|
|
660
660
|
|
|
661
661
|
# ── submit ────────────────────────────────────────────────────────────────
|
|
662
662
|
submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
|
|
663
|
-
submit_parser.add_argument("--model-id", required=True, help="
|
|
663
|
+
submit_parser.add_argument("--model-id", required=True, help="Model ID")
|
|
664
664
|
submit_parser.add_argument("--technique", required=True,
|
|
665
665
|
choices=["sft", "dpo", "rlaif", "rlvr"],
|
|
666
666
|
help="Customization technique")
|
package/templates/do/register
CHANGED
|
@@ -191,8 +191,14 @@ fi
|
|
|
191
191
|
# ============================================================
|
|
192
192
|
|
|
193
193
|
# DEPLOYMENT_CONFIG format: <architecture>-<backend> (e.g., transformers-vllm, http-flask, triton-fil)
|
|
194
|
-
|
|
195
|
-
|
|
194
|
+
# Special case: marketplace has no backend
|
|
195
|
+
if [ "${DEPLOYMENT_CONFIG}" = "marketplace" ]; then
|
|
196
|
+
ARCHITECTURE="marketplace"
|
|
197
|
+
BACKEND=""
|
|
198
|
+
else
|
|
199
|
+
ARCHITECTURE="${DEPLOYMENT_CONFIG%%-*}"
|
|
200
|
+
BACKEND="${DEPLOYMENT_CONFIG#*-}"
|
|
201
|
+
fi
|
|
196
202
|
|
|
197
203
|
echo "📋 Registering deployment to registry"
|
|
198
204
|
echo " Project: ${PROJECT_NAME}"
|
package/templates/do/test
CHANGED
|
@@ -103,9 +103,9 @@ case "${FRAMEWORK}" in
|
|
|
103
103
|
case "${MODEL_SERVER}" in
|
|
104
104
|
vllm|sglang)
|
|
105
105
|
# OpenAI-compatible chat completions format
|
|
106
|
-
# For S3/
|
|
106
|
+
# For S3/registry models, vLLM registers the model under the local path
|
|
107
107
|
VLLM_MODEL_NAME="${MODEL_NAME}"
|
|
108
|
-
if [[ "${MODEL_NAME}" ==
|
|
108
|
+
if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
|
|
109
109
|
VLLM_MODEL_NAME="/opt/ml/model"
|
|
110
110
|
fi
|
|
111
111
|
TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
|
|
@@ -431,7 +431,7 @@ case "${FRAMEWORK}" in
|
|
|
431
431
|
case "${MODEL_SERVER}" in
|
|
432
432
|
vllm|sglang)
|
|
433
433
|
VLLM_MODEL_NAME="${MODEL_NAME}"
|
|
434
|
-
if [[ "${MODEL_NAME}" ==
|
|
434
|
+
if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
|
|
435
435
|
VLLM_MODEL_NAME="/opt/ml/model"
|
|
436
436
|
fi
|
|
437
437
|
TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
|
|
@@ -808,7 +808,7 @@ case "${FRAMEWORK}" in
|
|
|
808
808
|
vllm|sglang)
|
|
809
809
|
# OpenAI-compatible chat completions format
|
|
810
810
|
VLLM_MODEL_NAME="${MODEL_NAME}"
|
|
811
|
-
if [[ "${MODEL_NAME}" ==
|
|
811
|
+
if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
|
|
812
812
|
VLLM_MODEL_NAME="/opt/ml/model"
|
|
813
813
|
fi
|
|
814
814
|
TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
|
|
@@ -1095,7 +1095,7 @@ case "${FRAMEWORK}" in
|
|
|
1095
1095
|
case "${MODEL_SERVER}" in
|
|
1096
1096
|
vllm|sglang)
|
|
1097
1097
|
VLLM_MODEL_NAME="${MODEL_NAME}"
|
|
1098
|
-
if [[ "${MODEL_NAME}" ==
|
|
1098
|
+
if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
|
|
1099
1099
|
VLLM_MODEL_NAME="/opt/ml/model"
|
|
1100
1100
|
fi
|
|
1101
1101
|
TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
|
package/templates/do/tune
CHANGED
|
@@ -67,7 +67,7 @@ _parse_args() {
|
|
|
67
67
|
ARG_TRAINING_TYPE="$2"; shift 2 ;;
|
|
68
68
|
--model)
|
|
69
69
|
if [ -z "${2:-}" ]; then
|
|
70
|
-
echo "❌ --model requires a
|
|
70
|
+
echo "❌ --model requires a model ID"
|
|
71
71
|
exit 1
|
|
72
72
|
fi
|
|
73
73
|
ARG_MODEL="$2"; shift 2 ;;
|
|
@@ -287,7 +287,7 @@ for family in sorted(families.keys()):
|
|
|
287
287
|
for entry in entries:
|
|
288
288
|
techniques = list(entry.get('techniques', {}).keys())
|
|
289
289
|
print(f' • {entry[\"displayName\"]}')
|
|
290
|
-
print(f' ID: {entry[\"
|
|
290
|
+
print(f' ID: {entry[\"modelId\"]}')
|
|
291
291
|
for t in techniques:
|
|
292
292
|
tc = entry['techniques'][t]
|
|
293
293
|
types = ', '.join(tc.get('trainingTypes', []))
|