@aws/ml-container-creator 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +1 -1
  2. package/bin/cli.js +1 -1
  3. package/config/tune-catalog.json +303 -1
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +3 -2
  6. package/servers/base-image-picker/index.js +65 -18
  7. package/servers/instance-sizer/index.js +32 -0
  8. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  9. package/servers/lib/catalogs/model-arch-support.json +51 -0
  10. package/servers/lib/catalogs/model-servers.json +2842 -1516
  11. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  12. package/src/app.js +6 -4
  13. package/src/lib/bootstrap-command-handler.js +12 -2
  14. package/src/lib/bootstrap-profile-manager.js +16 -0
  15. package/src/lib/cross-cutting-checker.js +6 -1
  16. package/src/lib/generated/cli-options.js +1 -1
  17. package/src/lib/generated/parameter-matrix.js +1 -1
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +110 -3
  20. package/src/lib/prompt-runner.js +66 -22
  21. package/src/lib/template-variable-resolver.js +8 -0
  22. package/src/lib/train-config-builder.js +339 -0
  23. package/templates/do/.benchmark_writer.py +3 -0
  24. package/templates/do/.eval_helper.py +409 -0
  25. package/templates/do/.register_helper.py +185 -11
  26. package/templates/do/.train_build_request.py +102 -113
  27. package/templates/do/.train_helper.py +433 -0
  28. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  29. package/templates/do/adapter +157 -0
  30. package/templates/do/benchmark +60 -3
  31. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  32. package/templates/do/evaluate +272 -0
  33. package/templates/do/lib/resolve-instance.sh +155 -0
  34. package/templates/do/register +5 -0
  35. package/templates/do/test +1 -0
  36. package/templates/do/train +879 -126
  37. package/templates/do/training/config.yaml +83 -11
  38. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  39. package/templates/do/training/dpo/defaults.yaml +26 -0
  40. package/templates/do/training/dpo/prompts.json +8 -0
  41. package/templates/do/training/dpo/train.py +363 -0
  42. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  43. package/templates/do/training/sft/defaults.yaml +18 -0
  44. package/templates/do/training/sft/prompts.json +7 -0
  45. package/templates/do/training/sft/train.py +310 -0
  46. package/templates/do/tune +11 -2
  47. package/templates/do/.train_poll_parser.py +0 -135
  48. package/templates/do/.train_status_parser.py +0 -187
  49. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -1,14 +1,11 @@
1
1
  #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
2
  # SPDX-License-Identifier: Apache-2.0
3
+ """Build a CreateTrainingJob JSON request from CLI arguments.
4
4
 
5
- """
6
- Build the CreateTrainingJob JSON request for SageMaker.
5
+ Called by do/train _build_job_request() to construct the JSON payload
6
+ that is later passed to either AWS CLI or .train_helper.py for submission.
7
7
 
8
- This helper is called by do/train to construct the full API request body.
9
- It handles conditional fields (spot training, metric definitions, environment,
10
- tags) and writes the result to a JSON file for use with:
11
- aws sagemaker create-training-job --cli-input-json file://path.json
8
+ Outputs a JSON file at --output-file containing the full CreateTrainingJob request.
12
9
  """
13
10
 
14
11
  import argparse
@@ -16,126 +13,118 @@ import json
16
13
  import sys
17
14
 
18
15
 
19
- def parse_args():
20
- """Parse command-line arguments."""
21
- parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
22
- parser.add_argument('--job-name', required=True, help='Training job name')
23
- parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
24
- parser.add_argument('--image', required=True, help='Training container image URI')
25
- parser.add_argument('--instance-type', required=True, help='Instance type')
26
- parser.add_argument('--instance-count', required=True, help='Instance count')
27
- parser.add_argument('--volume-size', required=True, help='Volume size in GB')
28
- parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
29
- parser.add_argument('--output-path', required=True, help='S3 URI for output')
30
- parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
31
- parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
32
- parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
33
- parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
34
- parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
35
- parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
36
- parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
37
- parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
38
- parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
39
- return parser.parse_args()
40
-
41
-
42
- def build_request(args):
43
- """Construct the CreateTrainingJob request dictionary."""
44
- # Parse JSON inputs
45
- hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
46
- metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
47
- environment = json.loads(args.environment) if args.environment else {}
48
- tags = json.loads(args.tags) if args.tags else {}
49
-
50
- # Base request structure
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Build CreateTrainingJob JSON request")
18
+ parser.add_argument("--job-name", required=True)
19
+ parser.add_argument("--role-arn", required=True)
20
+ parser.add_argument("--image", required=True)
21
+ parser.add_argument("--instance-type", required=True)
22
+ parser.add_argument("--instance-count", default="1")
23
+ parser.add_argument("--volume-size", default="50")
24
+ parser.add_argument("--dataset", default="")
25
+ parser.add_argument("--output-path", required=True)
26
+ parser.add_argument("--max-runtime", default="86400")
27
+ parser.add_argument("--hyperparams", default="{}")
28
+ parser.add_argument("--enable-spot", default="false")
29
+ parser.add_argument("--max-wait", default="172800")
30
+ parser.add_argument("--checkpoint-path", default="")
31
+ parser.add_argument("--metric-definitions", default="[]")
32
+ parser.add_argument("--environment", default="{}")
33
+ parser.add_argument("--tags", default="[]")
34
+ parser.add_argument("--output-file", required=True)
35
+ args = parser.parse_args()
36
+
37
+ # Parse JSON args
38
+ try:
39
+ hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
40
+ except json.JSONDecodeError:
41
+ hyperparams = {}
42
+
43
+ try:
44
+ metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
45
+ except json.JSONDecodeError:
46
+ metric_definitions = []
47
+
48
+ try:
49
+ environment = json.loads(args.environment) if args.environment else {}
50
+ except json.JSONDecodeError:
51
+ environment = {}
52
+
53
+ try:
54
+ tags = json.loads(args.tags) if args.tags else []
55
+ except json.JSONDecodeError:
56
+ tags = []
57
+
58
+ # Build the request
51
59
  request = {
52
- 'TrainingJobName': args.job_name,
53
- 'RoleArn': args.role_arn,
54
- 'AlgorithmSpecification': {
55
- 'TrainingImage': args.image,
56
- 'TrainingInputMode': 'File'
60
+ "TrainingJobName": args.job_name,
61
+ "RoleArn": args.role_arn,
62
+ "AlgorithmSpecification": {
63
+ "TrainingImage": args.image,
64
+ "TrainingInputMode": "File",
57
65
  },
58
- 'InputDataConfig': [
59
- {
60
- 'ChannelName': 'training',
61
- 'DataSource': {
62
- 'S3DataSource': {
63
- 'S3DataType': 'S3Prefix',
64
- 'S3Uri': args.dataset,
65
- 'S3DataDistributionType': 'FullyReplicated'
66
- }
67
- }
68
- }
69
- ],
70
- 'OutputDataConfig': {
71
- 'S3OutputPath': args.output_path
66
+ "ResourceConfig": {
67
+ "InstanceType": args.instance_type,
68
+ "InstanceCount": int(args.instance_count),
69
+ "VolumeSizeInGB": int(args.volume_size),
72
70
  },
73
- 'ResourceConfig': {
74
- 'InstanceType': args.instance_type,
75
- 'InstanceCount': int(args.instance_count),
76
- 'VolumeSizeInGB': int(args.volume_size)
71
+ "OutputDataConfig": {
72
+ "S3OutputPath": args.output_path,
73
+ },
74
+ "StoppingCondition": {
75
+ "MaxRuntimeInSeconds": int(args.max_runtime),
77
76
  },
78
- 'StoppingCondition': {
79
- 'MaxRuntimeInSeconds': int(args.max_runtime)
80
- }
81
77
  }
82
78
 
83
- # Hyperparameters ensure all values are strings (SageMaker requirement)
84
- if hyperparams:
85
- request['HyperParameters'] = {
86
- str(k): str(v) for k, v in hyperparams.items()
87
- }
88
-
89
- # Managed spot training
90
- if args.enable_spot == 'true':
91
- request['EnableManagedSpotTraining'] = True
92
- request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
93
-
94
- # Checkpoint configuration (for spot training resumption)
95
- if args.checkpoint_path:
96
- request['CheckpointConfig'] = {
97
- 'S3Uri': args.checkpoint_path
98
- }
99
-
100
- # Metric definitions (custom CloudWatch metrics)
101
- if metric_definitions and metric_definitions != []:
102
- request['AlgorithmSpecification']['MetricDefinitions'] = [
103
- {'Name': m['name'], 'Regex': m['regex']}
104
- for m in metric_definitions
79
+ # Input data channels
80
+ if args.dataset:
81
+ request["InputDataConfig"] = [
82
+ {
83
+ "ChannelName": "training",
84
+ "DataSource": {
85
+ "S3DataSource": {
86
+ "S3DataType": "S3Prefix",
87
+ "S3Uri": args.dataset,
88
+ "S3DataDistributionType": "FullyReplicated",
89
+ }
90
+ },
91
+ "ContentType": "application/jsonlines",
92
+ }
105
93
  ]
106
94
 
107
- # Environment variables for the container
108
- if environment and environment != {}:
109
- request['Environment'] = environment
95
+ # Hyperparameters (all values must be strings)
96
+ if hyperparams:
97
+ request["HyperParameters"] = {k: str(v) for k, v in hyperparams.items()}
110
98
 
111
- # Tags — convert from {key: value} map to [{Key: k, Value: v}] array
112
- if tags and tags != {}:
113
- request['Tags'] = [
114
- {'Key': str(k), 'Value': str(v)}
115
- for k, v in tags.items()
116
- ]
99
+ # Environment variables
100
+ if environment:
101
+ request["Environment"] = {k: str(v) for k, v in environment.items()}
117
102
 
118
- return request
103
+ # Metric definitions
104
+ if metric_definitions:
105
+ request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
119
106
 
107
+ # Spot training
108
+ if args.enable_spot.lower() == "true":
109
+ request["EnableManagedSpotTraining"] = True
110
+ request["StoppingCondition"]["MaxWaitTimeInSeconds"] = int(args.max_wait)
120
111
 
121
- def main():
122
- """Main entry point."""
123
- args = parse_args()
112
+ # Checkpoint config
113
+ if args.checkpoint_path:
114
+ request["CheckpointConfig"] = {
115
+ "S3Uri": args.checkpoint_path,
116
+ }
124
117
 
125
- try:
126
- request = build_request(args)
127
- except (json.JSONDecodeError, ValueError) as e:
128
- print(f'❌ Failed to build request: {e}', file=sys.stderr)
129
- sys.exit(1)
118
+ # Tags
119
+ if tags:
120
+ request["Tags"] = tags
130
121
 
131
- # Write the JSON request to the output file
132
- try:
133
- with open(args.output_file, 'w') as f:
134
- json.dump(request, f, indent=2)
135
- except IOError as e:
136
- print(f'❌ Failed to write request file: {e}', file=sys.stderr)
137
- sys.exit(1)
122
+ # Write to output file
123
+ with open(args.output_file, "w") as f:
124
+ json.dump(request, f, indent=2)
125
+
126
+ print(f"✅ Request written to {args.output_file}", file=sys.stderr)
138
127
 
139
128
 
140
- if __name__ == '__main__':
129
+ if __name__ == "__main__":
141
130
  main()