@aws/ml-container-creator 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -0,0 +1,310 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Unmanaged SFT fine-tuning using TRL SFTTrainer + PEFT LoRA.
6
+
7
+ This script performs LoRA-based supervised fine-tuning on a causal language
8
+ model. It is invoked via `accelerate launch` and works on both single-GPU
9
+ and multi-GPU instances without code changes.
10
+
11
+ Portable env-var contract (works on SageMaker AI and HyperPod EKS):
12
+ DATA_DIR / SM_CHANNEL_TRAINING -> training data path
13
+ OUTPUT_DIR / SM_MODEL_DIR -> model artifact output
14
+ CHECKPOINT_DIR / SM_CHECKPOINT_DIR -> checkpoint path for spot resume
15
+ HF_MODEL_ID / SM_HP_MODEL_ID -> base model HuggingFace ID
16
+ SM_HPS -> JSON blob of all hyperparameters
17
+
18
+ Output:
19
+ LoRA adapter saved to OUTPUT_DIR (adapter_model.safetensors + adapter_config.json)
20
+ Metrics logged to stdout in SageMaker-parseable format
21
+ """
22
+
23
+ import glob
24
+ import json
25
+ import logging
26
+ import os
27
+ import sys
28
+
29
+ # ── Logging ───────────────────────────────────────────────────────────────────
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
33
+ )
34
+ logger = logging.getLogger("sft-trainer")
35
+
36
+ # ── Portable Path Resolution ─────────────────────────────────────────────────
37
+ # Fallback chain: generic env var -> SageMaker env var -> default path
38
+ # This allows the same script to run on SageMaker, HyperPod, or locally.
39
+
40
+ DATA_DIR = (
41
+ os.environ.get("DATA_DIR")
42
+ or os.environ.get("SM_CHANNEL_TRAINING")
43
+ or "/opt/ml/input/data/training"
44
+ )
45
+
46
+ OUTPUT_DIR = (
47
+ os.environ.get("OUTPUT_DIR")
48
+ or os.environ.get("SM_MODEL_DIR")
49
+ or "/opt/ml/model"
50
+ )
51
+
52
+ CHECKPOINT_DIR = (
53
+ os.environ.get("CHECKPOINT_DIR")
54
+ or os.environ.get("SM_CHECKPOINT_DIR")
55
+ or "/opt/ml/checkpoints"
56
+ )
57
+
58
+ MODEL_ID = (
59
+ os.environ.get("HF_MODEL_ID")
60
+ or os.environ.get("SM_HP_MODEL_ID")
61
+ or ""
62
+ )
63
+
64
+
65
+ # ── Hyperparameter Loading ────────────────────────────────────────────────────
66
+
67
+ def load_hyperparameters():
68
+ """Load hyperparameters from SageMaker SM_HPS env var or individual SM_HP_* vars.
69
+
70
+ Returns:
71
+ dict with typed hyperparameter values.
72
+ """
73
+ defaults = {
74
+ "model_id": MODEL_ID,
75
+ "lora_r": 16,
76
+ "lora_alpha": 32,
77
+ "lora_dropout": 0.05,
78
+ "learning_rate": 2e-4,
79
+ "epochs": 3,
80
+ "batch_size": 4,
81
+ "max_seq_length": 2048,
82
+ "gradient_accumulation_steps": 4,
83
+ "warmup_ratio": 0.03,
84
+ "dataset_text_field": "text",
85
+ }
86
+
87
+ # Try SM_HPS (JSON blob of all hyperparameters)
88
+ sm_hps = os.environ.get("SM_HPS")
89
+ if sm_hps:
90
+ try:
91
+ raw = json.loads(sm_hps)
92
+ # SageMaker passes all values as strings — cast them
93
+ for key, default_val in defaults.items():
94
+ if key in raw:
95
+ defaults[key] = _cast(raw[key], type(default_val))
96
+ return defaults
97
+ except (json.JSONDecodeError, ValueError) as e:
98
+ logger.warning("Failed to parse SM_HPS: %s", e)
99
+
100
+ # Fallback: individual SM_HP_* env vars
101
+ for key, default_val in defaults.items():
102
+ env_key = f"SM_HP_{key.upper()}"
103
+ env_val = os.environ.get(env_key)
104
+ if env_val is not None:
105
+ defaults[key] = _cast(env_val, type(default_val))
106
+
107
+ return defaults
108
+
109
+
110
+ def _cast(value, target_type):
111
+ """Cast a string value to the target type."""
112
+ if target_type == bool:
113
+ return str(value).lower() in ("true", "1", "yes")
114
+ if target_type == int:
115
+ return int(float(value))
116
+ if target_type == float:
117
+ return float(value)
118
+ return str(value)
119
+
120
+
121
+ # ── Dataset Loading ───────────────────────────────────────────────────────────
122
+
123
+ def load_dataset(data_dir, text_field):
124
+ """Load training dataset from data directory.
125
+
126
+ Supports .jsonl, .parquet, and .csv files.
127
+
128
+ Args:
129
+ data_dir: Path to directory containing training data files.
130
+ text_field: Name of the text column in the dataset.
131
+
132
+ Returns:
133
+ A Hugging Face Dataset object.
134
+ """
135
+ from datasets import load_dataset as hf_load_dataset
136
+
137
+ # Find data files
138
+ extensions = ["jsonl", "json", "parquet", "csv"]
139
+ data_files = []
140
+ for ext in extensions:
141
+ data_files.extend(glob.glob(os.path.join(data_dir, f"*.{ext}")))
142
+ data_files.extend(glob.glob(os.path.join(data_dir, f"**/*.{ext}"), recursive=True))
143
+
144
+ if not data_files:
145
+ logger.error("No data files found in %s (searched: %s)", data_dir, extensions)
146
+ sys.exit(1)
147
+
148
+ # Deduplicate and sort
149
+ data_files = sorted(set(data_files))
150
+ logger.info("Found %d data file(s) in %s", len(data_files), data_dir)
151
+
152
+ # Determine format from first file extension
153
+ first_ext = data_files[0].rsplit(".", 1)[-1].lower()
154
+ format_map = {"jsonl": "json", "json": "json", "parquet": "parquet", "csv": "csv"}
155
+ file_format = format_map.get(first_ext, "json")
156
+
157
+ dataset = hf_load_dataset(file_format, data_files=data_files, split="train")
158
+ logger.info("Loaded dataset: %d rows, columns: %s", len(dataset), dataset.column_names)
159
+
160
+ # Verify text field exists
161
+ if text_field not in dataset.column_names:
162
+ logger.error(
163
+ "Text field '%s' not found in dataset. Available columns: %s",
164
+ text_field, dataset.column_names,
165
+ )
166
+ sys.exit(1)
167
+
168
+ return dataset
169
+
170
+
171
+ # ── Main Training Function ────────────────────────────────────────────────────
172
+
173
+ def main():
174
+ """Run SFT training with TRL SFTTrainer + PEFT LoRA."""
175
+ from accelerate import Accelerator
176
+ from peft import LoraConfig, TaskType
177
+ from transformers import (
178
+ AutoModelForCausalLM,
179
+ AutoTokenizer,
180
+ TrainingArguments,
181
+ )
182
+ from trl import SFTTrainer
183
+
184
+ # Initialize accelerator (handles distributed setup)
185
+ accelerator = Accelerator()
186
+
187
+ # Load hyperparameters
188
+ hparams = load_hyperparameters()
189
+ model_id = hparams["model_id"]
190
+
191
+ if not model_id:
192
+ logger.error("No model ID specified. Set HF_MODEL_ID env var or model_id hyperparameter.")
193
+ sys.exit(1)
194
+
195
+ if accelerator.is_main_process:
196
+ logger.info("=" * 60)
197
+ logger.info("SFT Training Configuration")
198
+ logger.info("=" * 60)
199
+ logger.info(" Model: %s", model_id)
200
+ logger.info(" Data dir: %s", DATA_DIR)
201
+ logger.info(" Output dir: %s", OUTPUT_DIR)
202
+ logger.info(" Checkpoint dir: %s", CHECKPOINT_DIR)
203
+ logger.info(" LoRA r: %d", hparams["lora_r"])
204
+ logger.info(" LoRA alpha: %d", hparams["lora_alpha"])
205
+ logger.info(" Learning rate: %s", hparams["learning_rate"])
206
+ logger.info(" Epochs: %d", hparams["epochs"])
207
+ logger.info(" Batch size: %d", hparams["batch_size"])
208
+ logger.info(" Max seq len: %d", hparams["max_seq_length"])
209
+ logger.info(" Text field: %s", hparams["dataset_text_field"])
210
+ logger.info("=" * 60)
211
+
212
+ # ── Load tokenizer and model ──────────────────────────────────────────────
213
+ if accelerator.is_main_process:
214
+ logger.info("Loading tokenizer and model: %s", model_id)
215
+
216
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
217
+ if tokenizer.pad_token is None:
218
+ tokenizer.pad_token = tokenizer.eos_token
219
+
220
+ model = AutoModelForCausalLM.from_pretrained(
221
+ model_id,
222
+ torch_dtype="auto",
223
+ trust_remote_code=True,
224
+ )
225
+
226
+ # ── Configure LoRA ────────────────────────────────────────────────────────
227
+ lora_config = LoraConfig(
228
+ r=hparams["lora_r"],
229
+ lora_alpha=hparams["lora_alpha"],
230
+ lora_dropout=hparams["lora_dropout"],
231
+ target_modules="all-linear",
232
+ task_type=TaskType.CAUSAL_LM,
233
+ )
234
+
235
+ # ── Load dataset ──────────────────────────────────────────────────────────
236
+ dataset = load_dataset(DATA_DIR, hparams["dataset_text_field"])
237
+
238
+ # ── Training arguments ────────────────────────────────────────────────────
239
+ training_args = TrainingArguments(
240
+ output_dir=os.path.join(CHECKPOINT_DIR, "trainer-state"),
241
+ num_train_epochs=hparams["epochs"],
242
+ per_device_train_batch_size=hparams["batch_size"],
243
+ gradient_accumulation_steps=hparams["gradient_accumulation_steps"],
244
+ learning_rate=hparams["learning_rate"],
245
+ warmup_ratio=hparams["warmup_ratio"],
246
+ bf16=True,
247
+ logging_steps=10,
248
+ save_strategy="epoch",
249
+ save_total_limit=2,
250
+ report_to="none",
251
+ remove_unused_columns=False,
252
+ )
253
+
254
+ # ── Check for existing checkpoint (spot resume) ───────────────────────────
255
+ resume_from_checkpoint = None
256
+ trainer_state_dir = os.path.join(CHECKPOINT_DIR, "trainer-state")
257
+ if os.path.isdir(trainer_state_dir):
258
+ checkpoints = sorted(
259
+ glob.glob(os.path.join(trainer_state_dir, "checkpoint-*")),
260
+ key=lambda x: int(x.rsplit("-", 1)[-1]) if x.rsplit("-", 1)[-1].isdigit() else 0,
261
+ )
262
+ if checkpoints:
263
+ resume_from_checkpoint = checkpoints[-1]
264
+ if accelerator.is_main_process:
265
+ logger.info("Resuming from checkpoint: %s", resume_from_checkpoint)
266
+
267
+ # ── Initialize SFTTrainer ─────────────────────────────────────────────────
268
+ trainer = SFTTrainer(
269
+ model=model,
270
+ tokenizer=tokenizer,
271
+ train_dataset=dataset,
272
+ peft_config=lora_config,
273
+ args=training_args,
274
+ max_seq_length=hparams["max_seq_length"],
275
+ dataset_text_field=hparams["dataset_text_field"],
276
+ )
277
+
278
+ # ── Train ─────────────────────────────────────────────────────────────────
279
+ if accelerator.is_main_process:
280
+ logger.info("Starting SFT training...")
281
+
282
+ train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
283
+
284
+ # ── Save adapter (rank 0 only) ───────────────────────────────────────────
285
+ if accelerator.is_main_process:
286
+ logger.info("Saving LoRA adapter to: %s", OUTPUT_DIR)
287
+ trainer.save_model(OUTPUT_DIR)
288
+ tokenizer.save_pretrained(OUTPUT_DIR)
289
+
290
+ # ── Log final metrics ─────────────────────────────────────────────────
291
+ # Format: metric_name: value (SageMaker captures via regex in config.yaml)
292
+ metrics = train_result.metrics
293
+ print(f"train_loss: {metrics.get('train_loss', 0.0):.4f}")
294
+ print(f"train_runtime: {metrics.get('train_runtime', 0.0):.1f}")
295
+ print(f"train_samples_per_second: {metrics.get('train_samples_per_second', 0.0):.2f}")
296
+ print(f"epochs: {hparams['epochs']}")
297
+
298
+ logger.info("Training complete!")
299
+ logger.info(" Loss: %.4f", metrics.get("train_loss", 0.0))
300
+ logger.info(" Runtime: %.1fs", metrics.get("train_runtime", 0.0))
301
+ logger.info(" Samples/s: %.2f", metrics.get("train_samples_per_second", 0.0))
302
+
303
+ # Wait for all processes before exit
304
+ accelerator.wait_for_everyone()
305
+
306
+
307
+ # ── Entry Point ───────────────────────────────────────────────────────────────
308
+
309
+ if __name__ == "__main__":
310
+ main()
package/templates/do/tune CHANGED
@@ -796,9 +796,16 @@ else:
796
796
  _validate_dataset() {
797
797
  local dataset="${ARG_DATASET}"
798
798
 
799
- # ── Parse @v<N> version suffix (AC-2.1, AC-2.3) ──────────────────────────
799
+ # ── Parse @v<version> suffix (AC-2.1, AC-2.3) ───────────────────────────
800
800
  # Syntax: dataset-name@v2 → name="dataset-name", version ordinal=2
801
- if [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
801
+ # dataset-name@v1.0.0 name="dataset-name", version="1.0.0" (semver)
802
+ if [[ "${dataset}" =~ ^(.+)@v([0-9]+\.[0-9]+\.[0-9]+)$ ]]; then
803
+ # Semver syntax: @v1.0.0
804
+ ARG_DATASET_NAME="${BASH_REMATCH[1]}"
805
+ ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
806
+ dataset="" # Clear so name-based resolution takes over
807
+ elif [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
808
+ # Ordinal syntax: @v1, @v2
802
809
  ARG_DATASET_NAME="${BASH_REMATCH[1]}"
803
810
  ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
804
811
  dataset="" # Clear so name-based resolution takes over
@@ -824,6 +831,8 @@ _validate_dataset() {
824
831
  echo " Resolved to: ${resolved_uri}"
825
832
  dataset="${resolved_uri}"
826
833
  ARG_DATASET="${resolved_uri}"
834
+ RESOLVED_DATASET_S3_URI="${resolved_uri}"
835
+ return 0
827
836
  else
828
837
  echo "❌ Dataset '${ARG_DATASET_NAME}' not found in registry"
829
838
  echo " Register it first: ./do/register --dataset --dataset-name ${ARG_DATASET_NAME} --dataset-s3-uri s3://..."
@@ -1,172 +0,0 @@
1
- // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
- // SPDX-License-Identifier: Apache-2.0
3
-
4
- /**
5
- * Auto-Prompt Builder — generates targeted prompts for missing required parameters.
6
- *
7
- * Used by --auto-prompt mode to ask only for values that cannot be inferred
8
- * or defaulted from the provided CLI flags.
9
- */
10
-
11
- /**
12
- * Builds a minimal set of prompts for the given missing parameters.
13
- * Each prompt is self-contained and doesn't depend on multi-phase wizard state.
14
- *
15
- * @param {string[]} missingParams - Parameter names that need values
16
- * @param {object} currentConfig - Current configuration (with defaults filled)
17
- * @returns {Array} Array of prompt objects compatible with runPrompts()
18
- */
19
- export function buildAutoPrompts(missingParams, currentConfig) {
20
- const prompts = [];
21
-
22
- for (const param of missingParams) {
23
- const builder = PROMPT_BUILDERS[param];
24
- if (builder) {
25
- const prompt = builder(currentConfig);
26
- if (prompt) {
27
- prompts.push(prompt);
28
- }
29
- } else {
30
- // Fallback: generic text input for unknown parameters
31
- prompts.push({
32
- type: 'input',
33
- name: param,
34
- message: `Enter value for ${param}:`
35
- });
36
- }
37
- }
38
-
39
- return prompts;
40
- }
41
-
42
- /**
43
- * Map of parameter names to prompt builder functions.
44
- * Each builder receives the current config and returns a prompt object.
45
- */
46
- const PROMPT_BUILDERS = {
47
- deploymentConfig: (_config) => ({
48
- type: 'list',
49
- name: 'deploymentConfig',
50
- message: 'Select deployment configuration:',
51
- choices: [
52
- { type: 'separator', separator: '── Large Language Models ──' },
53
- { name: 'Transformers with vLLM', value: 'transformers-vllm' },
54
- { name: 'Transformers with SGLang', value: 'transformers-sglang' },
55
- { name: 'Transformers with TensorRT-LLM', value: 'transformers-tensorrt-llm' },
56
- { name: 'Transformers with LMI', value: 'transformers-lmi' },
57
- { name: 'Transformers with DJL', value: 'transformers-djl' },
58
- { type: 'separator', separator: '── HTTP Serving ──' },
59
- { name: 'HTTP with Flask', value: 'http-flask' },
60
- { name: 'HTTP with FastAPI', value: 'http-fastapi' },
61
- { type: 'separator', separator: '── NVIDIA Triton ──' },
62
- { name: 'Triton FIL (XGBoost, LightGBM)', value: 'triton-fil' },
63
- { name: 'Triton ONNX Runtime', value: 'triton-onnxruntime' },
64
- { name: 'Triton TensorFlow', value: 'triton-tensorflow' },
65
- { name: 'Triton PyTorch', value: 'triton-pytorch' },
66
- { name: 'Triton vLLM', value: 'triton-vllm' },
67
- { name: 'Triton TensorRT-LLM', value: 'triton-tensorrtllm' },
68
- { name: 'Triton Python Backend', value: 'triton-python' },
69
- { type: 'separator', separator: '── Diffusion Models ──' },
70
- { name: 'Diffusors with vLLM Omni', value: 'diffusors-vllm-omni' }
71
- ]
72
- }),
73
-
74
- instanceType: (config) => {
75
- const architecture = config.architecture || 'http';
76
- const isGpu = architecture === 'transformers' || architecture === 'triton' || architecture === 'diffusors';
77
-
78
- const gpuChoices = [
79
- { name: 'ml.g5.xlarge (1× A10G 24GB — small LLMs)', value: 'ml.g5.xlarge' },
80
- { name: 'ml.g5.2xlarge (1× A10G 24GB — medium LLMs)', value: 'ml.g5.2xlarge' },
81
- { name: 'ml.g5.4xlarge (1× A10G 24GB — larger models)', value: 'ml.g5.4xlarge' },
82
- { name: 'ml.g5.12xlarge (4× A10G 96GB — large LLMs)', value: 'ml.g5.12xlarge' },
83
- { name: 'ml.g5.48xlarge (8× A10G 192GB — very large)', value: 'ml.g5.48xlarge' },
84
- { name: 'ml.g6.xlarge (1× L4 24GB)', value: 'ml.g6.xlarge' },
85
- { name: 'ml.g6.2xlarge (1× L4 24GB)', value: 'ml.g6.2xlarge' },
86
- { name: 'ml.p4d.24xlarge (8× A100 320GB)', value: 'ml.p4d.24xlarge' },
87
- { name: 'ml.p5.48xlarge (8× H100 640GB)', value: 'ml.p5.48xlarge' },
88
- { name: 'Custom (enter manually)', value: '_custom' }
89
- ];
90
-
91
- const cpuChoices = [
92
- { name: 'ml.m5.large (2 vCPU, 8GB — lightweight)', value: 'ml.m5.large' },
93
- { name: 'ml.m5.xlarge (4 vCPU, 16GB — small models)', value: 'ml.m5.xlarge' },
94
- { name: 'ml.m5.2xlarge (8 vCPU, 32GB — medium models)', value: 'ml.m5.2xlarge' },
95
- { name: 'ml.m5.4xlarge (16 vCPU, 64GB — large models)', value: 'ml.m5.4xlarge' },
96
- { name: 'ml.c5.xlarge (4 vCPU, 8GB — compute-heavy)', value: 'ml.c5.xlarge' },
97
- { name: 'ml.c5.2xlarge (8 vCPU, 16GB — compute-heavy)', value: 'ml.c5.2xlarge' },
98
- { name: 'Custom (enter manually)', value: '_custom' }
99
- ];
100
-
101
- return {
102
- type: 'list',
103
- name: 'instanceType',
104
- message: `Select instance type${isGpu ? ' (GPU recommended for this architecture)' : ''}:`,
105
- choices: isGpu ? gpuChoices : cpuChoices
106
- };
107
- },
108
-
109
- deploymentTarget: (_config) => ({
110
- type: 'list',
111
- name: 'deploymentTarget',
112
- message: 'Select deployment target:',
113
- choices: [
114
- { name: 'Real-Time Inference', value: 'realtime-inference' },
115
- { name: 'Async Inference', value: 'async-inference' },
116
- { name: 'Batch Transform', value: 'batch-transform' },
117
- { name: 'HyperPod EKS', value: 'hyperpod-eks' }
118
- ]
119
- }),
120
-
121
- modelFormat: (config) => {
122
- const engine = config.engine || 'sklearn';
123
- const formatMap = {
124
- sklearn: [
125
- { name: 'pkl (pickle)', value: 'pkl' },
126
- { name: 'joblib', value: 'joblib' }
127
- ],
128
- xgboost: [
129
- { name: 'json', value: 'json' },
130
- { name: 'model (binary)', value: 'model' },
131
- { name: 'ubj (universal binary JSON)', value: 'ubj' }
132
- ],
133
- tensorflow: [
134
- { name: 'keras', value: 'keras' },
135
- { name: 'h5', value: 'h5' },
136
- { name: 'SavedModel', value: 'SavedModel' }
137
- ]
138
- };
139
-
140
- const choices = formatMap[engine] || formatMap.sklearn;
141
-
142
- return {
143
- type: 'list',
144
- name: 'modelFormat',
145
- message: `Select model format for ${engine}:`,
146
- choices
147
- };
148
- },
149
-
150
- awsRegion: (_config) => ({
151
- type: 'list',
152
- name: 'awsRegion',
153
- message: 'Select AWS region:',
154
- choices: [
155
- { name: 'us-east-1 (N. Virginia)', value: 'us-east-1' },
156
- { name: 'us-west-2 (Oregon)', value: 'us-west-2' },
157
- { name: 'eu-west-1 (Ireland)', value: 'eu-west-1' },
158
- { name: 'ap-northeast-1 (Tokyo)', value: 'ap-northeast-1' },
159
- { name: 'ap-southeast-1 (Singapore)', value: 'ap-southeast-1' },
160
- { name: 'Custom (enter manually)', value: '_custom' }
161
- ]
162
- }),
163
-
164
- buildTarget: (_config) => ({
165
- type: 'list',
166
- name: 'buildTarget',
167
- message: 'Select build target:',
168
- choices: [
169
- { name: 'CodeBuild (recommended)', value: 'codebuild' }
170
- ]
171
- })
172
- };