@aws/ml-container-creator 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/bin/cli.js +1 -1
- package/config/tune-catalog.json +303 -1
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +3 -2
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1516
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/app.js +6 -4
- package/src/lib/bootstrap-command-handler.js +12 -2
- package/src/lib/bootstrap-profile-manager.js +16 -0
- package/src/lib/cross-cutting-checker.js +6 -1
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Build a CreateTrainingJob JSON request from CLI arguments.
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
Called by do/train _build_job_request() to construct the JSON payload
|
|
6
|
+
that is later passed to either AWS CLI or .train_helper.py for submission.
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
It handles conditional fields (spot training, metric definitions, environment,
|
|
10
|
-
tags) and writes the result to a JSON file for use with:
|
|
11
|
-
aws sagemaker create-training-job --cli-input-json file://path.json
|
|
8
|
+
Outputs a JSON file at --output-file containing the full CreateTrainingJob request.
|
|
12
9
|
"""
|
|
13
10
|
|
|
14
11
|
import argparse
|
|
@@ -16,126 +13,118 @@ import json
|
|
|
16
13
|
import sys
|
|
17
14
|
|
|
18
15
|
|
|
19
|
-
def
|
|
20
|
-
|
|
21
|
-
parser
|
|
22
|
-
parser.add_argument(
|
|
23
|
-
parser.add_argument(
|
|
24
|
-
parser.add_argument(
|
|
25
|
-
parser.add_argument(
|
|
26
|
-
parser.add_argument(
|
|
27
|
-
parser.add_argument(
|
|
28
|
-
parser.add_argument(
|
|
29
|
-
parser.add_argument(
|
|
30
|
-
parser.add_argument(
|
|
31
|
-
parser.add_argument(
|
|
32
|
-
parser.add_argument(
|
|
33
|
-
parser.add_argument(
|
|
34
|
-
parser.add_argument(
|
|
35
|
-
parser.add_argument(
|
|
36
|
-
parser.add_argument(
|
|
37
|
-
parser.add_argument(
|
|
38
|
-
parser.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
16
|
+
def main():
|
|
17
|
+
parser = argparse.ArgumentParser(description="Build CreateTrainingJob JSON request")
|
|
18
|
+
parser.add_argument("--job-name", required=True)
|
|
19
|
+
parser.add_argument("--role-arn", required=True)
|
|
20
|
+
parser.add_argument("--image", required=True)
|
|
21
|
+
parser.add_argument("--instance-type", required=True)
|
|
22
|
+
parser.add_argument("--instance-count", default="1")
|
|
23
|
+
parser.add_argument("--volume-size", default="50")
|
|
24
|
+
parser.add_argument("--dataset", default="")
|
|
25
|
+
parser.add_argument("--output-path", required=True)
|
|
26
|
+
parser.add_argument("--max-runtime", default="86400")
|
|
27
|
+
parser.add_argument("--hyperparams", default="{}")
|
|
28
|
+
parser.add_argument("--enable-spot", default="false")
|
|
29
|
+
parser.add_argument("--max-wait", default="172800")
|
|
30
|
+
parser.add_argument("--checkpoint-path", default="")
|
|
31
|
+
parser.add_argument("--metric-definitions", default="[]")
|
|
32
|
+
parser.add_argument("--environment", default="{}")
|
|
33
|
+
parser.add_argument("--tags", default="[]")
|
|
34
|
+
parser.add_argument("--output-file", required=True)
|
|
35
|
+
args = parser.parse_args()
|
|
36
|
+
|
|
37
|
+
# Parse JSON args
|
|
38
|
+
try:
|
|
39
|
+
hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
|
|
40
|
+
except json.JSONDecodeError:
|
|
41
|
+
hyperparams = {}
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
|
|
45
|
+
except json.JSONDecodeError:
|
|
46
|
+
metric_definitions = []
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
environment = json.loads(args.environment) if args.environment else {}
|
|
50
|
+
except json.JSONDecodeError:
|
|
51
|
+
environment = {}
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
tags = json.loads(args.tags) if args.tags else []
|
|
55
|
+
except json.JSONDecodeError:
|
|
56
|
+
tags = []
|
|
57
|
+
|
|
58
|
+
# Build the request
|
|
51
59
|
request = {
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
60
|
+
"TrainingJobName": args.job_name,
|
|
61
|
+
"RoleArn": args.role_arn,
|
|
62
|
+
"AlgorithmSpecification": {
|
|
63
|
+
"TrainingImage": args.image,
|
|
64
|
+
"TrainingInputMode": "File",
|
|
57
65
|
},
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
'S3DataSource': {
|
|
63
|
-
'S3DataType': 'S3Prefix',
|
|
64
|
-
'S3Uri': args.dataset,
|
|
65
|
-
'S3DataDistributionType': 'FullyReplicated'
|
|
66
|
-
}
|
|
67
|
-
}
|
|
68
|
-
}
|
|
69
|
-
],
|
|
70
|
-
'OutputDataConfig': {
|
|
71
|
-
'S3OutputPath': args.output_path
|
|
66
|
+
"ResourceConfig": {
|
|
67
|
+
"InstanceType": args.instance_type,
|
|
68
|
+
"InstanceCount": int(args.instance_count),
|
|
69
|
+
"VolumeSizeInGB": int(args.volume_size),
|
|
72
70
|
},
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
71
|
+
"OutputDataConfig": {
|
|
72
|
+
"S3OutputPath": args.output_path,
|
|
73
|
+
},
|
|
74
|
+
"StoppingCondition": {
|
|
75
|
+
"MaxRuntimeInSeconds": int(args.max_runtime),
|
|
77
76
|
},
|
|
78
|
-
'StoppingCondition': {
|
|
79
|
-
'MaxRuntimeInSeconds': int(args.max_runtime)
|
|
80
|
-
}
|
|
81
77
|
}
|
|
82
78
|
|
|
83
|
-
#
|
|
84
|
-
if
|
|
85
|
-
request[
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
'S3Uri': args.checkpoint_path
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
# Metric definitions (custom CloudWatch metrics)
|
|
101
|
-
if metric_definitions and metric_definitions != []:
|
|
102
|
-
request['AlgorithmSpecification']['MetricDefinitions'] = [
|
|
103
|
-
{'Name': m['name'], 'Regex': m['regex']}
|
|
104
|
-
for m in metric_definitions
|
|
79
|
+
# Input data channels
|
|
80
|
+
if args.dataset:
|
|
81
|
+
request["InputDataConfig"] = [
|
|
82
|
+
{
|
|
83
|
+
"ChannelName": "training",
|
|
84
|
+
"DataSource": {
|
|
85
|
+
"S3DataSource": {
|
|
86
|
+
"S3DataType": "S3Prefix",
|
|
87
|
+
"S3Uri": args.dataset,
|
|
88
|
+
"S3DataDistributionType": "FullyReplicated",
|
|
89
|
+
}
|
|
90
|
+
},
|
|
91
|
+
"ContentType": "application/jsonlines",
|
|
92
|
+
}
|
|
105
93
|
]
|
|
106
94
|
|
|
107
|
-
#
|
|
108
|
-
if
|
|
109
|
-
request[
|
|
95
|
+
# Hyperparameters (all values must be strings)
|
|
96
|
+
if hyperparams:
|
|
97
|
+
request["HyperParameters"] = {k: str(v) for k, v in hyperparams.items()}
|
|
110
98
|
|
|
111
|
-
#
|
|
112
|
-
if
|
|
113
|
-
request[
|
|
114
|
-
{'Key': str(k), 'Value': str(v)}
|
|
115
|
-
for k, v in tags.items()
|
|
116
|
-
]
|
|
99
|
+
# Environment variables
|
|
100
|
+
if environment:
|
|
101
|
+
request["Environment"] = {k: str(v) for k, v in environment.items()}
|
|
117
102
|
|
|
118
|
-
|
|
103
|
+
# Metric definitions
|
|
104
|
+
if metric_definitions:
|
|
105
|
+
request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
|
|
119
106
|
|
|
107
|
+
# Spot training
|
|
108
|
+
if args.enable_spot.lower() == "true":
|
|
109
|
+
request["EnableManagedSpotTraining"] = True
|
|
110
|
+
request["StoppingCondition"]["MaxWaitTimeInSeconds"] = int(args.max_wait)
|
|
120
111
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
112
|
+
# Checkpoint config
|
|
113
|
+
if args.checkpoint_path:
|
|
114
|
+
request["CheckpointConfig"] = {
|
|
115
|
+
"S3Uri": args.checkpoint_path,
|
|
116
|
+
}
|
|
124
117
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
print(f'❌ Failed to build request: {e}', file=sys.stderr)
|
|
129
|
-
sys.exit(1)
|
|
118
|
+
# Tags
|
|
119
|
+
if tags:
|
|
120
|
+
request["Tags"] = tags
|
|
130
121
|
|
|
131
|
-
# Write
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
print(f'❌ Failed to write request file: {e}', file=sys.stderr)
|
|
137
|
-
sys.exit(1)
|
|
122
|
+
# Write to output file
|
|
123
|
+
with open(args.output_file, "w") as f:
|
|
124
|
+
json.dump(request, f, indent=2)
|
|
125
|
+
|
|
126
|
+
print(f"✅ Request written to {args.output_file}", file=sys.stderr)
|
|
138
127
|
|
|
139
128
|
|
|
140
|
-
if __name__ ==
|
|
129
|
+
if __name__ == "__main__":
|
|
141
130
|
main()
|