@aws/ml-container-creator 1.0.4 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +9 -0
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/package.json +4 -1
  5. package/pyproject.toml +3 -0
  6. package/servers/agent-knowledge/index.js +592 -0
  7. package/servers/agent-knowledge/package.json +15 -0
  8. package/src/agent/__init__.py +2 -0
  9. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  10. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  11. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  12. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  13. package/src/agent/agent.py +513 -0
  14. package/src/agent/config_loader.py +215 -0
  15. package/src/agent/context.py +380 -0
  16. package/src/agent/data/capability-matrix.json +106 -0
  17. package/src/agent/health_check.py +341 -0
  18. package/src/agent/prompts/system.md +173 -0
  19. package/src/agent/requirements-agent.txt +3 -0
  20. package/src/lib/generated/cli-options.js +1 -1
  21. package/src/lib/generated/parameter-matrix.js +1 -1
  22. package/src/lib/generated/validation-rules.js +1 -1
  23. package/src/lib/tune-config-state.js +89 -68
  24. package/templates/do/config +6 -1
  25. package/src/lib/auto-prompt-builder.js +0 -172
  26. package/src/lib/cli-handler.js +0 -529
  27. package/src/lib/community-reports-validator.js +0 -91
  28. package/src/lib/configuration-exporter.js +0 -204
  29. package/src/lib/dataset-slug.js +0 -152
  30. package/src/lib/docker-introspection-validator.js +0 -51
  31. package/src/lib/known-flags-validator.js +0 -200
  32. package/src/lib/schema-validator.js +0 -157
  33. package/src/lib/train-config-parser.js +0 -136
  34. package/src/lib/train-config-persistence.js +0 -143
  35. package/src/lib/train-config-validator.js +0 -112
  36. package/src/lib/train-feedback.js +0 -46
  37. package/src/lib/train-idempotency.js +0 -97
  38. package/src/lib/train-request-builder.js +0 -120
  39. package/src/lib/tune-dataset-validator.js +0 -279
  40. package/src/lib/tune-output-resolver.js +0 -66
@@ -1,120 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Train Request Builder
6
- *
7
- * JavaScript module that replicates the Python helper's (.train_build_request.py)
8
- * logic for constructing a CreateTrainingJob JSON request from a parsed config.
9
- *
10
- * This module mirrors the behavior of the Python build_request() function,
11
- * providing a testable implementation of the config-to-API mapping logic.
12
- */
13
-
14
- /**
15
- * Build a CreateTrainingJob request from a parsed training config.
16
- *
17
- * Maps config fields to the SageMaker CreateTrainingJob API structure:
18
- * - image → AlgorithmSpecification.TrainingImage
19
- * - instance_type → ResourceConfig.InstanceType
20
- * - instance_count → ResourceConfig.InstanceCount
21
- * - dataset → InputDataConfig[0].DataSource.S3DataSource.S3Uri
22
- * - output_path → OutputDataConfig.S3OutputPath
23
- * - hyperparameters → HyperParameters (string key-value pairs)
24
- * - max_runtime_seconds → StoppingCondition.MaxRuntimeInSeconds
25
- * - enable_spot=true → EnableManagedSpotTraining = true
26
- * - enable_spot=true → StoppingCondition.MaxWaitTimeInSeconds
27
- * - checkpoint_path → CheckpointConfig.S3Uri
28
- * - metric_definitions → AlgorithmSpecification.MetricDefinitions
29
- * - environment → Environment
30
- * - tags → Tags (converted from {k:v} to [{Key:k, Value:v}])
31
- *
32
- * @param {object} options - Build options
33
- * @param {string} options.jobName - Training job name
34
- * @param {string} options.roleArn - SageMaker execution role ARN
35
- * @param {object} options.config - Parsed training config (from parseTrainingConfig)
36
- * @returns {object} CreateTrainingJob request body
37
- */
38
- export function buildTrainingJobRequest({ jobName, roleArn, config }) {
39
- const request = {
40
- TrainingJobName: jobName,
41
- RoleArn: roleArn,
42
- AlgorithmSpecification: {
43
- TrainingImage: config.image,
44
- TrainingInputMode: 'File'
45
- },
46
- InputDataConfig: [
47
- {
48
- ChannelName: 'training',
49
- DataSource: {
50
- S3DataSource: {
51
- S3DataType: 'S3Prefix',
52
- S3Uri: config.dataset,
53
- S3DataDistributionType: 'FullyReplicated'
54
- }
55
- }
56
- }
57
- ],
58
- OutputDataConfig: {
59
- S3OutputPath: config.output_path
60
- },
61
- ResourceConfig: {
62
- InstanceType: config.instance_type,
63
- InstanceCount: parseInt(config.instance_count, 10),
64
- VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
65
- },
66
- StoppingCondition: {
67
- MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
68
- }
69
- };
70
-
71
- // Hyperparameters — ensure all values are strings (SageMaker requirement)
72
- const hyperparams = config.hyperparameters || {};
73
- if (Object.keys(hyperparams).length > 0) {
74
- request.HyperParameters = {};
75
- for (const [k, v] of Object.entries(hyperparams)) {
76
- request.HyperParameters[String(k)] = String(v);
77
- }
78
- }
79
-
80
- // Managed spot training
81
- const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
82
- if (enableSpot) {
83
- request.EnableManagedSpotTraining = true;
84
- request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
85
- }
86
-
87
- // Checkpoint configuration (for spot training resumption)
88
- const checkpointPath = config.checkpoint_path || '';
89
- if (checkpointPath) {
90
- request.CheckpointConfig = {
91
- S3Uri: checkpointPath
92
- };
93
- }
94
-
95
- // Metric definitions (custom CloudWatch metrics)
96
- const metricDefs = config.metric_definitions || [];
97
- if (Array.isArray(metricDefs) && metricDefs.length > 0) {
98
- request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
99
- Name: m.name,
100
- Regex: m.regex
101
- }));
102
- }
103
-
104
- // Environment variables for the container
105
- const environment = config.environment || {};
106
- if (Object.keys(environment).length > 0) {
107
- request.Environment = environment;
108
- }
109
-
110
- // Tags — convert from {key: value} map to [{Key: k, Value: v}] array
111
- const tags = config.tags || {};
112
- if (Object.keys(tags).length > 0) {
113
- request.Tags = Object.entries(tags).map(([k, v]) => ({
114
- Key: String(k),
115
- Value: String(v)
116
- }));
117
- }
118
-
119
- return request;
120
- }
@@ -1,279 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Tune Dataset Validator
6
- *
7
- * Parses dataset arguments (S3 URIs and Hugging Face references) and
8
- * validates JSONL dataset lines against catalog-driven schemas.
9
- *
10
- * Requirements: 3.1, 3.5, 3.6, 3.7, 3.8, 3.10, 3.11, 3.12
11
- */
12
-
13
- /**
14
- * Parse a dataset argument string into a structured object.
15
- * Accepts S3 URIs (`s3://bucket/key`) or Hugging Face references
16
- * (`hf://org/name` or `hf://org/name/split`).
17
- *
18
- * @param {string} datasetStr - The dataset argument string
19
- * @returns {{ valid: boolean, type?: string, bucket?: string, key?: string, org?: string, name?: string, split?: string, error?: string }}
20
- */
21
- export function parseDatasetArg(datasetStr) {
22
- if (!datasetStr || typeof datasetStr !== 'string') {
23
- return {
24
- valid: false,
25
- error: 'Dataset argument is required and must be a non-empty string.'
26
- };
27
- }
28
-
29
- const trimmed = datasetStr.trim();
30
-
31
- if (trimmed.startsWith('s3://')) {
32
- return _parseS3Uri(trimmed);
33
- }
34
-
35
- if (trimmed.startsWith('hf://')) {
36
- return _parseHfReference(trimmed);
37
- }
38
-
39
- return {
40
- valid: false,
41
- error: `Invalid dataset format: "${trimmed}". Expected s3://bucket/key or hf://org/name[/split].`
42
- };
43
- }
44
-
45
- /**
46
- * Validate JSONL lines against a dataset schema from the catalog.
47
- * Inspects only the first 10 lines per requirement.
48
- *
49
- * @param {string[]} lines - Array of JSONL line strings
50
- * @param {Object} schema - The datasetSchema object from the catalog
51
- * @param {string[]} schema.required - Array of required top-level keys
52
- * @param {Object} schema.types - Object mapping key to expected type ("string", "array", "object", "number")
53
- * @returns {{ valid: boolean, error: string|null, lineNumber: number|null, malformedLine: string|null, expectedFormat: string|null }}
54
- */
55
- export function validateDatasetFormat(lines, schema) {
56
- if (!lines || !Array.isArray(lines)) {
57
- return {
58
- valid: false,
59
- error: 'Lines must be provided as an array.',
60
- lineNumber: null,
61
- malformedLine: null,
62
- expectedFormat: _buildExpectedFormat(schema)
63
- };
64
- }
65
-
66
- if (!schema || !schema.required || !Array.isArray(schema.required)) {
67
- return {
68
- valid: false,
69
- error: 'Schema must include a "required" array of keys.',
70
- lineNumber: null,
71
- malformedLine: null,
72
- expectedFormat: null
73
- };
74
- }
75
-
76
- const linesToInspect = lines.slice(0, 10);
77
-
78
- for (let i = 0; i < linesToInspect.length; i++) {
79
- const line = linesToInspect[i];
80
- const lineNumber = i + 1;
81
-
82
- // Skip empty lines
83
- if (!line || line.trim() === '') {
84
- continue;
85
- }
86
-
87
- // Try to parse as JSON
88
- let parsed;
89
- try {
90
- parsed = JSON.parse(line);
91
- } catch (e) {
92
- return {
93
- valid: false,
94
- error: `Line ${lineNumber} is not valid JSON: ${e.message}`,
95
- lineNumber,
96
- malformedLine: line,
97
- expectedFormat: _buildExpectedFormat(schema)
98
- };
99
- }
100
-
101
- // Check that parsed value is an object
102
- if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
103
- return {
104
- valid: false,
105
- error: `Line ${lineNumber} must be a JSON object.`,
106
- lineNumber,
107
- malformedLine: line,
108
- expectedFormat: _buildExpectedFormat(schema)
109
- };
110
- }
111
-
112
- // Check required keys
113
- for (const key of schema.required) {
114
- if (!Object.hasOwn(parsed, key)) {
115
- return {
116
- valid: false,
117
- error: `Line ${lineNumber} is missing required key "${key}".`,
118
- lineNumber,
119
- malformedLine: line,
120
- expectedFormat: _buildExpectedFormat(schema)
121
- };
122
- }
123
- }
124
-
125
- // Check types if specified
126
- if (schema.types) {
127
- for (const [key, expectedType] of Object.entries(schema.types)) {
128
- if (!Object.hasOwn(parsed, key)) {
129
- continue;
130
- }
131
-
132
- const value = parsed[key];
133
- if (!_checkType(value, expectedType)) {
134
- return {
135
- valid: false,
136
- error: `Line ${lineNumber} has key "${key}" with wrong type. Expected "${expectedType}", got "${_getType(value)}".`,
137
- lineNumber,
138
- malformedLine: line,
139
- expectedFormat: _buildExpectedFormat(schema)
140
- };
141
- }
142
- }
143
- }
144
- }
145
-
146
- return {
147
- valid: true,
148
- error: null,
149
- lineNumber: null,
150
- malformedLine: null,
151
- expectedFormat: null
152
- };
153
- }
154
-
155
- /**
156
- * Parse an S3 URI into bucket and key components.
157
- * @param {string} uri - The S3 URI (e.g., "s3://bucket/path/to/file.jsonl")
158
- * @returns {Object} Parsed result
159
- * @private
160
- */
161
- function _parseS3Uri(uri) {
162
- const withoutScheme = uri.slice(5); // Remove "s3://"
163
- const slashIndex = withoutScheme.indexOf('/');
164
-
165
- if (slashIndex === -1 || slashIndex === 0) {
166
- return {
167
- valid: false,
168
- error: `Invalid S3 URI: "${uri}". Expected format: s3://bucket/key.`
169
- };
170
- }
171
-
172
- const bucket = withoutScheme.slice(0, slashIndex);
173
- const key = withoutScheme.slice(slashIndex + 1);
174
-
175
- if (!bucket) {
176
- return {
177
- valid: false,
178
- error: `Invalid S3 URI: "${uri}". Bucket name is empty.`
179
- };
180
- }
181
-
182
- if (!key) {
183
- return {
184
- valid: false,
185
- error: `Invalid S3 URI: "${uri}". Key path is empty.`
186
- };
187
- }
188
-
189
- return {
190
- valid: true,
191
- type: 's3',
192
- bucket,
193
- key
194
- };
195
- }
196
-
197
- /**
198
- * Parse a Hugging Face dataset reference into org, name, and split.
199
- * Defaults to 'train' split if not specified.
200
- * @param {string} ref - The HF reference (e.g., "hf://org/name" or "hf://org/name/split")
201
- * @returns {Object} Parsed result
202
- * @private
203
- */
204
- function _parseHfReference(ref) {
205
- const withoutScheme = ref.slice(5); // Remove "hf://"
206
- const parts = withoutScheme.split('/');
207
-
208
- if (parts.length < 2 || !parts[0] || !parts[1]) {
209
- return {
210
- valid: false,
211
- error: `Invalid Hugging Face reference: "${ref}". Expected format: hf://org/name[/split].`
212
- };
213
- }
214
-
215
- const org = parts[0];
216
- const name = parts[1];
217
- const split = parts.length >= 3 && parts[2] ? parts[2] : 'train';
218
-
219
- return {
220
- valid: true,
221
- type: 'hf',
222
- org,
223
- name,
224
- split
225
- };
226
- }
227
-
228
- /**
229
- * Check if a value matches the expected schema type.
230
- * @param {*} value - The value to check
231
- * @param {string} expectedType - One of "string", "array", "object", "number"
232
- * @returns {boolean} True if the value matches the expected type
233
- * @private
234
- */
235
- function _checkType(value, expectedType) {
236
- switch (expectedType) {
237
- case 'string':
238
- return typeof value === 'string';
239
- case 'number':
240
- return typeof value === 'number';
241
- case 'array':
242
- return Array.isArray(value);
243
- case 'object':
244
- return typeof value === 'object' && value !== null && !Array.isArray(value);
245
- default:
246
- return true;
247
- }
248
- }
249
-
250
- /**
251
- * Get a human-readable type name for a value.
252
- * @param {*} value - The value to describe
253
- * @returns {string} The type name
254
- * @private
255
- */
256
- function _getType(value) {
257
- if (value === null) return 'null';
258
- if (Array.isArray(value)) return 'array';
259
- return typeof value;
260
- }
261
-
262
- /**
263
- * Build a human-readable expected format description from a schema.
264
- * @param {Object} schema - The dataset schema
265
- * @returns {string|null} Description of expected format
266
- * @private
267
- */
268
- function _buildExpectedFormat(schema) {
269
- if (!schema || !schema.required) {
270
- return null;
271
- }
272
-
273
- const fields = schema.required.map(key => {
274
- const type = schema.types && schema.types[key] ? schema.types[key] : 'any';
275
- return `"${key}": <${type}>`;
276
- });
277
-
278
- return `Each line must be a JSON object with: {${fields.join(', ')}}`;
279
- }
@@ -1,66 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Tune Output Resolver
6
- *
7
- * Detects output type from training type and generates context-aware
8
- * next-step commands for deploying tune job artifacts.
9
- *
10
- * Requirements: 8.3, 8.11
11
- */
12
-
13
- /**
14
- * Detect the output type based on the training type used for the job.
15
- * LoRA training produces adapter weights; full-rank produces a full model.
16
- *
17
- * @param {string} trainingType - The training type ('lora' or 'full-rank')
18
- * @returns {string} The output type: 'adapter' for lora, 'full-model' for full-rank
19
- */
20
- export function detectOutputType(trainingType) {
21
- if (trainingType === 'lora') {
22
- return 'adapter';
23
- }
24
- if (trainingType === 'full-rank') {
25
- return 'full-model';
26
- }
27
- return 'adapter';
28
- }
29
-
30
- /**
31
- * Generate context-aware next-step commands based on the output type.
32
- *
33
- * For adapter output:
34
- * - Quick path: ./do/adapter add tuned-${technique} --from-tune
35
- * - Technique-specific: ./do/adapter add tuned-${technique} --from-tune ${technique}
36
- * - Explicit path: ./do/adapter add tuned-${technique} --weights ${artifactPath}
37
- *
38
- * For full-model output:
39
- * - Deploy as new IC: ./do/add-ic tuned-v1 --from-tune
40
- * - Explicit path: ./do/add-ic tuned-v1 --model-data ${artifactPath}
41
- * - Replace current base: ./do/deploy --force-ic --model-data ${artifactPath}
42
- *
43
- * @param {string} outputType - The output type ('adapter' or 'full-model')
44
- * @param {string} technique - The technique used (e.g., 'sft', 'dpo')
45
- * @param {string} artifactPath - The S3 path to the output artifact
46
- * @returns {string[]} Array of suggested next-step commands
47
- */
48
- export function generateNextStepCommands(outputType, technique, artifactPath) {
49
- if (outputType === 'adapter') {
50
- return [
51
- `./do/adapter add tuned-${technique} --from-tune`,
52
- `./do/adapter add tuned-${technique} --from-tune ${technique}`,
53
- `./do/adapter add tuned-${technique} --weights ${artifactPath}`
54
- ];
55
- }
56
-
57
- if (outputType === 'full-model') {
58
- return [
59
- './do/add-ic tuned-v1 --from-tune',
60
- `./do/add-ic tuned-v1 --model-data ${artifactPath}`,
61
- `./do/deploy --force-ic --model-data ${artifactPath}`
62
- ];
63
- }
64
-
65
- return [];
66
- }