@aws/ml-container-creator 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -1
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +5 -2
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/config +6 -1
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -112,6 +112,74 @@ def _truncate_metadata(props):
|
|
|
112
112
|
return result
|
|
113
113
|
|
|
114
114
|
|
|
115
|
+
def _inject_eval_metrics(metadata, args):
|
|
116
|
+
"""Inject evaluation metrics from .mlcc/eval-results/ into metadata.
|
|
117
|
+
|
|
118
|
+
Looks for eval results matching the adapter name or project.
|
|
119
|
+
Adds metrics with 'eval_' prefix (G4 AC-3.1, AC-3.2).
|
|
120
|
+
Non-fatal: if no eval results exist, returns metadata unchanged.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
metadata: existing metadata dict (may be None)
|
|
124
|
+
args: parsed args with project_name, adapter name hints
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
metadata dict with eval metrics injected (or unchanged)
|
|
128
|
+
"""
|
|
129
|
+
if metadata is None:
|
|
130
|
+
metadata = {}
|
|
131
|
+
|
|
132
|
+
# Determine eval results directory (relative to script location)
|
|
133
|
+
# Convention: .mlcc/eval-results/<adapter-or-ic-name>.json
|
|
134
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
135
|
+
eval_results_dir = os.path.join(script_dir, "..", ".mlcc", "eval-results")
|
|
136
|
+
|
|
137
|
+
if not os.path.isdir(eval_results_dir):
|
|
138
|
+
return metadata
|
|
139
|
+
|
|
140
|
+
# Try to find eval results for this adapter
|
|
141
|
+
# Prioritize: adapter name from args > any available result
|
|
142
|
+
adapter_name = getattr(args, 'adapter_name', '') or ''
|
|
143
|
+
|
|
144
|
+
# Search for matching eval result file
|
|
145
|
+
eval_file = None
|
|
146
|
+
if adapter_name:
|
|
147
|
+
candidate = os.path.join(eval_results_dir, f"{adapter_name}.json")
|
|
148
|
+
if os.path.isfile(candidate):
|
|
149
|
+
eval_file = candidate
|
|
150
|
+
|
|
151
|
+
# If no specific adapter match, try to find any recent result
|
|
152
|
+
if not eval_file:
|
|
153
|
+
try:
|
|
154
|
+
json_files = [f for f in os.listdir(eval_results_dir) if f.endswith('.json')]
|
|
155
|
+
if json_files:
|
|
156
|
+
# Use most recently modified
|
|
157
|
+
json_files.sort(key=lambda f: os.path.getmtime(os.path.join(eval_results_dir, f)), reverse=True)
|
|
158
|
+
eval_file = os.path.join(eval_results_dir, json_files[0])
|
|
159
|
+
except OSError:
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
if not eval_file:
|
|
163
|
+
return metadata
|
|
164
|
+
|
|
165
|
+
# Load and inject metrics
|
|
166
|
+
try:
|
|
167
|
+
with open(eval_file, 'r') as f:
|
|
168
|
+
eval_data = json.load(f)
|
|
169
|
+
metrics = eval_data.get("metrics", {})
|
|
170
|
+
for metric_name, metric_value in metrics.items():
|
|
171
|
+
# Add with eval_ prefix, truncate to 256 chars
|
|
172
|
+
key = f"eval_{metric_name}"
|
|
173
|
+
str_val = str(metric_value)[:MAX_METADATA_VALUE_LEN]
|
|
174
|
+
metadata[key] = str_val
|
|
175
|
+
if metrics:
|
|
176
|
+
_warn(f"Injected {len(metrics)} eval metric(s) from {os.path.basename(eval_file)}")
|
|
177
|
+
except (IOError, json.JSONDecodeError, KeyError):
|
|
178
|
+
pass # Non-fatal — skip eval metrics if file is unreadable
|
|
179
|
+
|
|
180
|
+
return metadata
|
|
181
|
+
|
|
182
|
+
|
|
115
183
|
def _build_metadata(args):
|
|
116
184
|
"""Build customer_metadata_properties dict from CLI args.
|
|
117
185
|
|
|
@@ -283,7 +351,7 @@ def cmd_register_model(args):
|
|
|
283
351
|
|
|
284
352
|
# Step 3: Build inference specification
|
|
285
353
|
container_image = args.container_image or ""
|
|
286
|
-
model_data_url = args.model_data_url or ""
|
|
354
|
+
model_data_url = (args.model_data_url or "").rstrip("/")
|
|
287
355
|
|
|
288
356
|
# Step 4: Create Model Package version (AC-1.2, AC-1.7)
|
|
289
357
|
description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
|
|
@@ -437,7 +505,7 @@ def cmd_register_adapter(args):
|
|
|
437
505
|
|
|
438
506
|
# Step 3: Build inference specification
|
|
439
507
|
container_image = args.container_image or ""
|
|
440
|
-
model_data_url = args.model_data_url or ""
|
|
508
|
+
model_data_url = (args.model_data_url or "").rstrip("/")
|
|
441
509
|
|
|
442
510
|
# Step 4: Create adapter Model Package version (AC-2.1)
|
|
443
511
|
technique = args.tune_technique or "unknown"
|
|
@@ -463,12 +531,21 @@ def cmd_register_adapter(args):
|
|
|
463
531
|
"SupportedContentTypes": ["application/json"],
|
|
464
532
|
"SupportedResponseMIMETypes": ["application/json"],
|
|
465
533
|
}
|
|
466
|
-
|
|
534
|
+
# ModelDataUrl in InferenceSpecification requires a tar.gz object —
|
|
535
|
+
# uncompressed S3 prefixes (adapter directories) are not supported.
|
|
536
|
+
# Store uncompressed paths in metadata instead.
|
|
537
|
+
if model_data_url and model_data_url.endswith(".tar.gz"):
|
|
467
538
|
create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
|
|
468
|
-
|
|
539
|
+
|
|
540
|
+
# Always store model/adapter data URL in metadata for registry queries
|
|
541
|
+
if model_data_url:
|
|
469
542
|
if not metadata:
|
|
470
543
|
metadata = {}
|
|
471
544
|
metadata["modelDataUrl"] = model_data_url[:1024]
|
|
545
|
+
|
|
546
|
+
# Inject evaluation metrics if available (G4 AC-3.1, AC-3.2)
|
|
547
|
+
metadata = _inject_eval_metrics(metadata, args)
|
|
548
|
+
|
|
472
549
|
if metadata:
|
|
473
550
|
create_params["CustomerMetadataProperties"] = metadata
|
|
474
551
|
|
|
@@ -1366,9 +1443,24 @@ def cmd_get_version(args):
|
|
|
1366
1443
|
os.environ.setdefault("AWS_REGION", region)
|
|
1367
1444
|
|
|
1368
1445
|
try:
|
|
1369
|
-
|
|
1446
|
+
import boto3
|
|
1447
|
+
sm_client = boto3.client("sagemaker", region_name=region)
|
|
1370
1448
|
|
|
1371
|
-
|
|
1449
|
+
# Use boto3 directly — sagemaker-core v2.14 ModelPackage.get() requires
|
|
1450
|
+
# model_package_name (not ARN) and rejects model_package_arn as unexpected kwarg.
|
|
1451
|
+
pkg_response = sm_client.describe_model_package(ModelPackageName=version_arn)
|
|
1452
|
+
|
|
1453
|
+
# Wrap in a simple namespace for consistent access below
|
|
1454
|
+
class _Pkg:
|
|
1455
|
+
def __init__(self, data):
|
|
1456
|
+
self._data = data
|
|
1457
|
+
self.model_package_arn = data.get("ModelPackageArn", version_arn)
|
|
1458
|
+
self.inference_specification = data.get("InferenceSpecification")
|
|
1459
|
+
self.customer_metadata_properties = data.get("CustomerMetadataProperties", {})
|
|
1460
|
+
self.model_approval_status = data.get("ModelApprovalStatus", "")
|
|
1461
|
+
self.model_package_description = data.get("ModelPackageDescription", "")
|
|
1462
|
+
self.creation_time = data.get("CreationTime")
|
|
1463
|
+
pkg = _Pkg(pkg_response)
|
|
1372
1464
|
|
|
1373
1465
|
# Extract model data URL from inference spec
|
|
1374
1466
|
model_data_url = ""
|
|
@@ -1381,6 +1473,10 @@ def cmd_get_version(args):
|
|
|
1381
1473
|
# Get metadata
|
|
1382
1474
|
metadata = getattr(pkg, "customer_metadata_properties", None) or {}
|
|
1383
1475
|
|
|
1476
|
+
# Fallback: modelDataUrl stored in metadata when adapter is uncompressed S3 prefix
|
|
1477
|
+
if not model_data_url and metadata.get("modelDataUrl"):
|
|
1478
|
+
model_data_url = metadata["modelDataUrl"]
|
|
1479
|
+
|
|
1384
1480
|
# Get status
|
|
1385
1481
|
status = getattr(pkg, "model_approval_status", "") or ""
|
|
1386
1482
|
|
|
@@ -1414,6 +1510,7 @@ def cmd_resolve_dataset(args):
|
|
|
1414
1510
|
|
|
1415
1511
|
Version resolution (AC-2.1, AC-2.4):
|
|
1416
1512
|
- --version N: resolve the Nth version (ordinal, 1-based) for this name
|
|
1513
|
+
- --version X.Y.Z: resolve by semver string match
|
|
1417
1514
|
- No --version: resolve latest (existing behavior)
|
|
1418
1515
|
- If requested version doesn't exist: print available versions and exit 1 (AC-2.5)
|
|
1419
1516
|
|
|
@@ -1421,14 +1518,20 @@ def cmd_resolve_dataset(args):
|
|
|
1421
1518
|
or error if not found.
|
|
1422
1519
|
"""
|
|
1423
1520
|
name = args.name
|
|
1424
|
-
|
|
1521
|
+
version_spec = getattr(args, "version", None)
|
|
1425
1522
|
|
|
1426
1523
|
if not name:
|
|
1427
1524
|
_error_exit("--name is required", code="MISSING_ARGUMENT")
|
|
1428
1525
|
|
|
1429
1526
|
# If version is specified, use version-aware resolution
|
|
1430
|
-
if
|
|
1431
|
-
|
|
1527
|
+
if version_spec is not None:
|
|
1528
|
+
# Determine if it's an ordinal (pure integer) or semver string
|
|
1529
|
+
try:
|
|
1530
|
+
version_ordinal = int(version_spec)
|
|
1531
|
+
return _resolve_dataset_version(name, version_ordinal)
|
|
1532
|
+
except ValueError:
|
|
1533
|
+
# Not an integer — treat as semver string
|
|
1534
|
+
return _resolve_dataset_version_by_semver(name, version_spec)
|
|
1432
1535
|
|
|
1433
1536
|
# No version — resolve latest (existing behavior)
|
|
1434
1537
|
# Try SageMaker AI Registry API first
|
|
@@ -1545,6 +1648,77 @@ def _resolve_dataset_version(name, version_ordinal):
|
|
|
1545
1648
|
_error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
|
|
1546
1649
|
|
|
1547
1650
|
|
|
1651
|
+
def _resolve_dataset_version_by_semver(name, version_str):
|
|
1652
|
+
"""Resolve a specific version of a named dataset by semver string match.
|
|
1653
|
+
|
|
1654
|
+
Searches the versions[] array for an entry whose 'version' field matches
|
|
1655
|
+
the provided semver string (e.g., '1.0.0').
|
|
1656
|
+
|
|
1657
|
+
If the version doesn't exist, prints available versions and exits 1.
|
|
1658
|
+
|
|
1659
|
+
Args:
|
|
1660
|
+
name: Dataset name
|
|
1661
|
+
version_str: Semver string to match (e.g., '1.0.0', '2.1.0')
|
|
1662
|
+
"""
|
|
1663
|
+
# Load local registry
|
|
1664
|
+
entries = _load_registry(_DATASETS_REGISTRY)
|
|
1665
|
+
|
|
1666
|
+
for entry in entries:
|
|
1667
|
+
if entry.get("name") == name:
|
|
1668
|
+
versions = entry.get("versions", [])
|
|
1669
|
+
|
|
1670
|
+
if not versions:
|
|
1671
|
+
# Legacy entry without versions array — treat as having version "1.0.0"
|
|
1672
|
+
if version_str == "1.0.0":
|
|
1673
|
+
output = dict(entry)
|
|
1674
|
+
output["version"] = "1.0.0"
|
|
1675
|
+
output["ordinal"] = 1
|
|
1676
|
+
if "arn" not in output:
|
|
1677
|
+
output["arn"] = None
|
|
1678
|
+
_output(output)
|
|
1679
|
+
else:
|
|
1680
|
+
print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
|
|
1681
|
+
print(f"Available versions: 1.0.0", file=sys.stderr)
|
|
1682
|
+
print(json.dumps({
|
|
1683
|
+
"error": f"Version {version_str} not found for dataset '{name}'",
|
|
1684
|
+
"code": "VERSION_NOT_FOUND",
|
|
1685
|
+
"available_versions": [{"ordinal": 1, "version": "1.0.0"}],
|
|
1686
|
+
}))
|
|
1687
|
+
sys.exit(1)
|
|
1688
|
+
|
|
1689
|
+
# Search for matching version string
|
|
1690
|
+
for i, v in enumerate(versions, 1):
|
|
1691
|
+
ver = v.get("version", "")
|
|
1692
|
+
if ver == version_str:
|
|
1693
|
+
_output({
|
|
1694
|
+
"name": name,
|
|
1695
|
+
"s3_uri": v.get("s3_uri", entry.get("s3_uri", "")),
|
|
1696
|
+
"arn": entry.get("arn"),
|
|
1697
|
+
"format": v.get("format", entry.get("format", "jsonl")),
|
|
1698
|
+
"technique": v.get("technique", entry.get("technique", "")),
|
|
1699
|
+
"version": ver,
|
|
1700
|
+
"ordinal": i,
|
|
1701
|
+
"hash": v.get("hash"),
|
|
1702
|
+
})
|
|
1703
|
+
|
|
1704
|
+
# Version string not found — show available
|
|
1705
|
+
print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
|
|
1706
|
+
available = []
|
|
1707
|
+
for i, v in enumerate(versions, 1):
|
|
1708
|
+
ver = v.get("version", f"{i}.0.0")
|
|
1709
|
+
available.append({"ordinal": i, "version": ver})
|
|
1710
|
+
print(f" v{i} ({ver})", file=sys.stderr)
|
|
1711
|
+
print(json.dumps({
|
|
1712
|
+
"error": f"Version {version_str} not found for dataset '{name}'",
|
|
1713
|
+
"code": "VERSION_NOT_FOUND",
|
|
1714
|
+
"available_versions": available,
|
|
1715
|
+
}))
|
|
1716
|
+
sys.exit(1)
|
|
1717
|
+
|
|
1718
|
+
# Dataset name not found at all
|
|
1719
|
+
_error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
|
|
1720
|
+
|
|
1721
|
+
|
|
1548
1722
|
# ── Subcommand: resolve-evaluator ────────────────────────────────────────────
|
|
1549
1723
|
|
|
1550
1724
|
|
|
@@ -1706,8 +1880,8 @@ def main():
|
|
|
1706
1880
|
help="Resolve a registered dataset by name",
|
|
1707
1881
|
)
|
|
1708
1882
|
resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
|
|
1709
|
-
resolve_dataset_parser.add_argument("--version", type=
|
|
1710
|
-
help="Version
|
|
1883
|
+
resolve_dataset_parser.add_argument("--version", type=str, default=None,
|
|
1884
|
+
help="Version to resolve: ordinal (e.g., 2) or semver (e.g., 1.0.0). Default: latest.")
|
|
1711
1885
|
|
|
1712
1886
|
# ── resolve-evaluator ─────────────────────────────────────────────────
|
|
1713
1887
|
resolve_evaluator_parser = subparsers.add_parser(
|
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Build a CreateTrainingJob JSON request from CLI arguments.
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
Called by do/train _build_job_request() to construct the JSON payload
|
|
6
|
+
that is later passed to either AWS CLI or .train_helper.py for submission.
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
It handles conditional fields (spot training, metric definitions, environment,
|
|
10
|
-
tags) and writes the result to a JSON file for use with:
|
|
11
|
-
aws sagemaker create-training-job --cli-input-json file://path.json
|
|
8
|
+
Outputs a JSON file at --output-file containing the full CreateTrainingJob request.
|
|
12
9
|
"""
|
|
13
10
|
|
|
14
11
|
import argparse
|
|
@@ -16,126 +13,118 @@ import json
|
|
|
16
13
|
import sys
|
|
17
14
|
|
|
18
15
|
|
|
19
|
-
def
|
|
20
|
-
|
|
21
|
-
parser
|
|
22
|
-
parser.add_argument(
|
|
23
|
-
parser.add_argument(
|
|
24
|
-
parser.add_argument(
|
|
25
|
-
parser.add_argument(
|
|
26
|
-
parser.add_argument(
|
|
27
|
-
parser.add_argument(
|
|
28
|
-
parser.add_argument(
|
|
29
|
-
parser.add_argument(
|
|
30
|
-
parser.add_argument(
|
|
31
|
-
parser.add_argument(
|
|
32
|
-
parser.add_argument(
|
|
33
|
-
parser.add_argument(
|
|
34
|
-
parser.add_argument(
|
|
35
|
-
parser.add_argument(
|
|
36
|
-
parser.add_argument(
|
|
37
|
-
parser.add_argument(
|
|
38
|
-
parser.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
16
|
+
def main():
|
|
17
|
+
parser = argparse.ArgumentParser(description="Build CreateTrainingJob JSON request")
|
|
18
|
+
parser.add_argument("--job-name", required=True)
|
|
19
|
+
parser.add_argument("--role-arn", required=True)
|
|
20
|
+
parser.add_argument("--image", required=True)
|
|
21
|
+
parser.add_argument("--instance-type", required=True)
|
|
22
|
+
parser.add_argument("--instance-count", default="1")
|
|
23
|
+
parser.add_argument("--volume-size", default="50")
|
|
24
|
+
parser.add_argument("--dataset", default="")
|
|
25
|
+
parser.add_argument("--output-path", required=True)
|
|
26
|
+
parser.add_argument("--max-runtime", default="86400")
|
|
27
|
+
parser.add_argument("--hyperparams", default="{}")
|
|
28
|
+
parser.add_argument("--enable-spot", default="false")
|
|
29
|
+
parser.add_argument("--max-wait", default="172800")
|
|
30
|
+
parser.add_argument("--checkpoint-path", default="")
|
|
31
|
+
parser.add_argument("--metric-definitions", default="[]")
|
|
32
|
+
parser.add_argument("--environment", default="{}")
|
|
33
|
+
parser.add_argument("--tags", default="[]")
|
|
34
|
+
parser.add_argument("--output-file", required=True)
|
|
35
|
+
args = parser.parse_args()
|
|
36
|
+
|
|
37
|
+
# Parse JSON args
|
|
38
|
+
try:
|
|
39
|
+
hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
|
|
40
|
+
except json.JSONDecodeError:
|
|
41
|
+
hyperparams = {}
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
|
|
45
|
+
except json.JSONDecodeError:
|
|
46
|
+
metric_definitions = []
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
environment = json.loads(args.environment) if args.environment else {}
|
|
50
|
+
except json.JSONDecodeError:
|
|
51
|
+
environment = {}
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
tags = json.loads(args.tags) if args.tags else []
|
|
55
|
+
except json.JSONDecodeError:
|
|
56
|
+
tags = []
|
|
57
|
+
|
|
58
|
+
# Build the request
|
|
51
59
|
request = {
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
60
|
+
"TrainingJobName": args.job_name,
|
|
61
|
+
"RoleArn": args.role_arn,
|
|
62
|
+
"AlgorithmSpecification": {
|
|
63
|
+
"TrainingImage": args.image,
|
|
64
|
+
"TrainingInputMode": "File",
|
|
57
65
|
},
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
'S3DataSource': {
|
|
63
|
-
'S3DataType': 'S3Prefix',
|
|
64
|
-
'S3Uri': args.dataset,
|
|
65
|
-
'S3DataDistributionType': 'FullyReplicated'
|
|
66
|
-
}
|
|
67
|
-
}
|
|
68
|
-
}
|
|
69
|
-
],
|
|
70
|
-
'OutputDataConfig': {
|
|
71
|
-
'S3OutputPath': args.output_path
|
|
66
|
+
"ResourceConfig": {
|
|
67
|
+
"InstanceType": args.instance_type,
|
|
68
|
+
"InstanceCount": int(args.instance_count),
|
|
69
|
+
"VolumeSizeInGB": int(args.volume_size),
|
|
72
70
|
},
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
71
|
+
"OutputDataConfig": {
|
|
72
|
+
"S3OutputPath": args.output_path,
|
|
73
|
+
},
|
|
74
|
+
"StoppingCondition": {
|
|
75
|
+
"MaxRuntimeInSeconds": int(args.max_runtime),
|
|
77
76
|
},
|
|
78
|
-
'StoppingCondition': {
|
|
79
|
-
'MaxRuntimeInSeconds': int(args.max_runtime)
|
|
80
|
-
}
|
|
81
77
|
}
|
|
82
78
|
|
|
83
|
-
#
|
|
84
|
-
if
|
|
85
|
-
request[
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
'S3Uri': args.checkpoint_path
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
# Metric definitions (custom CloudWatch metrics)
|
|
101
|
-
if metric_definitions and metric_definitions != []:
|
|
102
|
-
request['AlgorithmSpecification']['MetricDefinitions'] = [
|
|
103
|
-
{'Name': m['name'], 'Regex': m['regex']}
|
|
104
|
-
for m in metric_definitions
|
|
79
|
+
# Input data channels
|
|
80
|
+
if args.dataset:
|
|
81
|
+
request["InputDataConfig"] = [
|
|
82
|
+
{
|
|
83
|
+
"ChannelName": "training",
|
|
84
|
+
"DataSource": {
|
|
85
|
+
"S3DataSource": {
|
|
86
|
+
"S3DataType": "S3Prefix",
|
|
87
|
+
"S3Uri": args.dataset,
|
|
88
|
+
"S3DataDistributionType": "FullyReplicated",
|
|
89
|
+
}
|
|
90
|
+
},
|
|
91
|
+
"ContentType": "application/jsonlines",
|
|
92
|
+
}
|
|
105
93
|
]
|
|
106
94
|
|
|
107
|
-
#
|
|
108
|
-
if
|
|
109
|
-
request[
|
|
95
|
+
# Hyperparameters (all values must be strings)
|
|
96
|
+
if hyperparams:
|
|
97
|
+
request["HyperParameters"] = {k: str(v) for k, v in hyperparams.items()}
|
|
110
98
|
|
|
111
|
-
#
|
|
112
|
-
if
|
|
113
|
-
request[
|
|
114
|
-
{'Key': str(k), 'Value': str(v)}
|
|
115
|
-
for k, v in tags.items()
|
|
116
|
-
]
|
|
99
|
+
# Environment variables
|
|
100
|
+
if environment:
|
|
101
|
+
request["Environment"] = {k: str(v) for k, v in environment.items()}
|
|
117
102
|
|
|
118
|
-
|
|
103
|
+
# Metric definitions
|
|
104
|
+
if metric_definitions:
|
|
105
|
+
request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
|
|
119
106
|
|
|
107
|
+
# Spot training
|
|
108
|
+
if args.enable_spot.lower() == "true":
|
|
109
|
+
request["EnableManagedSpotTraining"] = True
|
|
110
|
+
request["StoppingCondition"]["MaxWaitTimeInSeconds"] = int(args.max_wait)
|
|
120
111
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
112
|
+
# Checkpoint config
|
|
113
|
+
if args.checkpoint_path:
|
|
114
|
+
request["CheckpointConfig"] = {
|
|
115
|
+
"S3Uri": args.checkpoint_path,
|
|
116
|
+
}
|
|
124
117
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
print(f'❌ Failed to build request: {e}', file=sys.stderr)
|
|
129
|
-
sys.exit(1)
|
|
118
|
+
# Tags
|
|
119
|
+
if tags:
|
|
120
|
+
request["Tags"] = tags
|
|
130
121
|
|
|
131
|
-
# Write
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
print(f'❌ Failed to write request file: {e}', file=sys.stderr)
|
|
137
|
-
sys.exit(1)
|
|
122
|
+
# Write to output file
|
|
123
|
+
with open(args.output_file, "w") as f:
|
|
124
|
+
json.dump(request, f, indent=2)
|
|
125
|
+
|
|
126
|
+
print(f"✅ Request written to {args.output_file}", file=sys.stderr)
|
|
138
127
|
|
|
139
128
|
|
|
140
|
-
if __name__ ==
|
|
129
|
+
if __name__ == "__main__":
|
|
141
130
|
main()
|