@aws/ml-container-creator 0.13.3 → 0.13.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/README.md +23 -5
  2. package/infra/ci-harness/package-lock.json +1 -5
  3. package/package.json +5 -3
  4. package/pyproject.toml +21 -0
  5. package/requirements.txt +19 -0
  6. package/servers/instance-sizer/lib/model-resolver.js +127 -185
  7. package/servers/instance-sizer/lib/vram-estimator.js +86 -0
  8. package/servers/lib/catalogs/instances.json +0 -27
  9. package/src/app.js +2 -0
  10. package/src/lib/bootstrap-command-handler.js +35 -25
  11. package/src/lib/generated/cli-options.js +1 -1
  12. package/src/lib/generated/parameter-matrix.js +1 -1
  13. package/src/lib/generated/validation-rules.js +1 -1
  14. package/src/lib/prompt-runner.js +14 -31
  15. package/templates/IAM_PERMISSIONS.md +64 -13
  16. package/templates/do/.adapter_helper.py +451 -0
  17. package/templates/do/.benchmark_writer.py +13 -0
  18. package/templates/do/.stage_helper.py +419 -0
  19. package/templates/do/.tune_helper.py +218 -67
  20. package/templates/do/README.md +50 -604
  21. package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
  22. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  23. package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
  24. package/templates/do/adapter +109 -4
  25. package/templates/do/benchmark +150 -12
  26. package/templates/do/build +2 -5
  27. package/templates/do/clean.d/async-inference.ejs +2 -5
  28. package/templates/do/clean.d/batch-transform.ejs +2 -5
  29. package/templates/do/clean.d/hyperpod-eks.ejs +2 -5
  30. package/templates/do/clean.d/managed-inference.ejs +2 -5
  31. package/templates/do/config +4 -0
  32. package/templates/do/deploy.d/async-inference.ejs +6 -9
  33. package/templates/do/deploy.d/batch-transform.ejs +4 -7
  34. package/templates/do/deploy.d/hyperpod-eks.ejs +1 -4
  35. package/templates/do/deploy.d/managed-inference.ejs +15 -6
  36. package/templates/do/lib/profile.sh +24 -15
  37. package/templates/do/push +2 -5
  38. package/templates/do/register +2 -5
  39. package/templates/do/stage +114 -292
  40. package/templates/do/submit +1 -4
  41. package/templates/do/tune +64 -10
  42. package/templates/MIGRATION.md +0 -488
  43. package/templates/TEMPLATE_SYSTEM.md +0 -243
@@ -0,0 +1,451 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Processing Job helper for adapter staging.
6
+
7
+ Subcommands:
8
+ stage-from-tune - Submit Processing Job to copy adapter from training output to S3
9
+ status - Check Processing Job status
10
+
11
+ All output is JSON on stdout for bash consumption.
12
+
13
+ Uses sagemaker-core ProcessingJob.create() / ProcessingJob.get() per SDK v3 policy.
14
+ """
15
+
16
+ import argparse
17
+ import logging
18
+ import json
19
+ import os
20
+ import sys
21
+ import time
22
+ import warnings
23
+
24
+ # Suppress noisy dependency version warnings
25
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
26
+ warnings.filterwarnings("ignore", message=".*urllib3.*")
27
+
28
+ # Suppress sagemaker-core INFO/WARNING logging that pollutes stdout
29
+ logging.getLogger("sagemaker.config").setLevel(logging.ERROR)
30
+ logging.getLogger("sagemaker.core").setLevel(logging.ERROR)
31
+ logging.getLogger("sagemaker").setLevel(logging.ERROR)
32
+
33
+ # ── Constants ─────────────────────────────────────────────────────────────────
34
+ POLL_INTERVAL_SECONDS = 30
35
+ MAX_RUNTIME_SECONDS = 3600 # 1 hour timeout for adapter staging
36
+ INSTANCE_TYPE = "ml.m5.large"
37
+ VOLUME_SIZE_GB = 100
38
+
39
+ # ── Utility functions ─────────────────────────────────────────────────────────
40
+
41
+
42
+ def _error_exit(message, exit_code=1):
43
+ """Print error to stderr and exit."""
44
+ print(f"Error: {message}", file=sys.stderr)
45
+ sys.exit(exit_code)
46
+
47
+
48
+ def _output(data):
49
+ """Print JSON result to stdout."""
50
+ print(json.dumps(data))
51
+ sys.exit(0)
52
+
53
+
54
+ # ── Dependency checks ─────────────────────────────────────────────────────────
55
+
56
+
57
+ def _check_sagemaker_core():
58
+ """Verify sagemaker-core is installed."""
59
+ try:
60
+ from sagemaker.core.resources import ProcessingJob # noqa: F401
61
+ except ImportError:
62
+ _error_exit(
63
+ "sagemaker-core is not installed. "
64
+ "Please install: pip install 'sagemaker>=3.0.0' (includes sagemaker-core)"
65
+ )
66
+
67
+
68
+ def _check_boto3():
69
+ """Verify boto3 is installed (needed for S3 entrypoint upload)."""
70
+ try:
71
+ import boto3 # noqa: F401
72
+ except ImportError:
73
+ _error_exit(
74
+ "boto3 is not installed. "
75
+ "Please install: pip install boto3"
76
+ )
77
+
78
+
79
+ # ── Processing Job helpers ────────────────────────────────────────────────────
80
+
81
+
82
+ def _generate_job_name(project_name, adapter_name):
83
+ """Generate a unique Processing Job name."""
84
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
85
+ # Job names must be <= 63 chars, start with alphanumeric
86
+ base = f"mlcc-adapter-{project_name}-{adapter_name}"
87
+ # Truncate base to leave room for timestamp
88
+ max_base = 63 - len(timestamp) - 1
89
+ if len(base) > max_base:
90
+ base = base[:max_base]
91
+ return f"{base}-{timestamp}"
92
+
93
+
94
+ def _upload_entrypoint(bucket, job_name, region):
95
+ """Upload the processing job entrypoint script to S3.
96
+
97
+ The entrypoint simply copies files from the Processing input path
98
+ to the Processing output path (SageMaker handles S3 download/upload).
99
+
100
+ Returns the S3 URI of the uploaded entrypoint.
101
+ """
102
+ import boto3
103
+
104
+ entrypoint_content = """#!/bin/bash
105
+ set -e
106
+ echo "Adapter staging: copying input to output..."
107
+ echo "Input contents:"
108
+ ls -la /opt/ml/processing/input/adapter/ || echo "No input files found"
109
+ echo ""
110
+ echo "Copying adapter files..."
111
+ cp -r /opt/ml/processing/input/adapter/* /opt/ml/processing/output/ 2>/dev/null || \
112
+ cp -r /opt/ml/processing/input/adapter/. /opt/ml/processing/output/
113
+ echo "Output contents:"
114
+ ls -la /opt/ml/processing/output/
115
+ echo ""
116
+ echo "Adapter staging complete."
117
+ """
118
+
119
+ s3_key = f"staging-jobs/{job_name}/entrypoint.sh"
120
+ s3_uri = f"s3://{bucket}/{s3_key}"
121
+
122
+ s3_client = boto3.client("s3", region_name=region)
123
+ try:
124
+ s3_client.put_object(
125
+ Bucket=bucket,
126
+ Key=s3_key,
127
+ Body=entrypoint_content.encode("utf-8"),
128
+ ContentType="text/x-shellscript",
129
+ )
130
+ except Exception as e:
131
+ _error_exit(f"Failed to upload entrypoint to S3: {e}")
132
+
133
+ return s3_uri
134
+
135
+
136
+ def _resolve_container_image(region):
137
+ """Resolve the SageMaker-managed PyTorch CPU image URI for the region.
138
+
139
+ Uses the standard SageMaker DLC (Deep Learning Container) PyTorch CPU image
140
+ which includes AWS CLI and Python 3.10.
141
+ """
142
+ # SageMaker DLC account IDs per region
143
+ # https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html
144
+ dlc_accounts = {
145
+ "us-east-1": "763104351884",
146
+ "us-east-2": "763104351884",
147
+ "us-west-1": "763104351884",
148
+ "us-west-2": "763104351884",
149
+ "eu-west-1": "763104351884",
150
+ "eu-west-2": "763104351884",
151
+ "eu-central-1": "763104351884",
152
+ "ap-northeast-1": "763104351884",
153
+ "ap-southeast-1": "763104351884",
154
+ "ap-southeast-2": "763104351884",
155
+ "ap-south-1": "763104351884",
156
+ "ca-central-1": "763104351884",
157
+ }
158
+ account_id = dlc_accounts.get(region, "763104351884")
159
+ # Use PyTorch CPU processing image
160
+ return f"{account_id}.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.2.0-cpu-py310-ubuntu20.04-sagemaker"
161
+
162
+
163
+ # ── Subcommand: stage-from-tune ───────────────────────────────────────────────
164
+
165
+
166
+ def cmd_stage_from_tune(args):
167
+ """Submit a Processing Job to copy adapter from training output to S3 adapter location.
168
+
169
+ Returns: {"job_name": str, "status": str, "adapter_s3_uri": str}
170
+ """
171
+ _check_sagemaker_core()
172
+ _check_boto3()
173
+
174
+ from sagemaker.core.resources import ProcessingJob
175
+
176
+ # Validate required arguments
177
+ if not args.training_output_s3_uri:
178
+ _error_exit("--training-output-s3-uri is required")
179
+ if not args.adapter_name:
180
+ _error_exit("--adapter-name is required")
181
+ if not args.bucket:
182
+ _error_exit("--bucket is required")
183
+ if not args.project:
184
+ _error_exit("--project is required")
185
+ if not args.role_arn:
186
+ _error_exit("--role-arn is required")
187
+
188
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
189
+ # Ensure region is set in env for sagemaker-core
190
+ os.environ["AWS_DEFAULT_REGION"] = region
191
+ os.environ.setdefault("AWS_REGION", region)
192
+
193
+ # Generate job name
194
+ job_name = _generate_job_name(args.project, args.adapter_name)
195
+
196
+ # Build adapter output S3 URI
197
+ adapter_s3_uri = f"s3://{args.bucket}/{args.project}/adapters/{args.adapter_name}/"
198
+
199
+ # Resolve container image
200
+ container_image = args.container_image or _resolve_container_image(region)
201
+
202
+ # Upload entrypoint script to S3
203
+ entrypoint_s3_uri = _upload_entrypoint(args.bucket, job_name, region)
204
+
205
+ # Build entrypoint command — download script from S3 then execute
206
+ entrypoint_cmd = (
207
+ f"aws s3 cp {entrypoint_s3_uri} /tmp/entrypoint.sh && "
208
+ "chmod +x /tmp/entrypoint.sh && /tmp/entrypoint.sh"
209
+ )
210
+
211
+ # Normalize training output S3 URI (ensure trailing slash for S3Prefix)
212
+ training_output_s3_uri = args.training_output_s3_uri
213
+ if not training_output_s3_uri.endswith("/"):
214
+ training_output_s3_uri += "/"
215
+
216
+ # Submit Processing Job via sagemaker-core
217
+ try:
218
+ job = ProcessingJob.create(
219
+ processing_job_name=job_name,
220
+ processing_resources={
221
+ "cluster_config": {
222
+ "instance_count": 1,
223
+ "instance_type": INSTANCE_TYPE,
224
+ "volume_size_in_gb": VOLUME_SIZE_GB,
225
+ }
226
+ },
227
+ processing_inputs=[{
228
+ "input_name": "adapter",
229
+ "s3_input": {
230
+ "s3_uri": training_output_s3_uri,
231
+ "s3_data_type": "S3Prefix",
232
+ "s3_input_mode": "File",
233
+ "local_path": "/opt/ml/processing/input/adapter",
234
+ }
235
+ }],
236
+ processing_output_config={
237
+ "outputs": [{
238
+ "output_name": "staged-adapter",
239
+ "s3_output": {
240
+ "s3_uri": adapter_s3_uri,
241
+ "s3_upload_mode": "EndOfJob",
242
+ "local_path": "/opt/ml/processing/output",
243
+ }
244
+ }]
245
+ },
246
+ app_specification={
247
+ "image_uri": container_image,
248
+ "container_entrypoint": ["bash", "-c", entrypoint_cmd],
249
+ },
250
+ role_arn=args.role_arn,
251
+ stopping_condition={"max_runtime_in_seconds": MAX_RUNTIME_SECONDS},
252
+ )
253
+ except Exception as e:
254
+ error_msg = str(e)
255
+ if "AccessDeniedException" in error_msg or "AccessDenied" in error_msg:
256
+ _error_exit(
257
+ f"Access denied when creating Processing Job. "
258
+ f"Ensure the role has sagemaker:CreateProcessingJob permission. "
259
+ f"Details: {error_msg}"
260
+ )
261
+ elif "ResourceLimitExceeded" in error_msg:
262
+ _error_exit(
263
+ f"Resource limit exceeded. You may need to request a quota increase. "
264
+ f"Details: {error_msg}"
265
+ )
266
+ else:
267
+ _error_exit(f"Failed to create Processing Job: {error_msg}")
268
+
269
+ print(f"Processing Job submitted: {job_name}", file=sys.stderr)
270
+ print(f"Adapter output: {adapter_s3_uri}", file=sys.stderr)
271
+
272
+ # If --no-wait, return immediately
273
+ if args.no_wait:
274
+ _output({
275
+ "job_name": job_name,
276
+ "status": "InProgress",
277
+ "adapter_s3_uri": adapter_s3_uri,
278
+ })
279
+
280
+ # Poll until completion
281
+ print(f"Polling every {POLL_INTERVAL_SECONDS}s...", file=sys.stderr)
282
+ while True:
283
+ try:
284
+ job_desc = ProcessingJob.get(processing_job_name=job_name)
285
+ status = job_desc.processing_job_status
286
+ except Exception as e:
287
+ print(f"Warning: failed to get job status: {e}", file=sys.stderr)
288
+ time.sleep(POLL_INTERVAL_SECONDS)
289
+ continue
290
+
291
+ print(
292
+ f" [{time.strftime('%H:%M:%S')}] Status: {status}",
293
+ file=sys.stderr,
294
+ )
295
+
296
+ if status in ("Completed", "Failed", "Stopped"):
297
+ break
298
+
299
+ time.sleep(POLL_INTERVAL_SECONDS)
300
+
301
+ # Handle terminal states
302
+ if status == "Failed":
303
+ failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown failure"
304
+ print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
305
+ sys.exit(1)
306
+
307
+ if status == "Stopped":
308
+ print("Processing Job was stopped.", file=sys.stderr)
309
+ sys.exit(1)
310
+
311
+ # Success
312
+ _output({
313
+ "job_name": job_name,
314
+ "status": "Completed",
315
+ "adapter_s3_uri": adapter_s3_uri,
316
+ })
317
+
318
+
319
+ # ── Subcommand: status ────────────────────────────────────────────────────────
320
+
321
+
322
+ def cmd_status(args):
323
+ """Check Processing Job status.
324
+
325
+ Returns: {"job_name": str, "status": str, "failure_reason": str|None}
326
+ """
327
+ _check_sagemaker_core()
328
+
329
+ from sagemaker.core.resources import ProcessingJob
330
+
331
+ if not args.job_name:
332
+ _error_exit("--job-name is required")
333
+
334
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
335
+ os.environ["AWS_DEFAULT_REGION"] = region
336
+ os.environ.setdefault("AWS_REGION", region)
337
+
338
+ try:
339
+ job_desc = ProcessingJob.get(processing_job_name=args.job_name)
340
+ except Exception as e:
341
+ error_msg = str(e)
342
+ if "does not exist" in error_msg or "ValidationException" in error_msg:
343
+ _error_exit(f"Processing Job not found: {args.job_name}")
344
+ else:
345
+ _error_exit(f"Failed to get Processing Job status: {error_msg}")
346
+
347
+ status = job_desc.processing_job_status
348
+ failure_reason = None
349
+
350
+ if status == "Failed":
351
+ failure_reason = getattr(job_desc, "failure_reason", None) or "Unknown failure"
352
+ print(f"Processing Job failed: {failure_reason}", file=sys.stderr)
353
+
354
+ _output({
355
+ "job_name": args.job_name,
356
+ "status": status,
357
+ "failure_reason": failure_reason,
358
+ })
359
+
360
+
361
+ # ── Argument parsing ──────────────────────────────────────────────────────────
362
+
363
+
364
+ def main():
365
+ """Parse arguments and dispatch to subcommand."""
366
+ parser = argparse.ArgumentParser(
367
+ description="SageMaker Processing Job helper for adapter staging",
368
+ prog=".adapter_helper.py",
369
+ )
370
+ subparsers = parser.add_subparsers(dest="subcommand", help="Subcommand")
371
+
372
+ # ── stage-from-tune ───────────────────────────────────────────────────
373
+ stage_parser = subparsers.add_parser(
374
+ "stage-from-tune",
375
+ help="Submit Processing Job to stage adapter from training output",
376
+ )
377
+ stage_parser.add_argument(
378
+ "--training-output-s3-uri",
379
+ required=True,
380
+ help="S3 URI of training output (adapter artifacts)",
381
+ )
382
+ stage_parser.add_argument(
383
+ "--adapter-name",
384
+ required=True,
385
+ help="Name of the adapter (used in output S3 path)",
386
+ )
387
+ stage_parser.add_argument(
388
+ "--bucket",
389
+ required=True,
390
+ help="S3 bucket for adapter output",
391
+ )
392
+ stage_parser.add_argument(
393
+ "--project",
394
+ required=True,
395
+ help="Project name (used in S3 path prefix)",
396
+ )
397
+ stage_parser.add_argument(
398
+ "--role-arn",
399
+ required=True,
400
+ help="SageMaker execution role ARN",
401
+ )
402
+ stage_parser.add_argument(
403
+ "--region",
404
+ default=None,
405
+ help="AWS region (default: from environment)",
406
+ )
407
+ stage_parser.add_argument(
408
+ "--container-image",
409
+ default=None,
410
+ help="Override container image URI (default: SageMaker PyTorch CPU image)",
411
+ )
412
+ stage_parser.add_argument(
413
+ "--no-wait",
414
+ action="store_true",
415
+ default=False,
416
+ help="Return immediately after submitting the job",
417
+ )
418
+
419
+ # ── status ────────────────────────────────────────────────────────────
420
+ status_parser = subparsers.add_parser(
421
+ "status",
422
+ help="Check Processing Job status",
423
+ )
424
+ status_parser.add_argument(
425
+ "--job-name",
426
+ required=True,
427
+ help="Processing Job name to check",
428
+ )
429
+ status_parser.add_argument(
430
+ "--region",
431
+ default=None,
432
+ help="AWS region (default: from environment)",
433
+ )
434
+
435
+ # ── Parse and dispatch ────────────────────────────────────────────────
436
+ args = parser.parse_args()
437
+
438
+ if not args.subcommand:
439
+ parser.print_help()
440
+ sys.exit(1)
441
+
442
+ if args.subcommand == "stage-from-tune":
443
+ cmd_stage_from_tune(args)
444
+ elif args.subcommand == "status":
445
+ cmd_status(args)
446
+ else:
447
+ _error_exit(f"Unknown subcommand: {args.subcommand}")
448
+
449
+
450
+ if __name__ == "__main__":
451
+ main()
@@ -1385,6 +1385,7 @@ def _load_config_file(config_path):
1385
1385
  shell_map = {
1386
1386
  'PROJECT_NAME': 'project_name',
1387
1387
  'MODEL_NAME': 'model_name',
1388
+ 'HF_MODEL_ID': 'hf_model_id',
1388
1389
  'INSTANCE_TYPE': 'instance_type',
1389
1390
  'DEPLOYMENT_CONFIG': 'deployment_config',
1390
1391
  'DEPLOYMENT_TARGET': 'deployment_target',
@@ -1402,6 +1403,18 @@ def _load_config_file(config_path):
1402
1403
  except Exception:
1403
1404
  pass
1404
1405
 
1406
+ # Prefer HF_MODEL_ID over MODEL_NAME for the model_name field.
1407
+ # After do/stage runs, MODEL_NAME is rewritten to an S3 URI which is
1408
+ # unsuitable for S3 result paths (nested s3:// in path) and model family
1409
+ # derivation. HF_MODEL_ID preserves the original HuggingFace repo ID.
1410
+ if context.get('hf_model_id'):
1411
+ context['model_name'] = context.pop('hf_model_id')
1412
+ elif context.get('model_name', '').startswith('s3://'):
1413
+ # Fallback: if no HF_MODEL_ID but MODEL_NAME is an S3 URI, extract
1414
+ # the model slug from the S3 path (last non-empty segment)
1415
+ parts = context['model_name'].rstrip('/').split('/')
1416
+ context['model_name'] = parts[-1] if parts else context['model_name']
1417
+
1405
1418
  return context
1406
1419
 
1407
1420