@aws/ml-container-creator 1.0.4 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +9 -0
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/package.json +4 -1
  5. package/pyproject.toml +3 -0
  6. package/servers/agent-knowledge/index.js +592 -0
  7. package/servers/agent-knowledge/package.json +15 -0
  8. package/src/agent/__init__.py +2 -0
  9. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  10. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  11. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  12. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  13. package/src/agent/agent.py +513 -0
  14. package/src/agent/config_loader.py +215 -0
  15. package/src/agent/context.py +380 -0
  16. package/src/agent/data/capability-matrix.json +106 -0
  17. package/src/agent/health_check.py +341 -0
  18. package/src/agent/prompts/system.md +173 -0
  19. package/src/agent/requirements-agent.txt +3 -0
  20. package/src/lib/generated/cli-options.js +1 -1
  21. package/src/lib/generated/parameter-matrix.js +1 -1
  22. package/src/lib/generated/validation-rules.js +1 -1
  23. package/src/lib/tune-config-state.js +89 -68
  24. package/templates/do/config +6 -1
  25. package/src/lib/auto-prompt-builder.js +0 -172
  26. package/src/lib/cli-handler.js +0 -529
  27. package/src/lib/community-reports-validator.js +0 -91
  28. package/src/lib/configuration-exporter.js +0 -204
  29. package/src/lib/dataset-slug.js +0 -152
  30. package/src/lib/docker-introspection-validator.js +0 -51
  31. package/src/lib/known-flags-validator.js +0 -200
  32. package/src/lib/schema-validator.js +0 -157
  33. package/src/lib/train-config-parser.js +0 -136
  34. package/src/lib/train-config-persistence.js +0 -143
  35. package/src/lib/train-config-validator.js +0 -112
  36. package/src/lib/train-feedback.js +0 -46
  37. package/src/lib/train-idempotency.js +0 -97
  38. package/src/lib/train-request-builder.js +0 -120
  39. package/src/lib/tune-dataset-validator.js +0 -279
  40. package/src/lib/tune-output-resolver.js +0 -66
@@ -2,65 +2,91 @@
2
2
  // SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  /**
5
- * Tune Config State Manager
5
+ * Tune Config State
6
6
  *
7
- * JavaScript module that mimics the bash _update_config_var() behavior
8
- * from do/tune for testing purposes. Manages config variables written
9
- * after job submission.
7
+ * Manages bash-style config files (do/config) that contain lines like:
8
+ * export VAR_NAME="value"
9
+ *
10
+ * Provides read/write access for tuning job state variables.
10
11
  */
11
12
 
12
13
  import { readFileSync, writeFileSync } from 'node:fs';
13
14
 
14
15
  /**
15
- * Update or add a config variable in a do/config-style file.
16
- * Mimics the bash _update_config_var() function:
17
- * - If the variable exists (line starts with `export VAR_NAME=`), replace it
18
- * - Otherwise, append a new line
16
+ * Read a variable value from a bash config file.
17
+ * Looks for lines matching: export VAR_NAME="value", export VAR_NAME='value', or export VAR_NAME=value
19
18
  *
20
19
  * @param {string} configPath - Path to the config file
21
- * @param {string} varName - Variable name (e.g., TUNE_JOB_NAME_SFT)
22
- * @param {string} varValue - Variable value
20
+ * @param {string} varName - Variable name to read
21
+ * @returns {string|null} The unquoted value, or null if not found
23
22
  */
24
- export function updateConfigVar(configPath, varName, varValue) {
25
- let content = readFileSync(configPath, 'utf8');
26
- const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
23
+ export function readConfigVar(configPath, varName) {
24
+ const content = readFileSync(configPath, 'utf8');
25
+ const lines = content.split('\n');
27
26
 
28
- if (pattern.test(content)) {
29
- content = content.replace(pattern, `export ${varName}="${varValue}"`);
30
- } else {
31
- if (content.length > 0 && !content.endsWith('\n')) {
32
- content += '\n';
27
+ for (const line of lines) {
28
+ const trimmed = line.trim();
29
+ const prefix = `export ${varName}=`;
30
+ if (trimmed.startsWith(prefix)) {
31
+ let value = trimmed.slice(prefix.length);
32
+ // Strip surrounding quotes (double or single)
33
+ if ((value.startsWith('"') && value.endsWith('"')) ||
34
+ (value.startsWith('\'') && value.endsWith('\''))) {
35
+ value = value.slice(1, -1);
36
+ }
37
+ return value;
33
38
  }
34
- content += `export ${varName}="${varValue}"\n`;
35
39
  }
36
40
 
37
- writeFileSync(configPath, content, 'utf8');
41
+ return null;
38
42
  }
39
43
 
40
44
  /**
41
- * Read a config variable from a do/config-style file.
45
+ * Write or update a variable in a bash config file.
46
+ * If the variable already exists, replaces that line.
47
+ * If not, appends the new export line.
42
48
  *
43
49
  * @param {string} configPath - Path to the config file
44
- * @param {string} varName - Variable name to read
45
- * @returns {string|null} The variable value, or null if not found
50
+ * @param {string} varName - Variable name to set
51
+ * @param {string} value - Value to assign
46
52
  */
47
- export function readConfigVar(configPath, varName) {
53
+ export function updateConfigVar(configPath, varName, value) {
48
54
  const content = readFileSync(configPath, 'utf8');
49
- const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
50
- const match = content.match(pattern);
51
- return match ? match[1] : null;
55
+ const lines = content.split('\n');
56
+ const prefix = `export ${varName}=`;
57
+ const newLine = `export ${varName}="${value}"`;
58
+
59
+ let found = false;
60
+ for (let i = 0; i < lines.length; i++) {
61
+ if (lines[i].trim().startsWith(prefix)) {
62
+ lines[i] = newLine;
63
+ found = true;
64
+ break;
65
+ }
66
+ }
67
+
68
+ if (found) {
69
+ writeFileSync(configPath, lines.join('\n'), 'utf8');
70
+ } else {
71
+ // Append to end of file
72
+ let appendContent = content;
73
+ if (appendContent.length > 0 && !appendContent.endsWith('\n')) {
74
+ appendContent += '\n';
75
+ }
76
+ appendContent += `${newLine }\n`;
77
+ writeFileSync(configPath, appendContent, 'utf8');
78
+ }
52
79
  }
53
80
 
54
81
  /**
55
- * Simulate the config writes that happen after a successful job submission.
56
- * This mirrors the behavior in do/tune's _submit_job() function.
82
+ * Write tuning job submission state to config.
57
83
  *
58
84
  * @param {string} configPath - Path to the config file
59
- * @param {object} params - Submission parameters
60
- * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
61
- * @param {string} params.trainingType - Training type (lora, full-rank)
62
- * @param {string} params.datasetPath - Dataset path (s3://... or hf://...)
63
- * @param {string} params.jobName - Generated job name
85
+ * @param {object} state - Submission state
86
+ * @param {string} state.technique - Tuning technique (e.g., 'sft', 'dpo')
87
+ * @param {string} state.trainingType - Training type (e.g., 'lora', 'full-rank')
88
+ * @param {string} state.datasetPath - Dataset path (S3 or HF URI)
89
+ * @param {string} state.jobName - Generated job name
64
90
  */
65
91
  export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
66
92
  const techniqueUpper = technique.toUpperCase();
@@ -71,59 +97,54 @@ export function persistSubmissionState(configPath, { technique, trainingType, da
71
97
  }
72
98
 
73
99
  /**
74
- * Simulate the config writes that happen after a job completes successfully.
75
- * This mirrors the behavior in do/tune's _handle_completion() function.
76
- *
77
- * Writes three levels of tracking (AC-4.1, AC-4.2):
78
- * - Level 1: TUNE_OUTPUT_PATH_LATEST (always the last run, any technique)
79
- * - Level 2: TUNE_ADAPTER_PATH_<TECHNIQUE> (last run per technique)
80
- * - Level 3: TUNE_ADAPTER_PATH_<TECHNIQUE>_<SLUG> (per technique + dataset slug)
100
+ * Write tuning job completion state to config.
81
101
  *
82
102
  * @param {string} configPath - Path to the config file
83
- * @param {object} params - Completion parameters
84
- * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
85
- * @param {string} params.trainingType - Training type (lora, full-rank)
86
- * @param {string} params.artifactPath - S3 path to the output artifact
87
- * @param {string} params.outputType - Output type (adapter, full-model)
88
- * @param {string} [params.datasetSlug] - Optional dataset slug for per-technique-per-dataset tracking
103
+ * @param {object} state - Completion state
104
+ * @param {string} state.technique - Tuning technique
105
+ * @param {string} state.trainingType - Training type
106
+ * @param {string} state.artifactPath - Output artifact path (S3 URI)
107
+ * @param {string} state.outputType - Output type ('adapter' or 'model')
108
+ * @param {string} [state.datasetSlug] - Dataset slug for named paths
89
109
  */
90
- export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType, datasetSlug }) {
110
+ export function persistCompletionState(configPath, { technique, trainingType: _trainingType, artifactPath, outputType, datasetSlug }) {
91
111
  const techniqueUpper = technique.toUpperCase();
92
112
 
93
- if (trainingType === 'lora') {
94
- // Level 2: per-technique
113
+ updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
114
+ updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
115
+
116
+ if (outputType === 'adapter') {
95
117
  updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
96
- // Level 3: per-technique + per-dataset (if slug available)
97
118
  if (datasetSlug) {
98
119
  const slugUpper = datasetSlug.toUpperCase().replace(/-/g, '_');
99
120
  updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}_${slugUpper}`, artifactPath);
100
121
  }
101
- } else if (trainingType === 'full-rank') {
122
+ } else {
102
123
  updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
103
124
  }
104
-
105
- // Level 1: latest
106
- updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
107
- updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
108
125
  }
109
126
 
110
127
  /**
111
- * Generate a job name following the pattern used by do/tune.
112
- * Pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
128
+ * Generate a job name matching pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
129
+ * Uses local time for the timestamp.
113
130
  *
114
131
  * @param {string} projectName - Project name
115
- * @param {string} technique - Technique (sft, dpo, rlaif, rlvr)
116
- * @param {Date} [timestamp] - Optional timestamp (defaults to now)
117
- * @returns {string} Generated job name
132
+ * @param {string} technique - Tuning technique
133
+ * @param {Date} [timestamp] - Optional timestamp (defaults to new Date())
134
+ * @returns {string} Formatted job name
118
135
  */
119
- export function generateJobName(projectName, technique, timestamp = new Date()) {
120
- const year = timestamp.getFullYear().toString();
121
- const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
122
- const day = timestamp.getDate().toString().padStart(2, '0');
123
- const hours = timestamp.getHours().toString().padStart(2, '0');
124
- const minutes = timestamp.getMinutes().toString().padStart(2, '0');
125
- const seconds = timestamp.getSeconds().toString().padStart(2, '0');
136
+ export function generateJobName(projectName, technique, timestamp) {
137
+ const ts = timestamp || new Date();
138
+
139
+ const year = ts.getFullYear().toString();
140
+ const month = (ts.getMonth() + 1).toString().padStart(2, '0');
141
+ const day = ts.getDate().toString().padStart(2, '0');
142
+ const hours = ts.getHours().toString().padStart(2, '0');
143
+ const minutes = ts.getMinutes().toString().padStart(2, '0');
144
+ const seconds = ts.getSeconds().toString().padStart(2, '0');
145
+
126
146
  const dateStr = `${year}${month}${day}`;
127
147
  const timeStr = `${hours}${minutes}${seconds}`;
148
+
128
149
  return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
129
150
  }
@@ -220,6 +220,9 @@ export <%= key %>=${<%= key %>:-<%= value %>}
220
220
  <% Object.entries(icEnvVars).forEach(([key, value]) => { %>
221
221
  export IC_ENV_<%= key %>=${IC_ENV_<%= key %>:-<%= value %>}
222
222
  <% }); %>
223
+ <% if ((modelServer === 'vllm' || modelServer === 'sglang') && !icEnvVars['VLLM_MAX_MODEL_LEN'] && !icEnvVars['SGLANG_MAX_MODEL_LEN']) { %>
224
+ export IC_ENV_VLLM_MAX_MODEL_LEN=${IC_ENV_VLLM_MAX_MODEL_LEN:-4096}
225
+ <% } %>
223
226
  <% } else if (deploymentTarget === 'realtime-inference') { %>
224
227
  # ─── Deploy-time IC environment variables (uncomment to configure) ─────────────
225
228
  # These are passed as the Environment field in InferenceComponent.create() at deploy time.
@@ -227,7 +230,9 @@ export IC_ENV_<%= key %>=${IC_ENV_<%= key %>:-<%= value %>}
227
230
  # Max 16 vars, max 1024 chars per key/value.
228
231
  # WARNING: Do not store raw secrets here. Use Secrets Manager ARN pattern instead:
229
232
  # export IC_ENV_HF_TOKEN_ARN=arn:aws:secretsmanager:REGION:ACCOUNT:secret:NAME
230
- # export IC_ENV_VLLM_MAX_MODEL_LEN=8192
233
+ <% if (modelServer === 'vllm' || modelServer === 'sglang') { %>
234
+ export IC_ENV_VLLM_MAX_MODEL_LEN=${IC_ENV_VLLM_MAX_MODEL_LEN:-4096}
235
+ <% } %>
231
236
  # export IC_ENV_VLLM_GPU_MEMORY_UTILIZATION=0.85
232
237
  <% } %>
233
238
 
@@ -1,172 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Auto-Prompt Builder — generates targeted prompts for missing required parameters.
6
- *
7
- * Used by --auto-prompt mode to ask only for values that cannot be inferred
8
- * or defaulted from the provided CLI flags.
9
- */
10
-
11
- /**
12
- * Builds a minimal set of prompts for the given missing parameters.
13
- * Each prompt is self-contained and doesn't depend on multi-phase wizard state.
14
- *
15
- * @param {string[]} missingParams - Parameter names that need values
16
- * @param {object} currentConfig - Current configuration (with defaults filled)
17
- * @returns {Array} Array of prompt objects compatible with runPrompts()
18
- */
19
- export function buildAutoPrompts(missingParams, currentConfig) {
20
- const prompts = [];
21
-
22
- for (const param of missingParams) {
23
- const builder = PROMPT_BUILDERS[param];
24
- if (builder) {
25
- const prompt = builder(currentConfig);
26
- if (prompt) {
27
- prompts.push(prompt);
28
- }
29
- } else {
30
- // Fallback: generic text input for unknown parameters
31
- prompts.push({
32
- type: 'input',
33
- name: param,
34
- message: `Enter value for ${param}:`
35
- });
36
- }
37
- }
38
-
39
- return prompts;
40
- }
41
-
42
- /**
43
- * Map of parameter names to prompt builder functions.
44
- * Each builder receives the current config and returns a prompt object.
45
- */
46
- const PROMPT_BUILDERS = {
47
- deploymentConfig: (_config) => ({
48
- type: 'list',
49
- name: 'deploymentConfig',
50
- message: 'Select deployment configuration:',
51
- choices: [
52
- { type: 'separator', separator: '── Large Language Models ──' },
53
- { name: 'Transformers with vLLM', value: 'transformers-vllm' },
54
- { name: 'Transformers with SGLang', value: 'transformers-sglang' },
55
- { name: 'Transformers with TensorRT-LLM', value: 'transformers-tensorrt-llm' },
56
- { name: 'Transformers with LMI', value: 'transformers-lmi' },
57
- { name: 'Transformers with DJL', value: 'transformers-djl' },
58
- { type: 'separator', separator: '── HTTP Serving ──' },
59
- { name: 'HTTP with Flask', value: 'http-flask' },
60
- { name: 'HTTP with FastAPI', value: 'http-fastapi' },
61
- { type: 'separator', separator: '── NVIDIA Triton ──' },
62
- { name: 'Triton FIL (XGBoost, LightGBM)', value: 'triton-fil' },
63
- { name: 'Triton ONNX Runtime', value: 'triton-onnxruntime' },
64
- { name: 'Triton TensorFlow', value: 'triton-tensorflow' },
65
- { name: 'Triton PyTorch', value: 'triton-pytorch' },
66
- { name: 'Triton vLLM', value: 'triton-vllm' },
67
- { name: 'Triton TensorRT-LLM', value: 'triton-tensorrtllm' },
68
- { name: 'Triton Python Backend', value: 'triton-python' },
69
- { type: 'separator', separator: '── Diffusion Models ──' },
70
- { name: 'Diffusors with vLLM Omni', value: 'diffusors-vllm-omni' }
71
- ]
72
- }),
73
-
74
- instanceType: (config) => {
75
- const architecture = config.architecture || 'http';
76
- const isGpu = architecture === 'transformers' || architecture === 'triton' || architecture === 'diffusors';
77
-
78
- const gpuChoices = [
79
- { name: 'ml.g5.xlarge (1× A10G 24GB — small LLMs)', value: 'ml.g5.xlarge' },
80
- { name: 'ml.g5.2xlarge (1× A10G 24GB — medium LLMs)', value: 'ml.g5.2xlarge' },
81
- { name: 'ml.g5.4xlarge (1× A10G 24GB — larger models)', value: 'ml.g5.4xlarge' },
82
- { name: 'ml.g5.12xlarge (4× A10G 96GB — large LLMs)', value: 'ml.g5.12xlarge' },
83
- { name: 'ml.g5.48xlarge (8× A10G 192GB — very large)', value: 'ml.g5.48xlarge' },
84
- { name: 'ml.g6.xlarge (1× L4 24GB)', value: 'ml.g6.xlarge' },
85
- { name: 'ml.g6.2xlarge (1× L4 24GB)', value: 'ml.g6.2xlarge' },
86
- { name: 'ml.p4d.24xlarge (8× A100 320GB)', value: 'ml.p4d.24xlarge' },
87
- { name: 'ml.p5.48xlarge (8× H100 640GB)', value: 'ml.p5.48xlarge' },
88
- { name: 'Custom (enter manually)', value: '_custom' }
89
- ];
90
-
91
- const cpuChoices = [
92
- { name: 'ml.m5.large (2 vCPU, 8GB — lightweight)', value: 'ml.m5.large' },
93
- { name: 'ml.m5.xlarge (4 vCPU, 16GB — small models)', value: 'ml.m5.xlarge' },
94
- { name: 'ml.m5.2xlarge (8 vCPU, 32GB — medium models)', value: 'ml.m5.2xlarge' },
95
- { name: 'ml.m5.4xlarge (16 vCPU, 64GB — large models)', value: 'ml.m5.4xlarge' },
96
- { name: 'ml.c5.xlarge (4 vCPU, 8GB — compute-heavy)', value: 'ml.c5.xlarge' },
97
- { name: 'ml.c5.2xlarge (8 vCPU, 16GB — compute-heavy)', value: 'ml.c5.2xlarge' },
98
- { name: 'Custom (enter manually)', value: '_custom' }
99
- ];
100
-
101
- return {
102
- type: 'list',
103
- name: 'instanceType',
104
- message: `Select instance type${isGpu ? ' (GPU recommended for this architecture)' : ''}:`,
105
- choices: isGpu ? gpuChoices : cpuChoices
106
- };
107
- },
108
-
109
- deploymentTarget: (_config) => ({
110
- type: 'list',
111
- name: 'deploymentTarget',
112
- message: 'Select deployment target:',
113
- choices: [
114
- { name: 'Real-Time Inference', value: 'realtime-inference' },
115
- { name: 'Async Inference', value: 'async-inference' },
116
- { name: 'Batch Transform', value: 'batch-transform' },
117
- { name: 'HyperPod EKS', value: 'hyperpod-eks' }
118
- ]
119
- }),
120
-
121
- modelFormat: (config) => {
122
- const engine = config.engine || 'sklearn';
123
- const formatMap = {
124
- sklearn: [
125
- { name: 'pkl (pickle)', value: 'pkl' },
126
- { name: 'joblib', value: 'joblib' }
127
- ],
128
- xgboost: [
129
- { name: 'json', value: 'json' },
130
- { name: 'model (binary)', value: 'model' },
131
- { name: 'ubj (universal binary JSON)', value: 'ubj' }
132
- ],
133
- tensorflow: [
134
- { name: 'keras', value: 'keras' },
135
- { name: 'h5', value: 'h5' },
136
- { name: 'SavedModel', value: 'SavedModel' }
137
- ]
138
- };
139
-
140
- const choices = formatMap[engine] || formatMap.sklearn;
141
-
142
- return {
143
- type: 'list',
144
- name: 'modelFormat',
145
- message: `Select model format for ${engine}:`,
146
- choices
147
- };
148
- },
149
-
150
- awsRegion: (_config) => ({
151
- type: 'list',
152
- name: 'awsRegion',
153
- message: 'Select AWS region:',
154
- choices: [
155
- { name: 'us-east-1 (N. Virginia)', value: 'us-east-1' },
156
- { name: 'us-west-2 (Oregon)', value: 'us-west-2' },
157
- { name: 'eu-west-1 (Ireland)', value: 'eu-west-1' },
158
- { name: 'ap-northeast-1 (Tokyo)', value: 'ap-northeast-1' },
159
- { name: 'ap-southeast-1 (Singapore)', value: 'ap-southeast-1' },
160
- { name: 'Custom (enter manually)', value: '_custom' }
161
- ]
162
- }),
163
-
164
- buildTarget: (_config) => ({
165
- type: 'list',
166
- name: 'buildTarget',
167
- message: 'Select build target:',
168
- choices: [
169
- { name: 'CodeBuild (recommended)', value: 'codebuild' }
170
- ]
171
- })
172
- };