@aws/ml-container-creator 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,768 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Managed Model Customization helper.
6
+
7
+ Subcommands:
8
+ submit - Submit a new customization job
9
+ status - Get job status and metrics
10
+ resolve - Resolve output artifact path from job
11
+ stage-hf - Download HF dataset to S3
12
+ validate - Validate dataset format against schema
13
+
14
+ All output is JSON on stdout for bash consumption.
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ import time
22
+
23
+ # ── Inline dependency check ───────────────────────────────────────────────────
24
+ MIN_SAGEMAKER_VERSION = "2.232.0"
25
+
26
+
27
+ def _check_sagemaker_sdk():
28
+ """Verify sagemaker SDK is installed with minimum version."""
29
+ try:
30
+ import sagemaker # noqa: F401
31
+ from packaging.version import Version
32
+ if Version(sagemaker.__version__) < Version(MIN_SAGEMAKER_VERSION):
33
+ _error_exit(
34
+ f"sagemaker SDK version {sagemaker.__version__} is below minimum "
35
+ f"required version {MIN_SAGEMAKER_VERSION}. "
36
+ f"Please upgrade: pip install 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
37
+ )
38
+ except ImportError:
39
+ _error_exit(
40
+ f"sagemaker Python SDK is not installed. "
41
+ f"Please install: pip install 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
42
+ )
43
+
44
+
45
+ # ── Utility functions ─────────────────────────────────────────────────────────
46
+
47
+
48
+ def _error_exit(message):
49
+ """Print JSON error to stdout and exit with code 1."""
50
+ print(json.dumps({"error": message}))
51
+ sys.exit(1)
52
+
53
+
54
+ def _output(data):
55
+ """Print JSON result to stdout."""
56
+ print(json.dumps(data))
57
+ sys.exit(0)
58
+
59
+
60
+ # ── Subcommand: submit ────────────────────────────────────────────────────────
61
+
62
+
63
+ def cmd_submit(args):
64
+ """Submit customization job via SFTTrainer/DPOTrainer.
65
+
66
+ Returns: {"job_name": str, "job_arn": str, "mlflow_url": str|None}
67
+ """
68
+ _check_sagemaker_sdk()
69
+
70
+ from sagemaker.modules.train.sft_trainer import SFTTrainer
71
+ from sagemaker.modules.train.dpo_trainer import DPOTrainer
72
+ from sagemaker.modules.train.common import TrainingType
73
+
74
+ # Technique → Trainer class mapping
75
+ TRAINER_MAP = {
76
+ "sft": SFTTrainer,
77
+ "dpo": DPOTrainer,
78
+ # RLAIF and RLVR use SFTTrainer with evaluator config
79
+ "rlaif": SFTTrainer,
80
+ "rlvr": SFTTrainer,
81
+ }
82
+
83
+ technique = args.technique
84
+ trainer_cls = TRAINER_MAP.get(technique)
85
+ if not trainer_cls:
86
+ _error_exit(f"Unsupported technique: {technique}")
87
+
88
+ # Resolve training type
89
+ training_type_map = {
90
+ "lora": TrainingType.LORA,
91
+ "full-rank": TrainingType.FULL_RANK,
92
+ }
93
+ training_type = training_type_map.get(args.training_type)
94
+ if not training_type:
95
+ _error_exit(f"Unsupported training type: {args.training_type}")
96
+
97
+ # Build hyperparameters dict from optional overrides
98
+ hyperparameters = {}
99
+ if args.epochs is not None:
100
+ hyperparameters["epochs"] = args.epochs
101
+ if args.learning_rate is not None:
102
+ hyperparameters["learning_rate"] = args.learning_rate
103
+ if args.max_seq_length is not None:
104
+ hyperparameters["max_seq_length"] = args.max_seq_length
105
+ if args.lora_rank is not None:
106
+ hyperparameters["lora_rank"] = args.lora_rank
107
+ if args.lora_alpha is not None:
108
+ hyperparameters["lora_alpha"] = args.lora_alpha
109
+ if args.batch_size is not None:
110
+ hyperparameters["batch_size"] = args.batch_size
111
+
112
+ # Build trainer kwargs
113
+ trainer_kwargs = {
114
+ "model_id": args.model_id,
115
+ "training_type": training_type,
116
+ "train_data_uri": args.dataset_s3_uri,
117
+ "output_path": f"s3://{args.output_bucket}/{args.project_name}/tune/{technique}/",
118
+ "role": args.role_arn,
119
+ "job_name": args.job_name,
120
+ }
121
+
122
+ # Add model package group for artifact registration
123
+ if args.model_package_group:
124
+ trainer_kwargs["model_package_group_name"] = args.model_package_group
125
+
126
+ # Add hyperparameters if any were specified
127
+ if hyperparameters:
128
+ trainer_kwargs["hyperparameters"] = hyperparameters
129
+
130
+ # Add evaluator config for RLVR/RLAIF techniques
131
+ if technique in ("rlvr", "rlaif"):
132
+ if args.reward_function:
133
+ trainer_kwargs["evaluator_config"] = {
134
+ "reward_function_arn": args.reward_function
135
+ }
136
+ elif args.reward_prompt:
137
+ trainer_kwargs["evaluator_config"] = {
138
+ "reward_prompt_s3_uri": args.reward_prompt
139
+ }
140
+
141
+ try:
142
+ trainer = trainer_cls(**trainer_kwargs)
143
+ trainer.train(wait=False)
144
+
145
+ # Extract job info from the trainer
146
+ job_name = trainer.training_job_name
147
+ job_arn = getattr(trainer, "training_job_arn", None)
148
+
149
+ # Attempt to get MLflow URL if available
150
+ mlflow_url = None
151
+ try:
152
+ mlflow_url = getattr(trainer, "mlflow_tracking_uri", None)
153
+ except Exception:
154
+ pass
155
+
156
+ _output({
157
+ "job_name": job_name,
158
+ "job_arn": job_arn or "",
159
+ "mlflow_url": mlflow_url,
160
+ "model_package_group": args.model_package_group or "",
161
+ })
162
+
163
+ except Exception as e:
164
+ error_msg = str(e)
165
+ # Provide helpful context for common errors
166
+ if "AccessDeniedException" in error_msg or "AccessDenied" in error_msg:
167
+ _error_exit(
168
+ f"Access denied when submitting training job. "
169
+ f"Ensure the role has sagemaker:CreateTrainingJob permission. "
170
+ f"Details: {error_msg}"
171
+ )
172
+ elif "ResourceLimitExceeded" in error_msg:
173
+ _error_exit(
174
+ f"Resource limit exceeded. You may need to request a quota increase. "
175
+ f"Details: {error_msg}"
176
+ )
177
+ elif "ValidationException" in error_msg and "license" in error_msg.lower():
178
+ _error_exit(
179
+ f"Model license not accepted. Accept the license in JumpStart before "
180
+ f"using this model for customization. Details: {error_msg}"
181
+ )
182
+ else:
183
+ _error_exit(f"Failed to submit training job: {error_msg}")
184
+
185
+
186
+ # ── Subcommand: status ────────────────────────────────────────────────────────
187
+
188
+
189
+ def cmd_status(args):
190
+ """Query job status via DescribeTrainingJob.
191
+
192
+ Returns: {"status": str, "failure_reason": str|None,
193
+ "metrics": dict|None, "elapsed_seconds": int}
194
+ """
195
+ import boto3
196
+
197
+ client = boto3.client("sagemaker", region_name=args.region)
198
+
199
+ try:
200
+ response = client.describe_training_job(TrainingJobName=args.job_name)
201
+ except client.exceptions.ClientError as e:
202
+ error_code = e.response["Error"]["Code"]
203
+ if error_code == "ValidationException":
204
+ _error_exit(f"Training job not found: {args.job_name}")
205
+ _error_exit(f"Failed to describe training job: {e}")
206
+ except Exception as e:
207
+ _error_exit(f"Failed to describe training job: {e}")
208
+
209
+ status = response.get("TrainingJobStatus", "Unknown")
210
+ failure_reason = response.get("FailureReason")
211
+
212
+ # Calculate elapsed time
213
+ start_time = response.get("TrainingStartTime")
214
+ end_time = response.get("TrainingEndTime")
215
+ elapsed_seconds = 0
216
+
217
+ if start_time:
218
+ end = end_time if end_time else time.time()
219
+ if hasattr(end, "timestamp"):
220
+ end = end.timestamp()
221
+ elapsed_seconds = int(end - start_time.timestamp())
222
+
223
+ # Extract final metrics if available
224
+ metrics = None
225
+ final_metrics = response.get("FinalMetricDataList")
226
+ if final_metrics:
227
+ metrics = {}
228
+ for metric in final_metrics:
229
+ metrics[metric["MetricName"]] = metric["Value"]
230
+
231
+ # Get output path if completed
232
+ output_path = None
233
+ if status == "Completed":
234
+ model_artifacts = response.get("ModelArtifacts", {})
235
+ output_path = model_artifacts.get("S3ModelArtifacts")
236
+
237
+ _output({
238
+ "status": status,
239
+ "failure_reason": failure_reason,
240
+ "metrics": metrics,
241
+ "elapsed_seconds": elapsed_seconds,
242
+ "output_path": output_path,
243
+ })
244
+
245
+
246
+ # ── Subcommand: resolve ───────────────────────────────────────────────────────
247
+
248
+
249
+ def cmd_resolve(args):
250
+ """Resolve artifact path within S3 output directory.
251
+
252
+ Returns: {"artifact_path": str, "model_package_arn": str|None,
253
+ "output_type": str}
254
+ """
255
+ import boto3
256
+
257
+ client = boto3.client("sagemaker", region_name=args.region)
258
+
259
+ try:
260
+ response = client.describe_training_job(TrainingJobName=args.job_name)
261
+ except Exception as e:
262
+ _error_exit(f"Failed to describe training job: {e}")
263
+
264
+ status = response.get("TrainingJobStatus")
265
+ if status != "Completed":
266
+ _error_exit(
267
+ f"Cannot resolve artifacts for job in status: {status}. "
268
+ f"Job must be Completed."
269
+ )
270
+
271
+ # Get the S3 model artifacts path
272
+ model_artifacts = response.get("ModelArtifacts", {})
273
+ artifact_path = model_artifacts.get("S3ModelArtifacts", "")
274
+
275
+ if not artifact_path:
276
+ _error_exit("No model artifacts found in training job output.")
277
+
278
+ # Determine output type from training type
279
+ output_type = "adapter" if args.training_type == "lora" else "full-model"
280
+
281
+ # Try to find model package ARN if a model package group was used
282
+ model_package_arn = None
283
+ if args.model_package_group:
284
+ try:
285
+ mp_client = boto3.client("sagemaker", region_name=args.region)
286
+ packages = mp_client.list_model_packages(
287
+ ModelPackageGroupName=args.model_package_group,
288
+ SortBy="CreationTime",
289
+ SortOrder="Descending",
290
+ MaxResults=1,
291
+ )
292
+ package_list = packages.get("ModelPackageSummaryList", [])
293
+ if package_list:
294
+ model_package_arn = package_list[0].get("ModelPackageArn")
295
+ except Exception:
296
+ # Model package lookup is best-effort
297
+ pass
298
+
299
+ _output({
300
+ "artifact_path": artifact_path,
301
+ "model_package_arn": model_package_arn,
302
+ "output_type": output_type,
303
+ })
304
+
305
+
306
+ # ── Subcommand: stage-hf ─────────────────────────────────────────────────────
307
+
308
+
309
+ def cmd_stage_hf(args):
310
+ """Download HF dataset to S3 using huggingface_hub.
311
+
312
+ Handles auth via Secrets Manager or HF_TOKEN env var.
313
+
314
+ Returns: {"s3_uri": str, "num_records": int}
315
+ """
316
+ try:
317
+ from huggingface_hub import hf_hub_download, HfApi
318
+ except ImportError:
319
+ _error_exit(
320
+ "huggingface_hub is not installed. "
321
+ "Please install: pip install huggingface_hub"
322
+ )
323
+
324
+ import boto3
325
+ import tempfile
326
+
327
+ # Resolve HF token: Secrets Manager first, then env var
328
+ hf_token = _resolve_hf_token(args.region, args.hf_secret_name)
329
+
330
+ # Parse the HF reference
331
+ org = args.hf_org
332
+ name = args.hf_name
333
+ split = args.hf_split or "train"
334
+ dataset_id = f"{org}/{name}"
335
+
336
+ # Download dataset files to a temp directory
337
+ try:
338
+ api = HfApi(token=hf_token)
339
+
340
+ # List files in the dataset repo
341
+ repo_files = api.list_repo_files(
342
+ repo_id=dataset_id,
343
+ repo_type="dataset",
344
+ token=hf_token,
345
+ )
346
+
347
+ # Find the appropriate data file for the split
348
+ data_files = _find_data_files(repo_files, split)
349
+ if not data_files:
350
+ _error_exit(
351
+ f"No data files found for split '{split}' in dataset {dataset_id}. "
352
+ f"Available files: {', '.join(repo_files[:20])}"
353
+ )
354
+
355
+ # Download and upload to S3
356
+ s3_client = boto3.client("s3", region_name=args.region)
357
+ s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
358
+ num_records = 0
359
+
360
+ with tempfile.TemporaryDirectory() as tmpdir:
361
+ for data_file in data_files:
362
+ local_path = hf_hub_download(
363
+ repo_id=dataset_id,
364
+ filename=data_file,
365
+ repo_type="dataset",
366
+ token=hf_token,
367
+ local_dir=tmpdir,
368
+ )
369
+
370
+ # Count records (lines for JSONL)
371
+ with open(local_path, "r") as f:
372
+ for line in f:
373
+ if line.strip():
374
+ num_records += 1
375
+
376
+ # Upload to S3
377
+ s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
378
+ s3_client.upload_file(local_path, args.output_bucket, s3_key)
379
+
380
+ s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{os.path.basename(data_files[0])}"
381
+
382
+ _output({
383
+ "s3_uri": s3_uri,
384
+ "num_records": num_records,
385
+ })
386
+
387
+ except Exception as e:
388
+ error_msg = str(e)
389
+ if "404" in error_msg or "not found" in error_msg.lower():
390
+ _error_exit(
391
+ f"Dataset not found: {dataset_id}. "
392
+ f"Check the dataset name and ensure it exists on Hugging Face Hub."
393
+ )
394
+ elif "401" in error_msg or "unauthorized" in error_msg.lower():
395
+ _error_exit(
396
+ f"Authentication failed for dataset {dataset_id}. "
397
+ f"Ensure HF_TOKEN is set or configured via Secrets Manager."
398
+ )
399
+ else:
400
+ _error_exit(f"Failed to stage HF dataset: {error_msg}")
401
+
402
+
403
+ def _resolve_hf_token(region, secret_name=None):
404
+ """Resolve HF token from Secrets Manager or environment variable.
405
+
406
+ Args:
407
+ region: AWS region for Secrets Manager
408
+ secret_name: Optional Secrets Manager secret name/ARN
409
+
410
+ Returns:
411
+ str or None: The HF token, or None if not available
412
+ """
413
+ # Try Secrets Manager first if a secret name is provided
414
+ if secret_name:
415
+ try:
416
+ import boto3
417
+ client = boto3.client("secretsmanager", region_name=region)
418
+ response = client.get_secret_value(SecretId=secret_name)
419
+ secret_value = response.get("SecretString", "")
420
+ if secret_value:
421
+ return secret_value.strip()
422
+ except Exception:
423
+ # Fall through to env var
424
+ pass
425
+
426
+ # Fall back to HF_TOKEN environment variable
427
+ return os.environ.get("HF_TOKEN")
428
+
429
+
430
+ def _find_data_files(repo_files, split):
431
+ """Find data files matching the requested split.
432
+
433
+ Looks for common patterns: data/{split}.jsonl, {split}.jsonl,
434
+ data/{split}-*.parquet, etc.
435
+
436
+ Args:
437
+ repo_files: List of file paths in the repo
438
+ split: The dataset split name (e.g., "train")
439
+
440
+ Returns:
441
+ list: Matching file paths
442
+ """
443
+ # Priority order for file matching
444
+ patterns = [
445
+ f"data/{split}.jsonl",
446
+ f"{split}.jsonl",
447
+ f"data/{split}.json",
448
+ f"{split}.json",
449
+ f"data/{split}-00000-of-",
450
+ f"{split}-00000-of-",
451
+ ]
452
+
453
+ # Exact match first
454
+ for pattern in patterns[:4]:
455
+ if pattern in repo_files:
456
+ return [pattern]
457
+
458
+ # Prefix match for sharded files
459
+ matches = []
460
+ for f in repo_files:
461
+ for pattern in patterns[4:]:
462
+ if pattern in f:
463
+ matches.append(f)
464
+
465
+ if matches:
466
+ return sorted(matches)
467
+
468
+ # Fallback: any JSONL file containing the split name
469
+ jsonl_files = [f for f in repo_files if f.endswith(".jsonl") and split in f]
470
+ if jsonl_files:
471
+ return sorted(jsonl_files)
472
+
473
+ # Last resort: any JSONL file in data/ directory
474
+ data_jsonl = [f for f in repo_files if f.startswith("data/") and f.endswith(".jsonl")]
475
+ if data_jsonl:
476
+ return sorted(data_jsonl)
477
+
478
+ return []
479
+
480
+
481
+ # ── Subcommand: validate ──────────────────────────────────────────────────────
482
+
483
+
484
+ def cmd_validate(args):
485
+ """Validate dataset format against expected schema.
486
+
487
+ The schema is passed as a JSON string argument.
488
+
489
+ Returns: {"valid": bool, "error": str|None, "line_number": int|None,
490
+ "malformed_line": str|None}
491
+ """
492
+ # Parse the schema from JSON argument
493
+ try:
494
+ schema = json.loads(args.schema)
495
+ except json.JSONDecodeError as e:
496
+ _error_exit(f"Invalid schema JSON: {e}")
497
+
498
+ required_keys = schema.get("required", [])
499
+ type_map = schema.get("types", {})
500
+
501
+ # Read lines from stdin or file
502
+ lines = []
503
+ if args.file and args.file != "-":
504
+ try:
505
+ with open(args.file, "r") as f:
506
+ for i, line in enumerate(f):
507
+ lines.append(line.rstrip("\n"))
508
+ if i >= 9: # Only inspect first 10 lines
509
+ break
510
+ except FileNotFoundError:
511
+ _error_exit(f"Dataset file not found: {args.file}")
512
+ except Exception as e:
513
+ _error_exit(f"Failed to read dataset file: {e}")
514
+ else:
515
+ # Read from stdin
516
+ for i, line in enumerate(sys.stdin):
517
+ lines.append(line.rstrip("\n"))
518
+ if i >= 9: # Only inspect first 10 lines
519
+ break
520
+
521
+ # Validate each line
522
+ for i, line in enumerate(lines):
523
+ line_number = i + 1
524
+
525
+ # Skip empty lines
526
+ if not line or not line.strip():
527
+ continue
528
+
529
+ # Try to parse as JSON
530
+ try:
531
+ parsed = json.loads(line)
532
+ except json.JSONDecodeError as e:
533
+ _output({
534
+ "valid": False,
535
+ "error": f"Line {line_number} is not valid JSON: {e}",
536
+ "line_number": line_number,
537
+ "malformed_line": line,
538
+ "expected_format": _build_expected_format(schema),
539
+ })
540
+ return
541
+
542
+ # Check that parsed value is a dict
543
+ if not isinstance(parsed, dict):
544
+ _output({
545
+ "valid": False,
546
+ "error": f"Line {line_number} must be a JSON object.",
547
+ "line_number": line_number,
548
+ "malformed_line": line,
549
+ "expected_format": _build_expected_format(schema),
550
+ })
551
+ return
552
+
553
+ # Check required keys
554
+ for key in required_keys:
555
+ if key not in parsed:
556
+ _output({
557
+ "valid": False,
558
+ "error": f'Line {line_number} is missing required key "{key}".',
559
+ "line_number": line_number,
560
+ "malformed_line": line,
561
+ "expected_format": _build_expected_format(schema),
562
+ })
563
+ return
564
+
565
+ # Check types if specified
566
+ for key, expected_type in type_map.items():
567
+ if key not in parsed:
568
+ continue
569
+
570
+ value = parsed[key]
571
+ if not _check_type(value, expected_type):
572
+ actual_type = _get_type(value)
573
+ _output({
574
+ "valid": False,
575
+ "error": (
576
+ f'Line {line_number} has key "{key}" with wrong type. '
577
+ f'Expected "{expected_type}", got "{actual_type}".'
578
+ ),
579
+ "line_number": line_number,
580
+ "malformed_line": line,
581
+ "expected_format": _build_expected_format(schema),
582
+ })
583
+ return
584
+
585
+ _output({
586
+ "valid": True,
587
+ "error": None,
588
+ "line_number": None,
589
+ "malformed_line": None,
590
+ })
591
+
592
+
593
+ def _check_type(value, expected_type):
594
+ """Check if a value matches the expected schema type.
595
+
596
+ Args:
597
+ value: The value to check
598
+ expected_type: One of "string", "array", "object", "number"
599
+
600
+ Returns:
601
+ bool: True if the value matches the expected type
602
+ """
603
+ if expected_type == "string":
604
+ return isinstance(value, str)
605
+ elif expected_type == "number":
606
+ return isinstance(value, (int, float))
607
+ elif expected_type == "array":
608
+ return isinstance(value, list)
609
+ elif expected_type == "object":
610
+ return isinstance(value, dict)
611
+ return True
612
+
613
+
614
+ def _get_type(value):
615
+ """Get a human-readable type name for a value."""
616
+ if value is None:
617
+ return "null"
618
+ if isinstance(value, list):
619
+ return "array"
620
+ if isinstance(value, dict):
621
+ return "object"
622
+ if isinstance(value, bool):
623
+ return "boolean"
624
+ if isinstance(value, (int, float)):
625
+ return "number"
626
+ if isinstance(value, str):
627
+ return "string"
628
+ return type(value).__name__
629
+
630
+
631
+ def _build_expected_format(schema):
632
+ """Build a human-readable expected format description from a schema.
633
+
634
+ Args:
635
+ schema: The dataset schema dict
636
+
637
+ Returns:
638
+ str: Description of expected format
639
+ """
640
+ required = schema.get("required", [])
641
+ types = schema.get("types", {})
642
+
643
+ fields = []
644
+ for key in required:
645
+ field_type = types.get(key, "any")
646
+ fields.append(f'"{key}": <{field_type}>')
647
+
648
+ return "Each line must be a JSON object with: {" + ", ".join(fields) + "}"
649
+
650
+
651
+ # ── CLI argument parsing ──────────────────────────────────────────────────────
652
+
653
+
654
+ def main():
655
+ parser = argparse.ArgumentParser(
656
+ description="SageMaker Managed Model Customization helper",
657
+ formatter_class=argparse.RawDescriptionHelpFormatter,
658
+ )
659
+ subparsers = parser.add_subparsers(dest="command", help="Subcommand to run")
660
+
661
+ # ── submit ────────────────────────────────────────────────────────────────
662
+ submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
663
+ submit_parser.add_argument("--model-id", required=True, help="JumpStart model ID")
664
+ submit_parser.add_argument("--technique", required=True,
665
+ choices=["sft", "dpo", "rlaif", "rlvr"],
666
+ help="Customization technique")
667
+ submit_parser.add_argument("--training-type", required=True,
668
+ choices=["lora", "full-rank"],
669
+ help="Training type (lora or full-rank)")
670
+ submit_parser.add_argument("--dataset-s3-uri", required=True,
671
+ help="S3 URI of the training dataset")
672
+ submit_parser.add_argument("--output-bucket", required=True,
673
+ help="S3 bucket for output artifacts")
674
+ submit_parser.add_argument("--role-arn", required=True,
675
+ help="IAM execution role ARN")
676
+ submit_parser.add_argument("--job-name", required=True,
677
+ help="Unique job name")
678
+ submit_parser.add_argument("--project-name", required=True,
679
+ help="Project name for S3 path prefix")
680
+ submit_parser.add_argument("--model-package-group", default=None,
681
+ help="Model package group name for registration")
682
+ submit_parser.add_argument("--epochs", type=int, default=None,
683
+ help="Number of training epochs")
684
+ submit_parser.add_argument("--learning-rate", type=float, default=None,
685
+ help="Learning rate")
686
+ submit_parser.add_argument("--max-seq-length", type=int, default=None,
687
+ help="Maximum sequence length")
688
+ submit_parser.add_argument("--lora-rank", type=int, default=None,
689
+ help="LoRA rank")
690
+ submit_parser.add_argument("--lora-alpha", type=int, default=None,
691
+ help="LoRA alpha scaling factor")
692
+ submit_parser.add_argument("--batch-size", type=int, default=None,
693
+ help="Global batch size")
694
+ submit_parser.add_argument("--reward-function", default=None,
695
+ help="Lambda ARN for reward function (RLVR)")
696
+ submit_parser.add_argument("--reward-prompt", default=None,
697
+ help="S3 URI for reward prompt (RLAIF)")
698
+
699
+ # ── status ────────────────────────────────────────────────────────────────
700
+ status_parser = subparsers.add_parser("status", help="Get job status and metrics")
701
+ status_parser.add_argument("--job-name", required=True,
702
+ help="Training job name")
703
+ status_parser.add_argument("--region", required=True,
704
+ help="AWS region")
705
+
706
+ # ── resolve ───────────────────────────────────────────────────────────────
707
+ resolve_parser = subparsers.add_parser("resolve",
708
+ help="Resolve output artifact path")
709
+ resolve_parser.add_argument("--job-name", required=True,
710
+ help="Training job name")
711
+ resolve_parser.add_argument("--region", required=True,
712
+ help="AWS region")
713
+ resolve_parser.add_argument("--training-type", required=True,
714
+ choices=["lora", "full-rank"],
715
+ help="Training type used for the job")
716
+ resolve_parser.add_argument("--model-package-group", default=None,
717
+ help="Model package group name")
718
+
719
+ # ── stage-hf ──────────────────────────────────────────────────────────────
720
+ stage_hf_parser = subparsers.add_parser("stage-hf",
721
+ help="Download HF dataset to S3")
722
+ stage_hf_parser.add_argument("--hf-org", required=True,
723
+ help="Hugging Face organization/user")
724
+ stage_hf_parser.add_argument("--hf-name", required=True,
725
+ help="Hugging Face dataset name")
726
+ stage_hf_parser.add_argument("--hf-split", default="train",
727
+ help="Dataset split (default: train)")
728
+ stage_hf_parser.add_argument("--output-bucket", required=True,
729
+ help="S3 bucket for staged dataset")
730
+ stage_hf_parser.add_argument("--project-name", required=True,
731
+ help="Project name for S3 path prefix")
732
+ stage_hf_parser.add_argument("--region", required=True,
733
+ help="AWS region")
734
+ stage_hf_parser.add_argument("--hf-secret-name", default=None,
735
+ help="Secrets Manager secret name for HF token")
736
+
737
+ # ── validate ──────────────────────────────────────────────────────────────
738
+ validate_parser = subparsers.add_parser("validate",
739
+ help="Validate dataset format")
740
+ validate_parser.add_argument("--schema", required=True,
741
+ help="JSON string of the expected dataset schema")
742
+ validate_parser.add_argument("--file", default="-",
743
+ help="Path to dataset file (default: stdin)")
744
+
745
+ # ── Parse and dispatch ────────────────────────────────────────────────────
746
+ args = parser.parse_args()
747
+
748
+ if not args.command:
749
+ parser.print_help()
750
+ sys.exit(1)
751
+
752
+ command_map = {
753
+ "submit": cmd_submit,
754
+ "status": cmd_status,
755
+ "resolve": cmd_resolve,
756
+ "stage-hf": cmd_stage_hf,
757
+ "validate": cmd_validate,
758
+ }
759
+
760
+ handler = command_map.get(args.command)
761
+ if handler:
762
+ handler(args)
763
+ else:
764
+ _error_exit(f"Unknown command: {args.command}")
765
+
766
+
767
+ if __name__ == "__main__":
768
+ main()