@aws/ml-container-creator 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +9 -0
- package/config/bootstrap-stack.json +106 -9
- package/infra/ci-harness/package-lock.json +5 -1
- package/package.json +1 -1
- package/servers/instance-sizer/index.js +4 -4
- package/servers/instance-sizer/lib/model-resolver.js +1 -1
- package/servers/lib/catalogs/model-sizes.json +135 -90
- package/servers/lib/catalogs/models.json +483 -411
- package/src/app.js +29 -1
- package/src/lib/bootstrap-command-handler.js +71 -23
- package/src/lib/cli-handler.js +1 -1
- package/src/lib/config-manager.js +1 -1
- package/src/lib/mcp-client.js +3 -3
- package/src/lib/prompt-runner.js +5 -5
- package/src/lib/prompts.js +31 -5
- package/src/lib/tune-catalog-validator.js +143 -0
- package/src/lib/tune-config-state.js +116 -0
- package/src/lib/tune-dataset-validator.js +279 -0
- package/src/lib/tune-output-resolver.js +66 -0
- package/templates/do/.tune_helper.py +768 -0
- package/templates/do/adapter +128 -17
- package/templates/do/add-ic +155 -19
- package/templates/do/config +11 -4
- package/templates/do/tune +1143 -0
|
@@ -0,0 +1,768 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""SageMaker Managed Model Customization helper.
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
submit - Submit a new customization job
|
|
9
|
+
status - Get job status and metrics
|
|
10
|
+
resolve - Resolve output artifact path from job
|
|
11
|
+
stage-hf - Download HF dataset to S3
|
|
12
|
+
validate - Validate dataset format against schema
|
|
13
|
+
|
|
14
|
+
All output is JSON on stdout for bash consumption.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import time
|
|
22
|
+
|
|
23
|
+
# ── Inline dependency check ───────────────────────────────────────────────────
|
|
24
|
+
MIN_SAGEMAKER_VERSION = "2.232.0"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _check_sagemaker_sdk():
|
|
28
|
+
"""Verify sagemaker SDK is installed with minimum version."""
|
|
29
|
+
try:
|
|
30
|
+
import sagemaker # noqa: F401
|
|
31
|
+
from packaging.version import Version
|
|
32
|
+
if Version(sagemaker.__version__) < Version(MIN_SAGEMAKER_VERSION):
|
|
33
|
+
_error_exit(
|
|
34
|
+
f"sagemaker SDK version {sagemaker.__version__} is below minimum "
|
|
35
|
+
f"required version {MIN_SAGEMAKER_VERSION}. "
|
|
36
|
+
f"Please upgrade: pip install 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
|
|
37
|
+
)
|
|
38
|
+
except ImportError:
|
|
39
|
+
_error_exit(
|
|
40
|
+
f"sagemaker Python SDK is not installed. "
|
|
41
|
+
f"Please install: pip install 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _error_exit(message):
|
|
49
|
+
"""Print JSON error to stdout and exit with code 1."""
|
|
50
|
+
print(json.dumps({"error": message}))
|
|
51
|
+
sys.exit(1)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _output(data):
|
|
55
|
+
"""Print JSON result to stdout."""
|
|
56
|
+
print(json.dumps(data))
|
|
57
|
+
sys.exit(0)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# ── Subcommand: submit ────────────────────────────────────────────────────────
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def cmd_submit(args):
|
|
64
|
+
"""Submit customization job via SFTTrainer/DPOTrainer.
|
|
65
|
+
|
|
66
|
+
Returns: {"job_name": str, "job_arn": str, "mlflow_url": str|None}
|
|
67
|
+
"""
|
|
68
|
+
_check_sagemaker_sdk()
|
|
69
|
+
|
|
70
|
+
from sagemaker.modules.train.sft_trainer import SFTTrainer
|
|
71
|
+
from sagemaker.modules.train.dpo_trainer import DPOTrainer
|
|
72
|
+
from sagemaker.modules.train.common import TrainingType
|
|
73
|
+
|
|
74
|
+
# Technique → Trainer class mapping
|
|
75
|
+
TRAINER_MAP = {
|
|
76
|
+
"sft": SFTTrainer,
|
|
77
|
+
"dpo": DPOTrainer,
|
|
78
|
+
# RLAIF and RLVR use SFTTrainer with evaluator config
|
|
79
|
+
"rlaif": SFTTrainer,
|
|
80
|
+
"rlvr": SFTTrainer,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
technique = args.technique
|
|
84
|
+
trainer_cls = TRAINER_MAP.get(technique)
|
|
85
|
+
if not trainer_cls:
|
|
86
|
+
_error_exit(f"Unsupported technique: {technique}")
|
|
87
|
+
|
|
88
|
+
# Resolve training type
|
|
89
|
+
training_type_map = {
|
|
90
|
+
"lora": TrainingType.LORA,
|
|
91
|
+
"full-rank": TrainingType.FULL_RANK,
|
|
92
|
+
}
|
|
93
|
+
training_type = training_type_map.get(args.training_type)
|
|
94
|
+
if not training_type:
|
|
95
|
+
_error_exit(f"Unsupported training type: {args.training_type}")
|
|
96
|
+
|
|
97
|
+
# Build hyperparameters dict from optional overrides
|
|
98
|
+
hyperparameters = {}
|
|
99
|
+
if args.epochs is not None:
|
|
100
|
+
hyperparameters["epochs"] = args.epochs
|
|
101
|
+
if args.learning_rate is not None:
|
|
102
|
+
hyperparameters["learning_rate"] = args.learning_rate
|
|
103
|
+
if args.max_seq_length is not None:
|
|
104
|
+
hyperparameters["max_seq_length"] = args.max_seq_length
|
|
105
|
+
if args.lora_rank is not None:
|
|
106
|
+
hyperparameters["lora_rank"] = args.lora_rank
|
|
107
|
+
if args.lora_alpha is not None:
|
|
108
|
+
hyperparameters["lora_alpha"] = args.lora_alpha
|
|
109
|
+
if args.batch_size is not None:
|
|
110
|
+
hyperparameters["batch_size"] = args.batch_size
|
|
111
|
+
|
|
112
|
+
# Build trainer kwargs
|
|
113
|
+
trainer_kwargs = {
|
|
114
|
+
"model_id": args.model_id,
|
|
115
|
+
"training_type": training_type,
|
|
116
|
+
"train_data_uri": args.dataset_s3_uri,
|
|
117
|
+
"output_path": f"s3://{args.output_bucket}/{args.project_name}/tune/{technique}/",
|
|
118
|
+
"role": args.role_arn,
|
|
119
|
+
"job_name": args.job_name,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
# Add model package group for artifact registration
|
|
123
|
+
if args.model_package_group:
|
|
124
|
+
trainer_kwargs["model_package_group_name"] = args.model_package_group
|
|
125
|
+
|
|
126
|
+
# Add hyperparameters if any were specified
|
|
127
|
+
if hyperparameters:
|
|
128
|
+
trainer_kwargs["hyperparameters"] = hyperparameters
|
|
129
|
+
|
|
130
|
+
# Add evaluator config for RLVR/RLAIF techniques
|
|
131
|
+
if technique in ("rlvr", "rlaif"):
|
|
132
|
+
if args.reward_function:
|
|
133
|
+
trainer_kwargs["evaluator_config"] = {
|
|
134
|
+
"reward_function_arn": args.reward_function
|
|
135
|
+
}
|
|
136
|
+
elif args.reward_prompt:
|
|
137
|
+
trainer_kwargs["evaluator_config"] = {
|
|
138
|
+
"reward_prompt_s3_uri": args.reward_prompt
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
trainer = trainer_cls(**trainer_kwargs)
|
|
143
|
+
trainer.train(wait=False)
|
|
144
|
+
|
|
145
|
+
# Extract job info from the trainer
|
|
146
|
+
job_name = trainer.training_job_name
|
|
147
|
+
job_arn = getattr(trainer, "training_job_arn", None)
|
|
148
|
+
|
|
149
|
+
# Attempt to get MLflow URL if available
|
|
150
|
+
mlflow_url = None
|
|
151
|
+
try:
|
|
152
|
+
mlflow_url = getattr(trainer, "mlflow_tracking_uri", None)
|
|
153
|
+
except Exception:
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
_output({
|
|
157
|
+
"job_name": job_name,
|
|
158
|
+
"job_arn": job_arn or "",
|
|
159
|
+
"mlflow_url": mlflow_url,
|
|
160
|
+
"model_package_group": args.model_package_group or "",
|
|
161
|
+
})
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
error_msg = str(e)
|
|
165
|
+
# Provide helpful context for common errors
|
|
166
|
+
if "AccessDeniedException" in error_msg or "AccessDenied" in error_msg:
|
|
167
|
+
_error_exit(
|
|
168
|
+
f"Access denied when submitting training job. "
|
|
169
|
+
f"Ensure the role has sagemaker:CreateTrainingJob permission. "
|
|
170
|
+
f"Details: {error_msg}"
|
|
171
|
+
)
|
|
172
|
+
elif "ResourceLimitExceeded" in error_msg:
|
|
173
|
+
_error_exit(
|
|
174
|
+
f"Resource limit exceeded. You may need to request a quota increase. "
|
|
175
|
+
f"Details: {error_msg}"
|
|
176
|
+
)
|
|
177
|
+
elif "ValidationException" in error_msg and "license" in error_msg.lower():
|
|
178
|
+
_error_exit(
|
|
179
|
+
f"Model license not accepted. Accept the license in JumpStart before "
|
|
180
|
+
f"using this model for customization. Details: {error_msg}"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
_error_exit(f"Failed to submit training job: {error_msg}")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# ── Subcommand: status ────────────────────────────────────────────────────────
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def cmd_status(args):
|
|
190
|
+
"""Query job status via DescribeTrainingJob.
|
|
191
|
+
|
|
192
|
+
Returns: {"status": str, "failure_reason": str|None,
|
|
193
|
+
"metrics": dict|None, "elapsed_seconds": int}
|
|
194
|
+
"""
|
|
195
|
+
import boto3
|
|
196
|
+
|
|
197
|
+
client = boto3.client("sagemaker", region_name=args.region)
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
response = client.describe_training_job(TrainingJobName=args.job_name)
|
|
201
|
+
except client.exceptions.ClientError as e:
|
|
202
|
+
error_code = e.response["Error"]["Code"]
|
|
203
|
+
if error_code == "ValidationException":
|
|
204
|
+
_error_exit(f"Training job not found: {args.job_name}")
|
|
205
|
+
_error_exit(f"Failed to describe training job: {e}")
|
|
206
|
+
except Exception as e:
|
|
207
|
+
_error_exit(f"Failed to describe training job: {e}")
|
|
208
|
+
|
|
209
|
+
status = response.get("TrainingJobStatus", "Unknown")
|
|
210
|
+
failure_reason = response.get("FailureReason")
|
|
211
|
+
|
|
212
|
+
# Calculate elapsed time
|
|
213
|
+
start_time = response.get("TrainingStartTime")
|
|
214
|
+
end_time = response.get("TrainingEndTime")
|
|
215
|
+
elapsed_seconds = 0
|
|
216
|
+
|
|
217
|
+
if start_time:
|
|
218
|
+
end = end_time if end_time else time.time()
|
|
219
|
+
if hasattr(end, "timestamp"):
|
|
220
|
+
end = end.timestamp()
|
|
221
|
+
elapsed_seconds = int(end - start_time.timestamp())
|
|
222
|
+
|
|
223
|
+
# Extract final metrics if available
|
|
224
|
+
metrics = None
|
|
225
|
+
final_metrics = response.get("FinalMetricDataList")
|
|
226
|
+
if final_metrics:
|
|
227
|
+
metrics = {}
|
|
228
|
+
for metric in final_metrics:
|
|
229
|
+
metrics[metric["MetricName"]] = metric["Value"]
|
|
230
|
+
|
|
231
|
+
# Get output path if completed
|
|
232
|
+
output_path = None
|
|
233
|
+
if status == "Completed":
|
|
234
|
+
model_artifacts = response.get("ModelArtifacts", {})
|
|
235
|
+
output_path = model_artifacts.get("S3ModelArtifacts")
|
|
236
|
+
|
|
237
|
+
_output({
|
|
238
|
+
"status": status,
|
|
239
|
+
"failure_reason": failure_reason,
|
|
240
|
+
"metrics": metrics,
|
|
241
|
+
"elapsed_seconds": elapsed_seconds,
|
|
242
|
+
"output_path": output_path,
|
|
243
|
+
})
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# ── Subcommand: resolve ───────────────────────────────────────────────────────
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def cmd_resolve(args):
|
|
250
|
+
"""Resolve artifact path within S3 output directory.
|
|
251
|
+
|
|
252
|
+
Returns: {"artifact_path": str, "model_package_arn": str|None,
|
|
253
|
+
"output_type": str}
|
|
254
|
+
"""
|
|
255
|
+
import boto3
|
|
256
|
+
|
|
257
|
+
client = boto3.client("sagemaker", region_name=args.region)
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
response = client.describe_training_job(TrainingJobName=args.job_name)
|
|
261
|
+
except Exception as e:
|
|
262
|
+
_error_exit(f"Failed to describe training job: {e}")
|
|
263
|
+
|
|
264
|
+
status = response.get("TrainingJobStatus")
|
|
265
|
+
if status != "Completed":
|
|
266
|
+
_error_exit(
|
|
267
|
+
f"Cannot resolve artifacts for job in status: {status}. "
|
|
268
|
+
f"Job must be Completed."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Get the S3 model artifacts path
|
|
272
|
+
model_artifacts = response.get("ModelArtifacts", {})
|
|
273
|
+
artifact_path = model_artifacts.get("S3ModelArtifacts", "")
|
|
274
|
+
|
|
275
|
+
if not artifact_path:
|
|
276
|
+
_error_exit("No model artifacts found in training job output.")
|
|
277
|
+
|
|
278
|
+
# Determine output type from training type
|
|
279
|
+
output_type = "adapter" if args.training_type == "lora" else "full-model"
|
|
280
|
+
|
|
281
|
+
# Try to find model package ARN if a model package group was used
|
|
282
|
+
model_package_arn = None
|
|
283
|
+
if args.model_package_group:
|
|
284
|
+
try:
|
|
285
|
+
mp_client = boto3.client("sagemaker", region_name=args.region)
|
|
286
|
+
packages = mp_client.list_model_packages(
|
|
287
|
+
ModelPackageGroupName=args.model_package_group,
|
|
288
|
+
SortBy="CreationTime",
|
|
289
|
+
SortOrder="Descending",
|
|
290
|
+
MaxResults=1,
|
|
291
|
+
)
|
|
292
|
+
package_list = packages.get("ModelPackageSummaryList", [])
|
|
293
|
+
if package_list:
|
|
294
|
+
model_package_arn = package_list[0].get("ModelPackageArn")
|
|
295
|
+
except Exception:
|
|
296
|
+
# Model package lookup is best-effort
|
|
297
|
+
pass
|
|
298
|
+
|
|
299
|
+
_output({
|
|
300
|
+
"artifact_path": artifact_path,
|
|
301
|
+
"model_package_arn": model_package_arn,
|
|
302
|
+
"output_type": output_type,
|
|
303
|
+
})
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# ── Subcommand: stage-hf ─────────────────────────────────────────────────────
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def cmd_stage_hf(args):
|
|
310
|
+
"""Download HF dataset to S3 using huggingface_hub.
|
|
311
|
+
|
|
312
|
+
Handles auth via Secrets Manager or HF_TOKEN env var.
|
|
313
|
+
|
|
314
|
+
Returns: {"s3_uri": str, "num_records": int}
|
|
315
|
+
"""
|
|
316
|
+
try:
|
|
317
|
+
from huggingface_hub import hf_hub_download, HfApi
|
|
318
|
+
except ImportError:
|
|
319
|
+
_error_exit(
|
|
320
|
+
"huggingface_hub is not installed. "
|
|
321
|
+
"Please install: pip install huggingface_hub"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
import boto3
|
|
325
|
+
import tempfile
|
|
326
|
+
|
|
327
|
+
# Resolve HF token: Secrets Manager first, then env var
|
|
328
|
+
hf_token = _resolve_hf_token(args.region, args.hf_secret_name)
|
|
329
|
+
|
|
330
|
+
# Parse the HF reference
|
|
331
|
+
org = args.hf_org
|
|
332
|
+
name = args.hf_name
|
|
333
|
+
split = args.hf_split or "train"
|
|
334
|
+
dataset_id = f"{org}/{name}"
|
|
335
|
+
|
|
336
|
+
# Download dataset files to a temp directory
|
|
337
|
+
try:
|
|
338
|
+
api = HfApi(token=hf_token)
|
|
339
|
+
|
|
340
|
+
# List files in the dataset repo
|
|
341
|
+
repo_files = api.list_repo_files(
|
|
342
|
+
repo_id=dataset_id,
|
|
343
|
+
repo_type="dataset",
|
|
344
|
+
token=hf_token,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Find the appropriate data file for the split
|
|
348
|
+
data_files = _find_data_files(repo_files, split)
|
|
349
|
+
if not data_files:
|
|
350
|
+
_error_exit(
|
|
351
|
+
f"No data files found for split '{split}' in dataset {dataset_id}. "
|
|
352
|
+
f"Available files: {', '.join(repo_files[:20])}"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Download and upload to S3
|
|
356
|
+
s3_client = boto3.client("s3", region_name=args.region)
|
|
357
|
+
s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
|
|
358
|
+
num_records = 0
|
|
359
|
+
|
|
360
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
361
|
+
for data_file in data_files:
|
|
362
|
+
local_path = hf_hub_download(
|
|
363
|
+
repo_id=dataset_id,
|
|
364
|
+
filename=data_file,
|
|
365
|
+
repo_type="dataset",
|
|
366
|
+
token=hf_token,
|
|
367
|
+
local_dir=tmpdir,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Count records (lines for JSONL)
|
|
371
|
+
with open(local_path, "r") as f:
|
|
372
|
+
for line in f:
|
|
373
|
+
if line.strip():
|
|
374
|
+
num_records += 1
|
|
375
|
+
|
|
376
|
+
# Upload to S3
|
|
377
|
+
s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
|
|
378
|
+
s3_client.upload_file(local_path, args.output_bucket, s3_key)
|
|
379
|
+
|
|
380
|
+
s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{os.path.basename(data_files[0])}"
|
|
381
|
+
|
|
382
|
+
_output({
|
|
383
|
+
"s3_uri": s3_uri,
|
|
384
|
+
"num_records": num_records,
|
|
385
|
+
})
|
|
386
|
+
|
|
387
|
+
except Exception as e:
|
|
388
|
+
error_msg = str(e)
|
|
389
|
+
if "404" in error_msg or "not found" in error_msg.lower():
|
|
390
|
+
_error_exit(
|
|
391
|
+
f"Dataset not found: {dataset_id}. "
|
|
392
|
+
f"Check the dataset name and ensure it exists on Hugging Face Hub."
|
|
393
|
+
)
|
|
394
|
+
elif "401" in error_msg or "unauthorized" in error_msg.lower():
|
|
395
|
+
_error_exit(
|
|
396
|
+
f"Authentication failed for dataset {dataset_id}. "
|
|
397
|
+
f"Ensure HF_TOKEN is set or configured via Secrets Manager."
|
|
398
|
+
)
|
|
399
|
+
else:
|
|
400
|
+
_error_exit(f"Failed to stage HF dataset: {error_msg}")
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _resolve_hf_token(region, secret_name=None):
|
|
404
|
+
"""Resolve HF token from Secrets Manager or environment variable.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
region: AWS region for Secrets Manager
|
|
408
|
+
secret_name: Optional Secrets Manager secret name/ARN
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
str or None: The HF token, or None if not available
|
|
412
|
+
"""
|
|
413
|
+
# Try Secrets Manager first if a secret name is provided
|
|
414
|
+
if secret_name:
|
|
415
|
+
try:
|
|
416
|
+
import boto3
|
|
417
|
+
client = boto3.client("secretsmanager", region_name=region)
|
|
418
|
+
response = client.get_secret_value(SecretId=secret_name)
|
|
419
|
+
secret_value = response.get("SecretString", "")
|
|
420
|
+
if secret_value:
|
|
421
|
+
return secret_value.strip()
|
|
422
|
+
except Exception:
|
|
423
|
+
# Fall through to env var
|
|
424
|
+
pass
|
|
425
|
+
|
|
426
|
+
# Fall back to HF_TOKEN environment variable
|
|
427
|
+
return os.environ.get("HF_TOKEN")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _find_data_files(repo_files, split):
|
|
431
|
+
"""Find data files matching the requested split.
|
|
432
|
+
|
|
433
|
+
Looks for common patterns: data/{split}.jsonl, {split}.jsonl,
|
|
434
|
+
data/{split}-*.parquet, etc.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
repo_files: List of file paths in the repo
|
|
438
|
+
split: The dataset split name (e.g., "train")
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
list: Matching file paths
|
|
442
|
+
"""
|
|
443
|
+
# Priority order for file matching
|
|
444
|
+
patterns = [
|
|
445
|
+
f"data/{split}.jsonl",
|
|
446
|
+
f"{split}.jsonl",
|
|
447
|
+
f"data/{split}.json",
|
|
448
|
+
f"{split}.json",
|
|
449
|
+
f"data/{split}-00000-of-",
|
|
450
|
+
f"{split}-00000-of-",
|
|
451
|
+
]
|
|
452
|
+
|
|
453
|
+
# Exact match first
|
|
454
|
+
for pattern in patterns[:4]:
|
|
455
|
+
if pattern in repo_files:
|
|
456
|
+
return [pattern]
|
|
457
|
+
|
|
458
|
+
# Prefix match for sharded files
|
|
459
|
+
matches = []
|
|
460
|
+
for f in repo_files:
|
|
461
|
+
for pattern in patterns[4:]:
|
|
462
|
+
if pattern in f:
|
|
463
|
+
matches.append(f)
|
|
464
|
+
|
|
465
|
+
if matches:
|
|
466
|
+
return sorted(matches)
|
|
467
|
+
|
|
468
|
+
# Fallback: any JSONL file containing the split name
|
|
469
|
+
jsonl_files = [f for f in repo_files if f.endswith(".jsonl") and split in f]
|
|
470
|
+
if jsonl_files:
|
|
471
|
+
return sorted(jsonl_files)
|
|
472
|
+
|
|
473
|
+
# Last resort: any JSONL file in data/ directory
|
|
474
|
+
data_jsonl = [f for f in repo_files if f.startswith("data/") and f.endswith(".jsonl")]
|
|
475
|
+
if data_jsonl:
|
|
476
|
+
return sorted(data_jsonl)
|
|
477
|
+
|
|
478
|
+
return []
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
# ── Subcommand: validate ──────────────────────────────────────────────────────
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def cmd_validate(args):
|
|
485
|
+
"""Validate dataset format against expected schema.
|
|
486
|
+
|
|
487
|
+
The schema is passed as a JSON string argument.
|
|
488
|
+
|
|
489
|
+
Returns: {"valid": bool, "error": str|None, "line_number": int|None,
|
|
490
|
+
"malformed_line": str|None}
|
|
491
|
+
"""
|
|
492
|
+
# Parse the schema from JSON argument
|
|
493
|
+
try:
|
|
494
|
+
schema = json.loads(args.schema)
|
|
495
|
+
except json.JSONDecodeError as e:
|
|
496
|
+
_error_exit(f"Invalid schema JSON: {e}")
|
|
497
|
+
|
|
498
|
+
required_keys = schema.get("required", [])
|
|
499
|
+
type_map = schema.get("types", {})
|
|
500
|
+
|
|
501
|
+
# Read lines from stdin or file
|
|
502
|
+
lines = []
|
|
503
|
+
if args.file and args.file != "-":
|
|
504
|
+
try:
|
|
505
|
+
with open(args.file, "r") as f:
|
|
506
|
+
for i, line in enumerate(f):
|
|
507
|
+
lines.append(line.rstrip("\n"))
|
|
508
|
+
if i >= 9: # Only inspect first 10 lines
|
|
509
|
+
break
|
|
510
|
+
except FileNotFoundError:
|
|
511
|
+
_error_exit(f"Dataset file not found: {args.file}")
|
|
512
|
+
except Exception as e:
|
|
513
|
+
_error_exit(f"Failed to read dataset file: {e}")
|
|
514
|
+
else:
|
|
515
|
+
# Read from stdin
|
|
516
|
+
for i, line in enumerate(sys.stdin):
|
|
517
|
+
lines.append(line.rstrip("\n"))
|
|
518
|
+
if i >= 9: # Only inspect first 10 lines
|
|
519
|
+
break
|
|
520
|
+
|
|
521
|
+
# Validate each line
|
|
522
|
+
for i, line in enumerate(lines):
|
|
523
|
+
line_number = i + 1
|
|
524
|
+
|
|
525
|
+
# Skip empty lines
|
|
526
|
+
if not line or not line.strip():
|
|
527
|
+
continue
|
|
528
|
+
|
|
529
|
+
# Try to parse as JSON
|
|
530
|
+
try:
|
|
531
|
+
parsed = json.loads(line)
|
|
532
|
+
except json.JSONDecodeError as e:
|
|
533
|
+
_output({
|
|
534
|
+
"valid": False,
|
|
535
|
+
"error": f"Line {line_number} is not valid JSON: {e}",
|
|
536
|
+
"line_number": line_number,
|
|
537
|
+
"malformed_line": line,
|
|
538
|
+
"expected_format": _build_expected_format(schema),
|
|
539
|
+
})
|
|
540
|
+
return
|
|
541
|
+
|
|
542
|
+
# Check that parsed value is a dict
|
|
543
|
+
if not isinstance(parsed, dict):
|
|
544
|
+
_output({
|
|
545
|
+
"valid": False,
|
|
546
|
+
"error": f"Line {line_number} must be a JSON object.",
|
|
547
|
+
"line_number": line_number,
|
|
548
|
+
"malformed_line": line,
|
|
549
|
+
"expected_format": _build_expected_format(schema),
|
|
550
|
+
})
|
|
551
|
+
return
|
|
552
|
+
|
|
553
|
+
# Check required keys
|
|
554
|
+
for key in required_keys:
|
|
555
|
+
if key not in parsed:
|
|
556
|
+
_output({
|
|
557
|
+
"valid": False,
|
|
558
|
+
"error": f'Line {line_number} is missing required key "{key}".',
|
|
559
|
+
"line_number": line_number,
|
|
560
|
+
"malformed_line": line,
|
|
561
|
+
"expected_format": _build_expected_format(schema),
|
|
562
|
+
})
|
|
563
|
+
return
|
|
564
|
+
|
|
565
|
+
# Check types if specified
|
|
566
|
+
for key, expected_type in type_map.items():
|
|
567
|
+
if key not in parsed:
|
|
568
|
+
continue
|
|
569
|
+
|
|
570
|
+
value = parsed[key]
|
|
571
|
+
if not _check_type(value, expected_type):
|
|
572
|
+
actual_type = _get_type(value)
|
|
573
|
+
_output({
|
|
574
|
+
"valid": False,
|
|
575
|
+
"error": (
|
|
576
|
+
f'Line {line_number} has key "{key}" with wrong type. '
|
|
577
|
+
f'Expected "{expected_type}", got "{actual_type}".'
|
|
578
|
+
),
|
|
579
|
+
"line_number": line_number,
|
|
580
|
+
"malformed_line": line,
|
|
581
|
+
"expected_format": _build_expected_format(schema),
|
|
582
|
+
})
|
|
583
|
+
return
|
|
584
|
+
|
|
585
|
+
_output({
|
|
586
|
+
"valid": True,
|
|
587
|
+
"error": None,
|
|
588
|
+
"line_number": None,
|
|
589
|
+
"malformed_line": None,
|
|
590
|
+
})
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _check_type(value, expected_type):
|
|
594
|
+
"""Check if a value matches the expected schema type.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
value: The value to check
|
|
598
|
+
expected_type: One of "string", "array", "object", "number"
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
bool: True if the value matches the expected type
|
|
602
|
+
"""
|
|
603
|
+
if expected_type == "string":
|
|
604
|
+
return isinstance(value, str)
|
|
605
|
+
elif expected_type == "number":
|
|
606
|
+
return isinstance(value, (int, float))
|
|
607
|
+
elif expected_type == "array":
|
|
608
|
+
return isinstance(value, list)
|
|
609
|
+
elif expected_type == "object":
|
|
610
|
+
return isinstance(value, dict)
|
|
611
|
+
return True
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def _get_type(value):
|
|
615
|
+
"""Get a human-readable type name for a value."""
|
|
616
|
+
if value is None:
|
|
617
|
+
return "null"
|
|
618
|
+
if isinstance(value, list):
|
|
619
|
+
return "array"
|
|
620
|
+
if isinstance(value, dict):
|
|
621
|
+
return "object"
|
|
622
|
+
if isinstance(value, bool):
|
|
623
|
+
return "boolean"
|
|
624
|
+
if isinstance(value, (int, float)):
|
|
625
|
+
return "number"
|
|
626
|
+
if isinstance(value, str):
|
|
627
|
+
return "string"
|
|
628
|
+
return type(value).__name__
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def _build_expected_format(schema):
|
|
632
|
+
"""Build a human-readable expected format description from a schema.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
schema: The dataset schema dict
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
str: Description of expected format
|
|
639
|
+
"""
|
|
640
|
+
required = schema.get("required", [])
|
|
641
|
+
types = schema.get("types", {})
|
|
642
|
+
|
|
643
|
+
fields = []
|
|
644
|
+
for key in required:
|
|
645
|
+
field_type = types.get(key, "any")
|
|
646
|
+
fields.append(f'"{key}": <{field_type}>')
|
|
647
|
+
|
|
648
|
+
return "Each line must be a JSON object with: {" + ", ".join(fields) + "}"
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
# ── CLI argument parsing ──────────────────────────────────────────────────────
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def main():
|
|
655
|
+
parser = argparse.ArgumentParser(
|
|
656
|
+
description="SageMaker Managed Model Customization helper",
|
|
657
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
658
|
+
)
|
|
659
|
+
subparsers = parser.add_subparsers(dest="command", help="Subcommand to run")
|
|
660
|
+
|
|
661
|
+
# ── submit ────────────────────────────────────────────────────────────────
|
|
662
|
+
submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
|
|
663
|
+
submit_parser.add_argument("--model-id", required=True, help="JumpStart model ID")
|
|
664
|
+
submit_parser.add_argument("--technique", required=True,
|
|
665
|
+
choices=["sft", "dpo", "rlaif", "rlvr"],
|
|
666
|
+
help="Customization technique")
|
|
667
|
+
submit_parser.add_argument("--training-type", required=True,
|
|
668
|
+
choices=["lora", "full-rank"],
|
|
669
|
+
help="Training type (lora or full-rank)")
|
|
670
|
+
submit_parser.add_argument("--dataset-s3-uri", required=True,
|
|
671
|
+
help="S3 URI of the training dataset")
|
|
672
|
+
submit_parser.add_argument("--output-bucket", required=True,
|
|
673
|
+
help="S3 bucket for output artifacts")
|
|
674
|
+
submit_parser.add_argument("--role-arn", required=True,
|
|
675
|
+
help="IAM execution role ARN")
|
|
676
|
+
submit_parser.add_argument("--job-name", required=True,
|
|
677
|
+
help="Unique job name")
|
|
678
|
+
submit_parser.add_argument("--project-name", required=True,
|
|
679
|
+
help="Project name for S3 path prefix")
|
|
680
|
+
submit_parser.add_argument("--model-package-group", default=None,
|
|
681
|
+
help="Model package group name for registration")
|
|
682
|
+
submit_parser.add_argument("--epochs", type=int, default=None,
|
|
683
|
+
help="Number of training epochs")
|
|
684
|
+
submit_parser.add_argument("--learning-rate", type=float, default=None,
|
|
685
|
+
help="Learning rate")
|
|
686
|
+
submit_parser.add_argument("--max-seq-length", type=int, default=None,
|
|
687
|
+
help="Maximum sequence length")
|
|
688
|
+
submit_parser.add_argument("--lora-rank", type=int, default=None,
|
|
689
|
+
help="LoRA rank")
|
|
690
|
+
submit_parser.add_argument("--lora-alpha", type=int, default=None,
|
|
691
|
+
help="LoRA alpha scaling factor")
|
|
692
|
+
submit_parser.add_argument("--batch-size", type=int, default=None,
|
|
693
|
+
help="Global batch size")
|
|
694
|
+
submit_parser.add_argument("--reward-function", default=None,
|
|
695
|
+
help="Lambda ARN for reward function (RLVR)")
|
|
696
|
+
submit_parser.add_argument("--reward-prompt", default=None,
|
|
697
|
+
help="S3 URI for reward prompt (RLAIF)")
|
|
698
|
+
|
|
699
|
+
# ── status ────────────────────────────────────────────────────────────────
|
|
700
|
+
status_parser = subparsers.add_parser("status", help="Get job status and metrics")
|
|
701
|
+
status_parser.add_argument("--job-name", required=True,
|
|
702
|
+
help="Training job name")
|
|
703
|
+
status_parser.add_argument("--region", required=True,
|
|
704
|
+
help="AWS region")
|
|
705
|
+
|
|
706
|
+
# ── resolve ───────────────────────────────────────────────────────────────
|
|
707
|
+
resolve_parser = subparsers.add_parser("resolve",
|
|
708
|
+
help="Resolve output artifact path")
|
|
709
|
+
resolve_parser.add_argument("--job-name", required=True,
|
|
710
|
+
help="Training job name")
|
|
711
|
+
resolve_parser.add_argument("--region", required=True,
|
|
712
|
+
help="AWS region")
|
|
713
|
+
resolve_parser.add_argument("--training-type", required=True,
|
|
714
|
+
choices=["lora", "full-rank"],
|
|
715
|
+
help="Training type used for the job")
|
|
716
|
+
resolve_parser.add_argument("--model-package-group", default=None,
|
|
717
|
+
help="Model package group name")
|
|
718
|
+
|
|
719
|
+
# ── stage-hf ──────────────────────────────────────────────────────────────
|
|
720
|
+
stage_hf_parser = subparsers.add_parser("stage-hf",
|
|
721
|
+
help="Download HF dataset to S3")
|
|
722
|
+
stage_hf_parser.add_argument("--hf-org", required=True,
|
|
723
|
+
help="Hugging Face organization/user")
|
|
724
|
+
stage_hf_parser.add_argument("--hf-name", required=True,
|
|
725
|
+
help="Hugging Face dataset name")
|
|
726
|
+
stage_hf_parser.add_argument("--hf-split", default="train",
|
|
727
|
+
help="Dataset split (default: train)")
|
|
728
|
+
stage_hf_parser.add_argument("--output-bucket", required=True,
|
|
729
|
+
help="S3 bucket for staged dataset")
|
|
730
|
+
stage_hf_parser.add_argument("--project-name", required=True,
|
|
731
|
+
help="Project name for S3 path prefix")
|
|
732
|
+
stage_hf_parser.add_argument("--region", required=True,
|
|
733
|
+
help="AWS region")
|
|
734
|
+
stage_hf_parser.add_argument("--hf-secret-name", default=None,
|
|
735
|
+
help="Secrets Manager secret name for HF token")
|
|
736
|
+
|
|
737
|
+
# ── validate ──────────────────────────────────────────────────────────────
|
|
738
|
+
validate_parser = subparsers.add_parser("validate",
|
|
739
|
+
help="Validate dataset format")
|
|
740
|
+
validate_parser.add_argument("--schema", required=True,
|
|
741
|
+
help="JSON string of the expected dataset schema")
|
|
742
|
+
validate_parser.add_argument("--file", default="-",
|
|
743
|
+
help="Path to dataset file (default: stdin)")
|
|
744
|
+
|
|
745
|
+
# ── Parse and dispatch ────────────────────────────────────────────────────
|
|
746
|
+
args = parser.parse_args()
|
|
747
|
+
|
|
748
|
+
if not args.command:
|
|
749
|
+
parser.print_help()
|
|
750
|
+
sys.exit(1)
|
|
751
|
+
|
|
752
|
+
command_map = {
|
|
753
|
+
"submit": cmd_submit,
|
|
754
|
+
"status": cmd_status,
|
|
755
|
+
"resolve": cmd_resolve,
|
|
756
|
+
"stage-hf": cmd_stage_hf,
|
|
757
|
+
"validate": cmd_validate,
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
handler = command_map.get(args.command)
|
|
761
|
+
if handler:
|
|
762
|
+
handler(args)
|
|
763
|
+
else:
|
|
764
|
+
_error_exit(f"Unknown command: {args.command}")
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
if __name__ == "__main__":
|
|
768
|
+
main()
|