@aws/ml-container-creator 0.6.1 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -111,6 +111,14 @@ export default class PromptRunner {
111
111
  framework: framework || deploymentConfigAnswers.framework,
112
112
  modelServer: modelServer || deploymentConfigAnswers.modelServer
113
113
  };
114
+
115
+ // ──────────────────────────────────────────────────────────────────────
116
+ // Marketplace fast-path: skip all container-related prompts
117
+ // Requirements: 2.3, 2.4, 2.5
118
+ // ──────────────────────────────────────────────────────────────────────
119
+ if (frameworkAnswers.architecture === 'marketplace') {
120
+ return this._runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp);
121
+ }
114
122
 
115
123
  // Engine prompt for http architecture
116
124
  const engineAnswers = await this._runPhase(enginePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
@@ -596,13 +604,27 @@ export default class PromptRunner {
596
604
  // Infer modelSource from model name prefix if not set by MCP
597
605
  const modelName = combinedAnswers.customModelName || combinedAnswers.modelName;
598
606
  if (!combinedAnswers.modelSource && modelName) {
599
- if (modelName.startsWith('s3://')) {
607
+ // Reject deprecated JumpStart prefixes with migration message
608
+ if (modelName.startsWith('jumpstart://') || modelName.startsWith('jumpstart-hub://')) {
609
+ const bareId = modelName.replace(/^jumpstart(-hub)?:\/\//, '');
610
+ console.error(`\n ⚠️ JumpStart is no longer supported. Use the HuggingFace model ID directly: ${bareId}`);
611
+ console.error(' JumpStart model sources have been removed. Use one of:');
612
+ console.error(' • HuggingFace model ID (e.g., meta-llama/Llama-2-7b-hf)');
613
+ console.error(' • s3://bucket/path/model.tar.gz');
614
+ console.error(' • registry://model-package-name');
615
+ console.error(' • marketplace://arn:aws:sagemaker:...\n');
616
+ process.exit(1);
617
+ }
618
+ if (modelName.startsWith('marketplace://')) {
619
+ // marketplace://arn:aws:sagemaker:... → set architecture to marketplace and store ARN
620
+ const arn = modelName.replace(/^marketplace:\/\//, '');
621
+ combinedAnswers.modelPackageArn = arn;
622
+ combinedAnswers.architecture = 'marketplace';
623
+ combinedAnswers.deploymentConfig = 'marketplace';
624
+ combinedAnswers.modelSource = undefined;
625
+ } else if (modelName.startsWith('s3://')) {
600
626
  combinedAnswers.modelSource = 's3';
601
627
  combinedAnswers.artifactUri = modelName;
602
- } else if (modelName.startsWith('jumpstart://')) {
603
- combinedAnswers.modelSource = 'jumpstart';
604
- } else if (modelName.startsWith('jumpstart-hub://')) {
605
- combinedAnswers.modelSource = 'jumpstart-hub';
606
628
  } else if (modelName.startsWith('registry://')) {
607
629
  combinedAnswers.modelSource = 'registry';
608
630
  }
@@ -613,7 +635,7 @@ export default class PromptRunner {
613
635
  combinedAnswers.artifactUri = modelName;
614
636
  }
615
637
  }
616
- const downloadSources = ['jumpstart', 's3'];
638
+ const downloadSources = ['s3'];
617
639
  if (downloadSources.includes(combinedAnswers.modelSource) && !combinedAnswers.artifactUri) {
618
640
  console.log(`\n ⚠️ Model source is '${combinedAnswers.modelSource}' but no artifact URI was resolved.`);
619
641
  console.log(' The model-picker could not determine the download location.');
@@ -638,18 +660,7 @@ export default class PromptRunner {
638
660
  }
639
661
  }
640
662
 
641
- // Warn about jumpstart-hub:// models — private hub deployment requires
642
- // HubAccessConfig on CreateModel, which is not yet supported by the generator.
643
- if (combinedAnswers.modelSource === 'jumpstart-hub') {
644
- console.log('\n ⚠️ JumpStart Private Hub models are not yet fully supported.');
645
- console.log(' Private hub artifacts live in AWS-managed S3 buckets that require');
646
- console.log(' SageMaker\'s HubAccessConfig mechanism for access.');
647
- console.log(' The generated project will not be able to download model artifacts at runtime.');
648
- console.log(' This feature is tracked for a future release.\n');
649
- console.log(' Falling back to HuggingFace source.\n');
650
- combinedAnswers.modelSource = 'huggingface';
651
- delete combinedAnswers.artifactUri;
652
- }
663
+
653
664
 
654
665
  // Apply auto-set model format for Triton backends with single format
655
666
  // Requirements: 3.3, 3.4, 3.5
@@ -731,6 +742,265 @@ export default class PromptRunner {
731
742
  return combinedAnswers;
732
743
  }
733
744
 
745
+ /**
746
+ * Marketplace-specific prompt flow.
747
+ * Skips all container-related prompts (framework, model server, base image, CUDA version)
748
+ * and prompts only for: model package ARN, instance type, deployment target, region.
749
+ *
750
+ * Requirements: 2.3, 2.4, 2.5
751
+ * @private
752
+ */
753
+ async _runMarketplaceFlow(frameworkAnswers, explicitConfig, existingConfig, buildTimestamp) {
754
+ console.log('\n🏪 Marketplace Model Package Configuration');
755
+
756
+ // Query marketplace-picker MCP server for subscription discovery
757
+ // Requirements: 2.4, 6.1, 6.2
758
+ let mcpSubscriptions = [];
759
+ const cm = this.configManager;
760
+ if (cm && cm.getMcpServerNames && cm.getMcpServerNames().includes('marketplace-picker')) {
761
+ try {
762
+ console.log(' 🔍 Querying marketplace-picker for subscriptions...');
763
+ const result = await cm.queryMcpServer('marketplace-picker', {
764
+ region: explicitConfig.awsRegion || existingConfig.awsRegion || process.env.AWS_REGION || 'us-east-1'
765
+ });
766
+ if (result && result.metadata?.subscriptions?.length > 0) {
767
+ mcpSubscriptions = result.metadata.subscriptions;
768
+ console.log(` ✅ Found ${mcpSubscriptions.length} Marketplace subscription(s)`);
769
+ } else {
770
+ console.log(' ℹ️ No Marketplace subscriptions found — enter ARN manually');
771
+ }
772
+ } catch (err) {
773
+ console.log(` ⚠️ marketplace-picker unavailable: ${err.message}`);
774
+ console.log(' Falling back to manual ARN entry');
775
+ }
776
+ }
777
+
778
+ // Marketplace-specific prompts: model package ARN
779
+ const marketplacePrompts = [
780
+ {
781
+ type: mcpSubscriptions.length > 0 ? 'list' : 'input',
782
+ name: 'modelPackageArn',
783
+ message: mcpSubscriptions.length > 0
784
+ ? 'Select a Marketplace model package:'
785
+ : 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
786
+ ...(mcpSubscriptions.length > 0 ? {
787
+ choices: [
788
+ ...mcpSubscriptions.map(sub => ({
789
+ name: `${sub.modelName} (${sub.vendor}) — ${sub.arn}`,
790
+ value: sub.arn,
791
+ short: sub.modelName
792
+ })),
793
+ { type: 'separator', separator: '──────────────' },
794
+ { name: 'Enter ARN manually...', value: '__manual__', short: 'manual' }
795
+ ]
796
+ } : {
797
+ validate: (input) => {
798
+ if (!input || input.trim() === '') {
799
+ return 'Model package ARN is required';
800
+ }
801
+ const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
802
+ if (!arnPattern.test(input.trim())) {
803
+ return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
804
+ }
805
+ return true;
806
+ }
807
+ })
808
+ },
809
+ {
810
+ type: 'input',
811
+ name: 'modelPackageArnManual',
812
+ message: 'Model package ARN (arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>):',
813
+ when: (answers) => answers.modelPackageArn === '__manual__',
814
+ validate: (input) => {
815
+ if (!input || input.trim() === '') {
816
+ return 'Model package ARN is required';
817
+ }
818
+ const arnPattern = /^arn:aws:sagemaker:[a-z0-9-]+:\d{12}:model-package\/[\w-]+\/\d+$/;
819
+ if (!arnPattern.test(input.trim())) {
820
+ return 'Invalid ARN format. Expected: arn:aws:sagemaker:<region>:<account>:model-package/<name>/<version>';
821
+ }
822
+ return true;
823
+ }
824
+ }
825
+ ];
826
+ const marketplaceAnswers = await this._runPhase(marketplacePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
827
+
828
+ // Handle manual ARN entry fallback
829
+ if (marketplaceAnswers.modelPackageArn === '__manual__' && marketplaceAnswers.modelPackageArnManual) {
830
+ marketplaceAnswers.modelPackageArn = marketplaceAnswers.modelPackageArnManual;
831
+ delete marketplaceAnswers.modelPackageArnManual;
832
+ }
833
+
834
+ // Infrastructure prompts: region, deployment target, instance type
835
+ console.log('\n💪 Infrastructure & Deployment');
836
+ const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion;
837
+ const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {};
838
+
839
+ // Marketplace deployment targets (no HyperPod — vendor controls the container)
840
+ const marketplaceInfraPrompts = [
841
+ {
842
+ type: 'list',
843
+ name: 'awsRegion',
844
+ message: 'Target AWS region?',
845
+ choices: (answers) => {
846
+ const bootstrapReg = answers._bootstrapRegion;
847
+ const choices = ['us-east-1'];
848
+ if (bootstrapReg && bootstrapReg !== 'us-east-1') {
849
+ choices.unshift({ name: `${bootstrapReg} (from bootstrap profile)`, value: bootstrapReg });
850
+ }
851
+ choices.push({ name: 'Custom...', value: 'custom' });
852
+ return choices;
853
+ },
854
+ default: (answers) => answers._bootstrapRegion || 'us-east-1'
855
+ },
856
+ {
857
+ type: 'input',
858
+ name: 'customAwsRegion',
859
+ message: 'Enter AWS region (e.g., us-west-2, eu-west-1):',
860
+ when: answers => answers.awsRegion === 'custom'
861
+ },
862
+ {
863
+ type: 'list',
864
+ name: 'deploymentTarget',
865
+ message: 'Deployment target?',
866
+ choices: [
867
+ { name: 'SageMaker Real-Time Inference', value: 'realtime-inference' },
868
+ { name: 'SageMaker Async Inference', value: 'async-inference' },
869
+ { name: 'SageMaker Batch Transform', value: 'batch-transform' }
870
+ ],
871
+ default: 'realtime-inference'
872
+ },
873
+ {
874
+ type: 'list',
875
+ name: 'instanceType',
876
+ message: 'Instance type for deployment?',
877
+ choices: [
878
+ { name: 'ml.g5.xlarge (1 GPU, 24GB)', value: 'ml.g5.xlarge' },
879
+ { name: 'ml.g5.2xlarge (1 GPU, 24GB)', value: 'ml.g5.2xlarge' },
880
+ { name: 'ml.g5.4xlarge (1 GPU, 24GB)', value: 'ml.g5.4xlarge' },
881
+ { name: 'ml.g5.12xlarge (4 GPUs, 96GB)', value: 'ml.g5.12xlarge' },
882
+ { name: 'ml.p3.2xlarge (1 GPU, 16GB V100)', value: 'ml.p3.2xlarge' },
883
+ { name: 'ml.m5.xlarge (CPU, 16GB)', value: 'ml.m5.xlarge' },
884
+ { name: 'Custom...', value: 'custom' }
885
+ ],
886
+ default: 'ml.g5.xlarge'
887
+ },
888
+ {
889
+ type: 'input',
890
+ name: 'customInstanceType',
891
+ message: 'Enter instance type (e.g., ml.g5.xlarge):',
892
+ validate: (input) => {
893
+ if (!input || input.trim() === '') {
894
+ return 'Instance type is required';
895
+ }
896
+ if (!input.startsWith('ml.')) {
897
+ return 'Instance type must start with "ml." (e.g., ml.g5.xlarge)';
898
+ }
899
+ return true;
900
+ },
901
+ when: answers => answers.instanceType === 'custom'
902
+ }
903
+ ];
904
+ const infraAnswers = await this._runPhase(marketplaceInfraPrompts, { ...frameworkAnswers, ...regionPreviousAnswers }, explicitConfig, existingConfig);
905
+
906
+ // Async-specific prompts (only when deploymentTarget === 'async-inference')
907
+ let asyncAnswers = {};
908
+ if (infraAnswers.deploymentTarget === 'async-inference') {
909
+ asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
910
+ }
911
+
912
+ // Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
913
+ let batchTransformAnswers = {};
914
+ if (infraAnswers.deploymentTarget === 'batch-transform') {
915
+ batchTransformAnswers = await this._runPhase(
916
+ infraBatchTransformPrompts,
917
+ { ...infraAnswers },
918
+ explicitConfig,
919
+ existingConfig
920
+ );
921
+ }
922
+
923
+ // Role ARN prompt (always needed for marketplace deploy)
924
+ const rolePrompts = [
925
+ {
926
+ type: 'input',
927
+ name: 'awsRoleArn',
928
+ message: 'AWS IAM Role ARN for SageMaker execution (optional)?',
929
+ validate: (input) => {
930
+ if (!input || input.trim() === '') {
931
+ return true;
932
+ }
933
+ const arnPattern = /^arn:aws:iam::\d{12}:role\/[\w+=,.@-]+$/;
934
+ if (!arnPattern.test(input)) {
935
+ return 'Invalid ARN format. Expected: arn:aws:iam::123456789012:role/RoleName';
936
+ }
937
+ return true;
938
+ }
939
+ }
940
+ ];
941
+ const roleAnswers = await this._runPhase(rolePrompts, { ...infraAnswers }, explicitConfig, existingConfig);
942
+
943
+ // Project name + destination
944
+ console.log('\n📋 Project Configuration');
945
+ const allTechnicalAnswers = {
946
+ ...frameworkAnswers,
947
+ ...marketplaceAnswers,
948
+ ...infraAnswers,
949
+ ...asyncAnswers,
950
+ ...batchTransformAnswers,
951
+ ...roleAnswers
952
+ };
953
+ const projectAnswers = await this._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
954
+ const destinationAnswers = await this._runPhase(destinationPrompts,
955
+ { ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
956
+
957
+ // Combine all marketplace answers
958
+ const combinedAnswers = {
959
+ ...frameworkAnswers,
960
+ ...marketplaceAnswers,
961
+ ...infraAnswers,
962
+ ...asyncAnswers,
963
+ ...batchTransformAnswers,
964
+ ...roleAnswers,
965
+ ...projectAnswers,
966
+ ...destinationAnswers,
967
+ buildTimestamp
968
+ };
969
+
970
+ // Handle custom instance type
971
+ if (combinedAnswers.customInstanceType) {
972
+ combinedAnswers.instanceType = combinedAnswers.customInstanceType;
973
+ delete combinedAnswers.customInstanceType;
974
+ }
975
+
976
+ // Handle custom AWS region
977
+ if (combinedAnswers.customAwsRegion) {
978
+ combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
979
+ delete combinedAnswers.customAwsRegion;
980
+ }
981
+
982
+ // Map awsRoleArn to roleArn for templates
983
+ if (combinedAnswers.awsRoleArn) {
984
+ combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
985
+ delete combinedAnswers.awsRoleArn;
986
+ }
987
+
988
+ // Ensure CLI-provided values are in combinedAnswers
989
+ if (explicitConfig.modelPackageArn && !combinedAnswers.modelPackageArn) {
990
+ combinedAnswers.modelPackageArn = explicitConfig.modelPackageArn;
991
+ }
992
+
993
+ // Handle marketplace:// prefix from --model-name CLI option
994
+ const modelName = explicitConfig.modelName || combinedAnswers.modelName;
995
+ if (modelName && modelName.startsWith('marketplace://')) {
996
+ const arn = modelName.replace(/^marketplace:\/\//, '');
997
+ combinedAnswers.modelPackageArn = arn;
998
+ delete combinedAnswers.modelName;
999
+ }
1000
+
1001
+ return combinedAnswers;
1002
+ }
1003
+
734
1004
  /**
735
1005
  * Checks if a parameter is promptable according to the parameter matrix
736
1006
  * @param {string} parameterName - Name of the parameter
@@ -1746,9 +2016,7 @@ export default class PromptRunner {
1746
2016
  const registryConfigManager = this.registryConfigManager;
1747
2017
  if (registryConfigManager) {
1748
2018
  // Only try HuggingFace API for bare model IDs (not prefixed URIs)
1749
- const isNonHfUri = modelId.startsWith('jumpstart://') ||
1750
- modelId.startsWith('jumpstart-hub://') ||
1751
- modelId.startsWith('s3://') ||
2019
+ const isNonHfUri = modelId.startsWith('s3://') ||
1752
2020
  modelId.startsWith('registry://');
1753
2021
 
1754
2022
  if (!isNonHfUri) {
@@ -1773,7 +2041,7 @@ export default class PromptRunner {
1773
2041
  console.log(' ⚠️ HuggingFace API unavailable');
1774
2042
  }
1775
2043
  } else {
1776
- // Non-HF URI (jumpstart://, s3://, etc.) — skip HF lookup silently
2044
+ // Non-HF URI (s3://, registry://, etc.) — skip HF lookup silently
1777
2045
  // The summary at the end of this function will report "No additional model information"
1778
2046
  }
1779
2047
 
@@ -232,6 +232,12 @@ const deploymentConfigPrompts = [
232
232
  name: 'Diffusors with vLLM Omni',
233
233
  value: 'diffusors-vllm-omni',
234
234
  short: 'diffusors-vllm-omni'
235
+ },
236
+ { type: 'separator', separator: '── AWS Marketplace ──' },
237
+ {
238
+ name: 'Marketplace Model Package',
239
+ value: 'marketplace',
240
+ short: 'marketplace'
235
241
  }
236
242
  ]
237
243
  }
@@ -469,9 +475,9 @@ const modelFormatPrompts = [
469
475
  if (!input || input.trim() === '') {
470
476
  return 'Model name is required';
471
477
  }
472
- // Basic validation - must contain a slash (org/model, hub/model, s3://path, etc.)
478
+ // Basic validation - must contain a slash (org/model, s3://path, etc.)
473
479
  if (!input.includes('/')) {
474
- return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, jumpstart-hub://my-hub/my-model)';
480
+ return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, s3://bucket/model, registry://my-package)';
475
481
  }
476
482
  return true;
477
483
  },
@@ -583,7 +589,7 @@ const hfTokenPrompts = [
583
589
  }
584
590
 
585
591
  // Skip HF token prompt for non-HuggingFace model sources
586
- // (S3, JumpStart, Private Hub, Registry models don't need HF auth)
592
+ // (S3, Registry models don't need HF auth)
587
593
  const modelSource = answers.modelSource;
588
594
  if (modelSource && modelSource !== 'huggingface') {
589
595
  return false;
@@ -50,7 +50,7 @@ export default class TemplateManager {
50
50
  */
51
51
  validate() {
52
52
  const supportedOptions = {
53
- // 15 canonical deployment-config values (2 http, 5 transformers, 7 triton, 1 diffusors)
53
+ // 16 canonical deployment-config values (2 http, 5 transformers, 7 triton, 1 diffusors, 1 marketplace)
54
54
  deploymentConfigs: [
55
55
  // HTTP architecture (2)
56
56
  'http-flask', 'http-fastapi',
@@ -61,7 +61,9 @@ export default class TemplateManager {
61
61
  'triton-fil', 'triton-onnxruntime', 'triton-tensorflow',
62
62
  'triton-pytorch', 'triton-vllm', 'triton-tensorrtllm', 'triton-python',
63
63
  // Diffusors architecture (1)
64
- 'diffusors-vllm-omni'
64
+ 'diffusors-vllm-omni',
65
+ // Marketplace architecture (1)
66
+ 'marketplace'
65
67
  ],
66
68
  buildTargets: ['codebuild'],
67
69
  deploymentTargets: ['realtime-inference', 'async-inference', 'batch-transform', 'hyperpod-eks'],
@@ -82,7 +84,7 @@ export default class TemplateManager {
82
84
  this._validateGpuRequirement();
83
85
  } else {
84
86
  // Fallback: validate architecture and backend separately (new canonical format)
85
- const architectures = ['http', 'transformers', 'triton', 'diffusors'];
87
+ const architectures = ['http', 'transformers', 'triton', 'diffusors', 'marketplace'];
86
88
  const backends = [
87
89
  // http backends
88
90
  'flask', 'fastapi',
@@ -95,7 +97,11 @@ export default class TemplateManager {
95
97
  ];
96
98
 
97
99
  this._validateChoice('architecture', architectures);
98
- this._validateChoice('backend', backends);
100
+
101
+ // Marketplace has no backend — skip backend validation
102
+ if (this.answers.architecture !== 'marketplace') {
103
+ this._validateChoice('backend', backends);
104
+ }
99
105
 
100
106
  // Validate tensorrt-llm is only used with transformers architecture
101
107
  if (this.answers.backend === 'tensorrt-llm' && this.answers.architecture !== 'transformers') {
@@ -13,7 +13,7 @@
13
13
 
14
14
  /**
15
15
  * Look up a model entry in the catalog by model ID.
16
- * @param {string} modelId - The JumpStart model ID to look up
16
+ * @param {string} modelId - The model ID to look up
17
17
  * @param {Object} catalog - The tune catalog object with a `models` map
18
18
  * @returns {Object|null} The catalog entry for the model, or null if not found
19
19
  */
@@ -29,7 +29,7 @@ export function lookupModel(modelId, catalog) {
29
29
 
30
30
  /**
31
31
  * Check whether a model ID is present in the Supported Model Catalog.
32
- * @param {string} modelId - The JumpStart model ID to check
32
+ * @param {string} modelId - The model ID to check
33
33
  * @param {Object} catalog - The tune catalog object with a `models` map
34
34
  * @returns {boolean} True if the model is in the catalog
35
35
  */
@@ -41,7 +41,7 @@ export function isTuneSupported(modelId, catalog) {
41
41
  * Validate that a model ID exists in the catalog.
42
42
  * Returns a descriptive error when the model is not supported, including
43
43
  * the model name, supported families, and a reference to `do/train`.
44
- * @param {string} modelId - The JumpStart model ID to validate
44
+ * @param {string} modelId - The model ID to validate
45
45
  * @param {Object} catalog - The tune catalog object with a `models` map
46
46
  * @returns {{ valid: boolean, error?: string }}
47
47
  */
@@ -65,7 +65,7 @@ export function validateModel(modelId, catalog) {
65
65
  * Validate that a technique is supported for the given model.
66
66
  * Returns a descriptive error listing the supported techniques when
67
67
  * the requested technique is not available.
68
- * @param {string} modelId - The JumpStart model ID
68
+ * @param {string} modelId - The model ID
69
69
  * @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
70
70
  * @param {Object} catalog - The tune catalog object with a `models` map
71
71
  * @returns {{ valid: boolean, error?: string }}
@@ -92,7 +92,7 @@ export function validateTechnique(modelId, technique, catalog) {
92
92
  * Validate that a training type is supported for the given model and technique.
93
93
  * Returns a descriptive error listing the supported training types when
94
94
  * the requested type is not available.
95
- * @param {string} modelId - The JumpStart model ID
95
+ * @param {string} modelId - The model ID
96
96
  * @param {string} technique - The technique (e.g., 'sft', 'dpo')
97
97
  * @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
98
98
  * @param {Object} catalog - The tune catalog object with a `models` map
@@ -290,6 +290,7 @@ RUN chmod +x /usr/bin/serve_trtllm
290
290
 
291
291
  # Copy startup script
292
292
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
293
+ COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
293
294
  COPY code/start_server.sh /usr/bin/start_server.sh
294
295
  RUN chmod +x /usr/bin/start_server.sh /usr/bin/cuda_compat.sh
295
296
 
@@ -307,6 +308,7 @@ COPY code/serving.properties /opt/ml/model/serving.properties
307
308
  # The container will automatically start DJL Serving with the configuration
308
309
  <% } else { %>
309
310
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
311
+ COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
310
312
  COPY code/serve /usr/bin/serve
311
313
  RUN chmod 777 /usr/bin/serve /usr/bin/cuda_compat.sh
312
314
 
@@ -0,0 +1,64 @@
1
+ #!/usr/bin/env python3
2
+ """CloudWatch log forwarder — workaround for IC platform log routing gap.
3
+ Pipes stdin to a CW log stream while passing through to stderr.
4
+ Usage: exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1
5
+ """
6
+ import sys, os, time, threading
7
+ import boto3
8
+ from botocore.config import Config
9
+
10
+ LOG_GROUP = os.environ.get("CW_LOG_GROUP",
11
+ f"/aws/sagemaker/InferenceComponents/{os.environ.get('INFERENCE_COMPONENT_NAME', os.environ.get('HOSTNAME', 'unknown'))}")
12
+ LOG_STREAM = f"AllTraffic/{os.environ.get('HOSTNAME', 'container')}"
13
+ REGION = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-west-2"))
14
+
15
+ def main():
16
+ client = boto3.client("logs", region_name=REGION, config=Config(retries={"max_attempts": 2}))
17
+ try:
18
+ client.create_log_group(logGroupName=LOG_GROUP)
19
+ except Exception:
20
+ pass
21
+ try:
22
+ client.create_log_stream(logGroupName=LOG_GROUP, logStreamName=LOG_STREAM)
23
+ except Exception as e:
24
+ # Can't create stream — just passthrough
25
+ for line in sys.stdin:
26
+ sys.stderr.write(line)
27
+ return
28
+
29
+ buf, lock, seq = [], threading.Lock(), [None]
30
+
31
+ def flush():
32
+ with lock:
33
+ if not buf:
34
+ return
35
+ batch = buf[:50]
36
+ del buf[:50]
37
+ events = [{"timestamp": int(t * 1000), "message": m} for t, m in batch]
38
+ kw = {"logGroupName": LOG_GROUP, "logStreamName": LOG_STREAM, "logEvents": events}
39
+ if seq[0]:
40
+ kw["sequenceToken"] = seq[0]
41
+ try:
42
+ r = client.put_log_events(**kw)
43
+ seq[0] = r.get("nextSequenceToken")
44
+ except Exception:
45
+ pass
46
+
47
+ def loop():
48
+ while True:
49
+ time.sleep(2)
50
+ flush()
51
+
52
+ threading.Thread(target=loop, daemon=True).start()
53
+ try:
54
+ for line in sys.stdin:
55
+ sys.stderr.write(line)
56
+ with lock:
57
+ buf.append((time.time(), line.rstrip("\n")))
58
+ except (KeyboardInterrupt, BrokenPipeError):
59
+ pass
60
+ finally:
61
+ flush()
62
+
63
+ if __name__ == "__main__":
64
+ main()
@@ -2,6 +2,11 @@
2
2
  # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ # CloudWatch log forwarder — workaround for IC platform log routing gap
6
+ exec > >(python3 /usr/bin/cw_log_forwarder.py) 2>&1
7
+
8
+ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
9
+
5
10
  # CUDA compatibility setup (required for newer SageMaker inference AMIs)
6
11
  source /usr/bin/cuda_compat.sh 2>/dev/null || true
7
12
 
@@ -108,7 +113,7 @@ resolve_model() {
108
113
  echo "${!_MODEL_VAR}"
109
114
  return
110
115
  ;;
111
- s3|jumpstart|jumpstart-hub|registry)
116
+ s3|registry)
112
117
  # Check for pre-mounted artifacts first
113
118
  if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
114
119
  echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
@@ -240,7 +245,7 @@ ARG_PREFIX="--"
240
245
 
241
246
  # Define environment variables to exclude (internal variables set by base images)
242
247
  <% if (modelServer === 'vllm') { %>
243
- EXCLUDE_VARS=("VLLM_USAGE_SOURCE")
248
+ EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
244
249
  <% } else if (modelServer === 'sglang') { %>
245
250
  EXCLUDE_VARS=()
246
251
  <% } else if (modelServer === 'tensorrt-llm') { %>
@@ -270,8 +275,14 @@ for var in "${env_vars[@]}"; do
270
275
 
271
276
  # Remove prefix, convert to lowercase, and replace underscores with dashes
272
277
  arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
278
+
279
+ # Boolean handling: true = flag only, false = skip entirely
280
+ if [ "$value" = "false" ]; then
281
+ continue
282
+ fi
283
+
273
284
  SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
274
- if [ -n "$value" ]; then
285
+ if [ -n "$value" ] && [ "$value" != "true" ]; then
275
286
  SERVER_ARGS+=("$value")
276
287
  fi
277
288
  done
@@ -15,7 +15,7 @@ option.model_id=<%= modelName %>
15
15
  option.model_id=<%= artifactUri %>
16
16
  <% } else { %>
17
17
  # Model will be loaded from /opt/ml/model at runtime
18
- # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
18
+ # (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
19
19
  # option.model_id=/opt/ml/model
20
20
  <% } %>
21
21
 
@@ -71,7 +71,7 @@ option.model_id=<%= modelName %>
71
71
  option.model_id=<%= artifactUri %>
72
72
  <% } else { %>
73
73
  # Model will be loaded from /opt/ml/model at runtime
74
- # (JumpStart model without artifact URI — requires SageMaker ModelDataUrl)
74
+ # (requires SageMaker ModelDataUrl or MODEL_ARTIFACT_URI)
75
75
  # option.model_id=/opt/ml/model
76
76
  <% } %>
77
77