@aws/ml-container-creator 0.13.4 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/README.md +23 -5
  2. package/config/parameter-schema-v2.json +32 -4
  3. package/infra/ci-harness/lib/ci-harness-stack.ts +13 -5
  4. package/infra/ci-harness/package-lock.json +122 -116
  5. package/infra/ci-harness/package.json +1 -1
  6. package/package.json +5 -3
  7. package/pyproject.toml +21 -0
  8. package/requirements.txt +19 -0
  9. package/servers/instance-sizer/index.js +72 -4
  10. package/servers/instance-sizer/lib/model-resolver.js +28 -2
  11. package/src/app.js +17 -0
  12. package/src/lib/bootstrap-command-handler.js +33 -23
  13. package/src/lib/config-loader.js +18 -0
  14. package/src/lib/config-manager.js +6 -1
  15. package/src/lib/dataset-slug.js +152 -0
  16. package/src/lib/generated/cli-options.js +9 -3
  17. package/src/lib/generated/parameter-matrix.js +14 -3
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +6 -0
  20. package/src/lib/prompt-runner.js +5 -0
  21. package/src/lib/prompts/feature-prompts.js +1 -1
  22. package/src/lib/template-manager.js +0 -7
  23. package/src/lib/template-variable-resolver.js +51 -1
  24. package/src/lib/tune-config-state.js +14 -1
  25. package/templates/do/.adapter_helper.py +451 -0
  26. package/templates/do/.benchmark_writer.py +22 -0
  27. package/templates/do/.register_helper.py +1163 -0
  28. package/templates/do/.stage_helper.py +419 -0
  29. package/templates/do/.tune_helper.py +379 -65
  30. package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
  31. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  32. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  33. package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
  34. package/templates/do/adapter +427 -27
  35. package/templates/do/add-ic +85 -3
  36. package/templates/do/benchmark +173 -15
  37. package/templates/do/config +24 -0
  38. package/templates/do/lib/inference-component.sh +56 -3
  39. package/templates/do/lib/profile.sh +5 -0
  40. package/templates/do/register +552 -6
  41. package/templates/do/stage +91 -272
  42. package/templates/do/test +12 -2
  43. package/templates/do/tune +264 -12
@@ -0,0 +1,419 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Processing Job helper for model staging.
6
+
7
+ Subcommands:
8
+ submit - Submit a Processing Job to download model from HuggingFace → S3
9
+ status - Check Processing Job status
10
+ cancel - Cancel a running Processing Job
11
+
12
+ Uses sagemaker-core ProcessingJob resource API (SDK v3).
13
+ No SageMaker SDK v2 imports.
14
+
15
+ All output is JSON on stdout for bash consumption.
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ import sys
22
+ import time
23
+
24
+
25
+ # ── Inline dependency check ───────────────────────────────────────────────────
26
+
27
+
28
+ def _check_sagemaker_core():
29
+ """Verify sagemaker-core is installed."""
30
+ try:
31
+ from sagemaker.core.resources import ProcessingJob # noqa: F401
32
+ except ImportError:
33
+ _error_exit(
34
+ "sagemaker-core is not installed. "
35
+ "Please install: pip install 'sagemaker>=3.0.0' (includes sagemaker-core)"
36
+ )
37
+
38
+
39
+ def _check_boto3():
40
+ """Verify boto3 is available (used for S3 entrypoint upload)."""
41
+ try:
42
+ import boto3 # noqa: F401
43
+ except ImportError:
44
+ _error_exit(
45
+ "boto3 is not installed. "
46
+ "Please install: pip install boto3"
47
+ )
48
+
49
+
50
+ # ── Utility functions ─────────────────────────────────────────────────────────
51
+
52
+
53
+ def _error_exit(message, code=1):
54
+ """Print error to stderr and exit."""
55
+ print(message, file=sys.stderr)
56
+ sys.exit(code)
57
+
58
+
59
+ def _output(data):
60
+ """Print JSON result to stdout and exit 0."""
61
+ print(json.dumps(data))
62
+ sys.exit(0)
63
+
64
+
65
+ # ── Entrypoint script template ────────────────────────────────────────────────
66
+
67
+ ENTRYPOINT_SCRIPT = r"""#!/bin/bash
68
+ set -e
69
+ set -o pipefail
70
+
71
+ echo "=== MCC Model Staging Processing Job ==="
72
+ echo "Model: ${MODEL_ID}"
73
+ echo "Target: ${S3_OUTPUT_URI}"
74
+ echo ""
75
+
76
+ # Install dependencies
77
+ echo "Installing huggingface_hub and hf_transfer..."
78
+ pip install -q huggingface_hub hf_transfer 2>/dev/null || true
79
+
80
+ # Enable fast parallel downloads only if hf_transfer is available
81
+ if python3 -c "import hf_transfer" 2>/dev/null; then
82
+ export HF_XET_HIGH_PERFORMANCE=1
83
+ else
84
+ echo "hf_transfer not available - using standard download"
85
+ unset HF_XET_HIGH_PERFORMANCE 2>/dev/null || true
86
+ fi
87
+
88
+ # Set HF token if provided
89
+ if [ -n "${HF_TOKEN:-}" ]; then
90
+ echo "Using provided HuggingFace token"
91
+ fi
92
+
93
+ # Download model from HuggingFace
94
+ echo ""
95
+ echo "Downloading model: ${MODEL_ID}"
96
+
97
+ # Use 'hf' CLI if available (modern), fall back to python snapshot_download
98
+ DOWNLOAD_CMD=""
99
+ if command -v hf &>/dev/null; then
100
+ DOWNLOAD_CMD="hf"
101
+ fi
102
+
103
+ DOWNLOAD_ARGS="${MODEL_ID} --local-dir /opt/ml/processing/model"
104
+ if [ -n "${HF_TOKEN:-}" ]; then
105
+ DOWNLOAD_ARGS="${DOWNLOAD_ARGS} --token ${HF_TOKEN}"
106
+ fi
107
+
108
+ if [ -n "${DOWNLOAD_CMD}" ]; then
109
+ ${DOWNLOAD_CMD} download ${DOWNLOAD_ARGS}
110
+ else
111
+ # Fallback: use Python API directly
112
+ python3 -c "
113
+ from huggingface_hub import snapshot_download
114
+ import os
115
+ token = os.environ.get('HF_TOKEN', None)
116
+ snapshot_download('${MODEL_ID}', local_dir='/opt/ml/processing/model', token=token)
117
+ "
118
+ fi
119
+
120
+ echo ""
121
+ echo "Download complete"
122
+
123
+ CACHE_PATH="/opt/ml/processing/model"
124
+ echo "Model path: ${CACHE_PATH}"
125
+
126
+ # Sync to S3
127
+ echo ""
128
+ echo "Syncing to S3: ${S3_OUTPUT_URI}"
129
+ aws s3 sync "${CACHE_PATH}" "${S3_OUTPUT_URI}" \
130
+ --no-progress \
131
+ --exclude "*.lock" \
132
+ --exclude ".gitattributes"
133
+
134
+ echo ""
135
+ echo "Model staged successfully to: ${S3_OUTPUT_URI}"
136
+ """
137
+
138
+
139
+ # ── Subcommand: submit ────────────────────────────────────────────────────────
140
+
141
+
142
+ def cmd_submit(args):
143
+ """Submit a Processing Job to stage model from HuggingFace to S3.
144
+
145
+ Returns JSON: {"job_name": str, "status": str, "s3_uri": str}
146
+ """
147
+ _check_sagemaker_core()
148
+ _check_boto3()
149
+
150
+ import boto3
151
+ from sagemaker.core.resources import ProcessingJob
152
+
153
+ # Validate AWS credentials
154
+ try:
155
+ sts = boto3.client("sts", region_name=args.region)
156
+ sts.get_caller_identity()
157
+ except Exception as e:
158
+ _error_exit(
159
+ f"AWS credentials not configured or expired: {e}\n"
160
+ "Run: aws configure",
161
+ code=4,
162
+ )
163
+
164
+ # Build S3 URI for staged model
165
+ s3_uri = f"s3://{args.bucket}/{args.project}/models/{args.model_name}/"
166
+
167
+ # Idempotency: check if model already exists at target S3 path
168
+ if not args.force:
169
+ s3 = boto3.client("s3", region_name=args.region)
170
+ try:
171
+ s3.head_object(
172
+ Bucket=args.bucket,
173
+ Key=f"{args.project}/models/{args.model_name}/config.json",
174
+ )
175
+ # Model already staged
176
+ _output({
177
+ "job_name": "",
178
+ "status": "AlreadyStaged",
179
+ "s3_uri": s3_uri,
180
+ })
181
+ except s3.exceptions.ClientError:
182
+ pass # Not staged yet, proceed
183
+
184
+ # Generate job name with timestamp
185
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
186
+ job_name = f"mlcc-stage-{args.project}-{timestamp}"
187
+ # SageMaker job names max 63 chars, must match [a-zA-Z0-9](-*[a-zA-Z0-9])*
188
+ job_name = job_name[:63].rstrip("-")
189
+ # Replace invalid characters
190
+ job_name = "".join(c if c.isalnum() or c == "-" else "-" for c in job_name)
191
+
192
+ # Upload entrypoint script to S3
193
+ entrypoint_s3_key = f"staging-jobs/{job_name}/entrypoint.sh"
194
+ entrypoint_s3_uri = f"s3://{args.bucket}/{entrypoint_s3_key}"
195
+
196
+ s3 = boto3.client("s3", region_name=args.region)
197
+ try:
198
+ s3.put_object(
199
+ Bucket=args.bucket,
200
+ Key=entrypoint_s3_key,
201
+ Body=ENTRYPOINT_SCRIPT.encode("utf-8"),
202
+ )
203
+ except Exception as e:
204
+ _error_exit(f"Failed to upload entrypoint script to S3: {e}")
205
+
206
+ # Build environment variables for the container
207
+ environment = {
208
+ "MODEL_ID": args.model_name,
209
+ "S3_OUTPUT_URI": s3_uri,
210
+ }
211
+ if args.hf_token:
212
+ environment["HF_TOKEN"] = args.hf_token
213
+
214
+ # Container image: SageMaker-managed PyTorch CPU image
215
+ container_image = (
216
+ f"763104351884.dkr.ecr.{args.region}.amazonaws.com/"
217
+ "pytorch-training:2.1.0-cpu-py310-ubuntu20.04-sagemaker"
218
+ )
219
+
220
+ # Build the entrypoint command that downloads + executes the script from S3
221
+ entrypoint_cmd = (
222
+ f"aws s3 cp {entrypoint_s3_uri} /tmp/entrypoint.sh && "
223
+ "chmod +x /tmp/entrypoint.sh && /tmp/entrypoint.sh"
224
+ )
225
+
226
+ # Submit Processing Job via sagemaker-core
227
+ print(f"Submitting Processing Job: {job_name}", file=sys.stderr)
228
+ try:
229
+ ProcessingJob.create(
230
+ processing_job_name=job_name,
231
+ processing_resources={
232
+ "cluster_config": {
233
+ "instance_count": 1,
234
+ "instance_type": args.instance_type,
235
+ "volume_size_in_gb": args.volume_size_gb,
236
+ }
237
+ },
238
+ app_specification={
239
+ "image_uri": container_image,
240
+ "container_entrypoint": ["bash", "-c", entrypoint_cmd],
241
+ },
242
+ environment=environment,
243
+ role_arn=args.role_arn,
244
+ stopping_condition={"max_runtime_in_seconds": 86400},
245
+ )
246
+ except Exception as e:
247
+ error_msg = str(e)
248
+ if "AccessDeniedException" in error_msg or "AccessDenied" in error_msg:
249
+ _error_exit(
250
+ f"Access denied creating Processing Job. "
251
+ f"Ensure the execution role has sagemaker:CreateProcessingJob permission.\n"
252
+ f"Details: {error_msg}"
253
+ )
254
+ _error_exit(f"Failed to create Processing Job: {error_msg}")
255
+
256
+ # If --no-wait, return immediately with job name
257
+ if args.no_wait:
258
+ _output({
259
+ "job_name": job_name,
260
+ "status": "Submitted",
261
+ "s3_uri": s3_uri,
262
+ })
263
+
264
+ # Poll every 30s until terminal state
265
+ _poll_job(job_name, s3_uri, args.region)
266
+
267
+
268
+ def _poll_job(job_name, s3_uri, region):
269
+ """Poll Processing Job status every 30s until completion.
270
+
271
+ On success: output JSON to stdout.
272
+ On failure: print failure_reason to stderr, exit 1.
273
+ """
274
+ from sagemaker.core.resources import ProcessingJob
275
+
276
+ print(f"Polling Processing Job status (every 30s)...", file=sys.stderr)
277
+
278
+ while True:
279
+ try:
280
+ job_desc = ProcessingJob.get(processing_job_name=job_name)
281
+ except Exception as e:
282
+ print(f"Warning: failed to get job status (retrying): {e}", file=sys.stderr)
283
+ time.sleep(30)
284
+ continue
285
+
286
+ status = job_desc.processing_job_status
287
+
288
+ print(f"Status: {status}", file=sys.stderr)
289
+
290
+ if status in ("Completed", "Failed", "Stopped"):
291
+ break
292
+
293
+ time.sleep(30)
294
+
295
+ if status == "Failed":
296
+ failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown"
297
+ print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
298
+ sys.exit(1)
299
+
300
+ if status == "Stopped":
301
+ print(f"Processing Job was stopped: {job_name}", file=sys.stderr)
302
+ sys.exit(1)
303
+
304
+ # Success
305
+ _output({
306
+ "job_name": job_name,
307
+ "status": "Completed",
308
+ "s3_uri": s3_uri,
309
+ })
310
+
311
+
312
+ # ── Subcommand: status ────────────────────────────────────────────────────────
313
+
314
+
315
+ def cmd_status(args):
316
+ """Check Processing Job status.
317
+
318
+ Returns JSON: {"job_name": str, "status": str, "failure_reason": str|None}
319
+ """
320
+ _check_sagemaker_core()
321
+
322
+ from sagemaker.core.resources import ProcessingJob
323
+
324
+ try:
325
+ job_desc = ProcessingJob.get(processing_job_name=args.job_name)
326
+ except Exception as e:
327
+ _error_exit(f"Failed to get Processing Job status: {e}")
328
+
329
+ status = job_desc.processing_job_status
330
+ failure_reason = getattr(job_desc, "failure_reason", None)
331
+
332
+ _output({
333
+ "job_name": args.job_name,
334
+ "status": status,
335
+ "failure_reason": failure_reason,
336
+ })
337
+
338
+
339
+ # ── Subcommand: cancel ────────────────────────────────────────────────────────
340
+
341
+
342
+ def cmd_cancel(args):
343
+ """Cancel a running Processing Job.
344
+
345
+ Returns JSON: {"job_name": str, "status": str}
346
+ """
347
+ _check_sagemaker_core()
348
+
349
+ from sagemaker.core.resources import ProcessingJob
350
+
351
+ try:
352
+ job_desc = ProcessingJob.get(processing_job_name=args.job_name)
353
+ status = job_desc.processing_job_status
354
+
355
+ if status in ("Completed", "Failed", "Stopped"):
356
+ _output({
357
+ "job_name": args.job_name,
358
+ "status": status,
359
+ "message": f"Job already in terminal state: {status}",
360
+ })
361
+
362
+ job_desc.stop()
363
+ except Exception as e:
364
+ _error_exit(f"Failed to cancel Processing Job: {e}")
365
+
366
+ _output({
367
+ "job_name": args.job_name,
368
+ "status": "Stopping",
369
+ })
370
+
371
+
372
+ # ── CLI argument parsing ──────────────────────────────────────────────────────
373
+
374
+
375
+ def main():
376
+ parser = argparse.ArgumentParser(
377
+ description="SageMaker Processing Job helper for model staging"
378
+ )
379
+ subparsers = parser.add_subparsers(dest="command", required=True)
380
+
381
+ # submit
382
+ submit_parser = subparsers.add_parser("submit", help="Submit a Processing Job")
383
+ submit_parser.add_argument("--model-name", required=True, help="HuggingFace model ID")
384
+ submit_parser.add_argument("--bucket", required=True, help="S3 bucket for staging")
385
+ submit_parser.add_argument("--project", required=True, help="Project name")
386
+ submit_parser.add_argument("--role-arn", required=True, help="IAM execution role ARN")
387
+ submit_parser.add_argument("--region", required=True, help="AWS region")
388
+ submit_parser.add_argument("--hf-token", default="", help="HuggingFace token (for gated models)")
389
+ submit_parser.add_argument("--instance-type", default="ml.m5.xlarge", help="Instance type")
390
+ submit_parser.add_argument("--volume-size-gb", type=int, default=2048, help="Volume size in GB")
391
+ submit_parser.add_argument("--no-wait", action="store_true", help="Return immediately without polling")
392
+ submit_parser.add_argument("--force", action="store_true", help="Re-stage even if already present")
393
+ submit_parser.set_defaults(func=cmd_submit)
394
+
395
+ # status
396
+ status_parser = subparsers.add_parser("status", help="Check Processing Job status")
397
+ status_parser.add_argument("--job-name", required=True, help="Processing Job name")
398
+ status_parser.add_argument("--region", default=None, help="AWS region")
399
+ status_parser.set_defaults(func=cmd_status)
400
+
401
+ # cancel
402
+ cancel_parser = subparsers.add_parser("cancel", help="Cancel a Processing Job")
403
+ cancel_parser.add_argument("--job-name", required=True, help="Processing Job name")
404
+ cancel_parser.add_argument("--region", default=None, help="AWS region")
405
+ cancel_parser.set_defaults(func=cmd_cancel)
406
+
407
+ args = parser.parse_args()
408
+
409
+ # Set region in environment if provided (sagemaker-core uses env vars)
410
+ region = getattr(args, "region", None)
411
+ if region:
412
+ os.environ.setdefault("AWS_DEFAULT_REGION", region)
413
+ os.environ.setdefault("AWS_REGION", region)
414
+
415
+ args.func(args)
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()