@aws/ml-container-creator 0.6.0 → 0.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +9 -0
- package/config/bootstrap-stack.json +69 -3
- package/package.json +1 -1
- package/src/app.js +29 -1
- package/src/lib/bootstrap-command-handler.js +68 -26
- package/src/lib/tune-catalog-validator.js +143 -0
- package/src/lib/tune-config-state.js +116 -0
- package/src/lib/tune-dataset-validator.js +279 -0
- package/src/lib/tune-output-resolver.js +66 -0
- package/templates/Dockerfile +2 -0
- package/templates/code/cw_log_forwarder.py +64 -0
- package/templates/code/serve +12 -1
- package/templates/deploy_notebook_generator.py +897 -0
- package/templates/do/.tune_helper.py +768 -0
- package/templates/do/adapter +107 -12
- package/templates/do/add-ic +155 -19
- package/templates/do/config +6 -0
- package/templates/do/export +19 -2
- package/templates/do/lib/endpoint-config.sh +3 -1
- package/templates/do/lib/inference-component.sh +5 -1
- package/templates/do/tune +1143 -0
package/bin/cli.js
CHANGED
|
@@ -102,6 +102,15 @@ program
|
|
|
102
102
|
.addOption(new Option('--max-loras <n>', 'Maximum concurrent LoRA adapters in GPU memory (default: 30)'))
|
|
103
103
|
.addOption(new Option('--max-lora-rank <n>', 'Maximum LoRA rank (default: 64)'))
|
|
104
104
|
|
|
105
|
+
// --- Benchmarking ---
|
|
106
|
+
.addOption(new Option('--include-benchmark', 'Include SageMaker AI Benchmarking (transformers/diffusors only)'))
|
|
107
|
+
.addOption(new Option('--benchmark-concurrency <n>', 'Benchmark concurrent requests (default: 10)'))
|
|
108
|
+
.addOption(new Option('--benchmark-input-tokens <n>', 'Benchmark mean input tokens (default: 550)'))
|
|
109
|
+
.addOption(new Option('--benchmark-output-tokens <n>', 'Benchmark mean output tokens (default: 150)'))
|
|
110
|
+
.addOption(new Option('--benchmark-streaming', 'Enable streaming in benchmark (default: true)'))
|
|
111
|
+
.addOption(new Option('--benchmark-request-count <n>', 'Total benchmark requests (optional)'))
|
|
112
|
+
.addOption(new Option('--benchmark-s3-output-path <path>', 'S3 path for benchmark results'))
|
|
113
|
+
|
|
105
114
|
// --- MCP & Discovery ---
|
|
106
115
|
.addOption(new Option('--smart', 'Enable Bedrock-powered smart mode on MCP servers'))
|
|
107
116
|
.addOption(new Option('--discover', 'Enable live registry lookups via MCP discovery'))
|
|
@@ -62,6 +62,7 @@
|
|
|
62
62
|
"sagemaker:DescribeEndpointConfig",
|
|
63
63
|
"sagemaker:DescribeModel",
|
|
64
64
|
"sagemaker:DescribeInferenceComponent",
|
|
65
|
+
"sagemaker:ListInferenceComponents",
|
|
65
66
|
"sagemaker:InvokeEndpoint",
|
|
66
67
|
"sagemaker:InvokeEndpointAsync"
|
|
67
68
|
],
|
|
@@ -131,11 +132,14 @@
|
|
|
131
132
|
"Action": [
|
|
132
133
|
"s3:GetObject",
|
|
133
134
|
"s3:PutObject",
|
|
135
|
+
"s3:AbortMultipartUpload",
|
|
134
136
|
"s3:ListBucket"
|
|
135
137
|
],
|
|
136
138
|
"Resource": [
|
|
137
139
|
"arn:aws:s3:::mlcc-*",
|
|
138
|
-
"arn:aws:s3:::mlcc-*/*"
|
|
140
|
+
"arn:aws:s3:::mlcc-*/*",
|
|
141
|
+
"arn:aws:s3:::ml-container-creator-*",
|
|
142
|
+
"arn:aws:s3:::ml-container-creator-*/*"
|
|
139
143
|
]
|
|
140
144
|
},
|
|
141
145
|
{
|
|
@@ -163,18 +167,55 @@
|
|
|
163
167
|
"arn:aws:secretsmanager:*:*:secret:ml-container-creator/*"
|
|
164
168
|
]
|
|
165
169
|
},
|
|
170
|
+
{
|
|
171
|
+
"Sid": "SNSPublish",
|
|
172
|
+
"Effect": "Allow",
|
|
173
|
+
"Action": "sns:Publish",
|
|
174
|
+
"Resource": [
|
|
175
|
+
{ "Fn::Sub": "arn:aws:sns:*:${AWS::AccountId}:mlcc-*" },
|
|
176
|
+
{ "Fn::Sub": "arn:aws:sns:*:${AWS::AccountId}:ml-container-creator-*" }
|
|
177
|
+
]
|
|
178
|
+
},
|
|
166
179
|
{
|
|
167
180
|
"Sid": "QuotaAndAvailability",
|
|
168
181
|
"Effect": "Allow",
|
|
169
182
|
"Action": [
|
|
170
183
|
"service-quotas:GetServiceQuota",
|
|
171
184
|
"service-quotas:ListServiceQuotas",
|
|
172
|
-
"ec2:DescribeCapacityReservations",
|
|
173
185
|
"sagemaker:ListTrainingPlans",
|
|
174
186
|
"sagemaker:DescribeTrainingPlan",
|
|
175
187
|
"sagemaker:ListEndpoints"
|
|
176
188
|
],
|
|
177
189
|
"Resource": "*"
|
|
190
|
+
},
|
|
191
|
+
{
|
|
192
|
+
"Sid": "SageMakerModelCustomization",
|
|
193
|
+
"Effect": "Allow",
|
|
194
|
+
"Action": [
|
|
195
|
+
"sagemaker:CreateTrainingJob",
|
|
196
|
+
"sagemaker:DescribeTrainingJob",
|
|
197
|
+
"sagemaker:ListTrainingJobs",
|
|
198
|
+
"sagemaker:StopTrainingJob",
|
|
199
|
+
"sagemaker:CreateModelPackage",
|
|
200
|
+
"sagemaker:CreateModelPackageGroup",
|
|
201
|
+
"sagemaker:DescribeModelPackage",
|
|
202
|
+
"sagemaker:DescribeModelPackageGroup",
|
|
203
|
+
"sagemaker:ListModelPackages",
|
|
204
|
+
"sagemaker:CallMlflowAppApi"
|
|
205
|
+
],
|
|
206
|
+
"Resource": "*"
|
|
207
|
+
},
|
|
208
|
+
{
|
|
209
|
+
"Sid": "SageMakerMLflow",
|
|
210
|
+
"Effect": "Allow",
|
|
211
|
+
"Action": "sagemaker-mlflow:*",
|
|
212
|
+
"Resource": "*"
|
|
213
|
+
},
|
|
214
|
+
{
|
|
215
|
+
"Sid": "LambdaInvokeForReward",
|
|
216
|
+
"Effect": "Allow",
|
|
217
|
+
"Action": "lambda:InvokeFunction",
|
|
218
|
+
"Resource": { "Fn::Sub": "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:*" }
|
|
178
219
|
}
|
|
179
220
|
]
|
|
180
221
|
}
|
|
@@ -285,6 +326,26 @@
|
|
|
285
326
|
{ "Key": "mlcc:purpose", "Value": "benchmark-results" }
|
|
286
327
|
]
|
|
287
328
|
}
|
|
329
|
+
},
|
|
330
|
+
|
|
331
|
+
"TuneS3Bucket": {
|
|
332
|
+
"Type": "AWS::S3::Bucket",
|
|
333
|
+
"Condition": "ShouldCreateS3Buckets",
|
|
334
|
+
"DeletionPolicy": "Retain",
|
|
335
|
+
"UpdateReplacePolicy": "Retain",
|
|
336
|
+
"Properties": {
|
|
337
|
+
"BucketName": { "Fn::Sub": "mlcc-tune-${AWS::AccountId}-${AWS::Region}" },
|
|
338
|
+
"VersioningConfiguration": { "Status": "Enabled" },
|
|
339
|
+
"BucketEncryption": {
|
|
340
|
+
"ServerSideEncryptionConfiguration": [
|
|
341
|
+
{ "ServerSideEncryptionByDefault": { "SSEAlgorithm": "AES256" } }
|
|
342
|
+
]
|
|
343
|
+
},
|
|
344
|
+
"Tags": [
|
|
345
|
+
{ "Key": "mlcc:managed-by", "Value": "ml-container-creator" },
|
|
346
|
+
{ "Key": "mlcc:purpose", "Value": "tune-datasets-and-output" }
|
|
347
|
+
]
|
|
348
|
+
}
|
|
288
349
|
}
|
|
289
350
|
},
|
|
290
351
|
|
|
@@ -327,9 +388,14 @@
|
|
|
327
388
|
"Description": "S3 bucket for benchmark results output",
|
|
328
389
|
"Value": { "Ref": "BenchmarkS3Bucket" }
|
|
329
390
|
},
|
|
391
|
+
"TuneS3BucketName": {
|
|
392
|
+
"Condition": "ShouldCreateS3Buckets",
|
|
393
|
+
"Description": "S3 bucket for tune datasets and output",
|
|
394
|
+
"Value": { "Ref": "TuneS3Bucket" }
|
|
395
|
+
},
|
|
330
396
|
"StackVersion": {
|
|
331
397
|
"Description": "Bootstrap stack template version for forward compatibility tracking",
|
|
332
|
-
"Value": "2026-05-
|
|
398
|
+
"Value": "2026-05-18"
|
|
333
399
|
}
|
|
334
400
|
}
|
|
335
401
|
}
|
package/package.json
CHANGED
package/src/app.js
CHANGED
|
@@ -16,6 +16,7 @@ import CommentGenerator from './lib/comment-generator.js';
|
|
|
16
16
|
import ConfigurationManager from './lib/configuration-manager.js';
|
|
17
17
|
import RegistryLoader from './lib/registry-loader.js';
|
|
18
18
|
import { resolvePrefixedEnvVars } from './lib/engine-prefix-resolver.js';
|
|
19
|
+
import { isTuneSupported } from './lib/tune-catalog-validator.js';
|
|
19
20
|
import ejs from 'ejs';
|
|
20
21
|
|
|
21
22
|
const __filename = fileURLToPath(import.meta.url);
|
|
@@ -350,6 +351,12 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
350
351
|
ignorePatterns.push('**/do/adapters/**');
|
|
351
352
|
}
|
|
352
353
|
|
|
354
|
+
// Exclude tune files when framework is NOT transformers OR deploymentTarget is batch-transform
|
|
355
|
+
if (architecture !== 'transformers' || answers.deploymentTarget === 'batch-transform') {
|
|
356
|
+
ignorePatterns.push('**/do/tune');
|
|
357
|
+
ignorePatterns.push('**/do/.tune_helper.py');
|
|
358
|
+
}
|
|
359
|
+
|
|
353
360
|
// Exclude do/test when hosted-model-endpoint is not selected
|
|
354
361
|
const testTypes = answers.testTypes || [];
|
|
355
362
|
if (!testTypes.includes('hosted-model-endpoint')) {
|
|
@@ -452,6 +459,13 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
452
459
|
_copyFile(path.join(LIB_DIR, 'asset-manager.js'), path.join(doLibDir, 'asset-manager.js'));
|
|
453
460
|
_copyFile(path.join(LIB_DIR, 'bootstrap-config.js'), path.join(doLibDir, 'bootstrap-config.js'));
|
|
454
461
|
|
|
462
|
+
// Copy tune catalog to generated project when tune is included
|
|
463
|
+
if (architecture === 'transformers' && answers.deploymentTarget !== 'batch-transform') {
|
|
464
|
+
const tuneCatalogSrc = path.join(GENERATOR_ROOT, 'config', 'tune-catalog.json');
|
|
465
|
+
const tuneCatalogDest = path.join(destDir, 'do', '.tune_catalog.json');
|
|
466
|
+
_copyFile(tuneCatalogSrc, tuneCatalogDest);
|
|
467
|
+
}
|
|
468
|
+
|
|
455
469
|
// Generate .gitignore with benchmarks/ when benchmarking is enabled
|
|
456
470
|
if (answers.includeBenchmark) {
|
|
457
471
|
const gitignorePath = path.join(destDir, '.gitignore');
|
|
@@ -742,6 +756,19 @@ async function _ensureTemplateVariables(answers, registryConfigManager = null) {
|
|
|
742
756
|
}
|
|
743
757
|
}
|
|
744
758
|
}
|
|
759
|
+
|
|
760
|
+
// Determine tune support based on model presence in the tune catalog.
|
|
761
|
+
// Used by the do/config template to write TUNE_SUPPORTED=true|false.
|
|
762
|
+
if (answers.tuneSupported === undefined) {
|
|
763
|
+
try {
|
|
764
|
+
const tuneCatalogPath = path.resolve(__dirname, '..', 'config', 'tune-catalog.json');
|
|
765
|
+
const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
|
|
766
|
+
const modelId = answers.modelName || '';
|
|
767
|
+
answers.tuneSupported = isTuneSupported(modelId, tuneCatalog);
|
|
768
|
+
} catch {
|
|
769
|
+
answers.tuneSupported = false;
|
|
770
|
+
}
|
|
771
|
+
}
|
|
745
772
|
}
|
|
746
773
|
|
|
747
774
|
/**
|
|
@@ -1083,7 +1110,8 @@ function _setExecutablePermissions(destDir) {
|
|
|
1083
1110
|
'do/optimize',
|
|
1084
1111
|
'do/status',
|
|
1085
1112
|
'do/add-ic',
|
|
1086
|
-
'do/adapter'
|
|
1113
|
+
'do/adapter',
|
|
1114
|
+
'do/tune'
|
|
1087
1115
|
];
|
|
1088
1116
|
|
|
1089
1117
|
shellScripts.forEach(script => {
|
|
@@ -182,38 +182,80 @@ export default class BootstrapCommandHandler {
|
|
|
182
182
|
this._displayProgress('☁️', 'Deploying bootstrap infrastructure stack...');
|
|
183
183
|
const stackName = `${STACK_NAME_PREFIX}-${profileName}`;
|
|
184
184
|
|
|
185
|
+
// Check for existing bootstrap stack in this account-region (resources are singletons)
|
|
185
186
|
try {
|
|
186
|
-
const
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
187
|
+
const existingStacks = this._execAws(
|
|
188
|
+
`cloudformation list-stacks --stack-status-filter CREATE_COMPLETE UPDATE_COMPLETE UPDATE_ROLLBACK_COMPLETE --query "StackSummaries[?starts_with(StackName,'${STACK_NAME_PREFIX}-')].StackName" --output json`,
|
|
189
|
+
awsProfile
|
|
190
|
+
);
|
|
191
|
+
const stacks = Array.isArray(existingStacks) ? existingStacks : [];
|
|
192
|
+
const otherStack = stacks.find(s => s !== stackName);
|
|
193
|
+
if (otherStack) {
|
|
194
|
+
console.log(` ℹ️ Bootstrap infrastructure already exists in ${accountId}/${region} (stack: ${otherStack})`);
|
|
195
|
+
console.log(' Reusing existing resources (IAM role, ECR repo are singletons per account-region).');
|
|
196
|
+
console.log(' Use `ml-container-creator bootstrap update` to apply latest permissions.\n');
|
|
197
|
+
|
|
198
|
+
// Read outputs from existing stack
|
|
199
|
+
const outputs = this._execAws(
|
|
200
|
+
`cloudformation describe-stacks --stack-name ${otherStack} --query "Stacks[0].Outputs" --output json`,
|
|
201
|
+
awsProfile
|
|
202
|
+
);
|
|
203
|
+
const stackOutputs = {};
|
|
204
|
+
if (Array.isArray(outputs)) {
|
|
205
|
+
for (const o of outputs) {
|
|
206
|
+
stackOutputs[o.OutputKey] = o.OutputValue;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
190
209
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
210
|
+
profileData.roleArn = stackOutputs.RoleArn;
|
|
211
|
+
profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
|
|
212
|
+
profileData.stackName = otherStack;
|
|
213
|
+
if (stackOutputs.AsyncS3BucketName) profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
|
|
214
|
+
if (stackOutputs.BatchS3BucketName) profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
|
|
215
|
+
if (stackOutputs.AdapterS3BucketName) profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
|
|
216
|
+
if (stackOutputs.BenchmarkS3BucketName) profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
|
|
195
217
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
}
|
|
199
|
-
if (stackOutputs.BatchS3BucketName) {
|
|
200
|
-
profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
|
|
218
|
+
// Skip stack deployment, continue to CI setup and profile save
|
|
219
|
+
console.log(' ✅ Existing bootstrap infrastructure reused');
|
|
201
220
|
}
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
}
|
|
205
|
-
if (stackOutputs.BenchmarkS3BucketName) {
|
|
206
|
-
profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
|
|
207
|
-
}
|
|
208
|
-
|
|
209
|
-
console.log(' ✅ Bootstrap stack deployed successfully');
|
|
210
|
-
} catch (error) {
|
|
211
|
-
console.log(` ❌ Stack deployment failed: ${error.message}`);
|
|
212
|
-
console.log(' Check the CloudFormation console for details:');
|
|
213
|
-
console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
|
|
214
|
-
return;
|
|
221
|
+
} catch (_) {
|
|
222
|
+
// If list-stacks fails, proceed with normal deployment
|
|
215
223
|
}
|
|
216
224
|
|
|
225
|
+
if (!profileData.stackName) {
|
|
226
|
+
try {
|
|
227
|
+
const stackOutputs = this._deployStack(stackName, {
|
|
228
|
+
CreateS3Buckets: createS3Buckets ? 'true' : 'false',
|
|
229
|
+
UseExistingRoleArn: useExistingRoleArn
|
|
230
|
+
}, awsProfile, region);
|
|
231
|
+
|
|
232
|
+
// Read outputs into profile data
|
|
233
|
+
profileData.roleArn = stackOutputs.RoleArn;
|
|
234
|
+
profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
|
|
235
|
+
profileData.stackName = stackName;
|
|
236
|
+
|
|
237
|
+
if (stackOutputs.AsyncS3BucketName) {
|
|
238
|
+
profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
|
|
239
|
+
}
|
|
240
|
+
if (stackOutputs.BatchS3BucketName) {
|
|
241
|
+
profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
|
|
242
|
+
}
|
|
243
|
+
if (stackOutputs.AdapterS3BucketName) {
|
|
244
|
+
profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
|
|
245
|
+
}
|
|
246
|
+
if (stackOutputs.BenchmarkS3BucketName) {
|
|
247
|
+
profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
console.log(' ✅ Bootstrap stack deployed successfully');
|
|
251
|
+
} catch (error) {
|
|
252
|
+
console.log(` ❌ Stack deployment failed: ${error.message}`);
|
|
253
|
+
console.log(' Check the CloudFormation console for details:');
|
|
254
|
+
console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
|
|
255
|
+
return;
|
|
256
|
+
}
|
|
257
|
+
} // end if (!profileData.stackName)
|
|
258
|
+
|
|
217
259
|
// Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
|
|
218
260
|
this._displayProgress('🧪', 'CI Testing Infrastructure...');
|
|
219
261
|
try {
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Tune Catalog Validator
|
|
6
|
+
*
|
|
7
|
+
* Validates model IDs, techniques, and training types against the
|
|
8
|
+
* Supported Model Catalog. Provides descriptive error messages when
|
|
9
|
+
* a requested configuration is not supported.
|
|
10
|
+
*
|
|
11
|
+
* Requirements: 1.3, 1.4, 1.5, 1.6, 4.1, 4.2, 4.3, 4.4
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Look up a model entry in the catalog by model ID.
|
|
16
|
+
* @param {string} modelId - The JumpStart model ID to look up
|
|
17
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
18
|
+
* @returns {Object|null} The catalog entry for the model, or null if not found
|
|
19
|
+
*/
|
|
20
|
+
export function lookupModel(modelId, catalog) {
|
|
21
|
+
if (!catalog || !catalog.models) {
|
|
22
|
+
return null;
|
|
23
|
+
}
|
|
24
|
+
if (!Object.hasOwn(catalog.models, modelId)) {
|
|
25
|
+
return null;
|
|
26
|
+
}
|
|
27
|
+
return catalog.models[modelId] || null;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Check whether a model ID is present in the Supported Model Catalog.
|
|
32
|
+
* @param {string} modelId - The JumpStart model ID to check
|
|
33
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
34
|
+
* @returns {boolean} True if the model is in the catalog
|
|
35
|
+
*/
|
|
36
|
+
export function isTuneSupported(modelId, catalog) {
|
|
37
|
+
return lookupModel(modelId, catalog) !== null;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Validate that a model ID exists in the catalog.
|
|
42
|
+
* Returns a descriptive error when the model is not supported, including
|
|
43
|
+
* the model name, supported families, and a reference to `do/train`.
|
|
44
|
+
* @param {string} modelId - The JumpStart model ID to validate
|
|
45
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
46
|
+
* @returns {{ valid: boolean, error?: string }}
|
|
47
|
+
*/
|
|
48
|
+
export function validateModel(modelId, catalog) {
|
|
49
|
+
if (isTuneSupported(modelId, catalog)) {
|
|
50
|
+
return { valid: true };
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const families = _getSupportedFamilies(catalog);
|
|
54
|
+
const familyList = families.join(', ');
|
|
55
|
+
|
|
56
|
+
return {
|
|
57
|
+
valid: false,
|
|
58
|
+
error: `Model "${modelId}" is not yet supported for managed serverless customization. ` +
|
|
59
|
+
`Supported model families: ${familyList}. ` +
|
|
60
|
+
'For custom training workflows, see `do/train`.'
|
|
61
|
+
};
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* Validate that a technique is supported for the given model.
|
|
66
|
+
* Returns a descriptive error listing the supported techniques when
|
|
67
|
+
* the requested technique is not available.
|
|
68
|
+
* @param {string} modelId - The JumpStart model ID
|
|
69
|
+
* @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
|
|
70
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
71
|
+
* @returns {{ valid: boolean, error?: string }}
|
|
72
|
+
*/
|
|
73
|
+
export function validateTechnique(modelId, technique, catalog) {
|
|
74
|
+
const entry = lookupModel(modelId, catalog);
|
|
75
|
+
if (!entry) {
|
|
76
|
+
return validateModel(modelId, catalog);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
const supportedTechniques = Object.keys(entry.techniques);
|
|
80
|
+
if (supportedTechniques.includes(technique)) {
|
|
81
|
+
return { valid: true };
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return {
|
|
85
|
+
valid: false,
|
|
86
|
+
error: `Technique "${technique}" is not supported for model "${modelId}". ` +
|
|
87
|
+
`Supported techniques: ${supportedTechniques.join(', ')}.`
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Validate that a training type is supported for the given model and technique.
|
|
93
|
+
* Returns a descriptive error listing the supported training types when
|
|
94
|
+
* the requested type is not available.
|
|
95
|
+
* @param {string} modelId - The JumpStart model ID
|
|
96
|
+
* @param {string} technique - The technique (e.g., 'sft', 'dpo')
|
|
97
|
+
* @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
|
|
98
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
99
|
+
* @returns {{ valid: boolean, error?: string }}
|
|
100
|
+
*/
|
|
101
|
+
export function validateTrainingType(modelId, technique, trainingType, catalog) {
|
|
102
|
+
const entry = lookupModel(modelId, catalog);
|
|
103
|
+
if (!entry) {
|
|
104
|
+
return validateModel(modelId, catalog);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
const techniqueEntry = entry.techniques[technique];
|
|
108
|
+
if (!techniqueEntry) {
|
|
109
|
+
return validateTechnique(modelId, technique, catalog);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
const supportedTypes = techniqueEntry.trainingTypes;
|
|
113
|
+
if (supportedTypes.includes(trainingType)) {
|
|
114
|
+
return { valid: true };
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
valid: false,
|
|
119
|
+
error: `Training type "${trainingType}" is not supported for model "${modelId}" ` +
|
|
120
|
+
`with technique "${technique}". ` +
|
|
121
|
+
`Supported training types: ${supportedTypes.join(', ')}.`
|
|
122
|
+
};
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
/**
|
|
126
|
+
* Extract unique model family names from the catalog.
|
|
127
|
+
* @param {Object} catalog - The tune catalog object
|
|
128
|
+
* @returns {string[]} Array of unique family names
|
|
129
|
+
* @private
|
|
130
|
+
*/
|
|
131
|
+
function _getSupportedFamilies(catalog) {
|
|
132
|
+
if (!catalog || !catalog.models) {
|
|
133
|
+
return [];
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
const families = new Set();
|
|
137
|
+
for (const entry of Object.values(catalog.models)) {
|
|
138
|
+
if (entry.family) {
|
|
139
|
+
families.add(entry.family);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
return [...families];
|
|
143
|
+
}
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Tune Config State Manager
|
|
6
|
+
*
|
|
7
|
+
* JavaScript module that mimics the bash _update_config_var() behavior
|
|
8
|
+
* from do/tune for testing purposes. Manages config variables written
|
|
9
|
+
* after job submission.
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
import { readFileSync, writeFileSync } from 'node:fs';
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Update or add a config variable in a do/config-style file.
|
|
16
|
+
* Mimics the bash _update_config_var() function:
|
|
17
|
+
* - If the variable exists (line starts with `export VAR_NAME=`), replace it
|
|
18
|
+
* - Otherwise, append a new line
|
|
19
|
+
*
|
|
20
|
+
* @param {string} configPath - Path to the config file
|
|
21
|
+
* @param {string} varName - Variable name (e.g., TUNE_JOB_NAME_SFT)
|
|
22
|
+
* @param {string} varValue - Variable value
|
|
23
|
+
*/
|
|
24
|
+
export function updateConfigVar(configPath, varName, varValue) {
|
|
25
|
+
let content = readFileSync(configPath, 'utf8');
|
|
26
|
+
const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
|
|
27
|
+
|
|
28
|
+
if (pattern.test(content)) {
|
|
29
|
+
content = content.replace(pattern, `export ${varName}="${varValue}"`);
|
|
30
|
+
} else {
|
|
31
|
+
if (content.length > 0 && !content.endsWith('\n')) {
|
|
32
|
+
content += '\n';
|
|
33
|
+
}
|
|
34
|
+
content += `export ${varName}="${varValue}"\n`;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
writeFileSync(configPath, content, 'utf8');
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Read a config variable from a do/config-style file.
|
|
42
|
+
*
|
|
43
|
+
* @param {string} configPath - Path to the config file
|
|
44
|
+
* @param {string} varName - Variable name to read
|
|
45
|
+
* @returns {string|null} The variable value, or null if not found
|
|
46
|
+
*/
|
|
47
|
+
export function readConfigVar(configPath, varName) {
|
|
48
|
+
const content = readFileSync(configPath, 'utf8');
|
|
49
|
+
const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
|
|
50
|
+
const match = content.match(pattern);
|
|
51
|
+
return match ? match[1] : null;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Simulate the config writes that happen after a successful job submission.
|
|
56
|
+
* This mirrors the behavior in do/tune's _submit_job() function.
|
|
57
|
+
*
|
|
58
|
+
* @param {string} configPath - Path to the config file
|
|
59
|
+
* @param {object} params - Submission parameters
|
|
60
|
+
* @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
|
|
61
|
+
* @param {string} params.trainingType - Training type (lora, full-rank)
|
|
62
|
+
* @param {string} params.datasetPath - Dataset path (s3://... or hf://...)
|
|
63
|
+
* @param {string} params.jobName - Generated job name
|
|
64
|
+
*/
|
|
65
|
+
export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
|
|
66
|
+
const techniqueUpper = technique.toUpperCase();
|
|
67
|
+
updateConfigVar(configPath, `TUNE_JOB_NAME_${techniqueUpper}`, jobName);
|
|
68
|
+
updateConfigVar(configPath, 'TUNE_TECHNIQUE', technique);
|
|
69
|
+
updateConfigVar(configPath, 'TUNE_TRAINING_TYPE', trainingType);
|
|
70
|
+
updateConfigVar(configPath, 'TUNE_DATASET_PATH', datasetPath);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/**
|
|
74
|
+
* Simulate the config writes that happen after a job completes successfully.
|
|
75
|
+
* This mirrors the behavior in do/tune's _handle_completion() function.
|
|
76
|
+
*
|
|
77
|
+
* @param {string} configPath - Path to the config file
|
|
78
|
+
* @param {object} params - Completion parameters
|
|
79
|
+
* @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
|
|
80
|
+
* @param {string} params.trainingType - Training type (lora, full-rank)
|
|
81
|
+
* @param {string} params.artifactPath - S3 path to the output artifact
|
|
82
|
+
* @param {string} params.outputType - Output type (adapter, full-model)
|
|
83
|
+
*/
|
|
84
|
+
export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType }) {
|
|
85
|
+
const techniqueUpper = technique.toUpperCase();
|
|
86
|
+
|
|
87
|
+
if (trainingType === 'lora') {
|
|
88
|
+
updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
|
|
89
|
+
} else if (trainingType === 'full-rank') {
|
|
90
|
+
updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
|
|
94
|
+
updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
/**
|
|
98
|
+
* Generate a job name following the pattern used by do/tune.
|
|
99
|
+
* Pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
|
|
100
|
+
*
|
|
101
|
+
* @param {string} projectName - Project name
|
|
102
|
+
* @param {string} technique - Technique (sft, dpo, rlaif, rlvr)
|
|
103
|
+
* @param {Date} [timestamp] - Optional timestamp (defaults to now)
|
|
104
|
+
* @returns {string} Generated job name
|
|
105
|
+
*/
|
|
106
|
+
export function generateJobName(projectName, technique, timestamp = new Date()) {
|
|
107
|
+
const year = timestamp.getFullYear().toString();
|
|
108
|
+
const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
|
|
109
|
+
const day = timestamp.getDate().toString().padStart(2, '0');
|
|
110
|
+
const hours = timestamp.getHours().toString().padStart(2, '0');
|
|
111
|
+
const minutes = timestamp.getMinutes().toString().padStart(2, '0');
|
|
112
|
+
const seconds = timestamp.getSeconds().toString().padStart(2, '0');
|
|
113
|
+
const dateStr = `${year}${month}${day}`;
|
|
114
|
+
const timeStr = `${hours}${minutes}${seconds}`;
|
|
115
|
+
return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
|
|
116
|
+
}
|