@aws/ml-container-creator 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/bin/cli.js CHANGED
@@ -102,6 +102,15 @@ program
102
102
  .addOption(new Option('--max-loras <n>', 'Maximum concurrent LoRA adapters in GPU memory (default: 30)'))
103
103
  .addOption(new Option('--max-lora-rank <n>', 'Maximum LoRA rank (default: 64)'))
104
104
 
105
+ // --- Benchmarking ---
106
+ .addOption(new Option('--include-benchmark', 'Include SageMaker AI Benchmarking (transformers/diffusors only)'))
107
+ .addOption(new Option('--benchmark-concurrency <n>', 'Benchmark concurrent requests (default: 10)'))
108
+ .addOption(new Option('--benchmark-input-tokens <n>', 'Benchmark mean input tokens (default: 550)'))
109
+ .addOption(new Option('--benchmark-output-tokens <n>', 'Benchmark mean output tokens (default: 150)'))
110
+ .addOption(new Option('--benchmark-streaming', 'Enable streaming in benchmark (default: true)'))
111
+ .addOption(new Option('--benchmark-request-count <n>', 'Total benchmark requests (optional)'))
112
+ .addOption(new Option('--benchmark-s3-output-path <path>', 'S3 path for benchmark results'))
113
+
105
114
  // --- MCP & Discovery ---
106
115
  .addOption(new Option('--smart', 'Enable Bedrock-powered smart mode on MCP servers'))
107
116
  .addOption(new Option('--discover', 'Enable live registry lookups via MCP discovery'))
@@ -62,6 +62,7 @@
62
62
  "sagemaker:DescribeEndpointConfig",
63
63
  "sagemaker:DescribeModel",
64
64
  "sagemaker:DescribeInferenceComponent",
65
+ "sagemaker:ListInferenceComponents",
65
66
  "sagemaker:InvokeEndpoint",
66
67
  "sagemaker:InvokeEndpointAsync"
67
68
  ],
@@ -131,11 +132,14 @@
131
132
  "Action": [
132
133
  "s3:GetObject",
133
134
  "s3:PutObject",
135
+ "s3:AbortMultipartUpload",
134
136
  "s3:ListBucket"
135
137
  ],
136
138
  "Resource": [
137
139
  "arn:aws:s3:::mlcc-*",
138
- "arn:aws:s3:::mlcc-*/*"
140
+ "arn:aws:s3:::mlcc-*/*",
141
+ "arn:aws:s3:::ml-container-creator-*",
142
+ "arn:aws:s3:::ml-container-creator-*/*"
139
143
  ]
140
144
  },
141
145
  {
@@ -163,18 +167,55 @@
163
167
  "arn:aws:secretsmanager:*:*:secret:ml-container-creator/*"
164
168
  ]
165
169
  },
170
+ {
171
+ "Sid": "SNSPublish",
172
+ "Effect": "Allow",
173
+ "Action": "sns:Publish",
174
+ "Resource": [
175
+ { "Fn::Sub": "arn:aws:sns:*:${AWS::AccountId}:mlcc-*" },
176
+ { "Fn::Sub": "arn:aws:sns:*:${AWS::AccountId}:ml-container-creator-*" }
177
+ ]
178
+ },
166
179
  {
167
180
  "Sid": "QuotaAndAvailability",
168
181
  "Effect": "Allow",
169
182
  "Action": [
170
183
  "service-quotas:GetServiceQuota",
171
184
  "service-quotas:ListServiceQuotas",
172
- "ec2:DescribeCapacityReservations",
173
185
  "sagemaker:ListTrainingPlans",
174
186
  "sagemaker:DescribeTrainingPlan",
175
187
  "sagemaker:ListEndpoints"
176
188
  ],
177
189
  "Resource": "*"
190
+ },
191
+ {
192
+ "Sid": "SageMakerModelCustomization",
193
+ "Effect": "Allow",
194
+ "Action": [
195
+ "sagemaker:CreateTrainingJob",
196
+ "sagemaker:DescribeTrainingJob",
197
+ "sagemaker:ListTrainingJobs",
198
+ "sagemaker:StopTrainingJob",
199
+ "sagemaker:CreateModelPackage",
200
+ "sagemaker:CreateModelPackageGroup",
201
+ "sagemaker:DescribeModelPackage",
202
+ "sagemaker:DescribeModelPackageGroup",
203
+ "sagemaker:ListModelPackages",
204
+ "sagemaker:CallMlflowAppApi"
205
+ ],
206
+ "Resource": "*"
207
+ },
208
+ {
209
+ "Sid": "SageMakerMLflow",
210
+ "Effect": "Allow",
211
+ "Action": "sagemaker-mlflow:*",
212
+ "Resource": "*"
213
+ },
214
+ {
215
+ "Sid": "LambdaInvokeForReward",
216
+ "Effect": "Allow",
217
+ "Action": "lambda:InvokeFunction",
218
+ "Resource": { "Fn::Sub": "arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:*" }
178
219
  }
179
220
  ]
180
221
  }
@@ -285,6 +326,26 @@
285
326
  { "Key": "mlcc:purpose", "Value": "benchmark-results" }
286
327
  ]
287
328
  }
329
+ },
330
+
331
+ "TuneS3Bucket": {
332
+ "Type": "AWS::S3::Bucket",
333
+ "Condition": "ShouldCreateS3Buckets",
334
+ "DeletionPolicy": "Retain",
335
+ "UpdateReplacePolicy": "Retain",
336
+ "Properties": {
337
+ "BucketName": { "Fn::Sub": "mlcc-tune-${AWS::AccountId}-${AWS::Region}" },
338
+ "VersioningConfiguration": { "Status": "Enabled" },
339
+ "BucketEncryption": {
340
+ "ServerSideEncryptionConfiguration": [
341
+ { "ServerSideEncryptionByDefault": { "SSEAlgorithm": "AES256" } }
342
+ ]
343
+ },
344
+ "Tags": [
345
+ { "Key": "mlcc:managed-by", "Value": "ml-container-creator" },
346
+ { "Key": "mlcc:purpose", "Value": "tune-datasets-and-output" }
347
+ ]
348
+ }
288
349
  }
289
350
  },
290
351
 
@@ -327,9 +388,14 @@
327
388
  "Description": "S3 bucket for benchmark results output",
328
389
  "Value": { "Ref": "BenchmarkS3Bucket" }
329
390
  },
391
+ "TuneS3BucketName": {
392
+ "Condition": "ShouldCreateS3Buckets",
393
+ "Description": "S3 bucket for tune datasets and output",
394
+ "Value": { "Ref": "TuneS3Bucket" }
395
+ },
330
396
  "StackVersion": {
331
397
  "Description": "Bootstrap stack template version for forward compatibility tracking",
332
- "Value": "2026-05-04"
398
+ "Value": "2026-05-18"
333
399
  }
334
400
  }
335
401
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aws/ml-container-creator",
3
- "version": "0.6.0",
3
+ "version": "0.6.1",
4
4
  "description": "Generator for SageMaker AI BYOC paradigm for predictive inference use-cases.",
5
5
  "type": "module",
6
6
  "main": "src/app.js",
package/src/app.js CHANGED
@@ -16,6 +16,7 @@ import CommentGenerator from './lib/comment-generator.js';
16
16
  import ConfigurationManager from './lib/configuration-manager.js';
17
17
  import RegistryLoader from './lib/registry-loader.js';
18
18
  import { resolvePrefixedEnvVars } from './lib/engine-prefix-resolver.js';
19
+ import { isTuneSupported } from './lib/tune-catalog-validator.js';
19
20
  import ejs from 'ejs';
20
21
 
21
22
  const __filename = fileURLToPath(import.meta.url);
@@ -350,6 +351,12 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
350
351
  ignorePatterns.push('**/do/adapters/**');
351
352
  }
352
353
 
354
+ // Exclude tune files when framework is NOT transformers OR deploymentTarget is batch-transform
355
+ if (architecture !== 'transformers' || answers.deploymentTarget === 'batch-transform') {
356
+ ignorePatterns.push('**/do/tune');
357
+ ignorePatterns.push('**/do/.tune_helper.py');
358
+ }
359
+
353
360
  // Exclude do/test when hosted-model-endpoint is not selected
354
361
  const testTypes = answers.testTypes || [];
355
362
  if (!testTypes.includes('hosted-model-endpoint')) {
@@ -452,6 +459,13 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
452
459
  _copyFile(path.join(LIB_DIR, 'asset-manager.js'), path.join(doLibDir, 'asset-manager.js'));
453
460
  _copyFile(path.join(LIB_DIR, 'bootstrap-config.js'), path.join(doLibDir, 'bootstrap-config.js'));
454
461
 
462
+ // Copy tune catalog to generated project when tune is included
463
+ if (architecture === 'transformers' && answers.deploymentTarget !== 'batch-transform') {
464
+ const tuneCatalogSrc = path.join(GENERATOR_ROOT, 'config', 'tune-catalog.json');
465
+ const tuneCatalogDest = path.join(destDir, 'do', '.tune_catalog.json');
466
+ _copyFile(tuneCatalogSrc, tuneCatalogDest);
467
+ }
468
+
455
469
  // Generate .gitignore with benchmarks/ when benchmarking is enabled
456
470
  if (answers.includeBenchmark) {
457
471
  const gitignorePath = path.join(destDir, '.gitignore');
@@ -742,6 +756,19 @@ async function _ensureTemplateVariables(answers, registryConfigManager = null) {
742
756
  }
743
757
  }
744
758
  }
759
+
760
+ // Determine tune support based on model presence in the tune catalog.
761
+ // Used by the do/config template to write TUNE_SUPPORTED=true|false.
762
+ if (answers.tuneSupported === undefined) {
763
+ try {
764
+ const tuneCatalogPath = path.resolve(__dirname, '..', 'config', 'tune-catalog.json');
765
+ const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
766
+ const modelId = answers.modelName || '';
767
+ answers.tuneSupported = isTuneSupported(modelId, tuneCatalog);
768
+ } catch {
769
+ answers.tuneSupported = false;
770
+ }
771
+ }
745
772
  }
746
773
 
747
774
  /**
@@ -1083,7 +1110,8 @@ function _setExecutablePermissions(destDir) {
1083
1110
  'do/optimize',
1084
1111
  'do/status',
1085
1112
  'do/add-ic',
1086
- 'do/adapter'
1113
+ 'do/adapter',
1114
+ 'do/tune'
1087
1115
  ];
1088
1116
 
1089
1117
  shellScripts.forEach(script => {
@@ -182,38 +182,80 @@ export default class BootstrapCommandHandler {
182
182
  this._displayProgress('☁️', 'Deploying bootstrap infrastructure stack...');
183
183
  const stackName = `${STACK_NAME_PREFIX}-${profileName}`;
184
184
 
185
+ // Check for existing bootstrap stack in this account-region (resources are singletons)
185
186
  try {
186
- const stackOutputs = this._deployStack(stackName, {
187
- CreateS3Buckets: createS3Buckets ? 'true' : 'false',
188
- UseExistingRoleArn: useExistingRoleArn
189
- }, awsProfile, region);
187
+ const existingStacks = this._execAws(
188
+ `cloudformation list-stacks --stack-status-filter CREATE_COMPLETE UPDATE_COMPLETE UPDATE_ROLLBACK_COMPLETE --query "StackSummaries[?starts_with(StackName,'${STACK_NAME_PREFIX}-')].StackName" --output json`,
189
+ awsProfile
190
+ );
191
+ const stacks = Array.isArray(existingStacks) ? existingStacks : [];
192
+ const otherStack = stacks.find(s => s !== stackName);
193
+ if (otherStack) {
194
+ console.log(` ℹ️ Bootstrap infrastructure already exists in ${accountId}/${region} (stack: ${otherStack})`);
195
+ console.log(' Reusing existing resources (IAM role, ECR repo are singletons per account-region).');
196
+ console.log(' Use `ml-container-creator bootstrap update` to apply latest permissions.\n');
197
+
198
+ // Read outputs from existing stack
199
+ const outputs = this._execAws(
200
+ `cloudformation describe-stacks --stack-name ${otherStack} --query "Stacks[0].Outputs" --output json`,
201
+ awsProfile
202
+ );
203
+ const stackOutputs = {};
204
+ if (Array.isArray(outputs)) {
205
+ for (const o of outputs) {
206
+ stackOutputs[o.OutputKey] = o.OutputValue;
207
+ }
208
+ }
190
209
 
191
- // Read outputs into profile data
192
- profileData.roleArn = stackOutputs.RoleArn;
193
- profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
194
- profileData.stackName = stackName;
210
+ profileData.roleArn = stackOutputs.RoleArn;
211
+ profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
212
+ profileData.stackName = otherStack;
213
+ if (stackOutputs.AsyncS3BucketName) profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
214
+ if (stackOutputs.BatchS3BucketName) profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
215
+ if (stackOutputs.AdapterS3BucketName) profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
216
+ if (stackOutputs.BenchmarkS3BucketName) profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
195
217
 
196
- if (stackOutputs.AsyncS3BucketName) {
197
- profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
198
- }
199
- if (stackOutputs.BatchS3BucketName) {
200
- profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
218
+ // Skip stack deployment, continue to CI setup and profile save
219
+ console.log(' ✅ Existing bootstrap infrastructure reused');
201
220
  }
202
- if (stackOutputs.AdapterS3BucketName) {
203
- profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
204
- }
205
- if (stackOutputs.BenchmarkS3BucketName) {
206
- profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
207
- }
208
-
209
- console.log(' ✅ Bootstrap stack deployed successfully');
210
- } catch (error) {
211
- console.log(` ❌ Stack deployment failed: ${error.message}`);
212
- console.log(' Check the CloudFormation console for details:');
213
- console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
214
- return;
221
+ } catch (_) {
222
+ // If list-stacks fails, proceed with normal deployment
215
223
  }
216
224
 
225
+ if (!profileData.stackName) {
226
+ try {
227
+ const stackOutputs = this._deployStack(stackName, {
228
+ CreateS3Buckets: createS3Buckets ? 'true' : 'false',
229
+ UseExistingRoleArn: useExistingRoleArn
230
+ }, awsProfile, region);
231
+
232
+ // Read outputs into profile data
233
+ profileData.roleArn = stackOutputs.RoleArn;
234
+ profileData.ecrRepositoryName = stackOutputs.EcrRepositoryName;
235
+ profileData.stackName = stackName;
236
+
237
+ if (stackOutputs.AsyncS3BucketName) {
238
+ profileData.asyncS3Bucket = stackOutputs.AsyncS3BucketName;
239
+ }
240
+ if (stackOutputs.BatchS3BucketName) {
241
+ profileData.batchS3Bucket = stackOutputs.BatchS3BucketName;
242
+ }
243
+ if (stackOutputs.AdapterS3BucketName) {
244
+ profileData.adapterS3Bucket = stackOutputs.AdapterS3BucketName;
245
+ }
246
+ if (stackOutputs.BenchmarkS3BucketName) {
247
+ profileData.benchmarkS3Bucket = stackOutputs.BenchmarkS3BucketName;
248
+ }
249
+
250
+ console.log(' ✅ Bootstrap stack deployed successfully');
251
+ } catch (error) {
252
+ console.log(` ❌ Stack deployment failed: ${error.message}`);
253
+ console.log(' Check the CloudFormation console for details:');
254
+ console.log(` https://console.aws.amazon.com/cloudformation/home?region=${region}#/stacks`);
255
+ return;
256
+ }
257
+ } // end if (!profileData.stackName)
258
+
217
259
  // Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
218
260
  this._displayProgress('🧪', 'CI Testing Infrastructure...');
219
261
  try {
@@ -0,0 +1,143 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Tune Catalog Validator
6
+ *
7
+ * Validates model IDs, techniques, and training types against the
8
+ * Supported Model Catalog. Provides descriptive error messages when
9
+ * a requested configuration is not supported.
10
+ *
11
+ * Requirements: 1.3, 1.4, 1.5, 1.6, 4.1, 4.2, 4.3, 4.4
12
+ */
13
+
14
+ /**
15
+ * Look up a model entry in the catalog by model ID.
16
+ * @param {string} modelId - The JumpStart model ID to look up
17
+ * @param {Object} catalog - The tune catalog object with a `models` map
18
+ * @returns {Object|null} The catalog entry for the model, or null if not found
19
+ */
20
+ export function lookupModel(modelId, catalog) {
21
+ if (!catalog || !catalog.models) {
22
+ return null;
23
+ }
24
+ if (!Object.hasOwn(catalog.models, modelId)) {
25
+ return null;
26
+ }
27
+ return catalog.models[modelId] || null;
28
+ }
29
+
30
+ /**
31
+ * Check whether a model ID is present in the Supported Model Catalog.
32
+ * @param {string} modelId - The JumpStart model ID to check
33
+ * @param {Object} catalog - The tune catalog object with a `models` map
34
+ * @returns {boolean} True if the model is in the catalog
35
+ */
36
+ export function isTuneSupported(modelId, catalog) {
37
+ return lookupModel(modelId, catalog) !== null;
38
+ }
39
+
40
+ /**
41
+ * Validate that a model ID exists in the catalog.
42
+ * Returns a descriptive error when the model is not supported, including
43
+ * the model name, supported families, and a reference to `do/train`.
44
+ * @param {string} modelId - The JumpStart model ID to validate
45
+ * @param {Object} catalog - The tune catalog object with a `models` map
46
+ * @returns {{ valid: boolean, error?: string }}
47
+ */
48
+ export function validateModel(modelId, catalog) {
49
+ if (isTuneSupported(modelId, catalog)) {
50
+ return { valid: true };
51
+ }
52
+
53
+ const families = _getSupportedFamilies(catalog);
54
+ const familyList = families.join(', ');
55
+
56
+ return {
57
+ valid: false,
58
+ error: `Model "${modelId}" is not yet supported for managed serverless customization. ` +
59
+ `Supported model families: ${familyList}. ` +
60
+ 'For custom training workflows, see `do/train`.'
61
+ };
62
+ }
63
+
64
+ /**
65
+ * Validate that a technique is supported for the given model.
66
+ * Returns a descriptive error listing the supported techniques when
67
+ * the requested technique is not available.
68
+ * @param {string} modelId - The JumpStart model ID
69
+ * @param {string} technique - The technique to validate (e.g., 'sft', 'dpo')
70
+ * @param {Object} catalog - The tune catalog object with a `models` map
71
+ * @returns {{ valid: boolean, error?: string }}
72
+ */
73
+ export function validateTechnique(modelId, technique, catalog) {
74
+ const entry = lookupModel(modelId, catalog);
75
+ if (!entry) {
76
+ return validateModel(modelId, catalog);
77
+ }
78
+
79
+ const supportedTechniques = Object.keys(entry.techniques);
80
+ if (supportedTechniques.includes(technique)) {
81
+ return { valid: true };
82
+ }
83
+
84
+ return {
85
+ valid: false,
86
+ error: `Technique "${technique}" is not supported for model "${modelId}". ` +
87
+ `Supported techniques: ${supportedTechniques.join(', ')}.`
88
+ };
89
+ }
90
+
91
+ /**
92
+ * Validate that a training type is supported for the given model and technique.
93
+ * Returns a descriptive error listing the supported training types when
94
+ * the requested type is not available.
95
+ * @param {string} modelId - The JumpStart model ID
96
+ * @param {string} technique - The technique (e.g., 'sft', 'dpo')
97
+ * @param {string} trainingType - The training type to validate (e.g., 'lora', 'full-rank')
98
+ * @param {Object} catalog - The tune catalog object with a `models` map
99
+ * @returns {{ valid: boolean, error?: string }}
100
+ */
101
+ export function validateTrainingType(modelId, technique, trainingType, catalog) {
102
+ const entry = lookupModel(modelId, catalog);
103
+ if (!entry) {
104
+ return validateModel(modelId, catalog);
105
+ }
106
+
107
+ const techniqueEntry = entry.techniques[technique];
108
+ if (!techniqueEntry) {
109
+ return validateTechnique(modelId, technique, catalog);
110
+ }
111
+
112
+ const supportedTypes = techniqueEntry.trainingTypes;
113
+ if (supportedTypes.includes(trainingType)) {
114
+ return { valid: true };
115
+ }
116
+
117
+ return {
118
+ valid: false,
119
+ error: `Training type "${trainingType}" is not supported for model "${modelId}" ` +
120
+ `with technique "${technique}". ` +
121
+ `Supported training types: ${supportedTypes.join(', ')}.`
122
+ };
123
+ }
124
+
125
+ /**
126
+ * Extract unique model family names from the catalog.
127
+ * @param {Object} catalog - The tune catalog object
128
+ * @returns {string[]} Array of unique family names
129
+ * @private
130
+ */
131
+ function _getSupportedFamilies(catalog) {
132
+ if (!catalog || !catalog.models) {
133
+ return [];
134
+ }
135
+
136
+ const families = new Set();
137
+ for (const entry of Object.values(catalog.models)) {
138
+ if (entry.family) {
139
+ families.add(entry.family);
140
+ }
141
+ }
142
+ return [...families];
143
+ }
@@ -0,0 +1,116 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Tune Config State Manager
6
+ *
7
+ * JavaScript module that mimics the bash _update_config_var() behavior
8
+ * from do/tune for testing purposes. Manages config variables written
9
+ * after job submission.
10
+ */
11
+
12
+ import { readFileSync, writeFileSync } from 'node:fs';
13
+
14
+ /**
15
+ * Update or add a config variable in a do/config-style file.
16
+ * Mimics the bash _update_config_var() function:
17
+ * - If the variable exists (line starts with `export VAR_NAME=`), replace it
18
+ * - Otherwise, append a new line
19
+ *
20
+ * @param {string} configPath - Path to the config file
21
+ * @param {string} varName - Variable name (e.g., TUNE_JOB_NAME_SFT)
22
+ * @param {string} varValue - Variable value
23
+ */
24
+ export function updateConfigVar(configPath, varName, varValue) {
25
+ let content = readFileSync(configPath, 'utf8');
26
+ const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
27
+
28
+ if (pattern.test(content)) {
29
+ content = content.replace(pattern, `export ${varName}="${varValue}"`);
30
+ } else {
31
+ if (content.length > 0 && !content.endsWith('\n')) {
32
+ content += '\n';
33
+ }
34
+ content += `export ${varName}="${varValue}"\n`;
35
+ }
36
+
37
+ writeFileSync(configPath, content, 'utf8');
38
+ }
39
+
40
+ /**
41
+ * Read a config variable from a do/config-style file.
42
+ *
43
+ * @param {string} configPath - Path to the config file
44
+ * @param {string} varName - Variable name to read
45
+ * @returns {string|null} The variable value, or null if not found
46
+ */
47
+ export function readConfigVar(configPath, varName) {
48
+ const content = readFileSync(configPath, 'utf8');
49
+ const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
50
+ const match = content.match(pattern);
51
+ return match ? match[1] : null;
52
+ }
53
+
54
+ /**
55
+ * Simulate the config writes that happen after a successful job submission.
56
+ * This mirrors the behavior in do/tune's _submit_job() function.
57
+ *
58
+ * @param {string} configPath - Path to the config file
59
+ * @param {object} params - Submission parameters
60
+ * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
61
+ * @param {string} params.trainingType - Training type (lora, full-rank)
62
+ * @param {string} params.datasetPath - Dataset path (s3://... or hf://...)
63
+ * @param {string} params.jobName - Generated job name
64
+ */
65
+ export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
66
+ const techniqueUpper = technique.toUpperCase();
67
+ updateConfigVar(configPath, `TUNE_JOB_NAME_${techniqueUpper}`, jobName);
68
+ updateConfigVar(configPath, 'TUNE_TECHNIQUE', technique);
69
+ updateConfigVar(configPath, 'TUNE_TRAINING_TYPE', trainingType);
70
+ updateConfigVar(configPath, 'TUNE_DATASET_PATH', datasetPath);
71
+ }
72
+
73
+ /**
74
+ * Simulate the config writes that happen after a job completes successfully.
75
+ * This mirrors the behavior in do/tune's _handle_completion() function.
76
+ *
77
+ * @param {string} configPath - Path to the config file
78
+ * @param {object} params - Completion parameters
79
+ * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
80
+ * @param {string} params.trainingType - Training type (lora, full-rank)
81
+ * @param {string} params.artifactPath - S3 path to the output artifact
82
+ * @param {string} params.outputType - Output type (adapter, full-model)
83
+ */
84
+ export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType }) {
85
+ const techniqueUpper = technique.toUpperCase();
86
+
87
+ if (trainingType === 'lora') {
88
+ updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
89
+ } else if (trainingType === 'full-rank') {
90
+ updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
91
+ }
92
+
93
+ updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
94
+ updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
95
+ }
96
+
97
+ /**
98
+ * Generate a job name following the pattern used by do/tune.
99
+ * Pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
100
+ *
101
+ * @param {string} projectName - Project name
102
+ * @param {string} technique - Technique (sft, dpo, rlaif, rlvr)
103
+ * @param {Date} [timestamp] - Optional timestamp (defaults to now)
104
+ * @returns {string} Generated job name
105
+ */
106
+ export function generateJobName(projectName, technique, timestamp = new Date()) {
107
+ const year = timestamp.getFullYear().toString();
108
+ const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
109
+ const day = timestamp.getDate().toString().padStart(2, '0');
110
+ const hours = timestamp.getHours().toString().padStart(2, '0');
111
+ const minutes = timestamp.getMinutes().toString().padStart(2, '0');
112
+ const seconds = timestamp.getSeconds().toString().padStart(2, '0');
113
+ const dateStr = `${year}${month}${day}`;
114
+ const timeStr = `${hours}${minutes}${seconds}`;
115
+ return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
116
+ }