@aws/ml-container-creator 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/app.js CHANGED
@@ -16,6 +16,7 @@ import CommentGenerator from './lib/comment-generator.js';
16
16
  import ConfigurationManager from './lib/configuration-manager.js';
17
17
  import RegistryLoader from './lib/registry-loader.js';
18
18
  import { resolvePrefixedEnvVars } from './lib/engine-prefix-resolver.js';
19
+ import { isTuneSupported } from './lib/tune-catalog-validator.js';
19
20
  import ejs from 'ejs';
20
21
 
21
22
  const __filename = fileURLToPath(import.meta.url);
@@ -350,6 +351,12 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
350
351
  ignorePatterns.push('**/do/adapters/**');
351
352
  }
352
353
 
354
+ // Exclude tune files when framework is NOT transformers OR deploymentTarget is batch-transform
355
+ if (architecture !== 'transformers' || answers.deploymentTarget === 'batch-transform') {
356
+ ignorePatterns.push('**/do/tune');
357
+ ignorePatterns.push('**/do/.tune_helper.py');
358
+ }
359
+
353
360
  // Exclude do/test when hosted-model-endpoint is not selected
354
361
  const testTypes = answers.testTypes || [];
355
362
  if (!testTypes.includes('hosted-model-endpoint')) {
@@ -452,6 +459,13 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
452
459
  _copyFile(path.join(LIB_DIR, 'asset-manager.js'), path.join(doLibDir, 'asset-manager.js'));
453
460
  _copyFile(path.join(LIB_DIR, 'bootstrap-config.js'), path.join(doLibDir, 'bootstrap-config.js'));
454
461
 
462
+ // Copy tune catalog to generated project when tune is included
463
+ if (architecture === 'transformers' && answers.deploymentTarget !== 'batch-transform') {
464
+ const tuneCatalogSrc = path.join(GENERATOR_ROOT, 'config', 'tune-catalog.json');
465
+ const tuneCatalogDest = path.join(destDir, 'do', '.tune_catalog.json');
466
+ _copyFile(tuneCatalogSrc, tuneCatalogDest);
467
+ }
468
+
455
469
  // Generate .gitignore with benchmarks/ when benchmarking is enabled
456
470
  if (answers.includeBenchmark) {
457
471
  const gitignorePath = path.join(destDir, '.gitignore');
@@ -742,6 +756,19 @@ async function _ensureTemplateVariables(answers, registryConfigManager = null) {
742
756
  }
743
757
  }
744
758
  }
759
+
760
+ // Determine tune support based on model presence in the tune catalog.
761
+ // Used by the do/config template to write TUNE_SUPPORTED=true|false.
762
+ if (answers.tuneSupported === undefined) {
763
+ try {
764
+ const tuneCatalogPath = path.resolve(__dirname, '..', 'config', 'tune-catalog.json');
765
+ const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
766
+ const modelId = answers.modelName || '';
767
+ answers.tuneSupported = isTuneSupported(modelId, tuneCatalog);
768
+ } catch {
769
+ answers.tuneSupported = false;
770
+ }
771
+ }
745
772
  }
746
773
 
747
774
  /**
@@ -1083,7 +1110,8 @@ function _setExecutablePermissions(destDir) {
1083
1110
  'do/optimize',
1084
1111
  'do/status',
1085
1112
  'do/add-ic',
1086
- 'do/adapter'
1113
+ 'do/adapter',
1114
+ 'do/tune'
1087
1115
  ];
1088
1116
 
1089
1117
  shellScripts.forEach(script => {
@@ -182,35 +182,80 @@ export default class BootstrapCommandHandler {
182
182
  this._displayProgress('☁️', 'Deploying bootstrap infrastructure stack...');
183
183
  const stackName = `${STACK_NAME_PREFIX}-${profileName}`;
184
184
 
185
+ // Check for existing bootstrap stack in this account-region (resources are singletons)
185
186
  try {
186
- const stackOutputs = this._deployStack(stackName, {
187
- CreateS3Buckets: createS3Buckets ? 'true' : 'false',
188
- UseExistingRoleArn: useExistingRoleArn
189
- }, awsProfile, region);
187
+ const existingStacks = this._execAws(
188
+ `cloudformation list-stacks --stack-status-filter CREATE_COMPLETE UPDATE_COMPLETE UPDATE_ROLLBACK_COMPLETE --query "StackSummaries[?starts_with(StackName,'${STACK_NAME_PREFIX}-')].StackName" --output json`,
189
+ awsProfile
190
+ );
191
+ const stacks = Array.isArray(existingStacks) ? existingStacks : [];
192
+ const otherStack = stacks.find(s => s !== stackName);
193
+ if (otherStack) {
194
+ console.log(` ℹ️ Bootstrap infrastructure already exists in ${accountId}/${region} (stack: ${otherStack})`);
195
+ console.log(' Reusing existing resources (IAM role, ECR repo are singletons per account-region).');
196
+ console.log(' Use `ml-container-creator bootstrap update` to apply latest permissions.\n');
197
+
198
+ // Read outputs from existing stack
199
+ const outputs = this._execAws(
200
+ `cloudformation describe-stacks --stack-name ${otherStack} --query "Stacks[0].Outputs" --output json`,
201
+ awsProfile
202
+ );
203
+ const stackOutputs = {};
204
+ if (Array.isArray(outputs)) {
205
+ for (const o of outputs) {
206
+ stackOutputs[o.OutputKey] = o.OutputValue;
207
+ }
208
+ }
190
209
 
191
- // Read outputs into profile data
192
- profileData.roleArn = stackOutputs.RoleArn;
193
- profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
194
- profileData.stackName = stackName;
210
+ profileData.roleArn = stackOutputs.RoleArn;
211
+ profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
212
+ profileData.stackName = otherStack;
213
+ if (stackOutputs.AsyncS3BucketName) profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
214
+ if (stackOutputs.BatchS3BucketName) profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
215
+ if (stackOutputs.AdapterS3BucketName) profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
216
+ if (stackOutputs.BenchmarkS3BucketName) profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
195
217
 
196
- if (stackOutputs.AsyncS3BucketName) {
197
- profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
198
- }
199
- if (stackOutputs.BatchS3BucketName) {
200
- profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
218
+ // Skip stack deployment, continue to CI setup and profile save
219
+ console.log(' ✅ Existing bootstrap infrastructure reused');
201
220
  }
202
- if (stackOutputs.BenchmarkS3BucketName) {
203
- profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
204
- }
205
-
206
- console.log(' ✅ Bootstrap stack deployed successfully');
207
- } catch (error) {
208
- console.log(` ❌ Stack deployment failed: ${error.message}`);
209
- console.log(' Check the CloudFormation console for details:');
210
- console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
211
- return;
221
+ } catch (_) {
222
+ // If list-stacks fails, proceed with normal deployment
212
223
  }
213
224
 
225
+ if (!profileData.stackName) {
226
+ try {
227
+ const stackOutputs = this._deployStack(stackName, {
228
+ CreateS3Buckets: createS3Buckets ? 'true' : 'false',
229
+ UseExistingRoleArn: useExistingRoleArn
230
+ }, awsProfile, region);
231
+
232
+ // Read outputs into profile data
233
+ profileData.roleArn = stackOutputs.RoleArn;
234
+ profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
235
+ profileData.stackName = stackName;
236
+
237
+ if (stackOutputs.AsyncS3BucketName) {
238
+ profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
239
+ }
240
+ if (stackOutputs.BatchS3BucketName) {
241
+ profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
242
+ }
243
+ if (stackOutputs.AdapterS3BucketName) {
244
+ profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
245
+ }
246
+ if (stackOutputs.BenchmarkS3BucketName) {
247
+ profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
248
+ }
249
+
250
+ console.log(' ✅ Bootstrap stack deployed successfully');
251
+ } catch (error) {
252
+ console.log(` ❌ Stack deployment failed: ${error.message}`);
253
+ console.log(' Check the CloudFormation console for details:');
254
+ console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
255
+ return;
256
+ }
257
+ } // end if (!profileData.stackName)
258
+
214
259
  // Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
215
260
  this._displayProgress('🧪', 'CI Testing Infrastructure...');
216
261
  try {
@@ -390,6 +435,9 @@ export default class BootstrapCommandHandler {
390
435
  if (outputs.BatchS3BucketName) {
391
436
  console.log(` ✅ S3 bucket (batch): ${outputs.BatchS3BucketName}`);
392
437
  }
438
+ if (outputs.AdapterS3BucketName) {
439
+ console.log(` ✅ S3 bucket (adapters): ${outputs.AdapterS3BucketName}`);
440
+ }
393
441
  if (outputs.BenchmarkS3BucketName) {
394
442
  console.log(` ✅ S3 bucket (benchmark): ${outputs.BenchmarkS3BucketName}`);
395
443
  }
@@ -204,7 +204,7 @@ VALIDATION OPTIONS:
204
204
 
205
205
  MCP OPTIONS:
206
206
  --smart Enable Bedrock-powered smart mode on all MCP servers
207
- --discover Enable live registry lookups (e.g. Docker Hub) on MCP servers that support it
207
+ --no-discover Disable live registry lookups (HuggingFace API, quota checks) catalog-only mode
208
208
 
209
209
  REGISTRY SYSTEM:
210
210
  The generator includes built-in registries for frameworks, models, and instance types:
@@ -1631,7 +1631,7 @@ export default class ConfigManager {
1631
1631
  if (!mcpServerConfigs || !mcpServerConfigs[serverName]) return null;
1632
1632
 
1633
1633
  const smart = this.options.smart === true;
1634
- const discover = this.options.discover === true;
1634
+ const discover = this.options.discover !== false;
1635
1635
  const serverConfig = mcpServerConfigs[serverName];
1636
1636
 
1637
1637
  // Build a custom McpClient that passes context through
@@ -32,7 +32,7 @@ class McpClient {
32
32
  this.timeout = options.timeout || DEFAULT_TIMEOUT;
33
33
  this.parameterMatrix = options.parameterMatrix || {};
34
34
  this.smart = options.smart || false;
35
- this.discover = options.discover || false;
35
+ this.discover = options.discover !== undefined ? options.discover : true;
36
36
  this._transport = null;
37
37
  this._client = null;
38
38
  this._diagnosticMessage = null;
@@ -98,10 +98,10 @@ class McpClient {
98
98
 
99
99
  // Build environment: merge process.env with server-specific env
100
100
  // When --smart flag is active, inject BEDROCK_SMART=true for this run
101
- // When --discover flag is active, inject MCP_DISCOVER=true for this run
101
+ // Discover mode is now default; inject DISCOVER_MODE=false only when explicitly disabled
102
102
  // Always pass process.env so child processes inherit AWS credentials, profiles, etc.
103
103
  const smartEnv = this.smart ? { BEDROCK_SMART: 'true' } : {};
104
- const discoverEnv = this.discover ? { MCP_DISCOVER: 'true' } : {};
104
+ const discoverEnv = this.discover === false ? { DISCOVER_MODE: 'false' } : {};
105
105
  const serverEnv = env && Object.keys(env).length > 0 ? env : {};
106
106
  const spawnEnv = { ...process.env, ...smartEnv, ...discoverEnv, ...serverEnv };
107
107
 
@@ -1098,9 +1098,9 @@ export default class PromptRunner {
1098
1098
  if (!modelName || modelName === 'Custom (enter manually)') return;
1099
1099
 
1100
1100
  const smart = this.options.smart === true;
1101
- const discover = this.options.discover === true;
1101
+ const discover = this.options.discover !== false;
1102
1102
 
1103
- const modeLabel = [smart && '[smart]', discover && '[discover]'].filter(Boolean).join(' ');
1103
+ const modeLabel = [smart && '[smart]', !discover && '[no-discover]'].filter(Boolean).join(' ');
1104
1104
  console.log(` 🔍 Querying instance-sizer${modeLabel ? ` ${modeLabel}` : ''}...`);
1105
1105
 
1106
1106
  try {
@@ -1115,8 +1115,8 @@ export default class PromptRunner {
1115
1115
  const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
1116
1116
 
1117
1117
  const serverArgs = [...(serverConfig.args || [])];
1118
- if (discover && !serverArgs.includes('--discover')) {
1119
- serverArgs.push('--discover');
1118
+ if (!discover && !serverArgs.includes('--no-discover')) {
1119
+ serverArgs.push('--no-discover');
1120
1120
  }
1121
1121
 
1122
1122
  const transport = new StdioClientTransport({
@@ -1375,7 +1375,7 @@ export default class PromptRunner {
1375
1375
  if (!mcpServers.includes('base-image-picker')) return;
1376
1376
 
1377
1377
  const smart = this.options.smart === true;
1378
- const discover = this.options.discover === true;
1378
+ const discover = this.options.discover !== false;
1379
1379
  const framework = frameworkAnswers.framework;
1380
1380
  const modelServer = frameworkAnswers.modelServer;
1381
1381
  const architecture = frameworkAnswers.architecture || frameworkAnswers.deploymentConfig?.split('-')[0];
@@ -399,9 +399,33 @@ const modelFormatPrompts = [
399
399
  ];
400
400
  }
401
401
  return [
402
- 'openai/gpt-oss-20b',
403
- 'meta-llama/Llama-3.2-3B-Instruct',
402
+ { type: 'separator', separator: '── Meta Llama ──' },
404
403
  'meta-llama/Llama-3.2-1B-Instruct',
404
+ 'meta-llama/Llama-3.2-3B-Instruct',
405
+ 'meta-llama/Llama-3.1-8B-Instruct',
406
+ 'meta-llama/Llama-3.3-70B-Instruct',
407
+ { type: 'separator', separator: '── Qwen (Alibaba) ──' },
408
+ 'Qwen/Qwen3-0.6B',
409
+ 'Qwen/Qwen3-1.7B',
410
+ 'Qwen/Qwen3-4B',
411
+ 'Qwen/Qwen3-8B',
412
+ 'Qwen/Qwen3-14B',
413
+ 'Qwen/Qwen3-32B',
414
+ 'Qwen/Qwen2.5-7B-Instruct',
415
+ 'Qwen/Qwen2.5-14B-Instruct',
416
+ 'Qwen/Qwen2.5-32B-Instruct',
417
+ 'Qwen/Qwen2.5-72B-Instruct',
418
+ { type: 'separator', separator: '── DeepSeek ──' },
419
+ 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',
420
+ 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
421
+ 'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
422
+ 'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
423
+ 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B',
424
+ 'deepseek-ai/DeepSeek-R1-Distill-Llama-70B',
425
+ { type: 'separator', separator: '── OpenAI ──' },
426
+ 'openai/gpt-oss-20b',
427
+ 'openai/gpt-oss-120b',
428
+ { type: 'separator', separator: '──────────────' },
405
429
  'Custom (enter manually)'
406
430
  ];
407
431
  },
@@ -413,7 +437,7 @@ const modelFormatPrompts = [
413
437
  if (architecture === 'diffusors') {
414
438
  return 'stabilityai/stable-diffusion-3.5-medium';
415
439
  }
416
- return 'openai/gpt-oss-20b';
440
+ return 'meta-llama/Llama-3.1-8B-Instruct';
417
441
  },
418
442
  when: answers => {
419
443
  const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0];
@@ -528,9 +552,11 @@ const modelProfilePrompts = [
528
552
  */
529
553
  // eslint-disable-next-line no-unused-vars -- reference list for future use
530
554
  const EXAMPLE_MODEL_IDS = [
531
- 'openai/gpt-oss-20b',
555
+ 'meta-llama/Llama-3.1-8B-Instruct',
532
556
  'meta-llama/Llama-3.2-3B-Instruct',
533
- 'meta-llama/Llama-3.2-1B-Instruct'
557
+ 'Qwen/Qwen3-8B',
558
+ 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
559
+ 'openai/gpt-oss-20b'
534
560
  ];
535
561
 
536
562
  const hfTokenPrompts = [
@@ -0,0 +1,143 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Tune Catalog Validator
6
+ *
7
+ * Validates model IDs, techniques, and training types against the
8
+ * Supported Model Catalog. Provides descriptive error messages when
9
+ * a requested configuration is not supported.
10
+ *
11
+ * Requirements: 1.3, 1.4, 1.5, 1.6, 4.1, 4.2, 4.3, 4.4
12
+ */
13
+
14
+ /**
15
+ * Look up a model entry in the catalog by model ID.
16
+ * @param {string} modelId - The JumpStart model ID to look up
17
+ * @param {Object} catalog - The tune catalog object with a `models` map
18
+ * @returns {Object|null} The catalog entry for the model, or null if not found
19
+ */
20
+ export function lookupModel(modelId, catalog) {
21
+ if (!catalog || !catalog.models) {
22
+ return null;
23
+ }
24
+ if (!Object.hasOwn(catalog.models, modelId)) {
25
+ return null;
26
+ }
27
+ return catalog.models[modelId] || null;
28
+ }
29
+
30
+ /**
31
+ * Check whether a model ID is present in the Supported Model Catalog.
32
+ * @param {string} modelId - The JumpStart model ID to check
33
+ * @param {Object} catalog - The tune catalog object with a `models` map
34
+ * @returns {boolean} True if the model is in the catalog
35
+ */
36
+ export function isTuneSupported(modelId, catalog) {
37
+ return lookupModel(modelId, catalog) !== null;
38
+ }
39
+
40
+ /**
41
+ * Validate that a model ID exists in the catalog.
42
+ * Returns a descriptive error when the model is not supported, including
43
+ * the model name, supported families, and a reference to `do/train`.
44
+ * @param {string} modelId - The JumpStart model ID to validate
45
+ * @param {Object} catalog - The tune catalog object with a `models` map
46
+ * @returns {{ valid: boolean, error?: string }}
47
+ */
48
+ export function validateModel(modelId, catalog) {
49
+ if (isTuneSupported(modelId, catalog)) {
50
+ return { valid: true };
51
+ }
52
+
53
+ const families = _getSupportedFamilies(catalog);
54
+ const familyList = families.join(', ');
55
+
56
+ return {
57
+ valid: false,
58
+ error: `Model "${modelId}" is not yet supported for managed serverless customization. ` +
59
+ `Supported model families: ${familyList}. ` +
60
+ 'For custom training workflows, see `do/train`.'
61
+ };
62
+ }
63
+
64
+ /**
65
+ * Validate that a technique is supported for the given model.
66
+ * Returns a descriptive error listing the supported techniques when
67
+ * the requested technique is not available.
68
+ * @param {string} modelId - The JumpStart model ID
69
+ * @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
70
+ * @param {Object} catalog - The tune catalog object with a `models` map
71
+ * @returns {{ valid: boolean, error?: string }}
72
+ */
73
+ export function validateTechnique(modelId, technique, catalog) {
74
+ const entry = lookupModel(modelId, catalog);
75
+ if (!entry) {
76
+ return validateModel(modelId, catalog);
77
+ }
78
+
79
+ const supportedTechniques = Object.keys(entry.techniques);
80
+ if (supportedTechniques.includes(technique)) {
81
+ return { valid: true };
82
+ }
83
+
84
+ return {
85
+ valid: false,
86
+ error: `Technique "${technique}" is not supported for model "${modelId}". ` +
87
+ `Supported techniques: ${supportedTechniques.join(', ')}.`
88
+ };
89
+ }
90
+
91
+ /**
92
+ * Validate that a training type is supported for the given model and technique.
93
+ * Returns a descriptive error listing the supported training types when
94
+ * the requested type is not available.
95
+ * @param {string} modelId - The JumpStart model ID
96
+ * @param {string} technique - The technique (e.g., 'sft', 'dpo')
97
+ * @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
98
+ * @param {Object} catalog - The tune catalog object with a `models` map
99
+ * @returns {{ valid: boolean, error?: string }}
100
+ */
101
+ export function validateTrainingType(modelId, technique, trainingType, catalog) {
102
+ const entry = lookupModel(modelId, catalog);
103
+ if (!entry) {
104
+ return validateModel(modelId, catalog);
105
+ }
106
+
107
+ const techniqueEntry = entry.techniques[technique];
108
+ if (!techniqueEntry) {
109
+ return validateTechnique(modelId, technique, catalog);
110
+ }
111
+
112
+ const supportedTypes = techniqueEntry.trainingTypes;
113
+ if (supportedTypes.includes(trainingType)) {
114
+ return { valid: true };
115
+ }
116
+
117
+ return {
118
+ valid: false,
119
+ error: `Training type "${trainingType}" is not supported for model "${modelId}" ` +
120
+ `with technique "${technique}". ` +
121
+ `Supported training types: ${supportedTypes.join(', ')}.`
122
+ };
123
+ }
124
+
125
+ /**
126
+ * Extract unique model family names from the catalog.
127
+ * @param {Object} catalog - The tune catalog object
128
+ * @returns {string[]} Array of unique family names
129
+ * @private
130
+ */
131
+ function _getSupportedFamilies(catalog) {
132
+ if (!catalog || !catalog.models) {
133
+ return [];
134
+ }
135
+
136
+ const families = new Set();
137
+ for (const entry of Object.values(catalog.models)) {
138
+ if (entry.family) {
139
+ families.add(entry.family);
140
+ }
141
+ }
142
+ return [...families];
143
+ }
@@ -0,0 +1,116 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Tune Config State Manager
6
+ *
7
+ * JavaScript module that mimics the bash _update_config_var() behavior
8
+ * from do/tune for testing purposes. Manages config variables written
9
+ * after job submission.
10
+ */
11
+
12
+ import { readFileSync, writeFileSync } from 'node:fs';
13
+
14
+ /**
15
+ * Update or add a config variable in a do/config-style file.
16
+ * Mimics the bash _update_config_var() function:
17
+ * - If the variable exists (line starts with `export VAR_NAME=`), replace it
18
+ * - Otherwise, append a new line
19
+ *
20
+ * @param {string} configPath - Path to the config file
21
+ * @param {string} varName - Variable name (e.g., TUNE_JOB_NAME_SFT)
22
+ * @param {string} varValue - Variable value
23
+ */
24
+ export function updateConfigVar(configPath, varName, varValue) {
25
+ let content = readFileSync(configPath, 'utf8');
26
+ const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
27
+
28
+ if (pattern.test(content)) {
29
+ content = content.replace(pattern, `export ${varName}="${varValue}"`);
30
+ } else {
31
+ if (content.length > 0 && !content.endsWith('\n')) {
32
+ content += '\n';
33
+ }
34
+ content += `export ${varName}="${varValue}"\n`;
35
+ }
36
+
37
+ writeFileSync(configPath, content, 'utf8');
38
+ }
39
+
40
+ /**
41
+ * Read a config variable from a do/config-style file.
42
+ *
43
+ * @param {string} configPath - Path to the config file
44
+ * @param {string} varName - Variable name to read
45
+ * @returns {string|null} The variable value, or null if not found
46
+ */
47
+ export function readConfigVar(configPath, varName) {
48
+ const content = readFileSync(configPath, 'utf8');
49
+ const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
50
+ const match = content.match(pattern);
51
+ return match ? match[1] : null;
52
+ }
53
+
54
+ /**
55
+ * Simulate the config writes that happen after a successful job submission.
56
+ * This mirrors the behavior in do/tune's _submit_job() function.
57
+ *
58
+ * @param {string} configPath - Path to the config file
59
+ * @param {object} params - Submission parameters
60
+ * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
61
+ * @param {string} params.trainingType - Training type (lora, full-rank)
62
+ * @param {string} params.datasetPath - Dataset path (s3://... or hf://...)
63
+ * @param {string} params.jobName - Generated job name
64
+ */
65
+ export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
66
+ const techniqueUpper = technique.toUpperCase();
67
+ updateConfigVar(configPath, `TUNE_JOB_NAME_${techniqueUpper}`, jobName);
68
+ updateConfigVar(configPath, 'TUNE_TECHNIQUE', technique);
69
+ updateConfigVar(configPath, 'TUNE_TRAINING_TYPE', trainingType);
70
+ updateConfigVar(configPath, 'TUNE_DATASET_PATH', datasetPath);
71
+ }
72
+
73
+ /**
74
+ * Simulate the config writes that happen after a job completes successfully.
75
+ * This mirrors the behavior in do/tune's _handle_completion() function.
76
+ *
77
+ * @param {string} configPath - Path to the config file
78
+ * @param {object} params - Completion parameters
79
+ * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
80
+ * @param {string} params.trainingType - Training type (lora, full-rank)
81
+ * @param {string} params.artifactPath - S3 path to the output artifact
82
+ * @param {string} params.outputType - Output type (adapter, full-model)
83
+ */
84
+ export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType }) {
85
+ const techniqueUpper = technique.toUpperCase();
86
+
87
+ if (trainingType === 'lora') {
88
+ updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
89
+ } else if (trainingType === 'full-rank') {
90
+ updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
91
+ }
92
+
93
+ updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
94
+ updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
95
+ }
96
+
97
+ /**
98
+ * Generate a job name following the pattern used by do/tune.
99
+ * Pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
100
+ *
101
+ * @param {string} projectName - Project name
102
+ * @param {string} technique - Technique (sft, dpo, rlaif, rlvr)
103
+ * @param {Date} [timestamp] - Optional timestamp (defaults to now)
104
+ * @returns {string} Generated job name
105
+ */
106
+ export function generateJobName(projectName, technique, timestamp = new Date()) {
107
+ const year = timestamp.getFullYear().toString();
108
+ const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
109
+ const day = timestamp.getDate().toString().padStart(2, '0');
110
+ const hours = timestamp.getHours().toString().padStart(2, '0');
111
+ const minutes = timestamp.getMinutes().toString().padStart(2, '0');
112
+ const seconds = timestamp.getSeconds().toString().padStart(2, '0');
113
+ const dateStr = `${year}${month}${day}`;
114
+ const timeStr = `${hours}${minutes}${seconds}`;
115
+ return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
116
+ }