@aws/ml-container-creator 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -112,6 +112,74 @@ def _truncate_metadata(props):
112
112
  return result
113
113
 
114
114
 
115
+ def _inject_eval_metrics(metadata, args):
116
+ """Inject evaluation metrics from .mlcc/eval-results/ into metadata.
117
+
118
+ Looks for eval results matching the adapter name or project.
119
+ Adds metrics with 'eval_' prefix (G4 AC-3.1, AC-3.2).
120
+ Non-fatal: if no eval results exist, returns metadata unchanged.
121
+
122
+ Args:
123
+ metadata: existing metadata dict (may be None)
124
+ args: parsed args with project_name, adapter name hints
125
+
126
+ Returns:
127
+ metadata dict with eval metrics injected (or unchanged)
128
+ """
129
+ if metadata is None:
130
+ metadata = {}
131
+
132
+ # Determine eval results directory (relative to script location)
133
+ # Convention: .mlcc/eval-results/<adapter-or-ic-name>.json
134
+ script_dir = os.path.dirname(os.path.abspath(__file__))
135
+ eval_results_dir = os.path.join(script_dir, "..", ".mlcc", "eval-results")
136
+
137
+ if not os.path.isdir(eval_results_dir):
138
+ return metadata
139
+
140
+ # Try to find eval results for this adapter
141
+ # Prioritize: adapter name from args > any available result
142
+ adapter_name = getattr(args, 'adapter_name', '') or ''
143
+
144
+ # Search for matching eval result file
145
+ eval_file = None
146
+ if adapter_name:
147
+ candidate = os.path.join(eval_results_dir, f"{adapter_name}.json")
148
+ if os.path.isfile(candidate):
149
+ eval_file = candidate
150
+
151
+ # If no specific adapter match, try to find any recent result
152
+ if not eval_file:
153
+ try:
154
+ json_files = [f for f in os.listdir(eval_results_dir) if f.endswith('.json')]
155
+ if json_files:
156
+ # Use most recently modified
157
+ json_files.sort(key=lambda f: os.path.getmtime(os.path.join(eval_results_dir, f)), reverse=True)
158
+ eval_file = os.path.join(eval_results_dir, json_files[0])
159
+ except OSError:
160
+ pass
161
+
162
+ if not eval_file:
163
+ return metadata
164
+
165
+ # Load and inject metrics
166
+ try:
167
+ with open(eval_file, 'r') as f:
168
+ eval_data = json.load(f)
169
+ metrics = eval_data.get("metrics", {})
170
+ for metric_name, metric_value in metrics.items():
171
+ # Add with eval_ prefix, truncate to 256 chars
172
+ key = f"eval_{metric_name}"
173
+ str_val = str(metric_value)[:MAX_METADATA_VALUE_LEN]
174
+ metadata[key] = str_val
175
+ if metrics:
176
+ _warn(f"Injected {len(metrics)} eval metric(s) from {os.path.basename(eval_file)}")
177
+ except (IOError, json.JSONDecodeError, KeyError):
178
+ pass # Non-fatal — skip eval metrics if file is unreadable
179
+
180
+ return metadata
181
+
182
+
115
183
  def _build_metadata(args):
116
184
  """Build customer_metadata_properties dict from CLI args.
117
185
 
@@ -283,7 +351,7 @@ def cmd_register_model(args):
283
351
 
284
352
  # Step 3: Build inference specification
285
353
  container_image = args.container_image or ""
286
- model_data_url = args.model_data_url or ""
354
+ model_data_url = (args.model_data_url or "").rstrip("/")
287
355
 
288
356
  # Step 4: Create Model Package version (AC-1.2, AC-1.7)
289
357
  description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
@@ -437,7 +505,7 @@ def cmd_register_adapter(args):
437
505
 
438
506
  # Step 3: Build inference specification
439
507
  container_image = args.container_image or ""
440
- model_data_url = args.model_data_url or ""
508
+ model_data_url = (args.model_data_url or "").rstrip("/")
441
509
 
442
510
  # Step 4: Create adapter Model Package version (AC-2.1)
443
511
  technique = args.tune_technique or "unknown"
@@ -463,12 +531,21 @@ def cmd_register_adapter(args):
463
531
  "SupportedContentTypes": ["application/json"],
464
532
  "SupportedResponseMIMETypes": ["application/json"],
465
533
  }
466
- if model_data_url:
534
+ # ModelDataUrl in InferenceSpecification requires a tar.gz object —
535
+ # uncompressed S3 prefixes (adapter directories) are not supported.
536
+ # Store uncompressed paths in metadata instead.
537
+ if model_data_url and model_data_url.endswith(".tar.gz"):
467
538
  create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
468
- elif model_data_url:
539
+
540
+ # Always store model/adapter data URL in metadata for registry queries
541
+ if model_data_url:
469
542
  if not metadata:
470
543
  metadata = {}
471
544
  metadata["modelDataUrl"] = model_data_url[:1024]
545
+
546
+ # Inject evaluation metrics if available (G4 AC-3.1, AC-3.2)
547
+ metadata = _inject_eval_metrics(metadata, args)
548
+
472
549
  if metadata:
473
550
  create_params["CustomerMetadataProperties"] = metadata
474
551
 
@@ -1366,9 +1443,24 @@ def cmd_get_version(args):
1366
1443
  os.environ.setdefault("AWS_REGION", region)
1367
1444
 
1368
1445
  try:
1369
- from sagemaker.core.resources import ModelPackage
1446
+ import boto3
1447
+ sm_client = boto3.client("sagemaker", region_name=region)
1370
1448
 
1371
- pkg = ModelPackage.get(model_package_arn=version_arn)
1449
+ # Use boto3 directly — sagemaker-core v2.14 ModelPackage.get() requires
1450
+ # model_package_name (not ARN) and rejects model_package_arn as unexpected kwarg.
1451
+ pkg_response = sm_client.describe_model_package(ModelPackageName=version_arn)
1452
+
1453
+ # Wrap in a simple namespace for consistent access below
1454
+ class _Pkg:
1455
+ def __init__(self, data):
1456
+ self._data = data
1457
+ self.model_package_arn = data.get("ModelPackageArn", version_arn)
1458
+ self.inference_specification = data.get("InferenceSpecification")
1459
+ self.customer_metadata_properties = data.get("CustomerMetadataProperties", {})
1460
+ self.model_approval_status = data.get("ModelApprovalStatus", "")
1461
+ self.model_package_description = data.get("ModelPackageDescription", "")
1462
+ self.creation_time = data.get("CreationTime")
1463
+ pkg = _Pkg(pkg_response)
1372
1464
 
1373
1465
  # Extract model data URL from inference spec
1374
1466
  model_data_url = ""
@@ -1381,6 +1473,10 @@ def cmd_get_version(args):
1381
1473
  # Get metadata
1382
1474
  metadata = getattr(pkg, "customer_metadata_properties", None) or {}
1383
1475
 
1476
+ # Fallback: modelDataUrl stored in metadata when adapter is uncompressed S3 prefix
1477
+ if not model_data_url and metadata.get("modelDataUrl"):
1478
+ model_data_url = metadata["modelDataUrl"]
1479
+
1384
1480
  # Get status
1385
1481
  status = getattr(pkg, "model_approval_status", "") or ""
1386
1482
 
@@ -1414,6 +1510,7 @@ def cmd_resolve_dataset(args):
1414
1510
 
1415
1511
  Version resolution (AC-2.1, AC-2.4):
1416
1512
  - --version N: resolve the Nth version (ordinal, 1-based) for this name
1513
+ - --version X.Y.Z: resolve by semver string match
1417
1514
  - No --version: resolve latest (existing behavior)
1418
1515
  - If requested version doesn't exist: print available versions and exit 1 (AC-2.5)
1419
1516
 
@@ -1421,14 +1518,20 @@ def cmd_resolve_dataset(args):
1421
1518
  or error if not found.
1422
1519
  """
1423
1520
  name = args.name
1424
- version_ordinal = getattr(args, "version", None)
1521
+ version_spec = getattr(args, "version", None)
1425
1522
 
1426
1523
  if not name:
1427
1524
  _error_exit("--name is required", code="MISSING_ARGUMENT")
1428
1525
 
1429
1526
  # If version is specified, use version-aware resolution
1430
- if version_ordinal is not None:
1431
- return _resolve_dataset_version(name, version_ordinal)
1527
+ if version_spec is not None:
1528
+ # Determine if it's an ordinal (pure integer) or semver string
1529
+ try:
1530
+ version_ordinal = int(version_spec)
1531
+ return _resolve_dataset_version(name, version_ordinal)
1532
+ except ValueError:
1533
+ # Not an integer — treat as semver string
1534
+ return _resolve_dataset_version_by_semver(name, version_spec)
1432
1535
 
1433
1536
  # No version — resolve latest (existing behavior)
1434
1537
  # Try SageMaker AI Registry API first
@@ -1545,6 +1648,77 @@ def _resolve_dataset_version(name, version_ordinal):
1545
1648
  _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
1546
1649
 
1547
1650
 
1651
+ def _resolve_dataset_version_by_semver(name, version_str):
1652
+ """Resolve a specific version of a named dataset by semver string match.
1653
+
1654
+ Searches the versions[] array for an entry whose 'version' field matches
1655
+ the provided semver string (e.g., '1.0.0').
1656
+
1657
+ If the version doesn't exist, prints available versions and exits 1.
1658
+
1659
+ Args:
1660
+ name: Dataset name
1661
+ version_str: Semver string to match (e.g., '1.0.0', '2.1.0')
1662
+ """
1663
+ # Load local registry
1664
+ entries = _load_registry(_DATASETS_REGISTRY)
1665
+
1666
+ for entry in entries:
1667
+ if entry.get("name") == name:
1668
+ versions = entry.get("versions", [])
1669
+
1670
+ if not versions:
1671
+ # Legacy entry without versions array — treat as having version "1.0.0"
1672
+ if version_str == "1.0.0":
1673
+ output = dict(entry)
1674
+ output["version"] = "1.0.0"
1675
+ output["ordinal"] = 1
1676
+ if "arn" not in output:
1677
+ output["arn"] = None
1678
+ _output(output)
1679
+ else:
1680
+ print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
1681
+ print(f"Available versions: 1.0.0", file=sys.stderr)
1682
+ print(json.dumps({
1683
+ "error": f"Version {version_str} not found for dataset '{name}'",
1684
+ "code": "VERSION_NOT_FOUND",
1685
+ "available_versions": [{"ordinal": 1, "version": "1.0.0"}],
1686
+ }))
1687
+ sys.exit(1)
1688
+
1689
+ # Search for matching version string
1690
+ for i, v in enumerate(versions, 1):
1691
+ ver = v.get("version", "")
1692
+ if ver == version_str:
1693
+ _output({
1694
+ "name": name,
1695
+ "s3_uri": v.get("s3_uri", entry.get("s3_uri", "")),
1696
+ "arn": entry.get("arn"),
1697
+ "format": v.get("format", entry.get("format", "jsonl")),
1698
+ "technique": v.get("technique", entry.get("technique", "")),
1699
+ "version": ver,
1700
+ "ordinal": i,
1701
+ "hash": v.get("hash"),
1702
+ })
1703
+
1704
+ # Version string not found — show available
1705
+ print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
1706
+ available = []
1707
+ for i, v in enumerate(versions, 1):
1708
+ ver = v.get("version", f"{i}.0.0")
1709
+ available.append({"ordinal": i, "version": ver})
1710
+ print(f" v{i} ({ver})", file=sys.stderr)
1711
+ print(json.dumps({
1712
+ "error": f"Version {version_str} not found for dataset '{name}'",
1713
+ "code": "VERSION_NOT_FOUND",
1714
+ "available_versions": available,
1715
+ }))
1716
+ sys.exit(1)
1717
+
1718
+ # Dataset name not found at all
1719
+ _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
1720
+
1721
+
1548
1722
  # ── Subcommand: resolve-evaluator ────────────────────────────────────────────
1549
1723
 
1550
1724
 
@@ -1706,8 +1880,8 @@ def main():
1706
1880
  help="Resolve a registered dataset by name",
1707
1881
  )
1708
1882
  resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
1709
- resolve_dataset_parser.add_argument("--version", type=int, default=None,
1710
- help="Version ordinal to resolve (e.g., 2 for the 2nd version). Default: latest.")
1883
+ resolve_dataset_parser.add_argument("--version", type=str, default=None,
1884
+ help="Version to resolve: ordinal (e.g., 2) or semver (e.g., 1.0.0). Default: latest.")
1711
1885
 
1712
1886
  # ── resolve-evaluator ─────────────────────────────────────────────────
1713
1887
  resolve_evaluator_parser = subparsers.add_parser(
@@ -1,14 +1,11 @@
1
1
  #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
2
  # SPDX-License-Identifier: Apache-2.0
3
+ """Build a CreateTrainingJob JSON request from CLI arguments.
4
4
 
5
- """
6
- Build the CreateTrainingJob JSON request for SageMaker.
5
+ Called by do/train _build_job_request() to construct the JSON payload
6
+ that is later passed to either AWS CLI or .train_helper.py for submission.
7
7
 
8
- This helper is called by do/train to construct the full API request body.
9
- It handles conditional fields (spot training, metric definitions, environment,
10
- tags) and writes the result to a JSON file for use with:
11
- aws sagemaker create-training-job --cli-input-json file://path.json
8
+ Outputs a JSON file at --output-file containing the full CreateTrainingJob request.
12
9
  """
13
10
 
14
11
  import argparse
@@ -16,126 +13,118 @@ import json
16
13
  import sys
17
14
 
18
15
 
19
- def parse_args():
20
- """Parse command-line arguments."""
21
- parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
22
- parser.add_argument('--job-name', required=True, help='Training job name')
23
- parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
24
- parser.add_argument('--image', required=True, help='Training container image URI')
25
- parser.add_argument('--instance-type', required=True, help='Instance type')
26
- parser.add_argument('--instance-count', required=True, help='Instance count')
27
- parser.add_argument('--volume-size', required=True, help='Volume size in GB')
28
- parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
29
- parser.add_argument('--output-path', required=True, help='S3 URI for output')
30
- parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
31
- parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
32
- parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
33
- parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
34
- parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
35
- parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
36
- parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
37
- parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
38
- parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
39
- return parser.parse_args()
40
-
41
-
42
- def build_request(args):
43
- """Construct the CreateTrainingJob request dictionary."""
44
- # Parse JSON inputs
45
- hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
46
- metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
47
- environment = json.loads(args.environment) if args.environment else {}
48
- tags = json.loads(args.tags) if args.tags else {}
49
-
50
- # Base request structure
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Build CreateTrainingJob JSON request")
18
+ parser.add_argument("--job-name", required=True)
19
+ parser.add_argument("--role-arn", required=True)
20
+ parser.add_argument("--image", required=True)
21
+ parser.add_argument("--instance-type", required=True)
22
+ parser.add_argument("--instance-count", default="1")
23
+ parser.add_argument("--volume-size", default="50")
24
+ parser.add_argument("--dataset", default="")
25
+ parser.add_argument("--output-path", required=True)
26
+ parser.add_argument("--max-runtime", default="86400")
27
+ parser.add_argument("--hyperparams", default="{}")
28
+ parser.add_argument("--enable-spot", default="false")
29
+ parser.add_argument("--max-wait", default="172800")
30
+ parser.add_argument("--checkpoint-path", default="")
31
+ parser.add_argument("--metric-definitions", default="[]")
32
+ parser.add_argument("--environment", default="{}")
33
+ parser.add_argument("--tags", default="[]")
34
+ parser.add_argument("--output-file", required=True)
35
+ args = parser.parse_args()
36
+
37
+ # Parse JSON args
38
+ try:
39
+ hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
40
+ except json.JSONDecodeError:
41
+ hyperparams = {}
42
+
43
+ try:
44
+ metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
45
+ except json.JSONDecodeError:
46
+ metric_definitions = []
47
+
48
+ try:
49
+ environment = json.loads(args.environment) if args.environment else {}
50
+ except json.JSONDecodeError:
51
+ environment = {}
52
+
53
+ try:
54
+ tags = json.loads(args.tags) if args.tags else []
55
+ except json.JSONDecodeError:
56
+ tags = []
57
+
58
+ # Build the request
51
59
  request = {
52
- 'TrainingJobName': args.job_name,
53
- 'RoleArn': args.role_arn,
54
- 'AlgorithmSpecification': {
55
- 'TrainingImage': args.image,
56
- 'TrainingInputMode': 'File'
60
+ "TrainingJobName": args.job_name,
61
+ "RoleArn": args.role_arn,
62
+ "AlgorithmSpecification": {
63
+ "TrainingImage": args.image,
64
+ "TrainingInputMode": "File",
57
65
  },
58
- 'InputDataConfig': [
59
- {
60
- 'ChannelName': 'training',
61
- 'DataSource': {
62
- 'S3DataSource': {
63
- 'S3DataType': 'S3Prefix',
64
- 'S3Uri': args.dataset,
65
- 'S3DataDistributionType': 'FullyReplicated'
66
- }
67
- }
68
- }
69
- ],
70
- 'OutputDataConfig': {
71
- 'S3OutputPath': args.output_path
66
+ "ResourceConfig": {
67
+ "InstanceType": args.instance_type,
68
+ "InstanceCount": int(args.instance_count),
69
+ "VolumeSizeInGB": int(args.volume_size),
72
70
  },
73
- 'ResourceConfig': {
74
- 'InstanceType': args.instance_type,
75
- 'InstanceCount': int(args.instance_count),
76
- 'VolumeSizeInGB': int(args.volume_size)
71
+ "OutputDataConfig": {
72
+ "S3OutputPath": args.output_path,
73
+ },
74
+ "StoppingCondition": {
75
+ "MaxRuntimeInSeconds": int(args.max_runtime),
77
76
  },
78
- 'StoppingCondition': {
79
- 'MaxRuntimeInSeconds': int(args.max_runtime)
80
- }
81
77
  }
82
78
 
83
- # Hyperparameters ensure all values are strings (SageMaker requirement)
84
- if hyperparams:
85
- request['HyperParameters'] = {
86
- str(k): str(v) for k, v in hyperparams.items()
87
- }
88
-
89
- # Managed spot training
90
- if args.enable_spot == 'true':
91
- request['EnableManagedSpotTraining'] = True
92
- request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
93
-
94
- # Checkpoint configuration (for spot training resumption)
95
- if args.checkpoint_path:
96
- request['CheckpointConfig'] = {
97
- 'S3Uri': args.checkpoint_path
98
- }
99
-
100
- # Metric definitions (custom CloudWatch metrics)
101
- if metric_definitions and metric_definitions != []:
102
- request['AlgorithmSpecification']['MetricDefinitions'] = [
103
- {'Name': m['name'], 'Regex': m['regex']}
104
- for m in metric_definitions
79
+ # Input data channels
80
+ if args.dataset:
81
+ request["InputDataConfig"] = [
82
+ {
83
+ "ChannelName": "training",
84
+ "DataSource": {
85
+ "S3DataSource": {
86
+ "S3DataType": "S3Prefix",
87
+ "S3Uri": args.dataset,
88
+ "S3DataDistributionType": "FullyReplicated",
89
+ }
90
+ },
91
+ "ContentType": "application/jsonlines",
92
+ }
105
93
  ]
106
94
 
107
- # Environment variables for the container
108
- if environment and environment != {}:
109
- request['Environment'] = environment
95
+ # Hyperparameters (all values must be strings)
96
+ if hyperparams:
97
+ request["HyperParameters"] = {k: str(v) for k, v in hyperparams.items()}
110
98
 
111
- # Tags — convert from {key: value} map to [{Key: k, Value: v}] array
112
- if tags and tags != {}:
113
- request['Tags'] = [
114
- {'Key': str(k), 'Value': str(v)}
115
- for k, v in tags.items()
116
- ]
99
+ # Environment variables
100
+ if environment:
101
+ request["Environment"] = {k: str(v) for k, v in environment.items()}
117
102
 
118
- return request
103
+ # Metric definitions
104
+ if metric_definitions:
105
+ request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
119
106
 
107
+ # Spot training
108
+ if args.enable_spot.lower() == "true":
109
+ request["EnableManagedSpotTraining"] = True
110
+ request["StoppingCondition"]["MaxWaitTimeInSeconds"] = int(args.max_wait)
120
111
 
121
- def main():
122
- """Main entry point."""
123
- args = parse_args()
112
+ # Checkpoint config
113
+ if args.checkpoint_path:
114
+ request["CheckpointConfig"] = {
115
+ "S3Uri": args.checkpoint_path,
116
+ }
124
117
 
125
- try:
126
- request = build_request(args)
127
- except (json.JSONDecodeError, ValueError) as e:
128
- print(f'❌ Failed to build request: {e}', file=sys.stderr)
129
- sys.exit(1)
118
+ # Tags
119
+ if tags:
120
+ request["Tags"] = tags
130
121
 
131
- # Write the JSON request to the output file
132
- try:
133
- with open(args.output_file, 'w') as f:
134
- json.dump(request, f, indent=2)
135
- except IOError as e:
136
- print(f'❌ Failed to write request file: {e}', file=sys.stderr)
137
- sys.exit(1)
122
+ # Write to output file
123
+ with open(args.output_file, "w") as f:
124
+ json.dump(request, f, indent=2)
125
+
126
+ print(f"✅ Request written to {args.output_file}", file=sys.stderr)
138
127
 
139
128
 
140
- if __name__ == '__main__':
129
+ if __name__ == "__main__":
141
130
  main()