@aws/ml-container-creator 0.13.4 → 0.13.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -5
- package/infra/ci-harness/package-lock.json +1 -5
- package/package.json +4 -2
- package/pyproject.toml +21 -0
- package/requirements.txt +19 -0
- package/src/app.js +2 -0
- package/src/lib/bootstrap-command-handler.js +33 -23
- package/templates/do/.adapter_helper.py +451 -0
- package/templates/do/.benchmark_writer.py +13 -0
- package/templates/do/.stage_helper.py +419 -0
- package/templates/do/.tune_helper.py +213 -65
- package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +108 -0
- package/templates/do/benchmark +150 -12
- package/templates/do/config +4 -0
- package/templates/do/lib/profile.sh +5 -0
- package/templates/do/stage +91 -272
- package/templates/do/tune +63 -6
package/README.md
CHANGED
|
@@ -87,11 +87,29 @@ Full documentation is available at [awslabs.github.io/ml-container-creator](http
|
|
|
87
87
|
|
|
88
88
|
## Prerequisites
|
|
89
89
|
|
|
90
|
-
| Tool | Version | Purpose |
|
|
91
|
-
|
|
92
|
-
| [Node.js](https://nodejs.org/) | 24+ | Runs the CLI |
|
|
93
|
-
| [
|
|
94
|
-
| [
|
|
90
|
+
| Tool | Version | Purpose | Required |
|
|
91
|
+
|---|---|---|---|
|
|
92
|
+
| [Node.js](https://nodejs.org/) | 24+ | Runs the CLI | Yes |
|
|
93
|
+
| [Python](https://www.python.org/) | 3.10+ | `do/` lifecycle scripts (stage, tune, benchmark) | Yes |
|
|
94
|
+
| [uv](https://docs.astral.sh/uv/) | latest | Fast Python package installer | Recommended |
|
|
95
|
+
| [Docker](https://docs.docker.com/get-docker/) | 20+ | Container builds | Yes |
|
|
96
|
+
| [AWS CLI](https://aws.amazon.com/cli/) | 2+ | AWS resource management | Yes |
|
|
97
|
+
|
|
98
|
+
### Python dependencies
|
|
99
|
+
|
|
100
|
+
The `do/` lifecycle scripts (`do/tune`, `do/stage`, `do/adapter`) require Python packages. Install them in your Python environment before first use:
|
|
101
|
+
|
|
102
|
+
```bash
|
|
103
|
+
# Recommended (fast):
|
|
104
|
+
uv pip install -r requirements.txt
|
|
105
|
+
|
|
106
|
+
# Or with pip:
|
|
107
|
+
pip install -r requirements.txt
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
If you use virtual environments, activate yours first. See [`requirements.txt`](requirements.txt) for the full list (`boto3`, `sagemaker-core`, `huggingface_hub`, `pyarrow`, etc.).
|
|
111
|
+
|
|
112
|
+
> **Tip:** Install [uv](https://docs.astral.sh/uv/) for 10-50x faster Python package installs: `curl -LsSf https://astral.sh/uv/install.sh | sh`
|
|
95
113
|
|
|
96
114
|
## Contributing
|
|
97
115
|
|
|
@@ -48,7 +48,6 @@
|
|
|
48
48
|
"semver"
|
|
49
49
|
],
|
|
50
50
|
"license": "Apache-2.0",
|
|
51
|
-
"peer": true,
|
|
52
51
|
"dependencies": {
|
|
53
52
|
"jsonschema": "~1.4.1",
|
|
54
53
|
"semver": "^7.7.4"
|
|
@@ -2151,7 +2150,6 @@
|
|
|
2151
2150
|
"integrity": "sha512-wGdMcf+vPYM6jikpS/qhg6WiqSV/OhG+jeeHT/KlVqxYfD40iYJf9/AE1uQxVWFvU7MipKRkRv8NSHiCGgPr8Q==",
|
|
2152
2151
|
"dev": true,
|
|
2153
2152
|
"license": "MIT",
|
|
2154
|
-
"peer": true,
|
|
2155
2153
|
"dependencies": {
|
|
2156
2154
|
"undici-types": "~6.21.0"
|
|
2157
2155
|
}
|
|
@@ -2791,8 +2789,7 @@
|
|
|
2791
2789
|
"version": "10.6.0",
|
|
2792
2790
|
"resolved": "https://registry.npmjs.org/constructs/-/constructs-10.6.0.tgz",
|
|
2793
2791
|
"integrity": "sha512-TxHOnBO5zMo/G76ykzGF/wMpEHu257TbWiIxP9K0Yv/+t70UzgBQiTqjkAsWOPC6jW91DzJI0+ehQV6xDRNBuQ==",
|
|
2794
|
-
"license": "Apache-2.0"
|
|
2795
|
-
"peer": true
|
|
2792
|
+
"license": "Apache-2.0"
|
|
2796
2793
|
},
|
|
2797
2794
|
"node_modules/create-require": {
|
|
2798
2795
|
"version": "1.1.1",
|
|
@@ -3697,7 +3694,6 @@
|
|
|
3697
3694
|
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
|
3698
3695
|
"dev": true,
|
|
3699
3696
|
"license": "Apache-2.0",
|
|
3700
|
-
"peer": true,
|
|
3701
3697
|
"bin": {
|
|
3702
3698
|
"tsc": "bin/tsc",
|
|
3703
3699
|
"tsserver": "bin/tsserver"
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@aws/ml-container-creator",
|
|
3
|
-
"version": "0.13.
|
|
3
|
+
"version": "0.13.5",
|
|
4
4
|
"description": "Build and deploy custom ML containers on AWS SageMaker with minimal configuration.",
|
|
5
5
|
"main": "src/index.js",
|
|
6
6
|
"bin": {
|
|
@@ -76,7 +76,9 @@
|
|
|
76
76
|
"README.md",
|
|
77
77
|
"LICENSE",
|
|
78
78
|
"LICENSE-THIRD-PARTY",
|
|
79
|
-
"NOTICE"
|
|
79
|
+
"NOTICE",
|
|
80
|
+
"requirements.txt",
|
|
81
|
+
"pyproject.toml"
|
|
80
82
|
],
|
|
81
83
|
"type": "module",
|
|
82
84
|
"license": "Apache-2.0",
|
package/pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "ml-container-creator"
|
|
3
|
+
version = "0.13.4"
|
|
4
|
+
description = "Python dependencies for ml-container-creator do/ lifecycle scripts"
|
|
5
|
+
requires-python = ">=3.10"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"boto3>=1.35.0",
|
|
8
|
+
"huggingface-hub>=0.25.0",
|
|
9
|
+
"hf-transfer>=0.1.8",
|
|
10
|
+
"pyarrow>=17.0.0",
|
|
11
|
+
"sagemaker-core>=1.0.0",
|
|
12
|
+
"sagemaker[train]>=3.0.0",
|
|
13
|
+
"sagemaker[serve]>=3.0.0",
|
|
14
|
+
"packaging>=24.0",
|
|
15
|
+
"pyyaml>=6.0",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[dependency-groups]
|
|
19
|
+
dev = [
|
|
20
|
+
"pytest>=8.0",
|
|
21
|
+
]
|
package/requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Python dependencies for do/ lifecycle scripts
|
|
2
|
+
#
|
|
3
|
+
# Install with uv (recommended):
|
|
4
|
+
# uv pip install -r requirements.txt
|
|
5
|
+
#
|
|
6
|
+
# Or with pip:
|
|
7
|
+
# pip install -r requirements.txt
|
|
8
|
+
#
|
|
9
|
+
# Source of truth: pyproject.toml
|
|
10
|
+
|
|
11
|
+
boto3>=1.35.0
|
|
12
|
+
huggingface_hub>=0.25.0
|
|
13
|
+
hf_transfer>=0.1.8
|
|
14
|
+
pyarrow>=17.0.0
|
|
15
|
+
sagemaker-core>=1.0.0
|
|
16
|
+
sagemaker[train]>=3.0.0
|
|
17
|
+
sagemaker[serve]>=3.0.0
|
|
18
|
+
packaging>=24.0
|
|
19
|
+
PyYAML>=6.0
|
package/src/app.js
CHANGED
|
@@ -400,6 +400,8 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
400
400
|
ignorePatterns.push('**/do/adapters/**');
|
|
401
401
|
ignorePatterns.push('**/do/tune');
|
|
402
402
|
ignorePatterns.push('**/do/.tune_helper.py');
|
|
403
|
+
ignorePatterns.push('**/do/.stage_helper.py');
|
|
404
|
+
ignorePatterns.push('**/do/.adapter_helper.py');
|
|
403
405
|
ignorePatterns.push('**/do/train');
|
|
404
406
|
ignorePatterns.push('**/do/.train_build_request.py');
|
|
405
407
|
ignorePatterns.push('**/do/.train_status_parser.py');
|
|
@@ -459,39 +459,49 @@ export default class BootstrapCommandHandler {
|
|
|
459
459
|
|
|
460
460
|
// --no-rollback prevents rollback on AlreadyExists errors for IAM roles
|
|
461
461
|
// that may pre-exist from a prior deployment or another region.
|
|
462
|
-
// Check if benchmark bucket already exists
|
|
463
|
-
|
|
462
|
+
// Check if benchmark results bucket already exists.
|
|
463
|
+
// If it does, skip CDK deploy for benchmark infra — just update the profile.
|
|
464
|
+
let benchmarkBucketExists = false;
|
|
464
465
|
if (options.benchmarkInfra) {
|
|
466
|
+
const resultsBucketName = `mlcc-benchmark-results-${profileData.accountId}-${profileData.awsRegion}`;
|
|
465
467
|
try {
|
|
466
468
|
execSync(
|
|
467
|
-
`aws s3api head-bucket --bucket
|
|
469
|
+
`aws s3api head-bucket --bucket ${resultsBucketName}${profileData.awsProfile ? ` --profile ${profileData.awsProfile}` : ''} --region ${profileData.awsRegion}`,
|
|
468
470
|
{ encoding: 'utf8', stdio: ['pipe', 'pipe', 'pipe'] }
|
|
469
471
|
);
|
|
470
|
-
|
|
471
|
-
console.log(
|
|
472
|
+
benchmarkBucketExists = true;
|
|
473
|
+
console.log(` ✅ Benchmark results bucket already exists: ${resultsBucketName}`);
|
|
474
|
+
console.log(' Skipping CDK deploy for benchmark infra — updating profile only.');
|
|
475
|
+
profileData.benchmarkInfraProvisioned = true;
|
|
476
|
+
profileData.ciGlueDatabase = profileData.ciGlueDatabase || 'mlcc_ci';
|
|
477
|
+
profileData.ciBenchmarkResultsBucket = resultsBucketName;
|
|
472
478
|
} catch {
|
|
473
479
|
// Bucket doesn't exist — will be created fresh
|
|
474
480
|
}
|
|
475
481
|
}
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
482
|
+
|
|
483
|
+
// Only run CDK deploy if we actually need to create infrastructure
|
|
484
|
+
if (!benchmarkBucketExists || !options.benchmarkInfra) {
|
|
485
|
+
const cdkDeployCmd = options.benchmarkInfra
|
|
486
|
+
? 'npx cdk deploy MlccCiHarnessStack --require-approval never --no-rollback --parameters MlccCiHarnessStack:CreateBenchmarkInfra=true'
|
|
487
|
+
: 'npx cdk deploy MlccCiHarnessStack --require-approval never --no-rollback';
|
|
488
|
+
execSync(
|
|
489
|
+
cdkDeployCmd,
|
|
490
|
+
{
|
|
491
|
+
cwd: ciHarnessDir,
|
|
492
|
+
encoding: 'utf8',
|
|
493
|
+
stdio: 'inherit',
|
|
494
|
+
env: {
|
|
495
|
+
...process.env,
|
|
496
|
+
AWS_REGION: profileData.awsRegion,
|
|
497
|
+
CDK_DEFAULT_REGION: profileData.awsRegion,
|
|
498
|
+
CDK_DEFAULT_ACCOUNT: profileData.accountId,
|
|
499
|
+
AWS_PROFILE: profileData.awsProfile
|
|
500
|
+
}
|
|
491
501
|
}
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
502
|
+
);
|
|
503
|
+
console.log(' ✅ CI harness stack deployed');
|
|
504
|
+
}
|
|
495
505
|
|
|
496
506
|
profileData.ciInfraProvisioned = true;
|
|
497
507
|
profileData.ciTableName = 'mlcc-ci-table';
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""SageMaker Processing Job helper for adapter staging.
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
stage-from-tune - Submit Processing Job to copy adapter from training output to S3
|
|
9
|
+
status - Check Processing Job status
|
|
10
|
+
|
|
11
|
+
All output is JSON on stdout for bash consumption.
|
|
12
|
+
|
|
13
|
+
Uses sagemaker-core ProcessingJob.create() / ProcessingJob.get() per SDK v3 policy.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import logging
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import time
|
|
22
|
+
import warnings
|
|
23
|
+
|
|
24
|
+
# Suppress noisy dependency version warnings
|
|
25
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
26
|
+
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
27
|
+
|
|
28
|
+
# Suppress sagemaker-core INFO/WARNING logging that pollutes stdout
|
|
29
|
+
logging.getLogger("sagemaker.config").setLevel(logging.ERROR)
|
|
30
|
+
logging.getLogger("sagemaker.core").setLevel(logging.ERROR)
|
|
31
|
+
logging.getLogger("sagemaker").setLevel(logging.ERROR)
|
|
32
|
+
|
|
33
|
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
34
|
+
POLL_INTERVAL_SECONDS = 30
|
|
35
|
+
MAX_RUNTIME_SECONDS = 3600 # 1 hour timeout for adapter staging
|
|
36
|
+
INSTANCE_TYPE = "ml.m5.large"
|
|
37
|
+
VOLUME_SIZE_GB = 100
|
|
38
|
+
|
|
39
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _error_exit(message, exit_code=1):
|
|
43
|
+
"""Print error to stderr and exit."""
|
|
44
|
+
print(f"Error: {message}", file=sys.stderr)
|
|
45
|
+
sys.exit(exit_code)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _output(data):
|
|
49
|
+
"""Print JSON result to stdout."""
|
|
50
|
+
print(json.dumps(data))
|
|
51
|
+
sys.exit(0)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ── Dependency checks ─────────────────────────────────────────────────────────
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _check_sagemaker_core():
|
|
58
|
+
"""Verify sagemaker-core is installed."""
|
|
59
|
+
try:
|
|
60
|
+
from sagemaker.core.resources import ProcessingJob # noqa: F401
|
|
61
|
+
except ImportError:
|
|
62
|
+
_error_exit(
|
|
63
|
+
"sagemaker-core is not installed. "
|
|
64
|
+
"Please install: pip install 'sagemaker>=3.0.0' (includes sagemaker-core)"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _check_boto3():
|
|
69
|
+
"""Verify boto3 is installed (needed for S3 entrypoint upload)."""
|
|
70
|
+
try:
|
|
71
|
+
import boto3 # noqa: F401
|
|
72
|
+
except ImportError:
|
|
73
|
+
_error_exit(
|
|
74
|
+
"boto3 is not installed. "
|
|
75
|
+
"Please install: pip install boto3"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ── Processing Job helpers ────────────────────────────────────────────────────
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _generate_job_name(project_name, adapter_name):
|
|
83
|
+
"""Generate a unique Processing Job name."""
|
|
84
|
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
85
|
+
# Job names must be <= 63 chars, start with alphanumeric
|
|
86
|
+
base = f"mlcc-adapter-{project_name}-{adapter_name}"
|
|
87
|
+
# Truncate base to leave room for timestamp
|
|
88
|
+
max_base = 63 - len(timestamp) - 1
|
|
89
|
+
if len(base) > max_base:
|
|
90
|
+
base = base[:max_base]
|
|
91
|
+
return f"{base}-{timestamp}"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _upload_entrypoint(bucket, job_name, region):
|
|
95
|
+
"""Upload the processing job entrypoint script to S3.
|
|
96
|
+
|
|
97
|
+
The entrypoint simply copies files from the Processing input path
|
|
98
|
+
to the Processing output path (SageMaker handles S3 download/upload).
|
|
99
|
+
|
|
100
|
+
Returns the S3 URI of the uploaded entrypoint.
|
|
101
|
+
"""
|
|
102
|
+
import boto3
|
|
103
|
+
|
|
104
|
+
entrypoint_content = """#!/bin/bash
|
|
105
|
+
set -e
|
|
106
|
+
echo "Adapter staging: copying input to output..."
|
|
107
|
+
echo "Input contents:"
|
|
108
|
+
ls -la /opt/ml/processing/input/adapter/ || echo "No input files found"
|
|
109
|
+
echo ""
|
|
110
|
+
echo "Copying adapter files..."
|
|
111
|
+
cp -r /opt/ml/processing/input/adapter/* /opt/ml/processing/output/ 2>/dev/null || \
|
|
112
|
+
cp -r /opt/ml/processing/input/adapter/. /opt/ml/processing/output/
|
|
113
|
+
echo "Output contents:"
|
|
114
|
+
ls -la /opt/ml/processing/output/
|
|
115
|
+
echo ""
|
|
116
|
+
echo "Adapter staging complete."
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
s3_key = f"staging-jobs/{job_name}/entrypoint.sh"
|
|
120
|
+
s3_uri = f"s3://{bucket}/{s3_key}"
|
|
121
|
+
|
|
122
|
+
s3_client = boto3.client("s3", region_name=region)
|
|
123
|
+
try:
|
|
124
|
+
s3_client.put_object(
|
|
125
|
+
Bucket=bucket,
|
|
126
|
+
Key=s3_key,
|
|
127
|
+
Body=entrypoint_content.encode("utf-8"),
|
|
128
|
+
ContentType="text/x-shellscript",
|
|
129
|
+
)
|
|
130
|
+
except Exception as e:
|
|
131
|
+
_error_exit(f"Failed to upload entrypoint to S3: {e}")
|
|
132
|
+
|
|
133
|
+
return s3_uri
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _resolve_container_image(region):
|
|
137
|
+
"""Resolve the SageMaker-managed PyTorch CPU image URI for the region.
|
|
138
|
+
|
|
139
|
+
Uses the standard SageMaker DLC (Deep Learning Container) PyTorch CPU image
|
|
140
|
+
which includes AWS CLI and Python 3.10.
|
|
141
|
+
"""
|
|
142
|
+
# SageMaker DLC account IDs per region
|
|
143
|
+
# https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html
|
|
144
|
+
dlc_accounts = {
|
|
145
|
+
"us-east-1": "763104351884",
|
|
146
|
+
"us-east-2": "763104351884",
|
|
147
|
+
"us-west-1": "763104351884",
|
|
148
|
+
"us-west-2": "763104351884",
|
|
149
|
+
"eu-west-1": "763104351884",
|
|
150
|
+
"eu-west-2": "763104351884",
|
|
151
|
+
"eu-central-1": "763104351884",
|
|
152
|
+
"ap-northeast-1": "763104351884",
|
|
153
|
+
"ap-southeast-1": "763104351884",
|
|
154
|
+
"ap-southeast-2": "763104351884",
|
|
155
|
+
"ap-south-1": "763104351884",
|
|
156
|
+
"ca-central-1": "763104351884",
|
|
157
|
+
}
|
|
158
|
+
account_id = dlc_accounts.get(region, "763104351884")
|
|
159
|
+
# Use PyTorch CPU processing image
|
|
160
|
+
return f"{account_id}.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.2.0-cpu-py310-ubuntu20.04-sagemaker"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# ── Subcommand: stage-from-tune ───────────────────────────────────────────────
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def cmd_stage_from_tune(args):
|
|
167
|
+
"""Submit a Processing Job to copy adapter from training output to S3 adapter location.
|
|
168
|
+
|
|
169
|
+
Returns: {"job_name": str, "status": str, "adapter_s3_uri": str}
|
|
170
|
+
"""
|
|
171
|
+
_check_sagemaker_core()
|
|
172
|
+
_check_boto3()
|
|
173
|
+
|
|
174
|
+
from sagemaker.core.resources import ProcessingJob
|
|
175
|
+
|
|
176
|
+
# Validate required arguments
|
|
177
|
+
if not args.training_output_s3_uri:
|
|
178
|
+
_error_exit("--training-output-s3-uri is required")
|
|
179
|
+
if not args.adapter_name:
|
|
180
|
+
_error_exit("--adapter-name is required")
|
|
181
|
+
if not args.bucket:
|
|
182
|
+
_error_exit("--bucket is required")
|
|
183
|
+
if not args.project:
|
|
184
|
+
_error_exit("--project is required")
|
|
185
|
+
if not args.role_arn:
|
|
186
|
+
_error_exit("--role-arn is required")
|
|
187
|
+
|
|
188
|
+
region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
|
|
189
|
+
# Ensure region is set in env for sagemaker-core
|
|
190
|
+
os.environ["AWS_DEFAULT_REGION"] = region
|
|
191
|
+
os.environ.setdefault("AWS_REGION", region)
|
|
192
|
+
|
|
193
|
+
# Generate job name
|
|
194
|
+
job_name = _generate_job_name(args.project, args.adapter_name)
|
|
195
|
+
|
|
196
|
+
# Build adapter output S3 URI
|
|
197
|
+
adapter_s3_uri = f"s3://{args.bucket}/{args.project}/adapters/{args.adapter_name}/"
|
|
198
|
+
|
|
199
|
+
# Resolve container image
|
|
200
|
+
container_image = args.container_image or _resolve_container_image(region)
|
|
201
|
+
|
|
202
|
+
# Upload entrypoint script to S3
|
|
203
|
+
entrypoint_s3_uri = _upload_entrypoint(args.bucket, job_name, region)
|
|
204
|
+
|
|
205
|
+
# Build entrypoint command — download script from S3 then execute
|
|
206
|
+
entrypoint_cmd = (
|
|
207
|
+
f"aws s3 cp {entrypoint_s3_uri} /tmp/entrypoint.sh && "
|
|
208
|
+
"chmod +x /tmp/entrypoint.sh && /tmp/entrypoint.sh"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Normalize training output S3 URI (ensure trailing slash for S3Prefix)
|
|
212
|
+
training_output_s3_uri = args.training_output_s3_uri
|
|
213
|
+
if not training_output_s3_uri.endswith("/"):
|
|
214
|
+
training_output_s3_uri += "/"
|
|
215
|
+
|
|
216
|
+
# Submit Processing Job via sagemaker-core
|
|
217
|
+
try:
|
|
218
|
+
job = ProcessingJob.create(
|
|
219
|
+
processing_job_name=job_name,
|
|
220
|
+
processing_resources={
|
|
221
|
+
"cluster_config": {
|
|
222
|
+
"instance_count": 1,
|
|
223
|
+
"instance_type": INSTANCE_TYPE,
|
|
224
|
+
"volume_size_in_gb": VOLUME_SIZE_GB,
|
|
225
|
+
}
|
|
226
|
+
},
|
|
227
|
+
processing_inputs=[{
|
|
228
|
+
"input_name": "adapter",
|
|
229
|
+
"s3_input": {
|
|
230
|
+
"s3_uri": training_output_s3_uri,
|
|
231
|
+
"s3_data_type": "S3Prefix",
|
|
232
|
+
"s3_input_mode": "File",
|
|
233
|
+
"local_path": "/opt/ml/processing/input/adapter",
|
|
234
|
+
}
|
|
235
|
+
}],
|
|
236
|
+
processing_output_config={
|
|
237
|
+
"outputs": [{
|
|
238
|
+
"output_name": "staged-adapter",
|
|
239
|
+
"s3_output": {
|
|
240
|
+
"s3_uri": adapter_s3_uri,
|
|
241
|
+
"s3_upload_mode": "EndOfJob",
|
|
242
|
+
"local_path": "/opt/ml/processing/output",
|
|
243
|
+
}
|
|
244
|
+
}]
|
|
245
|
+
},
|
|
246
|
+
app_specification={
|
|
247
|
+
"image_uri": container_image,
|
|
248
|
+
"container_entrypoint": ["bash", "-c", entrypoint_cmd],
|
|
249
|
+
},
|
|
250
|
+
role_arn=args.role_arn,
|
|
251
|
+
stopping_condition={"max_runtime_in_seconds": MAX_RUNTIME_SECONDS},
|
|
252
|
+
)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
error_msg = str(e)
|
|
255
|
+
if "AccessDeniedException" in error_msg or "AccessDenied" in error_msg:
|
|
256
|
+
_error_exit(
|
|
257
|
+
f"Access denied when creating Processing Job. "
|
|
258
|
+
f"Ensure the role has sagemaker:CreateProcessingJob permission. "
|
|
259
|
+
f"Details: {error_msg}"
|
|
260
|
+
)
|
|
261
|
+
elif "ResourceLimitExceeded" in error_msg:
|
|
262
|
+
_error_exit(
|
|
263
|
+
f"Resource limit exceeded. You may need to request a quota increase. "
|
|
264
|
+
f"Details: {error_msg}"
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
_error_exit(f"Failed to create Processing Job: {error_msg}")
|
|
268
|
+
|
|
269
|
+
print(f"Processing Job submitted: {job_name}", file=sys.stderr)
|
|
270
|
+
print(f"Adapter output: {adapter_s3_uri}", file=sys.stderr)
|
|
271
|
+
|
|
272
|
+
# If --no-wait, return immediately
|
|
273
|
+
if args.no_wait:
|
|
274
|
+
_output({
|
|
275
|
+
"job_name": job_name,
|
|
276
|
+
"status": "InProgress",
|
|
277
|
+
"adapter_s3_uri": adapter_s3_uri,
|
|
278
|
+
})
|
|
279
|
+
|
|
280
|
+
# Poll until completion
|
|
281
|
+
print(f"Polling every {POLL_INTERVAL_SECONDS}s...", file=sys.stderr)
|
|
282
|
+
while True:
|
|
283
|
+
try:
|
|
284
|
+
job_desc = ProcessingJob.get(processing_job_name=job_name)
|
|
285
|
+
status = job_desc.processing_job_status
|
|
286
|
+
except Exception as e:
|
|
287
|
+
print(f"Warning: failed to get job status: {e}", file=sys.stderr)
|
|
288
|
+
time.sleep(POLL_INTERVAL_SECONDS)
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
print(
|
|
292
|
+
f" [{time.strftime('%H:%M:%S')}] Status: {status}",
|
|
293
|
+
file=sys.stderr,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if status in ("Completed", "Failed", "Stopped"):
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
time.sleep(POLL_INTERVAL_SECONDS)
|
|
300
|
+
|
|
301
|
+
# Handle terminal states
|
|
302
|
+
if status == "Failed":
|
|
303
|
+
failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown failure"
|
|
304
|
+
print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
|
|
305
|
+
sys.exit(1)
|
|
306
|
+
|
|
307
|
+
if status == "Stopped":
|
|
308
|
+
print("Processing Job was stopped.", file=sys.stderr)
|
|
309
|
+
sys.exit(1)
|
|
310
|
+
|
|
311
|
+
# Success
|
|
312
|
+
_output({
|
|
313
|
+
"job_name": job_name,
|
|
314
|
+
"status": "Completed",
|
|
315
|
+
"adapter_s3_uri": adapter_s3_uri,
|
|
316
|
+
})
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# ── Subcommand: status ────────────────────────────────────────────────────────
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def cmd_status(args):
|
|
323
|
+
"""Check Processing Job status.
|
|
324
|
+
|
|
325
|
+
Returns: {"job_name": str, "status": str, "failure_reason": str|None}
|
|
326
|
+
"""
|
|
327
|
+
_check_sagemaker_core()
|
|
328
|
+
|
|
329
|
+
from sagemaker.core.resources import ProcessingJob
|
|
330
|
+
|
|
331
|
+
if not args.job_name:
|
|
332
|
+
_error_exit("--job-name is required")
|
|
333
|
+
|
|
334
|
+
region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
|
|
335
|
+
os.environ["AWS_DEFAULT_REGION"] = region
|
|
336
|
+
os.environ.setdefault("AWS_REGION", region)
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
job_desc = ProcessingJob.get(processing_job_name=args.job_name)
|
|
340
|
+
except Exception as e:
|
|
341
|
+
error_msg = str(e)
|
|
342
|
+
if "does not exist" in error_msg or "ValidationException" in error_msg:
|
|
343
|
+
_error_exit(f"Processing Job not found: {args.job_name}")
|
|
344
|
+
else:
|
|
345
|
+
_error_exit(f"Failed to get Processing Job status: {error_msg}")
|
|
346
|
+
|
|
347
|
+
status = job_desc.processing_job_status
|
|
348
|
+
failure_reason = None
|
|
349
|
+
|
|
350
|
+
if status == "Failed":
|
|
351
|
+
failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown failure"
|
|
352
|
+
print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
|
|
353
|
+
|
|
354
|
+
_output({
|
|
355
|
+
"job_name": args.job_name,
|
|
356
|
+
"status": status,
|
|
357
|
+
"failure_reason": failure_reason,
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
# ── Argument parsing ──────────────────────────────────────────────────────────
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def main():
|
|
365
|
+
"""Parse arguments and dispatch to subcommand."""
|
|
366
|
+
parser = argparse.ArgumentParser(
|
|
367
|
+
description="SageMaker Processing Job helper for adapter staging",
|
|
368
|
+
prog=".adapter_helper.py",
|
|
369
|
+
)
|
|
370
|
+
subparsers = parser.add_subparsers(dest="subcommand", help="Subcommand")
|
|
371
|
+
|
|
372
|
+
# ── stage-from-tune ───────────────────────────────────────────────────
|
|
373
|
+
stage_parser = subparsers.add_parser(
|
|
374
|
+
"stage-from-tune",
|
|
375
|
+
help="Submit Processing Job to stage adapter from training output",
|
|
376
|
+
)
|
|
377
|
+
stage_parser.add_argument(
|
|
378
|
+
"--training-output-s3-uri",
|
|
379
|
+
required=True,
|
|
380
|
+
help="S3 URI of training output (adapter artifacts)",
|
|
381
|
+
)
|
|
382
|
+
stage_parser.add_argument(
|
|
383
|
+
"--adapter-name",
|
|
384
|
+
required=True,
|
|
385
|
+
help="Name of the adapter (used in output S3 path)",
|
|
386
|
+
)
|
|
387
|
+
stage_parser.add_argument(
|
|
388
|
+
"--bucket",
|
|
389
|
+
required=True,
|
|
390
|
+
help="S3 bucket for adapter output",
|
|
391
|
+
)
|
|
392
|
+
stage_parser.add_argument(
|
|
393
|
+
"--project",
|
|
394
|
+
required=True,
|
|
395
|
+
help="Project name (used in S3 path prefix)",
|
|
396
|
+
)
|
|
397
|
+
stage_parser.add_argument(
|
|
398
|
+
"--role-arn",
|
|
399
|
+
required=True,
|
|
400
|
+
help="SageMaker execution role ARN",
|
|
401
|
+
)
|
|
402
|
+
stage_parser.add_argument(
|
|
403
|
+
"--region",
|
|
404
|
+
default=None,
|
|
405
|
+
help="AWS region (default: from environment)",
|
|
406
|
+
)
|
|
407
|
+
stage_parser.add_argument(
|
|
408
|
+
"--container-image",
|
|
409
|
+
default=None,
|
|
410
|
+
help="Override container image URI (default: SageMaker PyTorch CPU image)",
|
|
411
|
+
)
|
|
412
|
+
stage_parser.add_argument(
|
|
413
|
+
"--no-wait",
|
|
414
|
+
action="store_true",
|
|
415
|
+
default=False,
|
|
416
|
+
help="Return immediately after submitting the job",
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# ── status ────────────────────────────────────────────────────────────
|
|
420
|
+
status_parser = subparsers.add_parser(
|
|
421
|
+
"status",
|
|
422
|
+
help="Check Processing Job status",
|
|
423
|
+
)
|
|
424
|
+
status_parser.add_argument(
|
|
425
|
+
"--job-name",
|
|
426
|
+
required=True,
|
|
427
|
+
help="Processing Job name to check",
|
|
428
|
+
)
|
|
429
|
+
status_parser.add_argument(
|
|
430
|
+
"--region",
|
|
431
|
+
default=None,
|
|
432
|
+
help="AWS region (default: from environment)",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# ── Parse and dispatch ────────────────────────────────────────────────
|
|
436
|
+
args = parser.parse_args()
|
|
437
|
+
|
|
438
|
+
if not args.subcommand:
|
|
439
|
+
parser.print_help()
|
|
440
|
+
sys.exit(1)
|
|
441
|
+
|
|
442
|
+
if args.subcommand == "stage-from-tune":
|
|
443
|
+
cmd_stage_from_tune(args)
|
|
444
|
+
elif args.subcommand == "status":
|
|
445
|
+
cmd_status(args)
|
|
446
|
+
else:
|
|
447
|
+
_error_exit(f"Unknown subcommand: {args.subcommand}")
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
if __name__ == "__main__":
|
|
451
|
+
main()
|