@aws/ml-container-creator 1.0.4 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -0
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/package.json +4 -1
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/config +6 -1
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
|
@@ -1,120 +0,0 @@
|
|
|
1
|
-
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
-
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
/**
|
|
5
|
-
* Train Request Builder
|
|
6
|
-
*
|
|
7
|
-
* JavaScript module that replicates the Python helper's (.train_build_request.py)
|
|
8
|
-
* logic for constructing a CreateTrainingJob JSON request from a parsed config.
|
|
9
|
-
*
|
|
10
|
-
* This module mirrors the behavior of the Python build_request() function,
|
|
11
|
-
* providing a testable implementation of the config-to-API mapping logic.
|
|
12
|
-
*/
|
|
13
|
-
|
|
14
|
-
/**
|
|
15
|
-
* Build a CreateTrainingJob request from a parsed training config.
|
|
16
|
-
*
|
|
17
|
-
* Maps config fields to the SageMaker CreateTrainingJob API structure:
|
|
18
|
-
* - image → AlgorithmSpecification.TrainingImage
|
|
19
|
-
* - instance_type → ResourceConfig.InstanceType
|
|
20
|
-
* - instance_count → ResourceConfig.InstanceCount
|
|
21
|
-
* - dataset → InputDataConfig[0].DataSource.S3DataSource.S3Uri
|
|
22
|
-
* - output_path → OutputDataConfig.S3OutputPath
|
|
23
|
-
* - hyperparameters → HyperParameters (string key-value pairs)
|
|
24
|
-
* - max_runtime_seconds → StoppingCondition.MaxRuntimeInSeconds
|
|
25
|
-
* - enable_spot=true → EnableManagedSpotTraining = true
|
|
26
|
-
* - enable_spot=true → StoppingCondition.MaxWaitTimeInSeconds
|
|
27
|
-
* - checkpoint_path → CheckpointConfig.S3Uri
|
|
28
|
-
* - metric_definitions → AlgorithmSpecification.MetricDefinitions
|
|
29
|
-
* - environment → Environment
|
|
30
|
-
* - tags → Tags (converted from {k:v} to [{Key:k, Value:v}])
|
|
31
|
-
*
|
|
32
|
-
* @param {object} options - Build options
|
|
33
|
-
* @param {string} options.jobName - Training job name
|
|
34
|
-
* @param {string} options.roleArn - SageMaker execution role ARN
|
|
35
|
-
* @param {object} options.config - Parsed training config (from parseTrainingConfig)
|
|
36
|
-
* @returns {object} CreateTrainingJob request body
|
|
37
|
-
*/
|
|
38
|
-
export function buildTrainingJobRequest({ jobName, roleArn, config }) {
|
|
39
|
-
const request = {
|
|
40
|
-
TrainingJobName: jobName,
|
|
41
|
-
RoleArn: roleArn,
|
|
42
|
-
AlgorithmSpecification: {
|
|
43
|
-
TrainingImage: config.image,
|
|
44
|
-
TrainingInputMode: 'File'
|
|
45
|
-
},
|
|
46
|
-
InputDataConfig: [
|
|
47
|
-
{
|
|
48
|
-
ChannelName: 'training',
|
|
49
|
-
DataSource: {
|
|
50
|
-
S3DataSource: {
|
|
51
|
-
S3DataType: 'S3Prefix',
|
|
52
|
-
S3Uri: config.dataset,
|
|
53
|
-
S3DataDistributionType: 'FullyReplicated'
|
|
54
|
-
}
|
|
55
|
-
}
|
|
56
|
-
}
|
|
57
|
-
],
|
|
58
|
-
OutputDataConfig: {
|
|
59
|
-
S3OutputPath: config.output_path
|
|
60
|
-
},
|
|
61
|
-
ResourceConfig: {
|
|
62
|
-
InstanceType: config.instance_type,
|
|
63
|
-
InstanceCount: parseInt(config.instance_count, 10),
|
|
64
|
-
VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
|
|
65
|
-
},
|
|
66
|
-
StoppingCondition: {
|
|
67
|
-
MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
|
|
68
|
-
}
|
|
69
|
-
};
|
|
70
|
-
|
|
71
|
-
// Hyperparameters — ensure all values are strings (SageMaker requirement)
|
|
72
|
-
const hyperparams = config.hyperparameters || {};
|
|
73
|
-
if (Object.keys(hyperparams).length > 0) {
|
|
74
|
-
request.HyperParameters = {};
|
|
75
|
-
for (const [k, v] of Object.entries(hyperparams)) {
|
|
76
|
-
request.HyperParameters[String(k)] = String(v);
|
|
77
|
-
}
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
// Managed spot training
|
|
81
|
-
const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
|
|
82
|
-
if (enableSpot) {
|
|
83
|
-
request.EnableManagedSpotTraining = true;
|
|
84
|
-
request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
// Checkpoint configuration (for spot training resumption)
|
|
88
|
-
const checkpointPath = config.checkpoint_path || '';
|
|
89
|
-
if (checkpointPath) {
|
|
90
|
-
request.CheckpointConfig = {
|
|
91
|
-
S3Uri: checkpointPath
|
|
92
|
-
};
|
|
93
|
-
}
|
|
94
|
-
|
|
95
|
-
// Metric definitions (custom CloudWatch metrics)
|
|
96
|
-
const metricDefs = config.metric_definitions || [];
|
|
97
|
-
if (Array.isArray(metricDefs) && metricDefs.length > 0) {
|
|
98
|
-
request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
|
|
99
|
-
Name: m.name,
|
|
100
|
-
Regex: m.regex
|
|
101
|
-
}));
|
|
102
|
-
}
|
|
103
|
-
|
|
104
|
-
// Environment variables for the container
|
|
105
|
-
const environment = config.environment || {};
|
|
106
|
-
if (Object.keys(environment).length > 0) {
|
|
107
|
-
request.Environment = environment;
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
// Tags — convert from {key: value} map to [{Key: k, Value: v}] array
|
|
111
|
-
const tags = config.tags || {};
|
|
112
|
-
if (Object.keys(tags).length > 0) {
|
|
113
|
-
request.Tags = Object.entries(tags).map(([k, v]) => ({
|
|
114
|
-
Key: String(k),
|
|
115
|
-
Value: String(v)
|
|
116
|
-
}));
|
|
117
|
-
}
|
|
118
|
-
|
|
119
|
-
return request;
|
|
120
|
-
}
|
|
@@ -1,279 +0,0 @@
|
|
|
1
|
-
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
-
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
/**
|
|
5
|
-
* Tune Dataset Validator
|
|
6
|
-
*
|
|
7
|
-
* Parses dataset arguments (S3 URIs and Hugging Face references) and
|
|
8
|
-
* validates JSONL dataset lines against catalog-driven schemas.
|
|
9
|
-
*
|
|
10
|
-
* Requirements: 3.1, 3.5, 3.6, 3.7, 3.8, 3.10, 3.11, 3.12
|
|
11
|
-
*/
|
|
12
|
-
|
|
13
|
-
/**
|
|
14
|
-
* Parse a dataset argument string into a structured object.
|
|
15
|
-
* Accepts S3 URIs (`s3://bucket/key`) or Hugging Face references
|
|
16
|
-
* (`hf://org/name` or `hf://org/name/split`).
|
|
17
|
-
*
|
|
18
|
-
* @param {string} datasetStr - The dataset argument string
|
|
19
|
-
* @returns {{ valid: boolean, type?: string, bucket?: string, key?: string, org?: string, name?: string, split?: string, error?: string }}
|
|
20
|
-
*/
|
|
21
|
-
export function parseDatasetArg(datasetStr) {
|
|
22
|
-
if (!datasetStr || typeof datasetStr !== 'string') {
|
|
23
|
-
return {
|
|
24
|
-
valid: false,
|
|
25
|
-
error: 'Dataset argument is required and must be a non-empty string.'
|
|
26
|
-
};
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
const trimmed = datasetStr.trim();
|
|
30
|
-
|
|
31
|
-
if (trimmed.startsWith('s3://')) {
|
|
32
|
-
return _parseS3Uri(trimmed);
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
if (trimmed.startsWith('hf://')) {
|
|
36
|
-
return _parseHfReference(trimmed);
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
return {
|
|
40
|
-
valid: false,
|
|
41
|
-
error: `Invalid dataset format: "${trimmed}". Expected s3://bucket/key or hf://org/name[/split].`
|
|
42
|
-
};
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
/**
|
|
46
|
-
* Validate JSONL lines against a dataset schema from the catalog.
|
|
47
|
-
* Inspects only the first 10 lines per requirement.
|
|
48
|
-
*
|
|
49
|
-
* @param {string[]} lines - Array of JSONL line strings
|
|
50
|
-
* @param {Object} schema - The datasetSchema object from the catalog
|
|
51
|
-
* @param {string[]} schema.required - Array of required top-level keys
|
|
52
|
-
* @param {Object} schema.types - Object mapping key to expected type ("string", "array", "object", "number")
|
|
53
|
-
* @returns {{ valid: boolean, error: string|null, lineNumber: number|null, malformedLine: string|null, expectedFormat: string|null }}
|
|
54
|
-
*/
|
|
55
|
-
export function validateDatasetFormat(lines, schema) {
|
|
56
|
-
if (!lines || !Array.isArray(lines)) {
|
|
57
|
-
return {
|
|
58
|
-
valid: false,
|
|
59
|
-
error: 'Lines must be provided as an array.',
|
|
60
|
-
lineNumber: null,
|
|
61
|
-
malformedLine: null,
|
|
62
|
-
expectedFormat: _buildExpectedFormat(schema)
|
|
63
|
-
};
|
|
64
|
-
}
|
|
65
|
-
|
|
66
|
-
if (!schema || !schema.required || !Array.isArray(schema.required)) {
|
|
67
|
-
return {
|
|
68
|
-
valid: false,
|
|
69
|
-
error: 'Schema must include a "required" array of keys.',
|
|
70
|
-
lineNumber: null,
|
|
71
|
-
malformedLine: null,
|
|
72
|
-
expectedFormat: null
|
|
73
|
-
};
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
const linesToInspect = lines.slice(0, 10);
|
|
77
|
-
|
|
78
|
-
for (let i = 0; i < linesToInspect.length; i++) {
|
|
79
|
-
const line = linesToInspect[i];
|
|
80
|
-
const lineNumber = i + 1;
|
|
81
|
-
|
|
82
|
-
// Skip empty lines
|
|
83
|
-
if (!line || line.trim() === '') {
|
|
84
|
-
continue;
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
// Try to parse as JSON
|
|
88
|
-
let parsed;
|
|
89
|
-
try {
|
|
90
|
-
parsed = JSON.parse(line);
|
|
91
|
-
} catch (e) {
|
|
92
|
-
return {
|
|
93
|
-
valid: false,
|
|
94
|
-
error: `Line ${lineNumber} is not valid JSON: ${e.message}`,
|
|
95
|
-
lineNumber,
|
|
96
|
-
malformedLine: line,
|
|
97
|
-
expectedFormat: _buildExpectedFormat(schema)
|
|
98
|
-
};
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
// Check that parsed value is an object
|
|
102
|
-
if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
|
|
103
|
-
return {
|
|
104
|
-
valid: false,
|
|
105
|
-
error: `Line ${lineNumber} must be a JSON object.`,
|
|
106
|
-
lineNumber,
|
|
107
|
-
malformedLine: line,
|
|
108
|
-
expectedFormat: _buildExpectedFormat(schema)
|
|
109
|
-
};
|
|
110
|
-
}
|
|
111
|
-
|
|
112
|
-
// Check required keys
|
|
113
|
-
for (const key of schema.required) {
|
|
114
|
-
if (!Object.hasOwn(parsed, key)) {
|
|
115
|
-
return {
|
|
116
|
-
valid: false,
|
|
117
|
-
error: `Line ${lineNumber} is missing required key "${key}".`,
|
|
118
|
-
lineNumber,
|
|
119
|
-
malformedLine: line,
|
|
120
|
-
expectedFormat: _buildExpectedFormat(schema)
|
|
121
|
-
};
|
|
122
|
-
}
|
|
123
|
-
}
|
|
124
|
-
|
|
125
|
-
// Check types if specified
|
|
126
|
-
if (schema.types) {
|
|
127
|
-
for (const [key, expectedType] of Object.entries(schema.types)) {
|
|
128
|
-
if (!Object.hasOwn(parsed, key)) {
|
|
129
|
-
continue;
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
const value = parsed[key];
|
|
133
|
-
if (!_checkType(value, expectedType)) {
|
|
134
|
-
return {
|
|
135
|
-
valid: false,
|
|
136
|
-
error: `Line ${lineNumber} has key "${key}" with wrong type. Expected "${expectedType}", got "${_getType(value)}".`,
|
|
137
|
-
lineNumber,
|
|
138
|
-
malformedLine: line,
|
|
139
|
-
expectedFormat: _buildExpectedFormat(schema)
|
|
140
|
-
};
|
|
141
|
-
}
|
|
142
|
-
}
|
|
143
|
-
}
|
|
144
|
-
}
|
|
145
|
-
|
|
146
|
-
return {
|
|
147
|
-
valid: true,
|
|
148
|
-
error: null,
|
|
149
|
-
lineNumber: null,
|
|
150
|
-
malformedLine: null,
|
|
151
|
-
expectedFormat: null
|
|
152
|
-
};
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
/**
|
|
156
|
-
* Parse an S3 URI into bucket and key components.
|
|
157
|
-
* @param {string} uri - The S3 URI (e.g., "s3://bucket/path/to/file.jsonl")
|
|
158
|
-
* @returns {Object} Parsed result
|
|
159
|
-
* @private
|
|
160
|
-
*/
|
|
161
|
-
function _parseS3Uri(uri) {
|
|
162
|
-
const withoutScheme = uri.slice(5); // Remove "s3://"
|
|
163
|
-
const slashIndex = withoutScheme.indexOf('/');
|
|
164
|
-
|
|
165
|
-
if (slashIndex === -1 || slashIndex === 0) {
|
|
166
|
-
return {
|
|
167
|
-
valid: false,
|
|
168
|
-
error: `Invalid S3 URI: "${uri}". Expected format: s3://bucket/key.`
|
|
169
|
-
};
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
const bucket = withoutScheme.slice(0, slashIndex);
|
|
173
|
-
const key = withoutScheme.slice(slashIndex + 1);
|
|
174
|
-
|
|
175
|
-
if (!bucket) {
|
|
176
|
-
return {
|
|
177
|
-
valid: false,
|
|
178
|
-
error: `Invalid S3 URI: "${uri}". Bucket name is empty.`
|
|
179
|
-
};
|
|
180
|
-
}
|
|
181
|
-
|
|
182
|
-
if (!key) {
|
|
183
|
-
return {
|
|
184
|
-
valid: false,
|
|
185
|
-
error: `Invalid S3 URI: "${uri}". Key path is empty.`
|
|
186
|
-
};
|
|
187
|
-
}
|
|
188
|
-
|
|
189
|
-
return {
|
|
190
|
-
valid: true,
|
|
191
|
-
type: 's3',
|
|
192
|
-
bucket,
|
|
193
|
-
key
|
|
194
|
-
};
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
/**
|
|
198
|
-
* Parse a Hugging Face dataset reference into org, name, and split.
|
|
199
|
-
* Defaults to 'train' split if not specified.
|
|
200
|
-
* @param {string} ref - The HF reference (e.g., "hf://org/name" or "hf://org/name/split")
|
|
201
|
-
* @returns {Object} Parsed result
|
|
202
|
-
* @private
|
|
203
|
-
*/
|
|
204
|
-
function _parseHfReference(ref) {
|
|
205
|
-
const withoutScheme = ref.slice(5); // Remove "hf://"
|
|
206
|
-
const parts = withoutScheme.split('/');
|
|
207
|
-
|
|
208
|
-
if (parts.length < 2 || !parts[0] || !parts[1]) {
|
|
209
|
-
return {
|
|
210
|
-
valid: false,
|
|
211
|
-
error: `Invalid Hugging Face reference: "${ref}". Expected format: hf://org/name[/split].`
|
|
212
|
-
};
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
const org = parts[0];
|
|
216
|
-
const name = parts[1];
|
|
217
|
-
const split = parts.length >= 3 && parts[2] ? parts[2] : 'train';
|
|
218
|
-
|
|
219
|
-
return {
|
|
220
|
-
valid: true,
|
|
221
|
-
type: 'hf',
|
|
222
|
-
org,
|
|
223
|
-
name,
|
|
224
|
-
split
|
|
225
|
-
};
|
|
226
|
-
}
|
|
227
|
-
|
|
228
|
-
/**
|
|
229
|
-
* Check if a value matches the expected schema type.
|
|
230
|
-
* @param {*} value - The value to check
|
|
231
|
-
* @param {string} expectedType - One of "string", "array", "object", "number"
|
|
232
|
-
* @returns {boolean} True if the value matches the expected type
|
|
233
|
-
* @private
|
|
234
|
-
*/
|
|
235
|
-
function _checkType(value, expectedType) {
|
|
236
|
-
switch (expectedType) {
|
|
237
|
-
case 'string':
|
|
238
|
-
return typeof value === 'string';
|
|
239
|
-
case 'number':
|
|
240
|
-
return typeof value === 'number';
|
|
241
|
-
case 'array':
|
|
242
|
-
return Array.isArray(value);
|
|
243
|
-
case 'object':
|
|
244
|
-
return typeof value === 'object' && value !== null && !Array.isArray(value);
|
|
245
|
-
default:
|
|
246
|
-
return true;
|
|
247
|
-
}
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
/**
|
|
251
|
-
* Get a human-readable type name for a value.
|
|
252
|
-
* @param {*} value - The value to describe
|
|
253
|
-
* @returns {string} The type name
|
|
254
|
-
* @private
|
|
255
|
-
*/
|
|
256
|
-
function _getType(value) {
|
|
257
|
-
if (value === null) return 'null';
|
|
258
|
-
if (Array.isArray(value)) return 'array';
|
|
259
|
-
return typeof value;
|
|
260
|
-
}
|
|
261
|
-
|
|
262
|
-
/**
|
|
263
|
-
* Build a human-readable expected format description from a schema.
|
|
264
|
-
* @param {Object} schema - The dataset schema
|
|
265
|
-
* @returns {string|null} Description of expected format
|
|
266
|
-
* @private
|
|
267
|
-
*/
|
|
268
|
-
function _buildExpectedFormat(schema) {
|
|
269
|
-
if (!schema || !schema.required) {
|
|
270
|
-
return null;
|
|
271
|
-
}
|
|
272
|
-
|
|
273
|
-
const fields = schema.required.map(key => {
|
|
274
|
-
const type = schema.types && schema.types[key] ? schema.types[key] : 'any';
|
|
275
|
-
return `"${key}": <${type}>`;
|
|
276
|
-
});
|
|
277
|
-
|
|
278
|
-
return `Each line must be a JSON object with: {${fields.join(', ')}}`;
|
|
279
|
-
}
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
-
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
/**
|
|
5
|
-
* Tune Output Resolver
|
|
6
|
-
*
|
|
7
|
-
* Detects output type from training type and generates context-aware
|
|
8
|
-
* next-step commands for deploying tune job artifacts.
|
|
9
|
-
*
|
|
10
|
-
* Requirements: 8.3, 8.11
|
|
11
|
-
*/
|
|
12
|
-
|
|
13
|
-
/**
|
|
14
|
-
* Detect the output type based on the training type used for the job.
|
|
15
|
-
* LoRA training produces adapter weights; full-rank produces a full model.
|
|
16
|
-
*
|
|
17
|
-
* @param {string} trainingType - The training type ('lora' or 'full-rank')
|
|
18
|
-
* @returns {string} The output type: 'adapter' for lora, 'full-model' for full-rank
|
|
19
|
-
*/
|
|
20
|
-
export function detectOutputType(trainingType) {
|
|
21
|
-
if (trainingType === 'lora') {
|
|
22
|
-
return 'adapter';
|
|
23
|
-
}
|
|
24
|
-
if (trainingType === 'full-rank') {
|
|
25
|
-
return 'full-model';
|
|
26
|
-
}
|
|
27
|
-
return 'adapter';
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
/**
|
|
31
|
-
* Generate context-aware next-step commands based on the output type.
|
|
32
|
-
*
|
|
33
|
-
* For adapter output:
|
|
34
|
-
* - Quick path: ./do/adapter add tuned-${technique} --from-tune
|
|
35
|
-
* - Technique-specific: ./do/adapter add tuned-${technique} --from-tune ${technique}
|
|
36
|
-
* - Explicit path: ./do/adapter add tuned-${technique} --weights ${artifactPath}
|
|
37
|
-
*
|
|
38
|
-
* For full-model output:
|
|
39
|
-
* - Deploy as new IC: ./do/add-ic tuned-v1 --from-tune
|
|
40
|
-
* - Explicit path: ./do/add-ic tuned-v1 --model-data ${artifactPath}
|
|
41
|
-
* - Replace current base: ./do/deploy --force-ic --model-data ${artifactPath}
|
|
42
|
-
*
|
|
43
|
-
* @param {string} outputType - The output type ('adapter' or 'full-model')
|
|
44
|
-
* @param {string} technique - The technique used (e.g., 'sft', 'dpo')
|
|
45
|
-
* @param {string} artifactPath - The S3 path to the output artifact
|
|
46
|
-
* @returns {string[]} Array of suggested next-step commands
|
|
47
|
-
*/
|
|
48
|
-
export function generateNextStepCommands(outputType, technique, artifactPath) {
|
|
49
|
-
if (outputType === 'adapter') {
|
|
50
|
-
return [
|
|
51
|
-
`./do/adapter add tuned-${technique} --from-tune`,
|
|
52
|
-
`./do/adapter add tuned-${technique} --from-tune ${technique}`,
|
|
53
|
-
`./do/adapter add tuned-${technique} --weights ${artifactPath}`
|
|
54
|
-
];
|
|
55
|
-
}
|
|
56
|
-
|
|
57
|
-
if (outputType === 'full-model') {
|
|
58
|
-
return [
|
|
59
|
-
'./do/add-ic tuned-v1 --from-tune',
|
|
60
|
-
`./do/add-ic tuned-v1 --model-data ${artifactPath}`,
|
|
61
|
-
`./do/deploy --force-ic --model-data ${artifactPath}`
|
|
62
|
-
];
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
return [];
|
|
66
|
-
}
|