@aws/ml-container-creator 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +9 -0
- package/config/bootstrap-stack.json +106 -9
- package/infra/ci-harness/package-lock.json +5 -1
- package/package.json +1 -1
- package/servers/instance-sizer/index.js +4 -4
- package/servers/instance-sizer/lib/model-resolver.js +1 -1
- package/servers/lib/catalogs/model-sizes.json +135 -90
- package/servers/lib/catalogs/models.json +483 -411
- package/src/app.js +29 -1
- package/src/lib/bootstrap-command-handler.js +71 -23
- package/src/lib/cli-handler.js +1 -1
- package/src/lib/config-manager.js +1 -1
- package/src/lib/mcp-client.js +3 -3
- package/src/lib/prompt-runner.js +5 -5
- package/src/lib/prompts.js +31 -5
- package/src/lib/tune-catalog-validator.js +143 -0
- package/src/lib/tune-config-state.js +116 -0
- package/src/lib/tune-dataset-validator.js +279 -0
- package/src/lib/tune-output-resolver.js +66 -0
- package/templates/do/.tune_helper.py +768 -0
- package/templates/do/adapter +128 -17
- package/templates/do/add-ic +155 -19
- package/templates/do/config +11 -4
- package/templates/do/tune +1143 -0
package/src/app.js
CHANGED
|
@@ -16,6 +16,7 @@ import CommentGenerator from './lib/comment-generator.js';
|
|
|
16
16
|
import ConfigurationManager from './lib/configuration-manager.js';
|
|
17
17
|
import RegistryLoader from './lib/registry-loader.js';
|
|
18
18
|
import { resolvePrefixedEnvVars } from './lib/engine-prefix-resolver.js';
|
|
19
|
+
import { isTuneSupported } from './lib/tune-catalog-validator.js';
|
|
19
20
|
import ejs from 'ejs';
|
|
20
21
|
|
|
21
22
|
const __filename = fileURLToPath(import.meta.url);
|
|
@@ -350,6 +351,12 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
350
351
|
ignorePatterns.push('**/do/adapters/**');
|
|
351
352
|
}
|
|
352
353
|
|
|
354
|
+
// Exclude tune files when framework is NOT transformers OR deploymentTarget is batch-transform
|
|
355
|
+
if (architecture !== 'transformers' || answers.deploymentTarget === 'batch-transform') {
|
|
356
|
+
ignorePatterns.push('**/do/tune');
|
|
357
|
+
ignorePatterns.push('**/do/.tune_helper.py');
|
|
358
|
+
}
|
|
359
|
+
|
|
353
360
|
// Exclude do/test when hosted-model-endpoint is not selected
|
|
354
361
|
const testTypes = answers.testTypes || [];
|
|
355
362
|
if (!testTypes.includes('hosted-model-endpoint')) {
|
|
@@ -452,6 +459,13 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
452
459
|
_copyFile(path.join(LIB_DIR, 'asset-manager.js'), path.join(doLibDir, 'asset-manager.js'));
|
|
453
460
|
_copyFile(path.join(LIB_DIR, 'bootstrap-config.js'), path.join(doLibDir, 'bootstrap-config.js'));
|
|
454
461
|
|
|
462
|
+
// Copy tune catalog to generated project when tune is included
|
|
463
|
+
if (architecture === 'transformers' && answers.deploymentTarget !== 'batch-transform') {
|
|
464
|
+
const tuneCatalogSrc = path.join(GENERATOR_ROOT, 'config', 'tune-catalog.json');
|
|
465
|
+
const tuneCatalogDest = path.join(destDir, 'do', '.tune_catalog.json');
|
|
466
|
+
_copyFile(tuneCatalogSrc, tuneCatalogDest);
|
|
467
|
+
}
|
|
468
|
+
|
|
455
469
|
// Generate .gitignore with benchmarks/ when benchmarking is enabled
|
|
456
470
|
if (answers.includeBenchmark) {
|
|
457
471
|
const gitignorePath = path.join(destDir, '.gitignore');
|
|
@@ -742,6 +756,19 @@ async function _ensureTemplateVariables(answers, registryConfigManager = null) {
|
|
|
742
756
|
}
|
|
743
757
|
}
|
|
744
758
|
}
|
|
759
|
+
|
|
760
|
+
// Determine tune support based on model presence in the tune catalog.
|
|
761
|
+
// Used by the do/config template to write TUNE_SUPPORTED=true|false.
|
|
762
|
+
if (answers.tuneSupported === undefined) {
|
|
763
|
+
try {
|
|
764
|
+
const tuneCatalogPath = path.resolve(__dirname, '..', 'config', 'tune-catalog.json');
|
|
765
|
+
const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
|
|
766
|
+
const modelId = answers.modelName || '';
|
|
767
|
+
answers.tuneSupported = isTuneSupported(modelId, tuneCatalog);
|
|
768
|
+
} catch {
|
|
769
|
+
answers.tuneSupported = false;
|
|
770
|
+
}
|
|
771
|
+
}
|
|
745
772
|
}
|
|
746
773
|
|
|
747
774
|
/**
|
|
@@ -1083,7 +1110,8 @@ function _setExecutablePermissions(destDir) {
|
|
|
1083
1110
|
'do/optimize',
|
|
1084
1111
|
'do/status',
|
|
1085
1112
|
'do/add-ic',
|
|
1086
|
-
'do/adapter'
|
|
1113
|
+
'do/adapter',
|
|
1114
|
+
'do/tune'
|
|
1087
1115
|
];
|
|
1088
1116
|
|
|
1089
1117
|
shellScripts.forEach(script => {
|
|
@@ -182,35 +182,80 @@ export default class BootstrapCommandHandler {
|
|
|
182
182
|
this._displayProgress('☁️', 'Deploying bootstrap infrastructure stack...');
|
|
183
183
|
const stackName = `${STACK_NAME_PREFIX}-${profileName}`;
|
|
184
184
|
|
|
185
|
+
// Check for existing bootstrap stack in this account-region (resources are singletons)
|
|
185
186
|
try {
|
|
186
|
-
const
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
187
|
+
const existingStacks = this._execAws(
|
|
188
|
+
`cloudformation list-stacks --stack-status-filter CREATE_COMPLETE UPDATE_COMPLETE UPDATE_ROLLBACK_COMPLETE --query "StackSummaries[?starts_with(StackName,'${STACK_NAME_PREFIX}-')].StackName" --output json`,
|
|
189
|
+
awsProfile
|
|
190
|
+
);
|
|
191
|
+
const stacks = Array.isArray(existingStacks) ? existingStacks : [];
|
|
192
|
+
const otherStack = stacks.find(s => s !== stackName);
|
|
193
|
+
if (otherStack) {
|
|
194
|
+
console.log(` ℹ️ Bootstrap infrastructure already exists in ${accountId}/${region} (stack: ${otherStack})`);
|
|
195
|
+
console.log(' Reusing existing resources (IAM role, ECR repo are singletons per account-region).');
|
|
196
|
+
console.log(' Use `ml-container-creator bootstrap update` to apply latest permissions.\n');
|
|
197
|
+
|
|
198
|
+
// Read outputs from existing stack
|
|
199
|
+
const outputs = this._execAws(
|
|
200
|
+
`cloudformation describe-stacks --stack-name ${otherStack} --query "Stacks[0].Outputs" --output json`,
|
|
201
|
+
awsProfile
|
|
202
|
+
);
|
|
203
|
+
const stackOutputs = {};
|
|
204
|
+
if (Array.isArray(outputs)) {
|
|
205
|
+
for (const o of outputs) {
|
|
206
|
+
stackOutputs[o.OutputKey] = o.OutputValue;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
190
209
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
210
|
+
profileData.roleArn = stackOutputs.RoleArn;
|
|
211
|
+
profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
|
|
212
|
+
profileData.stackName = otherStack;
|
|
213
|
+
if (stackOutputs.AsyncS3BucketName) profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
|
|
214
|
+
if (stackOutputs.BatchS3BucketName) profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
|
|
215
|
+
if (stackOutputs.AdapterS3BucketName) profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
|
|
216
|
+
if (stackOutputs.BenchmarkS3BucketName) profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
|
|
195
217
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
}
|
|
199
|
-
if (stackOutputs.BatchS3BucketName) {
|
|
200
|
-
profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
|
|
218
|
+
// Skip stack deployment, continue to CI setup and profile save
|
|
219
|
+
console.log(' ✅ Existing bootstrap infrastructure reused');
|
|
201
220
|
}
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
}
|
|
205
|
-
|
|
206
|
-
console.log(' ✅ Bootstrap stack deployed successfully');
|
|
207
|
-
} catch (error) {
|
|
208
|
-
console.log(` ❌ Stack deployment failed: ${error.message}`);
|
|
209
|
-
console.log(' Check the CloudFormation console for details:');
|
|
210
|
-
console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
|
|
211
|
-
return;
|
|
221
|
+
} catch (_) {
|
|
222
|
+
// If list-stacks fails, proceed with normal deployment
|
|
212
223
|
}
|
|
213
224
|
|
|
225
|
+
if (!profileData.stackName) {
|
|
226
|
+
try {
|
|
227
|
+
const stackOutputs = this._deployStack(stackName, {
|
|
228
|
+
CreateS3Buckets: createS3Buckets ? 'true' : 'false',
|
|
229
|
+
UseExistingRoleArn: useExistingRoleArn
|
|
230
|
+
}, awsProfile, region);
|
|
231
|
+
|
|
232
|
+
// Read outputs into profile data
|
|
233
|
+
profileData.roleArn = stackOutputs.RoleArn;
|
|
234
|
+
profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
|
|
235
|
+
profileData.stackName = stackName;
|
|
236
|
+
|
|
237
|
+
if (stackOutputs.AsyncS3BucketName) {
|
|
238
|
+
profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
|
|
239
|
+
}
|
|
240
|
+
if (stackOutputs.BatchS3BucketName) {
|
|
241
|
+
profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
|
|
242
|
+
}
|
|
243
|
+
if (stackOutputs.AdapterS3BucketName) {
|
|
244
|
+
profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
|
|
245
|
+
}
|
|
246
|
+
if (stackOutputs.BenchmarkS3BucketName) {
|
|
247
|
+
profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
console.log(' ✅ Bootstrap stack deployed successfully');
|
|
251
|
+
} catch (error) {
|
|
252
|
+
console.log(` ❌ Stack deployment failed: ${error.message}`);
|
|
253
|
+
console.log(' Check the CloudFormation console for details:');
|
|
254
|
+
console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
|
|
255
|
+
return;
|
|
256
|
+
}
|
|
257
|
+
} // end if (!profileData.stackName)
|
|
258
|
+
|
|
214
259
|
// Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
|
|
215
260
|
this._displayProgress('🧪', 'CI Testing Infrastructure...');
|
|
216
261
|
try {
|
|
@@ -390,6 +435,9 @@ export default class BootstrapCommandHandler {
|
|
|
390
435
|
if (outputs.BatchS3BucketName) {
|
|
391
436
|
console.log(` ✅ S3 bucket (batch): ${outputs.BatchS3BucketName}`);
|
|
392
437
|
}
|
|
438
|
+
if (outputs.AdapterS3BucketName) {
|
|
439
|
+
console.log(` ✅ S3 bucket (adapters): ${outputs.AdapterS3BucketName}`);
|
|
440
|
+
}
|
|
393
441
|
if (outputs.BenchmarkS3BucketName) {
|
|
394
442
|
console.log(` ✅ S3 bucket (benchmark): ${outputs.BenchmarkS3BucketName}`);
|
|
395
443
|
}
|
package/src/lib/cli-handler.js
CHANGED
|
@@ -204,7 +204,7 @@ VALIDATION OPTIONS:
|
|
|
204
204
|
|
|
205
205
|
MCP OPTIONS:
|
|
206
206
|
--smart Enable Bedrock-powered smart mode on all MCP servers
|
|
207
|
-
--discover
|
|
207
|
+
--no-discover Disable live registry lookups (HuggingFace API, quota checks) — catalog-only mode
|
|
208
208
|
|
|
209
209
|
REGISTRY SYSTEM:
|
|
210
210
|
The generator includes built-in registries for frameworks, models, and instance types:
|
|
@@ -1631,7 +1631,7 @@ export default class ConfigManager {
|
|
|
1631
1631
|
if (!mcpServerConfigs || !mcpServerConfigs[serverName]) return null;
|
|
1632
1632
|
|
|
1633
1633
|
const smart = this.options.smart === true;
|
|
1634
|
-
const discover = this.options.discover
|
|
1634
|
+
const discover = this.options.discover !== false;
|
|
1635
1635
|
const serverConfig = mcpServerConfigs[serverName];
|
|
1636
1636
|
|
|
1637
1637
|
// Build a custom McpClient that passes context through
|
package/src/lib/mcp-client.js
CHANGED
|
@@ -32,7 +32,7 @@ class McpClient {
|
|
|
32
32
|
this.timeout = options.timeout || DEFAULT_TIMEOUT;
|
|
33
33
|
this.parameterMatrix = options.parameterMatrix || {};
|
|
34
34
|
this.smart = options.smart || false;
|
|
35
|
-
this.discover = options.discover
|
|
35
|
+
this.discover = options.discover !== undefined ? options.discover : true;
|
|
36
36
|
this._transport = null;
|
|
37
37
|
this._client = null;
|
|
38
38
|
this._diagnosticMessage = null;
|
|
@@ -98,10 +98,10 @@ class McpClient {
|
|
|
98
98
|
|
|
99
99
|
// Build environment: merge process.env with server-specific env
|
|
100
100
|
// When --smart flag is active, inject BEDROCK_SMART=true for this run
|
|
101
|
-
//
|
|
101
|
+
// Discover mode is now default; inject DISCOVER_MODE=false only when explicitly disabled
|
|
102
102
|
// Always pass process.env so child processes inherit AWS credentials, profiles, etc.
|
|
103
103
|
const smartEnv = this.smart ? { BEDROCK_SMART: 'true' } : {};
|
|
104
|
-
const discoverEnv = this.discover ? {
|
|
104
|
+
const discoverEnv = this.discover === false ? { DISCOVER_MODE: 'false' } : {};
|
|
105
105
|
const serverEnv = env && Object.keys(env).length > 0 ? env : {};
|
|
106
106
|
const spawnEnv = { ...process.env, ...smartEnv, ...discoverEnv, ...serverEnv };
|
|
107
107
|
|
package/src/lib/prompt-runner.js
CHANGED
|
@@ -1098,9 +1098,9 @@ export default class PromptRunner {
|
|
|
1098
1098
|
if (!modelName || modelName === 'Custom (enter manually)') return;
|
|
1099
1099
|
|
|
1100
1100
|
const smart = this.options.smart === true;
|
|
1101
|
-
const discover = this.options.discover
|
|
1101
|
+
const discover = this.options.discover !== false;
|
|
1102
1102
|
|
|
1103
|
-
const modeLabel = [smart && '[smart]', discover && '[discover]'].filter(Boolean).join(' ');
|
|
1103
|
+
const modeLabel = [smart && '[smart]', !discover && '[no-discover]'].filter(Boolean).join(' ');
|
|
1104
1104
|
console.log(` 🔍 Querying instance-sizer${modeLabel ? ` ${modeLabel}` : ''}...`);
|
|
1105
1105
|
|
|
1106
1106
|
try {
|
|
@@ -1115,8 +1115,8 @@ export default class PromptRunner {
|
|
|
1115
1115
|
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
|
|
1116
1116
|
|
|
1117
1117
|
const serverArgs = [...(serverConfig.args || [])];
|
|
1118
|
-
if (discover && !serverArgs.includes('--discover')) {
|
|
1119
|
-
serverArgs.push('--discover');
|
|
1118
|
+
if (!discover && !serverArgs.includes('--no-discover')) {
|
|
1119
|
+
serverArgs.push('--no-discover');
|
|
1120
1120
|
}
|
|
1121
1121
|
|
|
1122
1122
|
const transport = new StdioClientTransport({
|
|
@@ -1375,7 +1375,7 @@ export default class PromptRunner {
|
|
|
1375
1375
|
if (!mcpServers.includes('base-image-picker')) return;
|
|
1376
1376
|
|
|
1377
1377
|
const smart = this.options.smart === true;
|
|
1378
|
-
const discover = this.options.discover
|
|
1378
|
+
const discover = this.options.discover !== false;
|
|
1379
1379
|
const framework = frameworkAnswers.framework;
|
|
1380
1380
|
const modelServer = frameworkAnswers.modelServer;
|
|
1381
1381
|
const architecture = frameworkAnswers.architecture || frameworkAnswers.deploymentConfig?.split('-')[0];
|
package/src/lib/prompts.js
CHANGED
|
@@ -399,9 +399,33 @@ const modelFormatPrompts = [
|
|
|
399
399
|
];
|
|
400
400
|
}
|
|
401
401
|
return [
|
|
402
|
-
'
|
|
403
|
-
'meta-llama/Llama-3.2-3B-Instruct',
|
|
402
|
+
{ type: 'separator', separator: '── Meta Llama ──' },
|
|
404
403
|
'meta-llama/Llama-3.2-1B-Instruct',
|
|
404
|
+
'meta-llama/Llama-3.2-3B-Instruct',
|
|
405
|
+
'meta-llama/Llama-3.1-8B-Instruct',
|
|
406
|
+
'meta-llama/Llama-3.3-70B-Instruct',
|
|
407
|
+
{ type: 'separator', separator: '── Qwen (Alibaba) ──' },
|
|
408
|
+
'Qwen/Qwen3-0.6B',
|
|
409
|
+
'Qwen/Qwen3-1.7B',
|
|
410
|
+
'Qwen/Qwen3-4B',
|
|
411
|
+
'Qwen/Qwen3-8B',
|
|
412
|
+
'Qwen/Qwen3-14B',
|
|
413
|
+
'Qwen/Qwen3-32B',
|
|
414
|
+
'Qwen/Qwen2.5-7B-Instruct',
|
|
415
|
+
'Qwen/Qwen2.5-14B-Instruct',
|
|
416
|
+
'Qwen/Qwen2.5-32B-Instruct',
|
|
417
|
+
'Qwen/Qwen2.5-72B-Instruct',
|
|
418
|
+
{ type: 'separator', separator: '── DeepSeek ──' },
|
|
419
|
+
'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',
|
|
420
|
+
'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
|
|
421
|
+
'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
|
|
422
|
+
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
|
|
423
|
+
'deepseek-ai/DeepSeek-R1-Distill-Llama-8B',
|
|
424
|
+
'deepseek-ai/DeepSeek-R1-Distill-Llama-70B',
|
|
425
|
+
{ type: 'separator', separator: '── OpenAI ──' },
|
|
426
|
+
'openai/gpt-oss-20b',
|
|
427
|
+
'openai/gpt-oss-120b',
|
|
428
|
+
{ type: 'separator', separator: '──────────────' },
|
|
405
429
|
'Custom (enter manually)'
|
|
406
430
|
];
|
|
407
431
|
},
|
|
@@ -413,7 +437,7 @@ const modelFormatPrompts = [
|
|
|
413
437
|
if (architecture === 'diffusors') {
|
|
414
438
|
return 'stabilityai/stable-diffusion-3.5-medium';
|
|
415
439
|
}
|
|
416
|
-
return '
|
|
440
|
+
return 'meta-llama/Llama-3.1-8B-Instruct';
|
|
417
441
|
},
|
|
418
442
|
when: answers => {
|
|
419
443
|
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0];
|
|
@@ -528,9 +552,11 @@ const modelProfilePrompts = [
|
|
|
528
552
|
*/
|
|
529
553
|
// eslint-disable-next-line no-unused-vars -- reference list for future use
|
|
530
554
|
const EXAMPLE_MODEL_IDS = [
|
|
531
|
-
'
|
|
555
|
+
'meta-llama/Llama-3.1-8B-Instruct',
|
|
532
556
|
'meta-llama/Llama-3.2-3B-Instruct',
|
|
533
|
-
'
|
|
557
|
+
'Qwen/Qwen3-8B',
|
|
558
|
+
'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
|
|
559
|
+
'openai/gpt-oss-20b'
|
|
534
560
|
];
|
|
535
561
|
|
|
536
562
|
const hfTokenPrompts = [
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Tune Catalog Validator
|
|
6
|
+
*
|
|
7
|
+
* Validates model IDs, techniques, and training types against the
|
|
8
|
+
* Supported Model Catalog. Provides descriptive error messages when
|
|
9
|
+
* a requested configuration is not supported.
|
|
10
|
+
*
|
|
11
|
+
* Requirements: 1.3, 1.4, 1.5, 1.6, 4.1, 4.2, 4.3, 4.4
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Look up a model entry in the catalog by model ID.
|
|
16
|
+
* @param {string} modelId - The JumpStart model ID to look up
|
|
17
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
18
|
+
* @returns {Object|null} The catalog entry for the model, or null if not found
|
|
19
|
+
*/
|
|
20
|
+
export function lookupModel(modelId, catalog) {
|
|
21
|
+
if (!catalog || !catalog.models) {
|
|
22
|
+
return null;
|
|
23
|
+
}
|
|
24
|
+
if (!Object.hasOwn(catalog.models, modelId)) {
|
|
25
|
+
return null;
|
|
26
|
+
}
|
|
27
|
+
return catalog.models[modelId] || null;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Check whether a model ID is present in the Supported Model Catalog.
|
|
32
|
+
* @param {string} modelId - The JumpStart model ID to check
|
|
33
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
34
|
+
* @returns {boolean} True if the model is in the catalog
|
|
35
|
+
*/
|
|
36
|
+
export function isTuneSupported(modelId, catalog) {
|
|
37
|
+
return lookupModel(modelId, catalog) !== null;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Validate that a model ID exists in the catalog.
|
|
42
|
+
* Returns a descriptive error when the model is not supported, including
|
|
43
|
+
* the model name, supported families, and a reference to `do/train`.
|
|
44
|
+
* @param {string} modelId - The JumpStart model ID to validate
|
|
45
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
46
|
+
* @returns {{ valid: boolean, error?: string }}
|
|
47
|
+
*/
|
|
48
|
+
export function validateModel(modelId, catalog) {
|
|
49
|
+
if (isTuneSupported(modelId, catalog)) {
|
|
50
|
+
return { valid: true };
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const families = _getSupportedFamilies(catalog);
|
|
54
|
+
const familyList = families.join(', ');
|
|
55
|
+
|
|
56
|
+
return {
|
|
57
|
+
valid: false,
|
|
58
|
+
error: `Model "${modelId}" is not yet supported for managed serverless customization. ` +
|
|
59
|
+
`Supported model families: ${familyList}. ` +
|
|
60
|
+
'For custom training workflows, see `do/train`.'
|
|
61
|
+
};
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* Validate that a technique is supported for the given model.
|
|
66
|
+
* Returns a descriptive error listing the supported techniques when
|
|
67
|
+
* the requested technique is not available.
|
|
68
|
+
* @param {string} modelId - The JumpStart model ID
|
|
69
|
+
* @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
|
|
70
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
71
|
+
* @returns {{ valid: boolean, error?: string }}
|
|
72
|
+
*/
|
|
73
|
+
export function validateTechnique(modelId, technique, catalog) {
|
|
74
|
+
const entry = lookupModel(modelId, catalog);
|
|
75
|
+
if (!entry) {
|
|
76
|
+
return validateModel(modelId, catalog);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
const supportedTechniques = Object.keys(entry.techniques);
|
|
80
|
+
if (supportedTechniques.includes(technique)) {
|
|
81
|
+
return { valid: true };
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return {
|
|
85
|
+
valid: false,
|
|
86
|
+
error: `Technique "${technique}" is not supported for model "${modelId}". ` +
|
|
87
|
+
`Supported techniques: ${supportedTechniques.join(', ')}.`
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Validate that a training type is supported for the given model and technique.
|
|
93
|
+
* Returns a descriptive error listing the supported training types when
|
|
94
|
+
* the requested type is not available.
|
|
95
|
+
* @param {string} modelId - The JumpStart model ID
|
|
96
|
+
* @param {string} technique - The technique (e.g., 'sft', 'dpo')
|
|
97
|
+
* @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
|
|
98
|
+
* @param {Object} catalog - The tune catalog object with a `models` map
|
|
99
|
+
* @returns {{ valid: boolean, error?: string }}
|
|
100
|
+
*/
|
|
101
|
+
export function validateTrainingType(modelId, technique, trainingType, catalog) {
|
|
102
|
+
const entry = lookupModel(modelId, catalog);
|
|
103
|
+
if (!entry) {
|
|
104
|
+
return validateModel(modelId, catalog);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
const techniqueEntry = entry.techniques[technique];
|
|
108
|
+
if (!techniqueEntry) {
|
|
109
|
+
return validateTechnique(modelId, technique, catalog);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
const supportedTypes = techniqueEntry.trainingTypes;
|
|
113
|
+
if (supportedTypes.includes(trainingType)) {
|
|
114
|
+
return { valid: true };
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
valid: false,
|
|
119
|
+
error: `Training type "${trainingType}" is not supported for model "${modelId}" ` +
|
|
120
|
+
`with technique "${technique}". ` +
|
|
121
|
+
`Supported training types: ${supportedTypes.join(', ')}.`
|
|
122
|
+
};
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
/**
|
|
126
|
+
* Extract unique model family names from the catalog.
|
|
127
|
+
* @param {Object} catalog - The tune catalog object
|
|
128
|
+
* @returns {string[]} Array of unique family names
|
|
129
|
+
* @private
|
|
130
|
+
*/
|
|
131
|
+
function _getSupportedFamilies(catalog) {
|
|
132
|
+
if (!catalog || !catalog.models) {
|
|
133
|
+
return [];
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
const families = new Set();
|
|
137
|
+
for (const entry of Object.values(catalog.models)) {
|
|
138
|
+
if (entry.family) {
|
|
139
|
+
families.add(entry.family);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
return [...families];
|
|
143
|
+
}
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Tune Config State Manager
|
|
6
|
+
*
|
|
7
|
+
* JavaScript module that mimics the bash _update_config_var() behavior
|
|
8
|
+
* from do/tune for testing purposes. Manages config variables written
|
|
9
|
+
* after job submission.
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
import { readFileSync, writeFileSync } from 'node:fs';
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Update or add a config variable in a do/config-style file.
|
|
16
|
+
* Mimics the bash _update_config_var() function:
|
|
17
|
+
* - If the variable exists (line starts with `export VAR_NAME=`), replace it
|
|
18
|
+
* - Otherwise, append a new line
|
|
19
|
+
*
|
|
20
|
+
* @param {string} configPath - Path to the config file
|
|
21
|
+
* @param {string} varName - Variable name (e.g., TUNE_JOB_NAME_SFT)
|
|
22
|
+
* @param {string} varValue - Variable value
|
|
23
|
+
*/
|
|
24
|
+
export function updateConfigVar(configPath, varName, varValue) {
|
|
25
|
+
let content = readFileSync(configPath, 'utf8');
|
|
26
|
+
const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
|
|
27
|
+
|
|
28
|
+
if (pattern.test(content)) {
|
|
29
|
+
content = content.replace(pattern, `export ${varName}="${varValue}"`);
|
|
30
|
+
} else {
|
|
31
|
+
if (content.length > 0 && !content.endsWith('\n')) {
|
|
32
|
+
content += '\n';
|
|
33
|
+
}
|
|
34
|
+
content += `export ${varName}="${varValue}"\n`;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
writeFileSync(configPath, content, 'utf8');
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Read a config variable from a do/config-style file.
|
|
42
|
+
*
|
|
43
|
+
* @param {string} configPath - Path to the config file
|
|
44
|
+
* @param {string} varName - Variable name to read
|
|
45
|
+
* @returns {string|null} The variable value, or null if not found
|
|
46
|
+
*/
|
|
47
|
+
export function readConfigVar(configPath, varName) {
|
|
48
|
+
const content = readFileSync(configPath, 'utf8');
|
|
49
|
+
const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
|
|
50
|
+
const match = content.match(pattern);
|
|
51
|
+
return match ? match[1] : null;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Simulate the config writes that happen after a successful job submission.
|
|
56
|
+
* This mirrors the behavior in do/tune's _submit_job() function.
|
|
57
|
+
*
|
|
58
|
+
* @param {string} configPath - Path to the config file
|
|
59
|
+
* @param {object} params - Submission parameters
|
|
60
|
+
* @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
|
|
61
|
+
* @param {string} params.trainingType - Training type (lora, full-rank)
|
|
62
|
+
* @param {string} params.datasetPath - Dataset path (s3://... or hf://...)
|
|
63
|
+
* @param {string} params.jobName - Generated job name
|
|
64
|
+
*/
|
|
65
|
+
export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
|
|
66
|
+
const techniqueUpper = technique.toUpperCase();
|
|
67
|
+
updateConfigVar(configPath, `TUNE_JOB_NAME_${techniqueUpper}`, jobName);
|
|
68
|
+
updateConfigVar(configPath, 'TUNE_TECHNIQUE', technique);
|
|
69
|
+
updateConfigVar(configPath, 'TUNE_TRAINING_TYPE', trainingType);
|
|
70
|
+
updateConfigVar(configPath, 'TUNE_DATASET_PATH', datasetPath);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/**
|
|
74
|
+
* Simulate the config writes that happen after a job completes successfully.
|
|
75
|
+
* This mirrors the behavior in do/tune's _handle_completion() function.
|
|
76
|
+
*
|
|
77
|
+
* @param {string} configPath - Path to the config file
|
|
78
|
+
* @param {object} params - Completion parameters
|
|
79
|
+
* @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
|
|
80
|
+
* @param {string} params.trainingType - Training type (lora, full-rank)
|
|
81
|
+
* @param {string} params.artifactPath - S3 path to the output artifact
|
|
82
|
+
* @param {string} params.outputType - Output type (adapter, full-model)
|
|
83
|
+
*/
|
|
84
|
+
export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType }) {
|
|
85
|
+
const techniqueUpper = technique.toUpperCase();
|
|
86
|
+
|
|
87
|
+
if (trainingType === 'lora') {
|
|
88
|
+
updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
|
|
89
|
+
} else if (trainingType === 'full-rank') {
|
|
90
|
+
updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
|
|
94
|
+
updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
/**
|
|
98
|
+
* Generate a job name following the pattern used by do/tune.
|
|
99
|
+
* Pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
|
|
100
|
+
*
|
|
101
|
+
* @param {string} projectName - Project name
|
|
102
|
+
* @param {string} technique - Technique (sft, dpo, rlaif, rlvr)
|
|
103
|
+
* @param {Date} [timestamp] - Optional timestamp (defaults to now)
|
|
104
|
+
* @returns {string} Generated job name
|
|
105
|
+
*/
|
|
106
|
+
export function generateJobName(projectName, technique, timestamp = new Date()) {
|
|
107
|
+
const year = timestamp.getFullYear().toString();
|
|
108
|
+
const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
|
|
109
|
+
const day = timestamp.getDate().toString().padStart(2, '0');
|
|
110
|
+
const hours = timestamp.getHours().toString().padStart(2, '0');
|
|
111
|
+
const minutes = timestamp.getMinutes().toString().padStart(2, '0');
|
|
112
|
+
const seconds = timestamp.getSeconds().toString().padStart(2, '0');
|
|
113
|
+
const dateStr = `${year}${month}${day}`;
|
|
114
|
+
const timeStr = `${hours}${minutes}${seconds}`;
|
|
115
|
+
return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
|
|
116
|
+
}
|