@aws/ml-container-creator 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -1
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +5 -2
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/config +6 -1
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Unmanaged SFT fine-tuning using TRL SFTTrainer + PEFT LoRA.
|
|
6
|
+
|
|
7
|
+
This script performs LoRA-based supervised fine-tuning on a causal language
|
|
8
|
+
model. It is invoked via `accelerate launch` and works on both single-GPU
|
|
9
|
+
and multi-GPU instances without code changes.
|
|
10
|
+
|
|
11
|
+
Portable env-var contract (works on SageMaker AI and HyperPod EKS):
|
|
12
|
+
DATA_DIR / SM_CHANNEL_TRAINING -> training data path
|
|
13
|
+
OUTPUT_DIR / SM_MODEL_DIR -> model artifact output
|
|
14
|
+
CHECKPOINT_DIR / SM_CHECKPOINT_DIR -> checkpoint path for spot resume
|
|
15
|
+
HF_MODEL_ID / SM_HP_MODEL_ID -> base model HuggingFace ID
|
|
16
|
+
SM_HPS -> JSON blob of all hyperparameters
|
|
17
|
+
|
|
18
|
+
Output:
|
|
19
|
+
LoRA adapter saved to OUTPUT_DIR (adapter_model.safetensors + adapter_config.json)
|
|
20
|
+
Metrics logged to stdout in SageMaker-parseable format
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import glob
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
import sys
|
|
28
|
+
|
|
29
|
+
# ── Logging ───────────────────────────────────────────────────────────────────
|
|
30
|
+
logging.basicConfig(
|
|
31
|
+
level=logging.INFO,
|
|
32
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
33
|
+
)
|
|
34
|
+
logger = logging.getLogger("sft-trainer")
|
|
35
|
+
|
|
36
|
+
# ── Portable Path Resolution ─────────────────────────────────────────────────
|
|
37
|
+
# Fallback chain: generic env var -> SageMaker env var -> default path
|
|
38
|
+
# This allows the same script to run on SageMaker, HyperPod, or locally.
|
|
39
|
+
|
|
40
|
+
DATA_DIR = (
|
|
41
|
+
os.environ.get("DATA_DIR")
|
|
42
|
+
or os.environ.get("SM_CHANNEL_TRAINING")
|
|
43
|
+
or "/opt/ml/input/data/training"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
OUTPUT_DIR = (
|
|
47
|
+
os.environ.get("OUTPUT_DIR")
|
|
48
|
+
or os.environ.get("SM_MODEL_DIR")
|
|
49
|
+
or "/opt/ml/model"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
CHECKPOINT_DIR = (
|
|
53
|
+
os.environ.get("CHECKPOINT_DIR")
|
|
54
|
+
or os.environ.get("SM_CHECKPOINT_DIR")
|
|
55
|
+
or "/opt/ml/checkpoints"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
MODEL_ID = (
|
|
59
|
+
os.environ.get("HF_MODEL_ID")
|
|
60
|
+
or os.environ.get("SM_HP_MODEL_ID")
|
|
61
|
+
or ""
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ── Hyperparameter Loading ────────────────────────────────────────────────────
|
|
66
|
+
|
|
67
|
+
def load_hyperparameters():
|
|
68
|
+
"""Load hyperparameters from SageMaker SM_HPS env var or individual SM_HP_* vars.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
dict with typed hyperparameter values.
|
|
72
|
+
"""
|
|
73
|
+
defaults = {
|
|
74
|
+
"model_id": MODEL_ID,
|
|
75
|
+
"lora_r": 16,
|
|
76
|
+
"lora_alpha": 32,
|
|
77
|
+
"lora_dropout": 0.05,
|
|
78
|
+
"learning_rate": 2e-4,
|
|
79
|
+
"epochs": 3,
|
|
80
|
+
"batch_size": 4,
|
|
81
|
+
"max_seq_length": 2048,
|
|
82
|
+
"gradient_accumulation_steps": 4,
|
|
83
|
+
"warmup_ratio": 0.03,
|
|
84
|
+
"dataset_text_field": "text",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
# Try SM_HPS (JSON blob of all hyperparameters)
|
|
88
|
+
sm_hps = os.environ.get("SM_HPS")
|
|
89
|
+
if sm_hps:
|
|
90
|
+
try:
|
|
91
|
+
raw = json.loads(sm_hps)
|
|
92
|
+
# SageMaker passes all values as strings — cast them
|
|
93
|
+
for key, default_val in defaults.items():
|
|
94
|
+
if key in raw:
|
|
95
|
+
defaults[key] = _cast(raw[key], type(default_val))
|
|
96
|
+
return defaults
|
|
97
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
98
|
+
logger.warning("Failed to parse SM_HPS: %s", e)
|
|
99
|
+
|
|
100
|
+
# Fallback: individual SM_HP_* env vars
|
|
101
|
+
for key, default_val in defaults.items():
|
|
102
|
+
env_key = f"SM_HP_{key.upper()}"
|
|
103
|
+
env_val = os.environ.get(env_key)
|
|
104
|
+
if env_val is not None:
|
|
105
|
+
defaults[key] = _cast(env_val, type(default_val))
|
|
106
|
+
|
|
107
|
+
return defaults
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _cast(value, target_type):
|
|
111
|
+
"""Cast a string value to the target type."""
|
|
112
|
+
if target_type == bool:
|
|
113
|
+
return str(value).lower() in ("true", "1", "yes")
|
|
114
|
+
if target_type == int:
|
|
115
|
+
return int(float(value))
|
|
116
|
+
if target_type == float:
|
|
117
|
+
return float(value)
|
|
118
|
+
return str(value)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# ── Dataset Loading ───────────────────────────────────────────────────────────
|
|
122
|
+
|
|
123
|
+
def load_dataset(data_dir, text_field):
|
|
124
|
+
"""Load training dataset from data directory.
|
|
125
|
+
|
|
126
|
+
Supports .jsonl, .parquet, and .csv files.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
data_dir: Path to directory containing training data files.
|
|
130
|
+
text_field: Name of the text column in the dataset.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
A Hugging Face Dataset object.
|
|
134
|
+
"""
|
|
135
|
+
from datasets import load_dataset as hf_load_dataset
|
|
136
|
+
|
|
137
|
+
# Find data files
|
|
138
|
+
extensions = ["jsonl", "json", "parquet", "csv"]
|
|
139
|
+
data_files = []
|
|
140
|
+
for ext in extensions:
|
|
141
|
+
data_files.extend(glob.glob(os.path.join(data_dir, f"*.{ext}")))
|
|
142
|
+
data_files.extend(glob.glob(os.path.join(data_dir, f"**/*.{ext}"), recursive=True))
|
|
143
|
+
|
|
144
|
+
if not data_files:
|
|
145
|
+
logger.error("No data files found in %s (searched: %s)", data_dir, extensions)
|
|
146
|
+
sys.exit(1)
|
|
147
|
+
|
|
148
|
+
# Deduplicate and sort
|
|
149
|
+
data_files = sorted(set(data_files))
|
|
150
|
+
logger.info("Found %d data file(s) in %s", len(data_files), data_dir)
|
|
151
|
+
|
|
152
|
+
# Determine format from first file extension
|
|
153
|
+
first_ext = data_files[0].rsplit(".", 1)[-1].lower()
|
|
154
|
+
format_map = {"jsonl": "json", "json": "json", "parquet": "parquet", "csv": "csv"}
|
|
155
|
+
file_format = format_map.get(first_ext, "json")
|
|
156
|
+
|
|
157
|
+
dataset = hf_load_dataset(file_format, data_files=data_files, split="train")
|
|
158
|
+
logger.info("Loaded dataset: %d rows, columns: %s", len(dataset), dataset.column_names)
|
|
159
|
+
|
|
160
|
+
# Verify text field exists
|
|
161
|
+
if text_field not in dataset.column_names:
|
|
162
|
+
logger.error(
|
|
163
|
+
"Text field '%s' not found in dataset. Available columns: %s",
|
|
164
|
+
text_field, dataset.column_names,
|
|
165
|
+
)
|
|
166
|
+
sys.exit(1)
|
|
167
|
+
|
|
168
|
+
return dataset
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ── Main Training Function ────────────────────────────────────────────────────
|
|
172
|
+
|
|
173
|
+
def main():
|
|
174
|
+
"""Run SFT training with TRL SFTTrainer + PEFT LoRA."""
|
|
175
|
+
from accelerate import Accelerator
|
|
176
|
+
from peft import LoraConfig, TaskType
|
|
177
|
+
from transformers import (
|
|
178
|
+
AutoModelForCausalLM,
|
|
179
|
+
AutoTokenizer,
|
|
180
|
+
TrainingArguments,
|
|
181
|
+
)
|
|
182
|
+
from trl import SFTTrainer
|
|
183
|
+
|
|
184
|
+
# Initialize accelerator (handles distributed setup)
|
|
185
|
+
accelerator = Accelerator()
|
|
186
|
+
|
|
187
|
+
# Load hyperparameters
|
|
188
|
+
hparams = load_hyperparameters()
|
|
189
|
+
model_id = hparams["model_id"]
|
|
190
|
+
|
|
191
|
+
if not model_id:
|
|
192
|
+
logger.error("No model ID specified. Set HF_MODEL_ID env var or model_id hyperparameter.")
|
|
193
|
+
sys.exit(1)
|
|
194
|
+
|
|
195
|
+
if accelerator.is_main_process:
|
|
196
|
+
logger.info("=" * 60)
|
|
197
|
+
logger.info("SFT Training Configuration")
|
|
198
|
+
logger.info("=" * 60)
|
|
199
|
+
logger.info(" Model: %s", model_id)
|
|
200
|
+
logger.info(" Data dir: %s", DATA_DIR)
|
|
201
|
+
logger.info(" Output dir: %s", OUTPUT_DIR)
|
|
202
|
+
logger.info(" Checkpoint dir: %s", CHECKPOINT_DIR)
|
|
203
|
+
logger.info(" LoRA r: %d", hparams["lora_r"])
|
|
204
|
+
logger.info(" LoRA alpha: %d", hparams["lora_alpha"])
|
|
205
|
+
logger.info(" Learning rate: %s", hparams["learning_rate"])
|
|
206
|
+
logger.info(" Epochs: %d", hparams["epochs"])
|
|
207
|
+
logger.info(" Batch size: %d", hparams["batch_size"])
|
|
208
|
+
logger.info(" Max seq len: %d", hparams["max_seq_length"])
|
|
209
|
+
logger.info(" Text field: %s", hparams["dataset_text_field"])
|
|
210
|
+
logger.info("=" * 60)
|
|
211
|
+
|
|
212
|
+
# ── Load tokenizer and model ──────────────────────────────────────────────
|
|
213
|
+
if accelerator.is_main_process:
|
|
214
|
+
logger.info("Loading tokenizer and model: %s", model_id)
|
|
215
|
+
|
|
216
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
217
|
+
if tokenizer.pad_token is None:
|
|
218
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
219
|
+
|
|
220
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
221
|
+
model_id,
|
|
222
|
+
torch_dtype="auto",
|
|
223
|
+
trust_remote_code=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# ── Configure LoRA ────────────────────────────────────────────────────────
|
|
227
|
+
lora_config = LoraConfig(
|
|
228
|
+
r=hparams["lora_r"],
|
|
229
|
+
lora_alpha=hparams["lora_alpha"],
|
|
230
|
+
lora_dropout=hparams["lora_dropout"],
|
|
231
|
+
target_modules="all-linear",
|
|
232
|
+
task_type=TaskType.CAUSAL_LM,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# ── Load dataset ──────────────────────────────────────────────────────────
|
|
236
|
+
dataset = load_dataset(DATA_DIR, hparams["dataset_text_field"])
|
|
237
|
+
|
|
238
|
+
# ── Training arguments ────────────────────────────────────────────────────
|
|
239
|
+
training_args = TrainingArguments(
|
|
240
|
+
output_dir=os.path.join(CHECKPOINT_DIR, "trainer-state"),
|
|
241
|
+
num_train_epochs=hparams["epochs"],
|
|
242
|
+
per_device_train_batch_size=hparams["batch_size"],
|
|
243
|
+
gradient_accumulation_steps=hparams["gradient_accumulation_steps"],
|
|
244
|
+
learning_rate=hparams["learning_rate"],
|
|
245
|
+
warmup_ratio=hparams["warmup_ratio"],
|
|
246
|
+
bf16=True,
|
|
247
|
+
logging_steps=10,
|
|
248
|
+
save_strategy="epoch",
|
|
249
|
+
save_total_limit=2,
|
|
250
|
+
report_to="none",
|
|
251
|
+
remove_unused_columns=False,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# ── Check for existing checkpoint (spot resume) ───────────────────────────
|
|
255
|
+
resume_from_checkpoint = None
|
|
256
|
+
trainer_state_dir = os.path.join(CHECKPOINT_DIR, "trainer-state")
|
|
257
|
+
if os.path.isdir(trainer_state_dir):
|
|
258
|
+
checkpoints = sorted(
|
|
259
|
+
glob.glob(os.path.join(trainer_state_dir, "checkpoint-*")),
|
|
260
|
+
key=lambda x: int(x.rsplit("-", 1)[-1]) if x.rsplit("-", 1)[-1].isdigit() else 0,
|
|
261
|
+
)
|
|
262
|
+
if checkpoints:
|
|
263
|
+
resume_from_checkpoint = checkpoints[-1]
|
|
264
|
+
if accelerator.is_main_process:
|
|
265
|
+
logger.info("Resuming from checkpoint: %s", resume_from_checkpoint)
|
|
266
|
+
|
|
267
|
+
# ── Initialize SFTTrainer ─────────────────────────────────────────────────
|
|
268
|
+
trainer = SFTTrainer(
|
|
269
|
+
model=model,
|
|
270
|
+
tokenizer=tokenizer,
|
|
271
|
+
train_dataset=dataset,
|
|
272
|
+
peft_config=lora_config,
|
|
273
|
+
args=training_args,
|
|
274
|
+
max_seq_length=hparams["max_seq_length"],
|
|
275
|
+
dataset_text_field=hparams["dataset_text_field"],
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# ── Train ─────────────────────────────────────────────────────────────────
|
|
279
|
+
if accelerator.is_main_process:
|
|
280
|
+
logger.info("Starting SFT training...")
|
|
281
|
+
|
|
282
|
+
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
283
|
+
|
|
284
|
+
# ── Save adapter (rank 0 only) ───────────────────────────────────────────
|
|
285
|
+
if accelerator.is_main_process:
|
|
286
|
+
logger.info("Saving LoRA adapter to: %s", OUTPUT_DIR)
|
|
287
|
+
trainer.save_model(OUTPUT_DIR)
|
|
288
|
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
289
|
+
|
|
290
|
+
# ── Log final metrics ─────────────────────────────────────────────────
|
|
291
|
+
# Format: metric_name: value (SageMaker captures via regex in config.yaml)
|
|
292
|
+
metrics = train_result.metrics
|
|
293
|
+
print(f"train_loss: {metrics.get('train_loss', 0.0):.4f}")
|
|
294
|
+
print(f"train_runtime: {metrics.get('train_runtime', 0.0):.1f}")
|
|
295
|
+
print(f"train_samples_per_second: {metrics.get('train_samples_per_second', 0.0):.2f}")
|
|
296
|
+
print(f"epochs: {hparams['epochs']}")
|
|
297
|
+
|
|
298
|
+
logger.info("Training complete!")
|
|
299
|
+
logger.info(" Loss: %.4f", metrics.get("train_loss", 0.0))
|
|
300
|
+
logger.info(" Runtime: %.1fs", metrics.get("train_runtime", 0.0))
|
|
301
|
+
logger.info(" Samples/s: %.2f", metrics.get("train_samples_per_second", 0.0))
|
|
302
|
+
|
|
303
|
+
# Wait for all processes before exit
|
|
304
|
+
accelerator.wait_for_everyone()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# ── Entry Point ───────────────────────────────────────────────────────────────
|
|
308
|
+
|
|
309
|
+
if __name__ == "__main__":
|
|
310
|
+
main()
|
package/templates/do/tune
CHANGED
|
@@ -796,9 +796,16 @@ else:
|
|
|
796
796
|
_validate_dataset() {
|
|
797
797
|
local dataset="${ARG_DATASET}"
|
|
798
798
|
|
|
799
|
-
# ── Parse @v<
|
|
799
|
+
# ── Parse @v<version> suffix (AC-2.1, AC-2.3) ───────────────────────────
|
|
800
800
|
# Syntax: dataset-name@v2 → name="dataset-name", version ordinal=2
|
|
801
|
-
|
|
801
|
+
# dataset-name@v1.0.0 → name="dataset-name", version="1.0.0" (semver)
|
|
802
|
+
if [[ "${dataset}" =~ ^(.+)@v([0-9]+\.[0-9]+\.[0-9]+)$ ]]; then
|
|
803
|
+
# Semver syntax: @v1.0.0
|
|
804
|
+
ARG_DATASET_NAME="${BASH_REMATCH[1]}"
|
|
805
|
+
ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
|
|
806
|
+
dataset="" # Clear so name-based resolution takes over
|
|
807
|
+
elif [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
|
|
808
|
+
# Ordinal syntax: @v1, @v2
|
|
802
809
|
ARG_DATASET_NAME="${BASH_REMATCH[1]}"
|
|
803
810
|
ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
|
|
804
811
|
dataset="" # Clear so name-based resolution takes over
|
|
@@ -824,6 +831,8 @@ _validate_dataset() {
|
|
|
824
831
|
echo " Resolved to: ${resolved_uri}"
|
|
825
832
|
dataset="${resolved_uri}"
|
|
826
833
|
ARG_DATASET="${resolved_uri}"
|
|
834
|
+
RESOLVED_DATASET_S3_URI="${resolved_uri}"
|
|
835
|
+
return 0
|
|
827
836
|
else
|
|
828
837
|
echo "❌ Dataset '${ARG_DATASET_NAME}' not found in registry"
|
|
829
838
|
echo " Register it first: ./do/register --dataset --dataset-name ${ARG_DATASET_NAME} --dataset-s3-uri s3://..."
|
|
@@ -1,172 +0,0 @@
|
|
|
1
|
-
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
-
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
|
|
4
|
-
/**
|
|
5
|
-
* Auto-Prompt Builder — generates targeted prompts for missing required parameters.
|
|
6
|
-
*
|
|
7
|
-
* Used by --auto-prompt mode to ask only for values that cannot be inferred
|
|
8
|
-
* or defaulted from the provided CLI flags.
|
|
9
|
-
*/
|
|
10
|
-
|
|
11
|
-
/**
|
|
12
|
-
* Builds a minimal set of prompts for the given missing parameters.
|
|
13
|
-
* Each prompt is self-contained and doesn't depend on multi-phase wizard state.
|
|
14
|
-
*
|
|
15
|
-
* @param {string[]} missingParams - Parameter names that need values
|
|
16
|
-
* @param {object} currentConfig - Current configuration (with defaults filled)
|
|
17
|
-
* @returns {Array} Array of prompt objects compatible with runPrompts()
|
|
18
|
-
*/
|
|
19
|
-
export function buildAutoPrompts(missingParams, currentConfig) {
|
|
20
|
-
const prompts = [];
|
|
21
|
-
|
|
22
|
-
for (const param of missingParams) {
|
|
23
|
-
const builder = PROMPT_BUILDERS[param];
|
|
24
|
-
if (builder) {
|
|
25
|
-
const prompt = builder(currentConfig);
|
|
26
|
-
if (prompt) {
|
|
27
|
-
prompts.push(prompt);
|
|
28
|
-
}
|
|
29
|
-
} else {
|
|
30
|
-
// Fallback: generic text input for unknown parameters
|
|
31
|
-
prompts.push({
|
|
32
|
-
type: 'input',
|
|
33
|
-
name: param,
|
|
34
|
-
message: `Enter value for ${param}:`
|
|
35
|
-
});
|
|
36
|
-
}
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
return prompts;
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
/**
|
|
43
|
-
* Map of parameter names to prompt builder functions.
|
|
44
|
-
* Each builder receives the current config and returns a prompt object.
|
|
45
|
-
*/
|
|
46
|
-
const PROMPT_BUILDERS = {
|
|
47
|
-
deploymentConfig: (_config) => ({
|
|
48
|
-
type: 'list',
|
|
49
|
-
name: 'deploymentConfig',
|
|
50
|
-
message: 'Select deployment configuration:',
|
|
51
|
-
choices: [
|
|
52
|
-
{ type: 'separator', separator: '── Large Language Models ──' },
|
|
53
|
-
{ name: 'Transformers with vLLM', value: 'transformers-vllm' },
|
|
54
|
-
{ name: 'Transformers with SGLang', value: 'transformers-sglang' },
|
|
55
|
-
{ name: 'Transformers with TensorRT-LLM', value: 'transformers-tensorrt-llm' },
|
|
56
|
-
{ name: 'Transformers with LMI', value: 'transformers-lmi' },
|
|
57
|
-
{ name: 'Transformers with DJL', value: 'transformers-djl' },
|
|
58
|
-
{ type: 'separator', separator: '── HTTP Serving ──' },
|
|
59
|
-
{ name: 'HTTP with Flask', value: 'http-flask' },
|
|
60
|
-
{ name: 'HTTP with FastAPI', value: 'http-fastapi' },
|
|
61
|
-
{ type: 'separator', separator: '── NVIDIA Triton ──' },
|
|
62
|
-
{ name: 'Triton FIL (XGBoost, LightGBM)', value: 'triton-fil' },
|
|
63
|
-
{ name: 'Triton ONNX Runtime', value: 'triton-onnxruntime' },
|
|
64
|
-
{ name: 'Triton TensorFlow', value: 'triton-tensorflow' },
|
|
65
|
-
{ name: 'Triton PyTorch', value: 'triton-pytorch' },
|
|
66
|
-
{ name: 'Triton vLLM', value: 'triton-vllm' },
|
|
67
|
-
{ name: 'Triton TensorRT-LLM', value: 'triton-tensorrtllm' },
|
|
68
|
-
{ name: 'Triton Python Backend', value: 'triton-python' },
|
|
69
|
-
{ type: 'separator', separator: '── Diffusion Models ──' },
|
|
70
|
-
{ name: 'Diffusors with vLLM Omni', value: 'diffusors-vllm-omni' }
|
|
71
|
-
]
|
|
72
|
-
}),
|
|
73
|
-
|
|
74
|
-
instanceType: (config) => {
|
|
75
|
-
const architecture = config.architecture || 'http';
|
|
76
|
-
const isGpu = architecture === 'transformers' || architecture === 'triton' || architecture === 'diffusors';
|
|
77
|
-
|
|
78
|
-
const gpuChoices = [
|
|
79
|
-
{ name: 'ml.g5.xlarge (1× A10G 24GB — small LLMs)', value: 'ml.g5.xlarge' },
|
|
80
|
-
{ name: 'ml.g5.2xlarge (1× A10G 24GB — medium LLMs)', value: 'ml.g5.2xlarge' },
|
|
81
|
-
{ name: 'ml.g5.4xlarge (1× A10G 24GB — larger models)', value: 'ml.g5.4xlarge' },
|
|
82
|
-
{ name: 'ml.g5.12xlarge (4× A10G 96GB — large LLMs)', value: 'ml.g5.12xlarge' },
|
|
83
|
-
{ name: 'ml.g5.48xlarge (8× A10G 192GB — very large)', value: 'ml.g5.48xlarge' },
|
|
84
|
-
{ name: 'ml.g6.xlarge (1× L4 24GB)', value: 'ml.g6.xlarge' },
|
|
85
|
-
{ name: 'ml.g6.2xlarge (1× L4 24GB)', value: 'ml.g6.2xlarge' },
|
|
86
|
-
{ name: 'ml.p4d.24xlarge (8× A100 320GB)', value: 'ml.p4d.24xlarge' },
|
|
87
|
-
{ name: 'ml.p5.48xlarge (8× H100 640GB)', value: 'ml.p5.48xlarge' },
|
|
88
|
-
{ name: 'Custom (enter manually)', value: '_custom' }
|
|
89
|
-
];
|
|
90
|
-
|
|
91
|
-
const cpuChoices = [
|
|
92
|
-
{ name: 'ml.m5.large (2 vCPU, 8GB — lightweight)', value: 'ml.m5.large' },
|
|
93
|
-
{ name: 'ml.m5.xlarge (4 vCPU, 16GB — small models)', value: 'ml.m5.xlarge' },
|
|
94
|
-
{ name: 'ml.m5.2xlarge (8 vCPU, 32GB — medium models)', value: 'ml.m5.2xlarge' },
|
|
95
|
-
{ name: 'ml.m5.4xlarge (16 vCPU, 64GB — large models)', value: 'ml.m5.4xlarge' },
|
|
96
|
-
{ name: 'ml.c5.xlarge (4 vCPU, 8GB — compute-heavy)', value: 'ml.c5.xlarge' },
|
|
97
|
-
{ name: 'ml.c5.2xlarge (8 vCPU, 16GB — compute-heavy)', value: 'ml.c5.2xlarge' },
|
|
98
|
-
{ name: 'Custom (enter manually)', value: '_custom' }
|
|
99
|
-
];
|
|
100
|
-
|
|
101
|
-
return {
|
|
102
|
-
type: 'list',
|
|
103
|
-
name: 'instanceType',
|
|
104
|
-
message: `Select instance type${isGpu ? ' (GPU recommended for this architecture)' : ''}:`,
|
|
105
|
-
choices: isGpu ? gpuChoices : cpuChoices
|
|
106
|
-
};
|
|
107
|
-
},
|
|
108
|
-
|
|
109
|
-
deploymentTarget: (_config) => ({
|
|
110
|
-
type: 'list',
|
|
111
|
-
name: 'deploymentTarget',
|
|
112
|
-
message: 'Select deployment target:',
|
|
113
|
-
choices: [
|
|
114
|
-
{ name: 'Real-Time Inference', value: 'realtime-inference' },
|
|
115
|
-
{ name: 'Async Inference', value: 'async-inference' },
|
|
116
|
-
{ name: 'Batch Transform', value: 'batch-transform' },
|
|
117
|
-
{ name: 'HyperPod EKS', value: 'hyperpod-eks' }
|
|
118
|
-
]
|
|
119
|
-
}),
|
|
120
|
-
|
|
121
|
-
modelFormat: (config) => {
|
|
122
|
-
const engine = config.engine || 'sklearn';
|
|
123
|
-
const formatMap = {
|
|
124
|
-
sklearn: [
|
|
125
|
-
{ name: 'pkl (pickle)', value: 'pkl' },
|
|
126
|
-
{ name: 'joblib', value: 'joblib' }
|
|
127
|
-
],
|
|
128
|
-
xgboost: [
|
|
129
|
-
{ name: 'json', value: 'json' },
|
|
130
|
-
{ name: 'model (binary)', value: 'model' },
|
|
131
|
-
{ name: 'ubj (universal binary JSON)', value: 'ubj' }
|
|
132
|
-
],
|
|
133
|
-
tensorflow: [
|
|
134
|
-
{ name: 'keras', value: 'keras' },
|
|
135
|
-
{ name: 'h5', value: 'h5' },
|
|
136
|
-
{ name: 'SavedModel', value: 'SavedModel' }
|
|
137
|
-
]
|
|
138
|
-
};
|
|
139
|
-
|
|
140
|
-
const choices = formatMap[engine] || formatMap.sklearn;
|
|
141
|
-
|
|
142
|
-
return {
|
|
143
|
-
type: 'list',
|
|
144
|
-
name: 'modelFormat',
|
|
145
|
-
message: `Select model format for ${engine}:`,
|
|
146
|
-
choices
|
|
147
|
-
};
|
|
148
|
-
},
|
|
149
|
-
|
|
150
|
-
awsRegion: (_config) => ({
|
|
151
|
-
type: 'list',
|
|
152
|
-
name: 'awsRegion',
|
|
153
|
-
message: 'Select AWS region:',
|
|
154
|
-
choices: [
|
|
155
|
-
{ name: 'us-east-1 (N. Virginia)', value: 'us-east-1' },
|
|
156
|
-
{ name: 'us-west-2 (Oregon)', value: 'us-west-2' },
|
|
157
|
-
{ name: 'eu-west-1 (Ireland)', value: 'eu-west-1' },
|
|
158
|
-
{ name: 'ap-northeast-1 (Tokyo)', value: 'ap-northeast-1' },
|
|
159
|
-
{ name: 'ap-southeast-1 (Singapore)', value: 'ap-southeast-1' },
|
|
160
|
-
{ name: 'Custom (enter manually)', value: '_custom' }
|
|
161
|
-
]
|
|
162
|
-
}),
|
|
163
|
-
|
|
164
|
-
buildTarget: (_config) => ({
|
|
165
|
-
type: 'list',
|
|
166
|
-
name: 'buildTarget',
|
|
167
|
-
message: 'Select build target:',
|
|
168
|
-
choices: [
|
|
169
|
-
{ name: 'CodeBuild (recommended)', value: 'codebuild' }
|
|
170
|
-
]
|
|
171
|
-
})
|
|
172
|
-
};
|