@aws/ml-container-creator 0.13.4 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -5
- package/config/parameter-schema-v2.json +32 -4
- package/infra/ci-harness/lib/ci-harness-stack.ts +13 -5
- package/infra/ci-harness/package-lock.json +122 -116
- package/infra/ci-harness/package.json +1 -1
- package/package.json +5 -3
- package/pyproject.toml +21 -0
- package/requirements.txt +19 -0
- package/servers/instance-sizer/index.js +72 -4
- package/servers/instance-sizer/lib/model-resolver.js +28 -2
- package/src/app.js +17 -0
- package/src/lib/bootstrap-command-handler.js +33 -23
- package/src/lib/config-loader.js +18 -0
- package/src/lib/config-manager.js +6 -1
- package/src/lib/dataset-slug.js +152 -0
- package/src/lib/generated/cli-options.js +9 -3
- package/src/lib/generated/parameter-matrix.js +14 -3
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +6 -0
- package/src/lib/prompt-runner.js +5 -0
- package/src/lib/prompts/feature-prompts.js +1 -1
- package/src/lib/template-manager.js +0 -7
- package/src/lib/template-variable-resolver.js +51 -1
- package/src/lib/tune-config-state.js +14 -1
- package/templates/do/.adapter_helper.py +451 -0
- package/templates/do/.benchmark_writer.py +22 -0
- package/templates/do/.register_helper.py +1163 -0
- package/templates/do/.stage_helper.py +419 -0
- package/templates/do/.tune_helper.py +379 -65
- package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +427 -27
- package/templates/do/add-ic +85 -3
- package/templates/do/benchmark +173 -15
- package/templates/do/config +24 -0
- package/templates/do/lib/inference-component.sh +56 -3
- package/templates/do/lib/profile.sh +5 -0
- package/templates/do/register +552 -6
- package/templates/do/stage +91 -272
- package/templates/do/test +12 -2
- package/templates/do/tune +264 -12
package/src/lib/prompt-runner.js
CHANGED
|
@@ -710,6 +710,11 @@ export default class PromptRunner {
|
|
|
710
710
|
delete combinedAnswers.customHyperPodCluster;
|
|
711
711
|
}
|
|
712
712
|
|
|
713
|
+
// Propagate max_model_len from instance-sizer context capping (AC-1.7)
|
|
714
|
+
if (this._sizerMaxModelLen) {
|
|
715
|
+
combinedAnswers.sizerMaxModelLen = this._sizerMaxModelLen;
|
|
716
|
+
}
|
|
717
|
+
|
|
713
718
|
// Apply CUDA version selection → inference AMI override
|
|
714
719
|
if (combinedAnswers._resolvedInferenceAmiVersion) {
|
|
715
720
|
combinedAnswers.inferenceAmiVersion = combinedAnswers._resolvedInferenceAmiVersion;
|
|
@@ -90,7 +90,7 @@ const loraPrompts = [
|
|
|
90
90
|
type: 'confirm',
|
|
91
91
|
name: 'enableLora',
|
|
92
92
|
message: 'Enable LoRA adapter serving?',
|
|
93
|
-
default:
|
|
93
|
+
default: true,
|
|
94
94
|
when: (answers) => {
|
|
95
95
|
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0];
|
|
96
96
|
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-');
|
|
@@ -314,13 +314,6 @@ export default class TemplateManager {
|
|
|
314
314
|
_validateBenchmarkConfig() {
|
|
315
315
|
if (!this.answers.includeBenchmark) return;
|
|
316
316
|
|
|
317
|
-
// Gate to supported architectures
|
|
318
|
-
const dc = this.answers.deploymentConfig;
|
|
319
|
-
const arch = dc ? dc.split('-')[0] : this.answers.architecture;
|
|
320
|
-
if (arch !== 'transformers' && arch !== 'diffusors') {
|
|
321
|
-
throw new Error('⚠️ Benchmarking is only supported with transformers and diffusors architectures.');
|
|
322
|
-
}
|
|
323
|
-
|
|
324
317
|
// Gate to supported deployment targets
|
|
325
318
|
if (this.answers.deploymentTarget === 'hyperpod-eks') {
|
|
326
319
|
throw new Error('⚠️ Benchmarking is only supported with managed-inference, async-inference, and batch-transform deployment targets');
|
|
@@ -232,7 +232,7 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
|
|
|
232
232
|
artifactUri: '',
|
|
233
233
|
modelLoadStrategy: 'runtime',
|
|
234
234
|
existingEndpointName: null,
|
|
235
|
-
enableLora:
|
|
235
|
+
enableLora: true,
|
|
236
236
|
maxLoras: 30,
|
|
237
237
|
maxLoraRank: 64
|
|
238
238
|
};
|
|
@@ -261,6 +261,20 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
|
|
|
261
261
|
}
|
|
262
262
|
}
|
|
263
263
|
|
|
264
|
+
// Always include benchmarking by default (AC-2.3 — enabled for all architectures).
|
|
265
|
+
// Only set when not explicitly provided by user (AC-2.4, AC-2.7 — respect explicit opt-out).
|
|
266
|
+
if (answers.includeBenchmark === undefined) {
|
|
267
|
+
answers.includeBenchmark = true;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Enforce enableLora scoping: only LoRA-capable servers get enableLora=true
|
|
271
|
+
// (AC-2.1, NFR-2). All incompatible backends are forced to false.
|
|
272
|
+
const loraCapableServers = ['vllm', 'sglang', 'djl-lmi', 'lmi', 'djl'];
|
|
273
|
+
const resolvedBackend = answers.backend || answers.modelServer;
|
|
274
|
+
if (!loraCapableServers.includes(resolvedBackend)) {
|
|
275
|
+
answers.enableLora = false;
|
|
276
|
+
}
|
|
277
|
+
|
|
264
278
|
// Merge catalog env vars into answers.envVars with correct precedence
|
|
265
279
|
await _mergeEnvVarsWithPrecedence(answers, registryConfigManager);
|
|
266
280
|
|
|
@@ -445,6 +459,35 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
|
|
|
445
459
|
}
|
|
446
460
|
}
|
|
447
461
|
|
|
462
|
+
// Propagate max_model_len from instance-sizer context capping to env vars (AC-1.7).
|
|
463
|
+
// The instance-sizer sets sizerMaxModelLen when the model's full context doesn't fit
|
|
464
|
+
// on the recommended instance. Write as VLLM_MAX_MODEL_LEN or SGLANG_MAX_MODEL_LEN.
|
|
465
|
+
const _MAX_MODEL_LEN_ENGINE_MAP = {
|
|
466
|
+
'vllm': 'VLLM_MAX_MODEL_LEN',
|
|
467
|
+
'vllm-omni': 'VLLM_MAX_MODEL_LEN',
|
|
468
|
+
'sglang': 'SGLANG_MAX_MODEL_LEN'
|
|
469
|
+
};
|
|
470
|
+
|
|
471
|
+
if (answers.sizerMaxModelLen) {
|
|
472
|
+
const maxLenEngine = answers.backend || answers.modelServer;
|
|
473
|
+
const maxLenEnvKey = maxLenEngine ? _MAX_MODEL_LEN_ENGINE_MAP[maxLenEngine] : null;
|
|
474
|
+
if (maxLenEnvKey) {
|
|
475
|
+
// Only set if user hasn't explicitly provided this env var
|
|
476
|
+
const userServerEnvVars = answers.serverEnvVars || {};
|
|
477
|
+
const userExplicitlySetMaxLen = (
|
|
478
|
+
userServerEnvVars['MAX_MODEL_LEN'] !== undefined ||
|
|
479
|
+
userServerEnvVars[maxLenEnvKey] !== undefined
|
|
480
|
+
);
|
|
481
|
+
if (!userExplicitlySetMaxLen && (!answers.envVars || !answers.envVars[maxLenEnvKey])) {
|
|
482
|
+
if (!answers.envVars) {
|
|
483
|
+
answers.envVars = {};
|
|
484
|
+
}
|
|
485
|
+
answers.envVars[maxLenEnvKey] = String(answers.sizerMaxModelLen);
|
|
486
|
+
console.log(` ℹ️ max_model_len: ${answers.sizerMaxModelLen} (context capped by instance-sizer)`);
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
|
|
448
491
|
// Determine tune support based on model presence in the tune catalog.
|
|
449
492
|
// Used by the do/config template to write TUNE_SUPPORTED=true|false.
|
|
450
493
|
if (answers.tuneSupported === undefined) {
|
|
@@ -481,4 +524,11 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
|
|
|
481
524
|
answers.tuneModelId = null;
|
|
482
525
|
}
|
|
483
526
|
}
|
|
527
|
+
|
|
528
|
+
// Propagate --ic-env KEY=VALUE pairs to icEnvVars for do/config template rendering.
|
|
529
|
+
// These are rendered as IC_ENV_* exports in do/config, which inference-component.sh
|
|
530
|
+
// reads at deploy time and passes as the Environment field in InferenceComponent.create().
|
|
531
|
+
if (!answers.icEnvVars) {
|
|
532
|
+
answers.icEnvVars = {};
|
|
533
|
+
}
|
|
484
534
|
}
|
|
@@ -74,22 +74,35 @@ export function persistSubmissionState(configPath, { technique, trainingType, da
|
|
|
74
74
|
* Simulate the config writes that happen after a job completes successfully.
|
|
75
75
|
* This mirrors the behavior in do/tune's _handle_completion() function.
|
|
76
76
|
*
|
|
77
|
+
* Writes three levels of tracking (AC-4.1, AC-4.2):
|
|
78
|
+
* - Level 1: TUNE_OUTPUT_PATH_LATEST (always the last run, any technique)
|
|
79
|
+
* - Level 2: TUNE_ADAPTER_PATH_<TECHNIQUE> (last run per technique)
|
|
80
|
+
* - Level 3: TUNE_ADAPTER_PATH_<TECHNIQUE>_<SLUG> (per technique + dataset slug)
|
|
81
|
+
*
|
|
77
82
|
* @param {string} configPath - Path to the config file
|
|
78
83
|
* @param {object} params - Completion parameters
|
|
79
84
|
* @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
|
|
80
85
|
* @param {string} params.trainingType - Training type (lora, full-rank)
|
|
81
86
|
* @param {string} params.artifactPath - S3 path to the output artifact
|
|
82
87
|
* @param {string} params.outputType - Output type (adapter, full-model)
|
|
88
|
+
* @param {string} [params.datasetSlug] - Optional dataset slug for per-technique-per-dataset tracking
|
|
83
89
|
*/
|
|
84
|
-
export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType }) {
|
|
90
|
+
export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType, datasetSlug }) {
|
|
85
91
|
const techniqueUpper = technique.toUpperCase();
|
|
86
92
|
|
|
87
93
|
if (trainingType === 'lora') {
|
|
94
|
+
// Level 2: per-technique
|
|
88
95
|
updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
|
|
96
|
+
// Level 3: per-technique + per-dataset (if slug available)
|
|
97
|
+
if (datasetSlug) {
|
|
98
|
+
const slugUpper = datasetSlug.toUpperCase().replace(/-/g, '_');
|
|
99
|
+
updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}_${slugUpper}`, artifactPath);
|
|
100
|
+
}
|
|
89
101
|
} else if (trainingType === 'full-rank') {
|
|
90
102
|
updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
|
|
91
103
|
}
|
|
92
104
|
|
|
105
|
+
// Level 1: latest
|
|
93
106
|
updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
|
|
94
107
|
updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
|
|
95
108
|
}
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""SageMaker Processing Job helper for adapter staging.
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
stage-from-tune - Submit Processing Job to copy adapter from training output to S3
|
|
9
|
+
status - Check Processing Job status
|
|
10
|
+
|
|
11
|
+
All output is JSON on stdout for bash consumption.
|
|
12
|
+
|
|
13
|
+
Uses sagemaker-core ProcessingJob.create() / ProcessingJob.get() per SDK v3 policy.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import logging
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import time
|
|
22
|
+
import warnings
|
|
23
|
+
|
|
24
|
+
# Suppress noisy dependency version warnings
|
|
25
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
26
|
+
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
27
|
+
|
|
28
|
+
# Suppress sagemaker-core INFO/WARNING logging that pollutes stdout
|
|
29
|
+
logging.getLogger("sagemaker.config").setLevel(logging.ERROR)
|
|
30
|
+
logging.getLogger("sagemaker.core").setLevel(logging.ERROR)
|
|
31
|
+
logging.getLogger("sagemaker").setLevel(logging.ERROR)
|
|
32
|
+
|
|
33
|
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
34
|
+
POLL_INTERVAL_SECONDS = 30
|
|
35
|
+
MAX_RUNTIME_SECONDS = 3600 # 1 hour timeout for adapter staging
|
|
36
|
+
INSTANCE_TYPE = "ml.m5.large"
|
|
37
|
+
VOLUME_SIZE_GB = 100
|
|
38
|
+
|
|
39
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _error_exit(message, exit_code=1):
|
|
43
|
+
"""Print error to stderr and exit."""
|
|
44
|
+
print(f"Error: {message}", file=sys.stderr)
|
|
45
|
+
sys.exit(exit_code)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _output(data):
|
|
49
|
+
"""Print JSON result to stdout."""
|
|
50
|
+
print(json.dumps(data))
|
|
51
|
+
sys.exit(0)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ── Dependency checks ─────────────────────────────────────────────────────────
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _check_sagemaker_core():
|
|
58
|
+
"""Verify sagemaker-core is installed."""
|
|
59
|
+
try:
|
|
60
|
+
from sagemaker.core.resources import ProcessingJob # noqa: F401
|
|
61
|
+
except ImportError:
|
|
62
|
+
_error_exit(
|
|
63
|
+
"sagemaker-core is not installed. "
|
|
64
|
+
"Please install: pip install 'sagemaker>=3.0.0' (includes sagemaker-core)"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _check_boto3():
|
|
69
|
+
"""Verify boto3 is installed (needed for S3 entrypoint upload)."""
|
|
70
|
+
try:
|
|
71
|
+
import boto3 # noqa: F401
|
|
72
|
+
except ImportError:
|
|
73
|
+
_error_exit(
|
|
74
|
+
"boto3 is not installed. "
|
|
75
|
+
"Please install: pip install boto3"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ── Processing Job helpers ────────────────────────────────────────────────────
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _generate_job_name(project_name, adapter_name):
|
|
83
|
+
"""Generate a unique Processing Job name."""
|
|
84
|
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
85
|
+
# Job names must be <= 63 chars, start with alphanumeric
|
|
86
|
+
base = f"mlcc-adapter-{project_name}-{adapter_name}"
|
|
87
|
+
# Truncate base to leave room for timestamp
|
|
88
|
+
max_base = 63 - len(timestamp) - 1
|
|
89
|
+
if len(base) > max_base:
|
|
90
|
+
base = base[:max_base]
|
|
91
|
+
return f"{base}-{timestamp}"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _upload_entrypoint(bucket, job_name, region):
|
|
95
|
+
"""Upload the processing job entrypoint script to S3.
|
|
96
|
+
|
|
97
|
+
The entrypoint simply copies files from the Processing input path
|
|
98
|
+
to the Processing output path (SageMaker handles S3 download/upload).
|
|
99
|
+
|
|
100
|
+
Returns the S3 URI of the uploaded entrypoint.
|
|
101
|
+
"""
|
|
102
|
+
import boto3
|
|
103
|
+
|
|
104
|
+
entrypoint_content = """#!/bin/bash
|
|
105
|
+
set -e
|
|
106
|
+
echo "Adapter staging: copying input to output..."
|
|
107
|
+
echo "Input contents:"
|
|
108
|
+
ls -la /opt/ml/processing/input/adapter/ || echo "No input files found"
|
|
109
|
+
echo ""
|
|
110
|
+
echo "Copying adapter files..."
|
|
111
|
+
cp -r /opt/ml/processing/input/adapter/* /opt/ml/processing/output/ 2>/dev/null || \
|
|
112
|
+
cp -r /opt/ml/processing/input/adapter/. /opt/ml/processing/output/
|
|
113
|
+
echo "Output contents:"
|
|
114
|
+
ls -la /opt/ml/processing/output/
|
|
115
|
+
echo ""
|
|
116
|
+
echo "Adapter staging complete."
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
s3_key = f"staging-jobs/{job_name}/entrypoint.sh"
|
|
120
|
+
s3_uri = f"s3://{bucket}/{s3_key}"
|
|
121
|
+
|
|
122
|
+
s3_client = boto3.client("s3", region_name=region)
|
|
123
|
+
try:
|
|
124
|
+
s3_client.put_object(
|
|
125
|
+
Bucket=bucket,
|
|
126
|
+
Key=s3_key,
|
|
127
|
+
Body=entrypoint_content.encode("utf-8"),
|
|
128
|
+
ContentType="text/x-shellscript",
|
|
129
|
+
)
|
|
130
|
+
except Exception as e:
|
|
131
|
+
_error_exit(f"Failed to upload entrypoint to S3: {e}")
|
|
132
|
+
|
|
133
|
+
return s3_uri
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _resolve_container_image(region):
|
|
137
|
+
"""Resolve the SageMaker-managed PyTorch CPU image URI for the region.
|
|
138
|
+
|
|
139
|
+
Uses the standard SageMaker DLC (Deep Learning Container) PyTorch CPU image
|
|
140
|
+
which includes AWS CLI and Python 3.10.
|
|
141
|
+
"""
|
|
142
|
+
# SageMaker DLC account IDs per region
|
|
143
|
+
# https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html
|
|
144
|
+
dlc_accounts = {
|
|
145
|
+
"us-east-1": "763104351884",
|
|
146
|
+
"us-east-2": "763104351884",
|
|
147
|
+
"us-west-1": "763104351884",
|
|
148
|
+
"us-west-2": "763104351884",
|
|
149
|
+
"eu-west-1": "763104351884",
|
|
150
|
+
"eu-west-2": "763104351884",
|
|
151
|
+
"eu-central-1": "763104351884",
|
|
152
|
+
"ap-northeast-1": "763104351884",
|
|
153
|
+
"ap-southeast-1": "763104351884",
|
|
154
|
+
"ap-southeast-2": "763104351884",
|
|
155
|
+
"ap-south-1": "763104351884",
|
|
156
|
+
"ca-central-1": "763104351884",
|
|
157
|
+
}
|
|
158
|
+
account_id = dlc_accounts.get(region, "763104351884")
|
|
159
|
+
# Use PyTorch CPU processing image
|
|
160
|
+
return f"{account_id}.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.2.0-cpu-py310-ubuntu20.04-sagemaker"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# ── Subcommand: stage-from-tune ───────────────────────────────────────────────
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def cmd_stage_from_tune(args):
|
|
167
|
+
"""Submit a Processing Job to copy adapter from training output to S3 adapter location.
|
|
168
|
+
|
|
169
|
+
Returns: {"job_name": str, "status": str, "adapter_s3_uri": str}
|
|
170
|
+
"""
|
|
171
|
+
_check_sagemaker_core()
|
|
172
|
+
_check_boto3()
|
|
173
|
+
|
|
174
|
+
from sagemaker.core.resources import ProcessingJob
|
|
175
|
+
|
|
176
|
+
# Validate required arguments
|
|
177
|
+
if not args.training_output_s3_uri:
|
|
178
|
+
_error_exit("--training-output-s3-uri is required")
|
|
179
|
+
if not args.adapter_name:
|
|
180
|
+
_error_exit("--adapter-name is required")
|
|
181
|
+
if not args.bucket:
|
|
182
|
+
_error_exit("--bucket is required")
|
|
183
|
+
if not args.project:
|
|
184
|
+
_error_exit("--project is required")
|
|
185
|
+
if not args.role_arn:
|
|
186
|
+
_error_exit("--role-arn is required")
|
|
187
|
+
|
|
188
|
+
region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
|
|
189
|
+
# Ensure region is set in env for sagemaker-core
|
|
190
|
+
os.environ["AWS_DEFAULT_REGION"] = region
|
|
191
|
+
os.environ.setdefault("AWS_REGION", region)
|
|
192
|
+
|
|
193
|
+
# Generate job name
|
|
194
|
+
job_name = _generate_job_name(args.project, args.adapter_name)
|
|
195
|
+
|
|
196
|
+
# Build adapter output S3 URI
|
|
197
|
+
adapter_s3_uri = f"s3://{args.bucket}/{args.project}/adapters/{args.adapter_name}/"
|
|
198
|
+
|
|
199
|
+
# Resolve container image
|
|
200
|
+
container_image = args.container_image or _resolve_container_image(region)
|
|
201
|
+
|
|
202
|
+
# Upload entrypoint script to S3
|
|
203
|
+
entrypoint_s3_uri = _upload_entrypoint(args.bucket, job_name, region)
|
|
204
|
+
|
|
205
|
+
# Build entrypoint command — download script from S3 then execute
|
|
206
|
+
entrypoint_cmd = (
|
|
207
|
+
f"aws s3 cp {entrypoint_s3_uri} /tmp/entrypoint.sh && "
|
|
208
|
+
"chmod +x /tmp/entrypoint.sh && /tmp/entrypoint.sh"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Normalize training output S3 URI (ensure trailing slash for S3Prefix)
|
|
212
|
+
training_output_s3_uri = args.training_output_s3_uri
|
|
213
|
+
if not training_output_s3_uri.endswith("/"):
|
|
214
|
+
training_output_s3_uri += "/"
|
|
215
|
+
|
|
216
|
+
# Submit Processing Job via sagemaker-core
|
|
217
|
+
try:
|
|
218
|
+
job = ProcessingJob.create(
|
|
219
|
+
processing_job_name=job_name,
|
|
220
|
+
processing_resources={
|
|
221
|
+
"cluster_config": {
|
|
222
|
+
"instance_count": 1,
|
|
223
|
+
"instance_type": INSTANCE_TYPE,
|
|
224
|
+
"volume_size_in_gb": VOLUME_SIZE_GB,
|
|
225
|
+
}
|
|
226
|
+
},
|
|
227
|
+
processing_inputs=[{
|
|
228
|
+
"input_name": "adapter",
|
|
229
|
+
"s3_input": {
|
|
230
|
+
"s3_uri": training_output_s3_uri,
|
|
231
|
+
"s3_data_type": "S3Prefix",
|
|
232
|
+
"s3_input_mode": "File",
|
|
233
|
+
"local_path": "/opt/ml/processing/input/adapter",
|
|
234
|
+
}
|
|
235
|
+
}],
|
|
236
|
+
processing_output_config={
|
|
237
|
+
"outputs": [{
|
|
238
|
+
"output_name": "staged-adapter",
|
|
239
|
+
"s3_output": {
|
|
240
|
+
"s3_uri": adapter_s3_uri,
|
|
241
|
+
"s3_upload_mode": "EndOfJob",
|
|
242
|
+
"local_path": "/opt/ml/processing/output",
|
|
243
|
+
}
|
|
244
|
+
}]
|
|
245
|
+
},
|
|
246
|
+
app_specification={
|
|
247
|
+
"image_uri": container_image,
|
|
248
|
+
"container_entrypoint": ["bash", "-c", entrypoint_cmd],
|
|
249
|
+
},
|
|
250
|
+
role_arn=args.role_arn,
|
|
251
|
+
stopping_condition={"max_runtime_in_seconds": MAX_RUNTIME_SECONDS},
|
|
252
|
+
)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
error_msg = str(e)
|
|
255
|
+
if "AccessDeniedException" in error_msg or "AccessDenied" in error_msg:
|
|
256
|
+
_error_exit(
|
|
257
|
+
f"Access denied when creating Processing Job. "
|
|
258
|
+
f"Ensure the role has sagemaker:CreateProcessingJob permission. "
|
|
259
|
+
f"Details: {error_msg}"
|
|
260
|
+
)
|
|
261
|
+
elif "ResourceLimitExceeded" in error_msg:
|
|
262
|
+
_error_exit(
|
|
263
|
+
f"Resource limit exceeded. You may need to request a quota increase. "
|
|
264
|
+
f"Details: {error_msg}"
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
_error_exit(f"Failed to create Processing Job: {error_msg}")
|
|
268
|
+
|
|
269
|
+
print(f"Processing Job submitted: {job_name}", file=sys.stderr)
|
|
270
|
+
print(f"Adapter output: {adapter_s3_uri}", file=sys.stderr)
|
|
271
|
+
|
|
272
|
+
# If --no-wait, return immediately
|
|
273
|
+
if args.no_wait:
|
|
274
|
+
_output({
|
|
275
|
+
"job_name": job_name,
|
|
276
|
+
"status": "InProgress",
|
|
277
|
+
"adapter_s3_uri": adapter_s3_uri,
|
|
278
|
+
})
|
|
279
|
+
|
|
280
|
+
# Poll until completion
|
|
281
|
+
print(f"Polling every {POLL_INTERVAL_SECONDS}s...", file=sys.stderr)
|
|
282
|
+
while True:
|
|
283
|
+
try:
|
|
284
|
+
job_desc = ProcessingJob.get(processing_job_name=job_name)
|
|
285
|
+
status = job_desc.processing_job_status
|
|
286
|
+
except Exception as e:
|
|
287
|
+
print(f"Warning: failed to get job status: {e}", file=sys.stderr)
|
|
288
|
+
time.sleep(POLL_INTERVAL_SECONDS)
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
print(
|
|
292
|
+
f" [{time.strftime('%H:%M:%S')}] Status: {status}",
|
|
293
|
+
file=sys.stderr,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if status in ("Completed", "Failed", "Stopped"):
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
time.sleep(POLL_INTERVAL_SECONDS)
|
|
300
|
+
|
|
301
|
+
# Handle terminal states
|
|
302
|
+
if status == "Failed":
|
|
303
|
+
failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown failure"
|
|
304
|
+
print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
|
|
305
|
+
sys.exit(1)
|
|
306
|
+
|
|
307
|
+
if status == "Stopped":
|
|
308
|
+
print("Processing Job was stopped.", file=sys.stderr)
|
|
309
|
+
sys.exit(1)
|
|
310
|
+
|
|
311
|
+
# Success
|
|
312
|
+
_output({
|
|
313
|
+
"job_name": job_name,
|
|
314
|
+
"status": "Completed",
|
|
315
|
+
"adapter_s3_uri": adapter_s3_uri,
|
|
316
|
+
})
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# ── Subcommand: status ────────────────────────────────────────────────────────
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def cmd_status(args):
|
|
323
|
+
"""Check Processing Job status.
|
|
324
|
+
|
|
325
|
+
Returns: {"job_name": str, "status": str, "failure_reason": str|None}
|
|
326
|
+
"""
|
|
327
|
+
_check_sagemaker_core()
|
|
328
|
+
|
|
329
|
+
from sagemaker.core.resources import ProcessingJob
|
|
330
|
+
|
|
331
|
+
if not args.job_name:
|
|
332
|
+
_error_exit("--job-name is required")
|
|
333
|
+
|
|
334
|
+
region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
|
|
335
|
+
os.environ["AWS_DEFAULT_REGION"] = region
|
|
336
|
+
os.environ.setdefault("AWS_REGION", region)
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
job_desc = ProcessingJob.get(processing_job_name=args.job_name)
|
|
340
|
+
except Exception as e:
|
|
341
|
+
error_msg = str(e)
|
|
342
|
+
if "does not exist" in error_msg or "ValidationException" in error_msg:
|
|
343
|
+
_error_exit(f"Processing Job not found: {args.job_name}")
|
|
344
|
+
else:
|
|
345
|
+
_error_exit(f"Failed to get Processing Job status: {error_msg}")
|
|
346
|
+
|
|
347
|
+
status = job_desc.processing_job_status
|
|
348
|
+
failure_reason = None
|
|
349
|
+
|
|
350
|
+
if status == "Failed":
|
|
351
|
+
failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown failure"
|
|
352
|
+
print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
|
|
353
|
+
|
|
354
|
+
_output({
|
|
355
|
+
"job_name": args.job_name,
|
|
356
|
+
"status": status,
|
|
357
|
+
"failure_reason": failure_reason,
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
# ── Argument parsing ──────────────────────────────────────────────────────────
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def main():
|
|
365
|
+
"""Parse arguments and dispatch to subcommand."""
|
|
366
|
+
parser = argparse.ArgumentParser(
|
|
367
|
+
description="SageMaker Processing Job helper for adapter staging",
|
|
368
|
+
prog=".adapter_helper.py",
|
|
369
|
+
)
|
|
370
|
+
subparsers = parser.add_subparsers(dest="subcommand", help="Subcommand")
|
|
371
|
+
|
|
372
|
+
# ── stage-from-tune ───────────────────────────────────────────────────
|
|
373
|
+
stage_parser = subparsers.add_parser(
|
|
374
|
+
"stage-from-tune",
|
|
375
|
+
help="Submit Processing Job to stage adapter from training output",
|
|
376
|
+
)
|
|
377
|
+
stage_parser.add_argument(
|
|
378
|
+
"--training-output-s3-uri",
|
|
379
|
+
required=True,
|
|
380
|
+
help="S3 URI of training output (adapter artifacts)",
|
|
381
|
+
)
|
|
382
|
+
stage_parser.add_argument(
|
|
383
|
+
"--adapter-name",
|
|
384
|
+
required=True,
|
|
385
|
+
help="Name of the adapter (used in output S3 path)",
|
|
386
|
+
)
|
|
387
|
+
stage_parser.add_argument(
|
|
388
|
+
"--bucket",
|
|
389
|
+
required=True,
|
|
390
|
+
help="S3 bucket for adapter output",
|
|
391
|
+
)
|
|
392
|
+
stage_parser.add_argument(
|
|
393
|
+
"--project",
|
|
394
|
+
required=True,
|
|
395
|
+
help="Project name (used in S3 path prefix)",
|
|
396
|
+
)
|
|
397
|
+
stage_parser.add_argument(
|
|
398
|
+
"--role-arn",
|
|
399
|
+
required=True,
|
|
400
|
+
help="SageMaker execution role ARN",
|
|
401
|
+
)
|
|
402
|
+
stage_parser.add_argument(
|
|
403
|
+
"--region",
|
|
404
|
+
default=None,
|
|
405
|
+
help="AWS region (default: from environment)",
|
|
406
|
+
)
|
|
407
|
+
stage_parser.add_argument(
|
|
408
|
+
"--container-image",
|
|
409
|
+
default=None,
|
|
410
|
+
help="Override container image URI (default: SageMaker PyTorch CPU image)",
|
|
411
|
+
)
|
|
412
|
+
stage_parser.add_argument(
|
|
413
|
+
"--no-wait",
|
|
414
|
+
action="store_true",
|
|
415
|
+
default=False,
|
|
416
|
+
help="Return immediately after submitting the job",
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# ── status ────────────────────────────────────────────────────────────
|
|
420
|
+
status_parser = subparsers.add_parser(
|
|
421
|
+
"status",
|
|
422
|
+
help="Check Processing Job status",
|
|
423
|
+
)
|
|
424
|
+
status_parser.add_argument(
|
|
425
|
+
"--job-name",
|
|
426
|
+
required=True,
|
|
427
|
+
help="Processing Job name to check",
|
|
428
|
+
)
|
|
429
|
+
status_parser.add_argument(
|
|
430
|
+
"--region",
|
|
431
|
+
default=None,
|
|
432
|
+
help="AWS region (default: from environment)",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# ── Parse and dispatch ────────────────────────────────────────────────
|
|
436
|
+
args = parser.parse_args()
|
|
437
|
+
|
|
438
|
+
if not args.subcommand:
|
|
439
|
+
parser.print_help()
|
|
440
|
+
sys.exit(1)
|
|
441
|
+
|
|
442
|
+
if args.subcommand == "stage-from-tune":
|
|
443
|
+
cmd_stage_from_tune(args)
|
|
444
|
+
elif args.subcommand == "status":
|
|
445
|
+
cmd_status(args)
|
|
446
|
+
else:
|
|
447
|
+
_error_exit(f"Unknown subcommand: {args.subcommand}")
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
if __name__ == "__main__":
|
|
451
|
+
main()
|
|
@@ -487,6 +487,7 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
487
487
|
'mcc_version': mcc_version,
|
|
488
488
|
'run_timestamp': run_timestamp.isoformat(),
|
|
489
489
|
'region': region,
|
|
490
|
+
'adapter_name': config.get('adapter_name', ''),
|
|
490
491
|
}
|
|
491
492
|
records.append(record)
|
|
492
493
|
|
|
@@ -859,6 +860,7 @@ def get_parquet_schema():
|
|
|
859
860
|
pa.field("mcc_version", pa.string()),
|
|
860
861
|
pa.field("run_timestamp", pa.string()),
|
|
861
862
|
pa.field("region", pa.string()),
|
|
863
|
+
pa.field("adapter_name", pa.string()),
|
|
862
864
|
])
|
|
863
865
|
|
|
864
866
|
|
|
@@ -1177,6 +1179,8 @@ def cmd_write(args):
|
|
|
1177
1179
|
input_data['workload'] = args.workload
|
|
1178
1180
|
if args.region:
|
|
1179
1181
|
input_data['region'] = args.region
|
|
1182
|
+
if args.adapter_name:
|
|
1183
|
+
input_data['adapter_name'] = args.adapter_name
|
|
1180
1184
|
|
|
1181
1185
|
# ── Validate before any S3 interaction ────────────────────────────────
|
|
1182
1186
|
errors = validate_benchmark_input(input_data)
|
|
@@ -1385,6 +1389,7 @@ def _load_config_file(config_path):
|
|
|
1385
1389
|
shell_map = {
|
|
1386
1390
|
'PROJECT_NAME': 'project_name',
|
|
1387
1391
|
'MODEL_NAME': 'model_name',
|
|
1392
|
+
'HF_MODEL_ID': 'hf_model_id',
|
|
1388
1393
|
'INSTANCE_TYPE': 'instance_type',
|
|
1389
1394
|
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1390
1395
|
'DEPLOYMENT_TARGET': 'deployment_target',
|
|
@@ -1402,6 +1407,18 @@ def _load_config_file(config_path):
|
|
|
1402
1407
|
except Exception:
|
|
1403
1408
|
pass
|
|
1404
1409
|
|
|
1410
|
+
# Prefer HF_MODEL_ID over MODEL_NAME for the model_name field.
|
|
1411
|
+
# After do/stage runs, MODEL_NAME is rewritten to an S3 URI which is
|
|
1412
|
+
# unsuitable for S3 result paths (nested s3:// in path) and model family
|
|
1413
|
+
# derivation. HF_MODEL_ID preserves the original HuggingFace repo ID.
|
|
1414
|
+
if context.get('hf_model_id'):
|
|
1415
|
+
context['model_name'] = context.pop('hf_model_id')
|
|
1416
|
+
elif context.get('model_name', '').startswith('s3://'):
|
|
1417
|
+
# Fallback: if no HF_MODEL_ID but MODEL_NAME is an S3 URI, extract
|
|
1418
|
+
# the model slug from the S3 path (last non-empty segment)
|
|
1419
|
+
parts = context['model_name'].rstrip('/').split('/')
|
|
1420
|
+
context['model_name'] = parts[-1] if parts else context['model_name']
|
|
1421
|
+
|
|
1405
1422
|
return context
|
|
1406
1423
|
|
|
1407
1424
|
|
|
@@ -1449,6 +1466,11 @@ def main():
|
|
|
1449
1466
|
'--region',
|
|
1450
1467
|
help='AWS region'
|
|
1451
1468
|
)
|
|
1469
|
+
write_parser.add_argument(
|
|
1470
|
+
'--adapter-name', dest='adapter_name', default=None,
|
|
1471
|
+
help='LoRA adapter name (differentiates adapter benchmarks from base model in Athena)'
|
|
1472
|
+
)
|
|
1473
|
+
|
|
1452
1474
|
write_parser.add_argument(
|
|
1453
1475
|
'--dry-run', dest='dry_run', action='store_true',
|
|
1454
1476
|
help='Output enriched records as JSON without writing to S3'
|