@aws/ml-container-creator 0.7.1 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -111,6 +111,14 @@ export default class PromptRunner {
111
111
  framework: framework || deploymentConfigAnswers.framework,
112
112
  modelServer: modelServer || deploymentConfigAnswers.modelServer
113
113
  };
114
+
115
+ // ──────────────────────────────────────────────────────────────────────
116
+ // Marketplace fast-path: skip all container-related prompts
117
+ // Requirements: 2.3, 2.4, 2.5
118
+ // ──────────────────────────────────────────────────────────────────────
119
+ if (frameworkAnswers.architecture === 'marketplace') {
120
+ return this._runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp);
121
+ }
114
122
 
115
123
  // Engine prompt for http architecture
116
124
  const engineAnswers = await this._runPhase(enginePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
@@ -596,13 +604,27 @@ export default class PromptRunner {
596
604
  // Infer modelSource from model name prefix if not set by MCP
597
605
  const modelName = combinedAnswers.customModelName || combinedAnswers.modelName;
598
606
  if (!combinedAnswers.modelSource && modelName) {
599
- if (modelName.startsWith('s3://')) {
607
+ // Reject deprecated JumpStart prefixes with migration message
608
+ if (modelName.startsWith('jumpstart://') || modelName.startsWith('jumpstart-hub://')) {
609
+ const bareId = modelName.replace(/^jumpstart(-hub)?:\/\//, '');
610
+ console.error(`\n ⚠️ JumpStart is no longer supported. Use the HuggingFace model ID directly: ${bareId}`);
611
+ console.error(' JumpStart model sources have been removed. Use one of:');
612
+ console.error(' • HuggingFace model ID (e.g., meta-llama/Llama-2-7b-hf)');
613
+ console.error(' • s3://bucket/path/model.tar.gz');
614
+ console.error(' • registry://model-package-name');
615
+ console.error(' • marketplace://arn:aws:sagemaker:...\n');
616
+ process.exit(1);
617
+ }
618
+ if (modelName.startsWith('marketplace://')) {
619
+ // marketplace://arn:aws:sagemaker:... → set architecture to marketplace and store ARN
620
+ const arn = modelName.replace(/^marketplace:\/\//, '');
621
+ combinedAnswers.modelPackageArn = arn;
622
+ combinedAnswers.architecture = 'marketplace';
623
+ combinedAnswers.deploymentConfig = 'marketplace';
624
+ combinedAnswers.modelSource = undefined;
625
+ } else if (modelName.startsWith('s3://')) {
600
626
  combinedAnswers.modelSource = 's3';
601
627
  combinedAnswers.artifactUri = modelName;
602
- } else if (modelName.startsWith('jumpstart://')) {
603
- combinedAnswers.modelSource = 'jumpstart';
604
- } else if (modelName.startsWith('jumpstart-hub://')) {
605
- combinedAnswers.modelSource = 'jumpstart-hub';
606
628
  } else if (modelName.startsWith('registry://')) {
607
629
  combinedAnswers.modelSource = 'registry';
608
630
  }
@@ -613,7 +635,7 @@ export default class PromptRunner {
613
635
  combinedAnswers.artifactUri = modelName;
614
636
  }
615
637
  }
616
- const downloadSources = ['jumpstart', 's3'];
638
+ const downloadSources = ['s3'];
617
639
  if (downloadSources.includes(combinedAnswers.modelSource) && !combinedAnswers.artifactUri) {
618
640
  console.log(`\n ⚠️ Model source is '${combinedAnswers.modelSource}' but no artifact URI was resolved.`);
619
641
  console.log(' The model-picker could not determine the download location.');
@@ -638,18 +660,7 @@ export default class PromptRunner {
638
660
  }
639
661
  }
640
662
 
641
- // Warn about jumpstart-hub:// models — private hub deployment requires
642
- // HubAccessConfig on CreateModel, which is not yet supported by the generator.
643
- if (combinedAnswers.modelSource === 'jumpstart-hub') {
644
- console.log('\n ⚠️ JumpStart Private Hub models are not yet fully supported.');
645
- console.log(' Private hub artifacts live in AWS-managed S3 buckets that require');
646
- console.log(' SageMaker\'s HubAccessConfig mechanism for access.');
647
- console.log(' The generated project will not be able to download model artifacts at runtime.');
648
- console.log(' This feature is tracked for a future release.\n');
649
- console.log(' Falling back to HuggingFace source.\n');
650
- combinedAnswers.modelSource = 'huggingface';
651
- delete combinedAnswers.artifactUri;
652
- }
663
+
653
664
 
654
665
  // Apply auto-set model format for Triton backends with single format
655
666
  // Requirements: 3.3, 3.4, 3.5
@@ -731,6 +742,265 @@ export default class PromptRunner {
731
742
  return combinedAnswers;
732
743
  }
733
744
 
745
+ /**
746
+ * Marketplace-specific prompt flow.
747
+ * Skips all container-related prompts (framework, model server, base image, CUDA version)
748
+ * and prompts only for: model package ARN, instance type, deployment target, region.
749
+ *
750
+ * Requirements: 2.3, 2.4, 2.5
751
+ * @private
752
+ */
753
+ async _runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp) {
754
+ console.log('\n🏪 Marketplace Model Package Configuration');
755
+
756
+ // Query marketplace-picker MCP server for subscription discovery
757
+ // Requirements: 2.4, 6.1, 6.2
758
+ let mcpSubscriptions = [];
759
+ const cm = this.configManager;
760
+ if (cm && cm.getMcpServerNames && cm.getMcpServerNames().includes('marketplace-picker')) {
761
+ try {
762
+ console.log(' 🔍 Querying marketplace-picker for subscriptions...');
763
+ const result = await cm.queryMcpServer('marketplace-picker', {
764
+ region: explicitConfig.awsRegion || existingConfig.awsRegion || process.env.AWS_REGION || 'us-east-1'
765
+ });
766
+ if (result && result.metadata?.subscriptions?.length > 0) {
767
+ mcpSubscriptions = result.metadata.subscriptions;
768
+ console.log(` ✅ Found ${mcpSubscriptions.length} Marketplace subscription(s)`);
769
+ } else {
770
+ console.log(' ℹ️ No Marketplace subscriptions found — enter ARN manually');
771
+ }
772
+ } catch (err) {
773
+ console.log(` ⚠️ marketplace-picker unavailable: ${err.message}`);
774
+ console.log(' Falling back to manual ARN entry');
775
+ }
776
+ }
777
+
778
+ // Marketplace-specific prompts: model package ARN
779
+ const marketplacePrompts = [
780
+ {
781
+ type: mcpSubscriptions.length > 0 ? 'list' : 'input',
782
+ name: 'modelPackageArn',
783
+ message: mcpSubscriptions.length > 0
784
+ ? 'Select a Marketplace model package:'
785
+ : 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
786
+ ...(mcpSubscriptions.length > 0 ? {
787
+ choices: [
788
+ ...mcpSubscriptions.map(sub => ({
789
+ name: `${sub.modelName} (${sub.vendor}) — ${sub.arn}`,
790
+ value: sub.arn,
791
+ short: sub.modelName
792
+ })),
793
+ { type: 'separator', separator: '──────────────' },
794
+ { name: 'Enter ARN manually...', value: '__manual__', short: 'manual' }
795
+ ]
796
+ } : {
797
+ validate: (input) => {
798
+ if (!input || input.trim() === '') {
799
+ return 'Model package ARN is required';
800
+ }
801
+ const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
802
+ if (!arnPattern.test(input.trim())) {
803
+ return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
804
+ }
805
+ return true;
806
+ }
807
+ })
808
+ },
809
+ {
810
+ type: 'input',
811
+ name: 'modelPackageArnManual',
812
+ message: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
813
+ when: (answers) => answers.modelPackageArn === '__manual__',
814
+ validate: (input) => {
815
+ if (!input || input.trim() === '') {
816
+ return 'Model package ARN is required';
817
+ }
818
+ const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
819
+ if (!arnPattern.test(input.trim())) {
820
+ return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
821
+ }
822
+ return true;
823
+ }
824
+ }
825
+ ];
826
+ const marketplaceAnswers = await this._runPhase(marketplacePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
827
+
828
+ // Handle manual ARN entry fallback
829
+ if (marketplaceAnswers.modelPackageArn === '__manual__' && marketplaceAnswers.modelPackageArnManual) {
830
+ marketplaceAnswers.modelPackageArn = marketplaceAnswers.modelPackageArnManual;
831
+ delete marketplaceAnswers.modelPackageArnManual;
832
+ }
833
+
834
+ // Infrastructure prompts: region, deployment target, instance type
835
+ console.log('\n💪 Infrastructure & Deployment');
836
+ const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion;
837
+ const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {};
838
+
839
+ // Marketplace deployment targets (no HyperPod — vendor controls the container)
840
+ const marketplaceInfraPrompts = [
841
+ {
842
+ type: 'list',
843
+ name: 'awsRegion',
844
+ message: 'Target AWS region?',
845
+ choices: (answers) => {
846
+ const bootstrapReg = answers._bootstrapRegion;
847
+ const choices = ['us-east-1'];
848
+ if (bootstrapReg && bootstrapReg !== 'us-east-1') {
849
+ choices.unshift({ name: `${bootstrapReg} (from bootstrap profile)`, value: bootstrapReg });
850
+ }
851
+ choices.push({ name: 'Custom...', value: 'custom' });
852
+ return choices;
853
+ },
854
+ default: (answers) => answers._bootstrapRegion || 'us-east-1'
855
+ },
856
+ {
857
+ type: 'input',
858
+ name: 'customAwsRegion',
859
+ message: 'Enter AWS region (e.g., us-west-2, eu-west-1):',
860
+ when: answers => answers.awsRegion === 'custom'
861
+ },
862
+ {
863
+ type: 'list',
864
+ name: 'deploymentTarget',
865
+ message: 'Deployment target?',
866
+ choices: [
867
+ { name: 'SageMaker Real-Time Inference', value: 'realtime-inference' },
868
+ { name: 'SageMaker Async Inference', value: 'async-inference' },
869
+ { name: 'SageMaker Batch Transform', value: 'batch-transform' }
870
+ ],
871
+ default: 'realtime-inference'
872
+ },
873
+ {
874
+ type: 'list',
875
+ name: 'instanceType',
876
+ message: 'Instance type for deployment?',
877
+ choices: [
878
+ { name: 'ml.g5.xlarge (1 GPU, 24GB)', value: 'ml.g5.xlarge' },
879
+ { name: 'ml.g5.2xlarge (1 GPU, 24GB)', value: 'ml.g5.2xlarge' },
880
+ { name: 'ml.g5.4xlarge (1 GPU, 24GB)', value: 'ml.g5.4xlarge' },
881
+ { name: 'ml.g5.12xlarge (4 GPUs, 96GB)', value: 'ml.g5.12xlarge' },
882
+ { name: 'ml.p3.2xlarge (1 GPU, 16GB V100)', value: 'ml.p3.2xlarge' },
883
+ { name: 'ml.m5.xlarge (CPU, 16GB)', value: 'ml.m5.xlarge' },
884
+ { name: 'Custom...', value: 'custom' }
885
+ ],
886
+ default: 'ml.g5.xlarge'
887
+ },
888
+ {
889
+ type: 'input',
890
+ name: 'customInstanceType',
891
+ message: 'Enter instance type (e.g., ml.g5.xlarge):',
892
+ validate: (input) => {
893
+ if (!input || input.trim() === '') {
894
+ return 'Instance type is required';
895
+ }
896
+ if (!input.startsWith('ml.')) {
897
+ return 'Instance type must start with "ml." (e.g., ml.g5.xlarge)';
898
+ }
899
+ return true;
900
+ },
901
+ when: answers => answers.instanceType === 'custom'
902
+ }
903
+ ];
904
+ const infraAnswers = await this._runPhase(marketplaceInfraPrompts, { ...frameworkAnswers, ...regionPreviousAnswers }, explicitConfig, existingConfig);
905
+
906
+ // Async-specific prompts (only when deploymentTarget === 'async-inference')
907
+ let asyncAnswers = {};
908
+ if (infraAnswers.deploymentTarget === 'async-inference') {
909
+ asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
910
+ }
911
+
912
+ // Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
913
+ let batchTransformAnswers = {};
914
+ if (infraAnswers.deploymentTarget === 'batch-transform') {
915
+ batchTransformAnswers = await this._runPhase(
916
+ infraBatchTransformPrompts,
917
+ { ...infraAnswers },
918
+ explicitConfig,
919
+ existingConfig
920
+ );
921
+ }
922
+
923
+ // Role ARN prompt (always needed for marketplace deploy)
924
+ const rolePrompts = [
925
+ {
926
+ type: 'input',
927
+ name: 'awsRoleArn',
928
+ message: 'AWS IAM Role ARN for SageMaker execution (optional)?',
929
+ validate: (input) => {
930
+ if (!input || input.trim() === '') {
931
+ return true;
932
+ }
933
+ const arnPattern = /^arn:aws:iam::\d{12}:role\/[\w+=,.@-]+$/;
934
+ if (!arnPattern.test(input)) {
935
+ return 'Invalid ARN format. Expected: arn:aws:iam::123456789012:role/RoleName';
936
+ }
937
+ return true;
938
+ }
939
+ }
940
+ ];
941
+ const roleAnswers = await this._runPhase(rolePrompts, { ...infraAnswers }, explicitConfig, existingConfig);
942
+
943
+ // Project name + destination
944
+ console.log('\n📋 Project Configuration');
945
+ const allTechnicalAnswers = {
946
+ ...frameworkAnswers,
947
+ ...marketplaceAnswers,
948
+ ...infraAnswers,
949
+ ...asyncAnswers,
950
+ ...batchTransformAnswers,
951
+ ...roleAnswers
952
+ };
953
+ const projectAnswers = await this._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
954
+ const destinationAnswers = await this._runPhase(destinationPrompts,
955
+ { ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
956
+
957
+ // Combine all marketplace answers
958
+ const combinedAnswers = {
959
+ ...frameworkAnswers,
960
+ ...marketplaceAnswers,
961
+ ...infraAnswers,
962
+ ...asyncAnswers,
963
+ ...batchTransformAnswers,
964
+ ...roleAnswers,
965
+ ...projectAnswers,
966
+ ...destinationAnswers,
967
+ buildTimestamp
968
+ };
969
+
970
+ // Handle custom instance type
971
+ if (combinedAnswers.customInstanceType) {
972
+ combinedAnswers.instanceType = combinedAnswers.customInstanceType;
973
+ delete combinedAnswers.customInstanceType;
974
+ }
975
+
976
+ // Handle custom AWS region
977
+ if (combinedAnswers.customAwsRegion) {
978
+ combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
979
+ delete combinedAnswers.customAwsRegion;
980
+ }
981
+
982
+ // Map awsRoleArn to roleArn for templates
983
+ if (combinedAnswers.awsRoleArn) {
984
+ combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
985
+ delete combinedAnswers.awsRoleArn;
986
+ }
987
+
988
+ // Ensure CLI-provided values are in combinedAnswers
989
+ if (explicitConfig.modelPackageArn && !combinedAnswers.modelPackageArn) {
990
+ combinedAnswers.modelPackageArn = explicitConfig.modelPackageArn;
991
+ }
992
+
993
+ // Handle marketplace:// prefix from --model-name CLI option
994
+ const modelName = explicitConfig.modelName || combinedAnswers.modelName;
995
+ if (modelName && modelName.startsWith('marketplace://')) {
996
+ const arn = modelName.replace(/^marketplace:\/\//, '');
997
+ combinedAnswers.modelPackageArn = arn;
998
+ delete combinedAnswers.modelName;
999
+ }
1000
+
1001
+ return combinedAnswers;
1002
+ }
1003
+
734
1004
  /**
735
1005
  * Checks if a parameter is promptable according to the parameter matrix
736
1006
  * @param {string} parameterName - Name of the parameter
@@ -1746,9 +2016,7 @@ export default class PromptRunner {
1746
2016
  const registryConfigManager = this.registryConfigManager;
1747
2017
  if (registryConfigManager) {
1748
2018
  // Only try HuggingFace API for bare model IDs (not prefixed URIs)
1749
- const isNonHfUri = modelId.startsWith('jumpstart://') ||
1750
- modelId.startsWith('jumpstart-hub://') ||
1751
- modelId.startsWith('s3://') ||
2019
+ const isNonHfUri = modelId.startsWith('s3://') ||
1752
2020
  modelId.startsWith('registry://');
1753
2021
 
1754
2022
  if (!isNonHfUri) {
@@ -1773,7 +2041,7 @@ export default class PromptRunner {
1773
2041
  console.log(' ⚠️ HuggingFace API unavailable');
1774
2042
  }
1775
2043
  } else {
1776
- // Non-HF URI (jumpstart://, s3://, etc.) — skip HF lookup silently
2044
+ // Non-HF URI (s3://, registry://, etc.) — skip HF lookup silently
1777
2045
  // The summary at the end of this function will report "No additional model information"
1778
2046
  }
1779
2047
 
@@ -232,6 +232,12 @@ const deploymentConfigPrompts = [
232
232
  name: 'Diffusors with vLLM Omni',
233
233
  value: 'diffusors-vllm-omni',
234
234
  short: 'diffusors-vllm-omni'
235
+ },
236
+ { type: 'separator', separator: '── AWS Marketplace ──' },
237
+ {
238
+ name: 'Marketplace Model Package',
239
+ value: 'marketplace',
240
+ short: 'marketplace'
235
241
  }
236
242
  ]
237
243
  }
@@ -469,9 +475,9 @@ const modelFormatPrompts = [
469
475
  if (!input || input.trim() === '') {
470
476
  return 'Model name is required';
471
477
  }
472
- // Basic validation - must contain a slash (org/model, hub/model, s3://path, etc.)
478
+ // Basic validation - must contain a slash (org/model, s3://path, etc.)
473
479
  if (!input.includes('/')) {
474
- return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, jumpstart-hub://my-hub/my-model)';
480
+ return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, s3://bucket/model, registry://my-package)';
475
481
  }
476
482
  return true;
477
483
  },
@@ -583,7 +589,7 @@ const hfTokenPrompts = [
583
589
  }
584
590
 
585
591
  // Skip HF token prompt for non-HuggingFace model sources
586
- // (S3, JumpStart, Private Hub, Registry models don't need HF auth)
592
+ // (S3, Registry models don't need HF auth)
587
593
  const modelSource = answers.modelSource;
588
594
  if (modelSource && modelSource !== 'huggingface') {
589
595
  return false;
@@ -50,7 +50,7 @@ export default class TemplateManager {
50
50
  */
51
51
  validate() {
52
52
  const supportedOptions = {
53
- // 15 canonical deployment-config values (2 http, 5 transformers, 7 triton, 1 diffusors)
53
+ // 16 canonical deployment-config values (2 http, 5 transformers, 7 triton, 1 diffusors, 1 marketplace)
54
54
  deploymentConfigs: [
55
55
  // HTTP architecture (2)
56
56
  'http-flask', 'http-fastapi',
@@ -61,7 +61,9 @@ export default class TemplateManager {
61
61
  'triton-fil', 'triton-onnxruntime', 'triton-tensorflow',
62
62
  'triton-pytorch', 'triton-vllm', 'triton-tensorrtllm', 'triton-python',
63
63
  // Diffusors architecture (1)
64
- 'diffusors-vllm-omni'
64
+ 'diffusors-vllm-omni',
65
+ // Marketplace architecture (1)
66
+ 'marketplace'
65
67
  ],
66
68
  buildTargets: ['codebuild'],
67
69
  deploymentTargets: ['realtime-inference', 'async-inference', 'batch-transform', 'hyperpod-eks'],
@@ -82,7 +84,7 @@ export default class TemplateManager {
82
84
  this._validateGpuRequirement();
83
85
  } else {
84
86
  // Fallback: validate architecture and backend separately (new canonical format)
85
- const architectures = ['http', 'transformers', 'triton', 'diffusors'];
87
+ const architectures = ['http', 'transformers', 'triton', 'diffusors', 'marketplace'];
86
88
  const backends = [
87
89
  // http backends
88
90
  'flask', 'fastapi',
@@ -95,7 +97,11 @@ export default class TemplateManager {
95
97
  ];
96
98
 
97
99
  this._validateChoice('architecture', architectures);
98
- this._validateChoice('backend', backends);
100
+
101
+ // Marketplace has no backend — skip backend validation
102
+ if (this.answers.architecture !== 'marketplace') {
103
+ this._validateChoice('backend', backends);
104
+ }
99
105
 
100
106
  // Validate tensorrt-llm is only used with transformers architecture
101
107
  if (this.answers.backend === 'tensorrt-llm' && this.answers.architecture !== 'transformers') {
@@ -13,7 +13,7 @@
13
13
 
14
14
  /**
15
15
  * Look up a model entry in the catalog by model ID.
16
- * @param {string} modelId - The JumpStart model ID to look up
16
+ * @param {string} modelId - The model ID to look up
17
17
  * @param {Object} catalog - The tune catalog object with a `models` map
18
18
  * @returns {Object|null} The catalog entry for the model, or null if not found
19
19
  */
@@ -29,7 +29,7 @@ export function lookupModel(modelId, catalog) {
29
29
 
30
30
  /**
31
31
  * Check whether a model ID is present in the Supported Model Catalog.
32
- * @param {string} modelId - The JumpStart model ID to check
32
+ * @param {string} modelId - The model ID to check
33
33
  * @param {Object} catalog - The tune catalog object with a `models` map
34
34
  * @returns {boolean} True if the model is in the catalog
35
35
  */
@@ -41,7 +41,7 @@ export function isTuneSupported(modelId, catalog) {
41
41
  * Validate that a model ID exists in the catalog.
42
42
  * Returns a descriptive error when the model is not supported, including
43
43
  * the model name, supported families, and a reference to `do/train`.
44
- * @param {string} modelId - The JumpStart model ID to validate
44
+ * @param {string} modelId - The model ID to validate
45
45
  * @param {Object} catalog - The tune catalog object with a `models` map
46
46
  * @returns {{ valid: boolean, error?: string }}
47
47
  */
@@ -65,7 +65,7 @@ export function validateModel(modelId, catalog) {
65
65
  * Validate that a technique is supported for the given model.
66
66
  * Returns a descriptive error listing the supported techniques when
67
67
  * the requested technique is not available.
68
- * @param {string} modelId - The JumpStart model ID
68
+ * @param {string} modelId - The model ID
69
69
  * @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
70
70
  * @param {Object} catalog - The tune catalog object with a `models` map
71
71
  * @returns {{ valid: boolean, error?: string }}
@@ -92,7 +92,7 @@ export function validateTechnique(modelId, technique, catalog) {
92
92
  * Validate that a training type is supported for the given model and technique.
93
93
  * Returns a descriptive error listing the supported training types when
94
94
  * the requested type is not available.
95
- * @param {string} modelId - The JumpStart model ID
95
+ * @param {string} modelId - The model ID
96
96
  * @param {string} technique - The technique (e.g., 'sft', 'dpo')
97
97
  * @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
98
98
  * @param {Object} catalog - The tune catalog object with a `models` map
@@ -113,7 +113,7 @@ resolve_model() {
113
113
  echo "${!_MODEL_VAR}"
114
114
  return
115
115
  ;;
116
- s3|jumpstart|jumpstart-hub|registry)
116
+ s3|registry)
117
117
  # Check for pre-mounted artifacts first
118
118
  if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
119
119
  echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
@@ -245,7 +245,7 @@ ARG_PREFIX="--"
245
245
 
246
246
  # Define environment variables to exclude (internal variables set by base images)
247
247
  <% if (modelServer === 'vllm') { %>
248
- EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
248
+ EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
249
249
  <% } else if (modelServer === 'sglang') { %>
250
250
  EXCLUDE_VARS=()
251
251
  <% } else if (modelServer === 'tensorrt-llm') { %>
@@ -15,7 +15,7 @@ option.model_id=<%= modelName %>
15
15
  option.model_id=<%= artifactUri %>
16
16
  <% } else { %>
17
17
  # Model will be loaded from /opt/ml/model at runtime
18
- # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
18
+ # (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
19
19
  # option.model_id=/opt/ml/model
20
20
  <% } %>
21
21
 
@@ -71,7 +71,7 @@ option.model_id=<%= modelName %>
71
71
  option.model_id=<%= artifactUri %>
72
72
  <% } else { %>
73
73
  # Model will be loaded from /opt/ml/model at runtime
74
- # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
74
+ # (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
75
75
  # option.model_id=/opt/ml/model
76
76
  <% } %>
77
77
 
@@ -9,10 +9,10 @@ echo "Starting vLLM-Omni server (diffusion model serving)"
9
9
 
10
10
  # Resolve model URI prefixes that engines cannot handle natively.
11
11
  # The generator's model-picker may store provider-specific URIs
12
- # (e.g. jumpstart://model-txt2img-stabilityai-stable-diffusion-v2-1-base)
13
- # as the model identifier. vLLM expects a HuggingFace repo ID or local path.
12
+ # (e.g. registry://my-model-group/1) as the model identifier.
13
+ # vLLM expects a HuggingFace repo ID or local path.
14
14
  _RAW_MODEL="${VLLM_MODEL:-}"
15
- if [[ "$_RAW_MODEL" == jumpstart://* ]] || [[ "$_RAW_MODEL" == jumpstart-hub://* ]] || [[ "$_RAW_MODEL" == registry://* ]]; then
15
+ if [[ "$_RAW_MODEL" == registry://* ]]; then
16
16
  if [ -d /opt/ml/model ] && [ "$(ls -A /opt/ml/model 2>/dev/null)" ]; then
17
17
  echo "Resolved VLLM_MODEL='${_RAW_MODEL}' → /opt/ml/model (local artifacts found)"
18
18
  export VLLM_MODEL="/opt/ml/model"
@@ -176,7 +176,7 @@ def cmd_submit(args):
176
176
  )
177
177
  elif "ValidationException" in error_msg and "license" in error_msg.lower():
178
178
  _error_exit(
179
- f"Model license not accepted. Accept the license in JumpStart before "
179
+ f"Model license not accepted. Accept the model license before "
180
180
  f"using this model for customization. Details: {error_msg}"
181
181
  )
182
182
  else:
@@ -660,7 +660,7 @@ def main():
660
660
 
661
661
  # ── submit ────────────────────────────────────────────────────────────────
662
662
  submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
663
- submit_parser.add_argument("--model-id", required=True, help="JumpStart model ID")
663
+ submit_parser.add_argument("--model-id", required=True, help="Model ID")
664
664
  submit_parser.add_argument("--technique", required=True,
665
665
  choices=["sft", "dpo", "rlaif", "rlvr"],
666
666
  help="Customization technique")
@@ -191,8 +191,14 @@ fi
191
191
  # ============================================================
192
192
 
193
193
  # DEPLOYMENT_CONFIG format: <architecture>-<backend> (e.g., transformers-vllm, http-flask, triton-fil)
194
- ARCHITECTURE="${DEPLOYMENT_CONFIG%%-*}"
195
- BACKEND="${DEPLOYMENT_CONFIG#*-}"
194
+ # Special case: marketplace has no backend
195
+ if [ "${DEPLOYMENT_CONFIG}" = "marketplace" ]; then
196
+ ARCHITECTURE="marketplace"
197
+ BACKEND=""
198
+ else
199
+ ARCHITECTURE="${DEPLOYMENT_CONFIG%%-*}"
200
+ BACKEND="${DEPLOYMENT_CONFIG#*-}"
201
+ fi
196
202
 
197
203
  echo "📋 Registering deployment to registry"
198
204
  echo " Project: ${PROJECT_NAME}"
package/templates/do/test CHANGED
@@ -103,9 +103,9 @@ case "${FRAMEWORK}" in
103
103
  case "${MODEL_SERVER}" in
104
104
  vllm|sglang)
105
105
  # OpenAI-compatible chat completions format
106
- # For S3/JumpStart models, vLLM registers the model under the local path
106
+ # For S3/registry models, vLLM registers the model under the local path
107
107
  VLLM_MODEL_NAME="${MODEL_NAME}"
108
- if [[ "${MODEL_NAME}" == jumpstart://* ]] || [[ "${MODEL_NAME}" == jumpstart-hub://* ]] || [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
108
+ if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
109
109
  VLLM_MODEL_NAME="/opt/ml/model"
110
110
  fi
111
111
  TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
@@ -431,7 +431,7 @@ case "${FRAMEWORK}" in
431
431
  case "${MODEL_SERVER}" in
432
432
  vllm|sglang)
433
433
  VLLM_MODEL_NAME="${MODEL_NAME}"
434
- if [[ "${MODEL_NAME}" == jumpstart://* ]] || [[ "${MODEL_NAME}" == jumpstart-hub://* ]] || [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
434
+ if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
435
435
  VLLM_MODEL_NAME="/opt/ml/model"
436
436
  fi
437
437
  TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
@@ -808,7 +808,7 @@ case "${FRAMEWORK}" in
808
808
  vllm|sglang)
809
809
  # OpenAI-compatible chat completions format
810
810
  VLLM_MODEL_NAME="${MODEL_NAME}"
811
- if [[ "${MODEL_NAME}" == jumpstart://* ]] || [[ "${MODEL_NAME}" == jumpstart-hub://* ]] || [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
811
+ if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
812
812
  VLLM_MODEL_NAME="/opt/ml/model"
813
813
  fi
814
814
  TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
@@ -1095,7 +1095,7 @@ case "${FRAMEWORK}" in
1095
1095
  case "${MODEL_SERVER}" in
1096
1096
  vllm|sglang)
1097
1097
  VLLM_MODEL_NAME="${MODEL_NAME}"
1098
- if [[ "${MODEL_NAME}" == jumpstart://* ]] || [[ "${MODEL_NAME}" == jumpstart-hub://* ]] || [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
1098
+ if [[ "${MODEL_NAME}" == s3://* ]] || [[ "${MODEL_NAME}" == /opt/ml/* ]]; then
1099
1099
  VLLM_MODEL_NAME="/opt/ml/model"
1100
1100
  fi
1101
1101
  TEST_PAYLOAD='{"model": "'"${VLLM_MODEL_NAME}"'", "messages": [{"role": "user", "content": "What is machine learning?"}], "max_tokens": 50, "temperature": 0.7}'
package/templates/do/tune CHANGED
@@ -67,7 +67,7 @@ _parse_args() {
67
67
  ARG_TRAINING_TYPE="$2"; shift 2 ;;
68
68
  --model)
69
69
  if [ -z "${2:-}" ]; then
70
- echo "❌ --model requires a JumpStart model ID"
70
+ echo "❌ --model requires a model ID"
71
71
  exit 1
72
72
  fi
73
73
  ARG_MODEL="$2"; shift 2 ;;
@@ -287,7 +287,7 @@ for family in sorted(families.keys()):
287
287
  for entry in entries:
288
288
  techniques = list(entry.get('techniques', {}).keys())
289
289
  print(f' • {entry[\"displayName\"]}')
290
- print(f' ID: {entry[\"jumpStartModelId\"]}')
290
+ print(f' ID: {entry[\"modelId\"]}')
291
291
  for t in techniques:
292
292
  tc = entry['techniques'][t]
293
293
  types = ', '.join(tc.get('trainingTypes', []))