@aws/ml-container-creator 0.13.4 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/README.md +23 -5
  2. package/config/parameter-schema-v2.json +32 -4
  3. package/infra/ci-harness/lib/ci-harness-stack.ts +13 -5
  4. package/infra/ci-harness/package-lock.json +122 -116
  5. package/infra/ci-harness/package.json +1 -1
  6. package/package.json +5 -3
  7. package/pyproject.toml +21 -0
  8. package/requirements.txt +19 -0
  9. package/servers/instance-sizer/index.js +72 -4
  10. package/servers/instance-sizer/lib/model-resolver.js +28 -2
  11. package/src/app.js +17 -0
  12. package/src/lib/bootstrap-command-handler.js +33 -23
  13. package/src/lib/config-loader.js +18 -0
  14. package/src/lib/config-manager.js +6 -1
  15. package/src/lib/dataset-slug.js +152 -0
  16. package/src/lib/generated/cli-options.js +9 -3
  17. package/src/lib/generated/parameter-matrix.js +14 -3
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +6 -0
  20. package/src/lib/prompt-runner.js +5 -0
  21. package/src/lib/prompts/feature-prompts.js +1 -1
  22. package/src/lib/template-manager.js +0 -7
  23. package/src/lib/template-variable-resolver.js +51 -1
  24. package/src/lib/tune-config-state.js +14 -1
  25. package/templates/do/.adapter_helper.py +451 -0
  26. package/templates/do/.benchmark_writer.py +22 -0
  27. package/templates/do/.register_helper.py +1163 -0
  28. package/templates/do/.stage_helper.py +419 -0
  29. package/templates/do/.tune_helper.py +379 -65
  30. package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
  31. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  32. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  33. package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
  34. package/templates/do/adapter +427 -27
  35. package/templates/do/add-ic +85 -3
  36. package/templates/do/benchmark +173 -15
  37. package/templates/do/config +24 -0
  38. package/templates/do/lib/inference-component.sh +56 -3
  39. package/templates/do/lib/profile.sh +5 -0
  40. package/templates/do/register +552 -6
  41. package/templates/do/stage +91 -272
  42. package/templates/do/test +12 -2
  43. package/templates/do/tune +264 -12
@@ -29,6 +29,12 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
29
29
  warnings.filterwarnings("ignore", message=".*urllib3.*")
30
30
  warnings.filterwarnings("ignore", message=".*charset_normalizer.*")
31
31
 
32
+ # Suppress ALL logging to prevent sagemaker-core/rich from writing to stdout.
33
+ # This script outputs JSON on stdout — any other stdout output corrupts parsing.
34
+ import logging as _logging
35
+ _logging.disable(_logging.CRITICAL)
36
+ os.environ.setdefault("SAGEMAKER_LOG_LEVEL", "CRITICAL")
37
+
32
38
  # ── Inline dependency check ───────────────────────────────────────────────────
33
39
  MIN_SAGEMAKER_VERSION = "3.0"
34
40
 
@@ -71,6 +77,164 @@ def _output(data):
71
77
  sys.exit(0)
72
78
 
73
79
 
80
+ def _sanitize_for_json(value):
81
+ """Convert sagemaker-core Unassigned sentinel values to None for JSON serialization.
82
+
83
+ sagemaker-core uses an 'Unassigned' type instead of None for unset fields.
84
+ This function converts any non-standard types to JSON-safe values.
85
+ """
86
+ if value is None:
87
+ return None
88
+ # Check for Unassigned type from sagemaker-core
89
+ type_name = type(value).__name__
90
+ if type_name == "Unassigned" or type_name == "UnassignedValue":
91
+ return None
92
+ if isinstance(value, (str, int, float, bool)):
93
+ return value
94
+ if isinstance(value, dict):
95
+ return {k: _sanitize_for_json(v) for k, v in value.items()}
96
+ if isinstance(value, list):
97
+ return [_sanitize_for_json(v) for v in value]
98
+ # For other types, try str conversion as fallback
99
+ try:
100
+ # Check if it's JSON serializable as-is
101
+ import json as _json
102
+ _json.dumps(value)
103
+ return value
104
+ except (TypeError, ValueError):
105
+ return str(value) if value else None
106
+
107
+
108
+ # ── Registry resolution helpers ───────────────────────────────────────────────
109
+
110
+
111
+ def _resolve_dataset_name(dataset_name):
112
+ """Resolve a registered dataset name to S3 URI (or ARN) via .register_helper.py.
113
+
114
+ Calls the resolve-dataset subcommand of .register_helper.py and returns
115
+ the resolved value. If the response contains an 'arn' field (Backlog #023,
116
+ AI Registry mode), returns the ARN for use with SFTTrainer(training_dataset=arn).
117
+ Otherwise returns the S3 URI for backward compatibility.
118
+ """
119
+ import subprocess
120
+
121
+ script_dir = os.path.dirname(os.path.abspath(__file__))
122
+ helper_path = os.path.join(script_dir, ".register_helper.py")
123
+
124
+ if not os.path.exists(helper_path):
125
+ _error_exit(
126
+ f"Cannot resolve dataset '{dataset_name}': .register_helper.py not found. "
127
+ f"Register datasets first with: ./do/register --dataset"
128
+ )
129
+
130
+ try:
131
+ result = subprocess.run(
132
+ ["python3", helper_path, "resolve-dataset", "--name", dataset_name],
133
+ capture_output=True, text=True, timeout=30
134
+ )
135
+ except subprocess.TimeoutExpired:
136
+ _error_exit(f"Timeout resolving dataset '{dataset_name}' from registry")
137
+ except Exception as e:
138
+ _error_exit(f"Failed to resolve dataset '{dataset_name}': {e}")
139
+
140
+ if result.returncode != 0:
141
+ _error_exit(
142
+ f"Dataset '{dataset_name}' not found in registry. "
143
+ f"Register it first: ./do/register --dataset --dataset-name {dataset_name} --dataset-s3-uri s3://..."
144
+ )
145
+
146
+ # Parse JSON output from resolve-dataset
147
+ try:
148
+ output = json.loads(result.stdout.strip())
149
+ except (json.JSONDecodeError, ValueError):
150
+ _error_exit(
151
+ f"Failed to parse registry response for dataset '{dataset_name}'. "
152
+ f"Raw output: {result.stdout[:200]}"
153
+ )
154
+
155
+ if "error" in output:
156
+ _error_exit(
157
+ f"Dataset '{dataset_name}' not found in registry: {output['error']}. "
158
+ f"Register it first: ./do/register --dataset --dataset-name {dataset_name} --dataset-s3-uri s3://..."
159
+ )
160
+
161
+ # Prefer ARN if available (Backlog #023 — AI Registry mode)
162
+ # When arn is present, use it directly with SFTTrainer(training_dataset=arn)
163
+ arn = output.get("arn")
164
+ if arn:
165
+ return arn
166
+
167
+ # Fallback: use S3 URI
168
+ s3_uri = output.get("s3_uri", "")
169
+ if not s3_uri:
170
+ _error_exit(
171
+ f"Dataset '{dataset_name}' resolved but has no S3 URI or ARN. "
172
+ f"Re-register with: ./do/register --dataset --dataset-name {dataset_name} --dataset-s3-uri s3://..."
173
+ )
174
+
175
+ return s3_uri
176
+
177
+
178
+ def _resolve_evaluator_name(evaluator_name):
179
+ """Resolve a registered evaluator name to type and ARN/URI via .register_helper.py.
180
+
181
+ Returns (evaluator_type, arn_or_uri) tuple.
182
+ evaluator_type is "lambda" for RLVR or "model" for RLAIF.
183
+ """
184
+ import subprocess
185
+
186
+ script_dir = os.path.dirname(os.path.abspath(__file__))
187
+ helper_path = os.path.join(script_dir, ".register_helper.py")
188
+
189
+ if not os.path.exists(helper_path):
190
+ _error_exit(
191
+ f"Cannot resolve evaluator '{evaluator_name}': .register_helper.py not found. "
192
+ f"Register evaluators first with: ./do/register --evaluator"
193
+ )
194
+
195
+ try:
196
+ result = subprocess.run(
197
+ ["python3", helper_path, "resolve-evaluator", "--name", evaluator_name],
198
+ capture_output=True, text=True, timeout=30
199
+ )
200
+ except subprocess.TimeoutExpired:
201
+ _error_exit(f"Timeout resolving evaluator '{evaluator_name}' from registry")
202
+ except Exception as e:
203
+ _error_exit(f"Failed to resolve evaluator '{evaluator_name}': {e}")
204
+
205
+ if result.returncode != 0:
206
+ _error_exit(
207
+ f"Evaluator '{evaluator_name}' not found in registry. "
208
+ f"Register it first: ./do/register --evaluator --evaluator-name {evaluator_name} ..."
209
+ )
210
+
211
+ # Parse JSON output from resolve-evaluator
212
+ try:
213
+ output = json.loads(result.stdout.strip())
214
+ except (json.JSONDecodeError, ValueError):
215
+ _error_exit(
216
+ f"Failed to parse registry response for evaluator '{evaluator_name}'. "
217
+ f"Raw output: {result.stdout[:200]}"
218
+ )
219
+
220
+ if "error" in output:
221
+ _error_exit(
222
+ f"Evaluator '{evaluator_name}' not found in registry: {output['error']}. "
223
+ f"Register it first: ./do/register --evaluator --evaluator-name {evaluator_name} ..."
224
+ )
225
+
226
+ ev_type = output.get("type", "")
227
+ arn_or_uri = output.get("arn_or_uri", "")
228
+
229
+ if not arn_or_uri:
230
+ _error_exit(
231
+ f"Evaluator '{evaluator_name}' resolved but has no ARN/URI. "
232
+ f"Re-register with: ./do/register --evaluator --evaluator-name {evaluator_name} ..."
233
+ )
234
+
235
+ return ev_type, arn_or_uri
236
+
237
+
74
238
  # ── Subcommand: submit ────────────────────────────────────────────────────────
75
239
 
76
240
 
@@ -90,6 +254,26 @@ def cmd_submit(args):
90
254
  os.environ["AWS_DEFAULT_REGION"] = region
91
255
  os.environ.setdefault("AWS_REGION", region)
92
256
 
257
+ # ── Resolve --dataset-name from registry (AC-2b.4) ────────────────────────
258
+ # --dataset-s3-uri wins if both are provided (backward compatible override)
259
+ if not args.dataset_s3_uri and args.dataset_name:
260
+ resolved_uri = _resolve_dataset_name(args.dataset_name)
261
+ args.dataset_s3_uri = resolved_uri
262
+ elif not args.dataset_s3_uri and not args.dataset_name:
263
+ _error_exit(
264
+ "Either --dataset-s3-uri or --dataset-name is required. "
265
+ "Provide an S3 URI directly or a registered dataset name."
266
+ )
267
+
268
+ # ── Resolve --evaluator-name from registry (AC-2c.3, AC-2c.4) ────────────
269
+ # --reward-function / --reward-prompt win if provided (backward compatible override)
270
+ if args.evaluator_name and not args.reward_function and not args.reward_prompt:
271
+ ev_type, ev_arn_or_uri = _resolve_evaluator_name(args.evaluator_name)
272
+ if ev_type == "lambda":
273
+ args.reward_function = ev_arn_or_uri
274
+ else:
275
+ args.reward_prompt = ev_arn_or_uri
276
+
93
277
  _check_sagemaker_sdk()
94
278
 
95
279
  # SDK v3 moved trainers from sagemaker.modules.train → sagemaker.train
@@ -171,20 +355,25 @@ def cmd_submit(args):
171
355
  trainer_kwargs["accept_eula"] = True
172
356
 
173
357
  # Resolve model package group — create if it doesn't exist
358
+ # Using sagemaker-core ModelPackageGroup.create() per SDK v3 policy
174
359
  mpg_name = args.model_package_group or f"{args.project_name}-tune-models"
175
360
  try:
176
- import boto3 as _boto3
177
- _sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
178
- _sm.describe_model_package_group(ModelPackageGroupName=mpg_name)
179
- except Exception as _mpg_err:
180
- if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
181
- try:
182
- _sm.create_model_package_group(
183
- ModelPackageGroupName=mpg_name,
184
- ModelPackageGroupDescription=f"Fine-tuned models for {args.project_name}",
185
- )
186
- except Exception:
187
- pass # May already exist or lack permissions — let the trainer handle it
361
+ from sagemaker.core.resources import ModelPackageGroup
362
+ from botocore.exceptions import ClientError as _ClientError
363
+ try:
364
+ ModelPackageGroup.get(model_package_group_name=mpg_name)
365
+ except (_ClientError, Exception) as _mpg_err:
366
+ if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
367
+ try:
368
+ ModelPackageGroup.create(
369
+ model_package_group_name=mpg_name,
370
+ model_package_group_description=f"Fine-tuned models for {args.project_name}",
371
+ )
372
+ except Exception:
373
+ pass # May already exist or lack permissions — let the trainer handle it
374
+ except ImportError:
375
+ # sagemaker-core not available — skip MPG creation, let trainer handle it
376
+ pass
188
377
  trainer_kwargs["model_package_group"] = mpg_name
189
378
 
190
379
  trainer = trainer_cls(**trainer_kwargs)
@@ -267,7 +456,9 @@ def cmd_submit(args):
267
456
  job_arn = job_arn or getattr(latest_job, 'arn', None)
268
457
 
269
458
  # If we still don't have the actual job name (SDK appends suffix),
270
- # query ListTrainingJobs to find it by our base_job_name prefix
459
+ # query ListTrainingJobs to find it by our base_job_name prefix.
460
+ # Note: list_training_jobs with NameContains filter is not available
461
+ # via sagemaker-core resource API, so boto3 is retained here.
271
462
  if not job_name or job_name == args.job_name:
272
463
  import boto3 as _boto3
273
464
  _sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
@@ -335,23 +526,28 @@ def cmd_submit(args):
335
526
 
336
527
 
337
528
  def cmd_status(args):
338
- """Query job status via DescribeTrainingJob.
529
+ """Query job status via sagemaker-core TrainingJob.get().
339
530
 
340
- Falls back to ListTrainingJobs with name-contains if exact name not found
531
+ Falls back to boto3 ListTrainingJobs with name-contains if exact name not found
341
532
  (SDK v3 appends a timestamp suffix to the base job name).
342
533
 
343
534
  Returns: {"status": str, "failure_reason": str|None,
344
535
  "metrics": dict|None, "elapsed_seconds": int}
345
536
  """
346
- import boto3
537
+ # Set region before any sagemaker import (creates boto3 clients at import time)
538
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
539
+ if region:
540
+ os.environ['AWS_DEFAULT_REGION'] = region
541
+ os.environ.setdefault('AWS_REGION', region)
347
542
 
348
- client = boto3.client("sagemaker", region_name=args.region)
543
+ from sagemaker.core.resources import TrainingJob
544
+ from botocore.exceptions import ClientError
349
545
 
350
- # Try exact name first
351
- response = None
546
+ # Try exact name first via sagemaker-core
547
+ job = None
352
548
  try:
353
- response = client.describe_training_job(TrainingJobName=args.job_name)
354
- except client.exceptions.ClientError as e:
549
+ job = TrainingJob.get(training_job_name=args.job_name)
550
+ except ClientError as e:
355
551
  error_code = e.response["Error"]["Code"]
356
552
  if error_code != "ValidationException":
357
553
  _error_exit(f"Failed to describe training job: {e}")
@@ -360,8 +556,13 @@ def cmd_status(args):
360
556
  _error_exit(f"Failed to describe training job: {e}")
361
557
 
362
558
  # Fallback: search by name prefix (SDK appends timestamp suffix)
363
- if response is None:
559
+ # Note: TrainingJob.get_all() with name_contains is not available in
560
+ # sagemaker-core for list operations, so we use boto3 list_training_jobs
561
+ # to find the actual name, then call TrainingJob.get() with it.
562
+ if job is None:
364
563
  try:
564
+ import boto3
565
+ client = boto3.client("sagemaker", region_name=args.region)
365
566
  list_response = client.list_training_jobs(
366
567
  NameContains=args.job_name,
367
568
  SortBy="CreationTime",
@@ -371,18 +572,26 @@ def cmd_status(args):
371
572
  summaries = list_response.get("TrainingJobSummaries", [])
372
573
  if summaries:
373
574
  actual_name = summaries[0]["TrainingJobName"]
374
- response = client.describe_training_job(TrainingJobName=actual_name)
575
+ job = TrainingJob.get(training_job_name=actual_name)
375
576
  else:
376
577
  _error_exit(f"Training job not found: {args.job_name}")
377
578
  except Exception as e:
378
579
  _error_exit(f"Failed to find training job: {e}")
379
580
 
380
- status = response.get("TrainingJobStatus", "Unknown")
381
- failure_reason = response.get("FailureReason")
581
+ # Read status attributes directly from the TrainingJob resource object.
582
+ # sagemaker-core returns status values in the same casing as the API
583
+ # (e.g., "InProgress", "Completed", "Failed", "Stopped").
584
+ status = getattr(job, "training_job_status", "Unknown") or "Unknown"
585
+ failure_reason = getattr(job, "failure_reason", None)
382
586
 
383
587
  # Calculate elapsed time
384
- start_time = response.get("TrainingStartTime")
385
- end_time = response.get("TrainingEndTime")
588
+ start_time = getattr(job, "training_start_time", None)
589
+ end_time = getattr(job, "training_end_time", None)
590
+ # Convert Unassigned sentinel to None
591
+ if start_time and type(start_time).__name__ in ("Unassigned", "UnassignedValue"):
592
+ start_time = None
593
+ if end_time and type(end_time).__name__ in ("Unassigned", "UnassignedValue"):
594
+ end_time = None
386
595
  elapsed_seconds = 0
387
596
 
388
597
  if start_time:
@@ -393,24 +602,30 @@ def cmd_status(args):
393
602
 
394
603
  # Extract final metrics if available
395
604
  metrics = None
396
- final_metrics = response.get("FinalMetricDataList")
605
+ final_metrics = getattr(job, "final_metric_data_list", None)
606
+ if final_metrics and type(final_metrics).__name__ in ("Unassigned", "UnassignedValue"):
607
+ final_metrics = None
397
608
  if final_metrics:
398
609
  metrics = {}
399
610
  for metric in final_metrics:
400
- metrics[metric["MetricName"]] = metric["Value"]
611
+ # sagemaker-core returns metrics as objects with snake_case attributes
612
+ metric_name = getattr(metric, "metric_name", None) or metric.get("MetricName", "")
613
+ metric_value = getattr(metric, "value", None) or metric.get("Value", 0)
614
+ metrics[metric_name] = metric_value
401
615
 
402
616
  # Get output path if completed
403
617
  output_path = None
404
618
  if status == "Completed":
405
- model_artifacts = response.get("ModelArtifacts", {})
406
- output_path = model_artifacts.get("S3ModelArtifacts")
619
+ model_artifacts = getattr(job, "model_artifacts", None)
620
+ if model_artifacts:
621
+ output_path = getattr(model_artifacts, "s3_model_artifacts", None)
407
622
 
408
623
  _output({
409
- "status": status,
410
- "failure_reason": failure_reason,
411
- "metrics": metrics,
624
+ "status": _sanitize_for_json(status),
625
+ "failure_reason": _sanitize_for_json(failure_reason),
626
+ "metrics": _sanitize_for_json(metrics),
412
627
  "elapsed_seconds": elapsed_seconds,
413
- "output_path": output_path,
628
+ "output_path": _sanitize_for_json(output_path),
414
629
  })
415
630
 
416
631
 
@@ -420,28 +635,37 @@ def cmd_status(args):
420
635
  def cmd_resolve(args):
421
636
  """Resolve artifact path within S3 output directory.
422
637
 
638
+ Uses sagemaker-core TrainingJob.get() to read model_artifacts and
639
+ output_data_config. Uses ModelPackage for model package lookup.
640
+
423
641
  Returns: {"artifact_path": str, "model_package_arn": str|None,
424
642
  "output_type": str}
425
643
  """
426
- import boto3
644
+ # Set region before any sagemaker import (creates boto3 clients at import time)
645
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
646
+ if region:
647
+ os.environ['AWS_DEFAULT_REGION'] = region
648
+ os.environ.setdefault('AWS_REGION', region)
427
649
 
428
- client = boto3.client("sagemaker", region_name=args.region)
650
+ from sagemaker.core.resources import TrainingJob
429
651
 
430
652
  try:
431
- response = client.describe_training_job(TrainingJobName=args.job_name)
653
+ job = TrainingJob.get(training_job_name=args.job_name)
432
654
  except Exception as e:
433
655
  _error_exit(f"Failed to describe training job: {e}")
434
656
 
435
- status = response.get("TrainingJobStatus")
657
+ status = getattr(job, "training_job_status", None)
436
658
  if status != "Completed":
437
659
  _error_exit(
438
660
  f"Cannot resolve artifacts for job in status: {status}. "
439
661
  f"Job must be Completed."
440
662
  )
441
663
 
442
- # Get the S3 model artifacts path
443
- model_artifacts = response.get("ModelArtifacts", {})
444
- artifact_path = model_artifacts.get("S3ModelArtifacts", "")
664
+ # Get the S3 model artifacts path from TrainingJob resource
665
+ model_artifacts = getattr(job, "model_artifacts", None)
666
+ artifact_path = ""
667
+ if model_artifacts:
668
+ artifact_path = getattr(model_artifacts, "s3_model_artifacts", "") or ""
445
669
 
446
670
  if not artifact_path:
447
671
  _error_exit("No model artifacts found in training job output.")
@@ -461,6 +685,9 @@ def cmd_resolve(args):
461
685
  model_package_arn = None
462
686
  if args.model_package_group:
463
687
  try:
688
+ # Use boto3 for list_model_packages since sagemaker-core ModelPackage
689
+ # doesn't have a direct list-by-group method with sort/limit
690
+ import boto3
464
691
  mp_client = boto3.client("sagemaker", region_name=args.region)
465
692
  packages = mp_client.list_model_packages(
466
693
  ModelPackageGroupName=args.model_package_group,
@@ -476,8 +703,8 @@ def cmd_resolve(args):
476
703
  pass
477
704
 
478
705
  _output({
479
- "artifact_path": artifact_path,
480
- "model_package_arn": model_package_arn,
706
+ "artifact_path": _sanitize_for_json(artifact_path),
707
+ "model_package_arn": _sanitize_for_json(model_package_arn),
481
708
  "output_type": output_type,
482
709
  })
483
710
 
@@ -765,11 +992,12 @@ def _get_schema_types(technique):
765
992
  return schemas.get(technique, {"prompt": "string", "completion": "string"})
766
993
 
767
994
 
768
- def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id):
995
+ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id, take=None):
769
996
  """Validate that the first record has required columns after mapping.
770
997
 
771
998
  Returns (mapped_record, column_map_dict) on success.
772
999
  Calls _error_exit with helpful suggestions on failure.
1000
+ If take is provided, includes --take N in the suggested command.
773
1001
  """
774
1002
  column_map = _parse_column_map(column_map_str)
775
1003
  mapped = _apply_column_map(first_record, column_map)
@@ -794,12 +1022,14 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
794
1022
  if suggestion:
795
1023
  lines.append(f"")
796
1024
  lines.append(f" 💡 Suggested fix:")
797
- lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}")
1025
+ take_suffix = f" --take {take}" if take else ""
1026
+ lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}{take_suffix}")
798
1027
  else:
799
1028
  lines.append(f"")
800
1029
  lines.append(f" 💡 Use --column-map to rename columns:")
801
1030
  example_map = ",".join(f"{r}=<your_column>" for r in missing)
802
- lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}")
1031
+ take_suffix = f" --take {take}" if take else ""
1032
+ lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}{take_suffix}")
803
1033
 
804
1034
  lines.append(f"")
805
1035
  lines.append(f" First record sample:")
@@ -811,6 +1041,16 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
811
1041
  _error_exit("\n".join(lines))
812
1042
 
813
1043
 
1044
+ def _check_empty_fields(record, required_columns):
1045
+ """Return list of required column names that are empty/blank in this record."""
1046
+ empty = []
1047
+ for col in required_columns:
1048
+ value = record.get(col, "")
1049
+ if value is None or (isinstance(value, str) and not value.strip()):
1050
+ empty.append(col)
1051
+ return empty
1052
+
1053
+
814
1054
  def cmd_stage_hf(args):
815
1055
  """Download HF dataset to S3 using huggingface_hub.
816
1056
 
@@ -854,21 +1094,35 @@ def cmd_stage_hf(args):
854
1094
 
855
1095
  # Find the appropriate data file for the split
856
1096
  data_files = _find_data_files(repo_files, split)
1097
+
1098
+ # Apply file filter if --hf-file is provided
1099
+ hf_file_pattern = getattr(args, 'hf_file', None)
1100
+
1101
+ if not data_files and hf_file_pattern:
1102
+ # Split-based lookup found nothing, but user specified a file filter.
1103
+ # Fall back to filtering directly from all data files in the repo.
1104
+ all_data_files = [
1105
+ f for f in repo_files
1106
+ if f.endswith(('.parquet', '.jsonl', '.json'))
1107
+ and not f.startswith('.')
1108
+ ]
1109
+ if all_data_files:
1110
+ data_files = _filter_data_files(all_data_files, hf_file_pattern)
1111
+ elif hf_file_pattern and data_files:
1112
+ # Normal case: apply file filter to split-matched results
1113
+ data_files = _filter_data_files(data_files, hf_file_pattern)
1114
+
857
1115
  if not data_files:
858
1116
  _error_exit(
859
1117
  f"No data files found for split '{split}' in dataset {dataset_id}. "
860
1118
  f"Available files: {', '.join(repo_files[:20])}"
861
1119
  )
862
1120
 
863
- # Apply file filter if --hf-file is provided
864
- hf_file_pattern = getattr(args, 'hf_file', None)
865
- if hf_file_pattern:
866
- data_files = _filter_data_files(data_files, hf_file_pattern)
867
-
868
1121
  # Download and upload to S3
869
1122
  s3_client = boto3.client("s3", region_name=args.region)
870
1123
  s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
871
1124
  num_records = 0
1125
+ empty_field_counts = {} # Track empty required fields: {field_name: count}
872
1126
 
873
1127
  with tempfile.TemporaryDirectory() as tmpdir:
874
1128
  # Schema divergence check (skip for single file)
@@ -907,7 +1161,7 @@ def cmd_stage_hf(args):
907
1161
  no_transform = getattr(args, 'no_transform', False)
908
1162
  batches = table.to_batches(max_chunksize=1)
909
1163
  first_record = batches[0].to_pylist()[0] if batches else {}
910
- _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
1164
+ _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
911
1165
 
912
1166
  # Apply column map to first record for detection
913
1167
  mapped_first = _apply_column_map(first_record, column_map)
@@ -944,16 +1198,31 @@ def cmd_stage_hf(args):
944
1198
  f" Detected format: {strategy_desc}"
945
1199
  )
946
1200
 
1201
+ take_limit = getattr(args, 'take', None)
947
1202
  with open(jsonl_path, "w", encoding="utf-8") as out_f:
948
1203
  for batch in table.to_batches():
949
1204
  for row in batch.to_pylist():
1205
+ if take_limit and num_records >= take_limit:
1206
+ break
950
1207
  mapped_row = _apply_column_map(row, column_map)
951
1208
  if chat_columns and not no_transform:
952
1209
  mapped_row = _flatten_record(mapped_row, chat_columns)
1210
+ # Track empty required fields
1211
+ for col in _check_empty_fields(mapped_row, required_columns):
1212
+ empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
953
1213
  out_f.write(json_mod.dumps(mapped_row, ensure_ascii=False) + "\n")
954
1214
  num_records += 1
1215
+ if take_limit and num_records >= take_limit:
1216
+ break
955
1217
 
956
1218
  # Upload converted JSONL
1219
+ # Verify file has content before uploading
1220
+ file_size = os.path.getsize(jsonl_path)
1221
+ if file_size == 0:
1222
+ _error_exit(
1223
+ f"Converted JSONL file is empty (0 bytes) after processing "
1224
+ f"{num_records} records. This is a bug — please report it."
1225
+ )
957
1226
  s3_key = f"{s3_prefix}/{jsonl_filename}"
958
1227
  s3_client.upload_file(jsonl_path, args.output_bucket, s3_key)
959
1228
 
@@ -975,7 +1244,7 @@ def cmd_stage_hf(args):
975
1244
  first_line = f.readline().strip()
976
1245
  if first_line:
977
1246
  first_record = json_mod.loads(first_line)
978
- _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
1247
+ _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
979
1248
 
980
1249
  # Apply column map to first record for detection
981
1250
  mapped_first = _apply_column_map(first_record, column_map)
@@ -1014,11 +1283,14 @@ def cmd_stage_hf(args):
1014
1283
 
1015
1284
  # Rewrite the file with mapped (and optionally flattened) columns
1016
1285
  should_flatten = bool(chat_columns) and not no_transform
1017
- if column_map or should_flatten:
1286
+ take_limit = getattr(args, 'take', None)
1287
+ if column_map or should_flatten or take_limit:
1018
1288
  mapped_path = local_path + ".mapped"
1019
1289
  with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
1020
1290
  open(mapped_path, "w", encoding="utf-8") as f_out:
1021
1291
  for line in f_in:
1292
+ if take_limit and num_records >= take_limit:
1293
+ break
1022
1294
  line = line.strip()
1023
1295
  if not line:
1024
1296
  continue
@@ -1026,15 +1298,32 @@ def cmd_stage_hf(args):
1026
1298
  mapped_record = _apply_column_map(record, column_map)
1027
1299
  if should_flatten:
1028
1300
  mapped_record = _flatten_record(mapped_record, chat_columns)
1301
+ # Track empty required fields
1302
+ for col in _check_empty_fields(mapped_record, _get_required_columns(technique)):
1303
+ empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
1029
1304
  f_out.write(json_mod.dumps(mapped_record, ensure_ascii=False) + "\n")
1030
1305
  num_records += 1
1031
1306
  local_path = mapped_path
1032
1307
  else:
1033
- # Count records
1034
- with open(local_path, "r", encoding="utf-8", errors="replace") as f:
1035
- for line in f:
1036
- if line.strip():
1037
- num_records += 1
1308
+ # Count records (and truncate if --take specified)
1309
+ take_limit = getattr(args, 'take', None)
1310
+ if take_limit:
1311
+ # Need to rewrite the file truncated
1312
+ mapped_path = local_path + ".mapped"
1313
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
1314
+ open(mapped_path, "w", encoding="utf-8") as f_out:
1315
+ for line in f_in:
1316
+ if num_records >= take_limit:
1317
+ break
1318
+ if line.strip():
1319
+ f_out.write(line)
1320
+ num_records += 1
1321
+ local_path = mapped_path
1322
+ else:
1323
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f:
1324
+ for line in f:
1325
+ if line.strip():
1326
+ num_records += 1
1038
1327
 
1039
1328
  # Upload to S3
1040
1329
  s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
@@ -1048,6 +1337,19 @@ def cmd_stage_hf(args):
1048
1337
  output_filename = os.path.basename(first_file)
1049
1338
  s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{output_filename}"
1050
1339
 
1340
+ # Warn if required columns have many empty values
1341
+ if num_records > 0 and empty_field_counts:
1342
+ for field, count in empty_field_counts.items():
1343
+ pct = (count / num_records) * 100
1344
+ if pct > 30:
1345
+ print(
1346
+ f"\u26a0\ufe0f Warning: {pct:.0f}% of records ({count}/{num_records}) "
1347
+ f"have empty '{field}' after column mapping.\n"
1348
+ f" SageMaker may reject these as invalid samples.\n"
1349
+ f" Consider using a different --column-map or dataset.",
1350
+ file=sys.stderr,
1351
+ )
1352
+
1051
1353
  _output({
1052
1354
  "s3_uri": s3_uri,
1053
1355
  "num_records": num_records,
@@ -1124,12 +1426,12 @@ def _find_data_files(repo_files, split):
1124
1426
  if pattern in repo_files:
1125
1427
  return [pattern]
1126
1428
 
1127
- # Prefix match for sharded files
1128
- matches = []
1429
+ # Prefix match for sharded files (deduplicate via set)
1430
+ matches = set()
1129
1431
  for f in repo_files:
1130
1432
  for pattern in patterns[4:]:
1131
1433
  if pattern in f:
1132
- matches.append(f)
1434
+ matches.add(f)
1133
1435
 
1134
1436
  if matches:
1135
1437
  return sorted(matches)
@@ -1508,6 +1810,10 @@ def _build_expected_format(schema):
1508
1810
  def cmd_discover(args):
1509
1811
  """Query JumpStart Hub for tune-eligible models matching a family.
1510
1812
 
1813
+ NOTE: This subcommand intentionally stays on boto3.client('sagemaker')
1814
+ because list_hub_contents / Hub API is NOT available in sagemaker-core.
1815
+ This is a documented exception per the SDK v3 migration policy.
1816
+
1511
1817
  Returns: {"models": [str], "count": int}
1512
1818
  """
1513
1819
  region = args.region or os.environ.get('AWS_REGION', 'us-east-1')
@@ -1532,6 +1838,8 @@ def cmd_discover(args):
1532
1838
  _error_exit("Hub discovery failed: boto3 is not installed. Install with: pip install boto3")
1533
1839
 
1534
1840
  try:
1841
+ # Documented exception: Hub API (list_hub_contents) is not available in
1842
+ # sagemaker-core, so we retain boto3.client('sagemaker') here.
1535
1843
  client = boto3.client("sagemaker", region_name=region)
1536
1844
  models = []
1537
1845
  paginator = client.get_paginator('list_hub_contents')
@@ -1573,8 +1881,10 @@ def main():
1573
1881
  submit_parser.add_argument("--training-type", required=True,
1574
1882
  choices=["lora", "full-rank"],
1575
1883
  help="Training type (lora or full-rank)")
1576
- submit_parser.add_argument("--dataset-s3-uri", required=True,
1577
- help="S3 URI of the training dataset")
1884
+ submit_parser.add_argument("--dataset-s3-uri", required=False, default=None,
1885
+ help="S3 URI of the training dataset (direct override)")
1886
+ submit_parser.add_argument("--dataset-name", default=None,
1887
+ help="Registered dataset name to resolve from registry")
1578
1888
  submit_parser.add_argument("--output-bucket", required=True,
1579
1889
  help="S3 bucket for output artifacts")
1580
1890
  submit_parser.add_argument("--role-arn", required=True,
@@ -1601,6 +1911,8 @@ def main():
1601
1911
  help="Lambda ARN for reward function (RLVR)")
1602
1912
  submit_parser.add_argument("--reward-prompt", default=None,
1603
1913
  help="S3 URI for reward prompt (RLAIF)")
1914
+ submit_parser.add_argument("--evaluator-name", default=None,
1915
+ help="Registered evaluator name to resolve from registry")
1604
1916
  submit_parser.add_argument("--accept-eula", action="store_true", default=False,
1605
1917
  help="Accept model EULA for gated models (e.g., Llama)")
1606
1918
 
@@ -1650,6 +1962,8 @@ def main():
1650
1962
  help="Customization technique (determines required columns)")
1651
1963
  stage_hf_parser.add_argument("--no-transform", action="store_true", default=False,
1652
1964
  help="Disable automatic chat-format flattening")
1965
+ stage_hf_parser.add_argument("--take", type=int, default=None,
1966
+ help="Take only the first N records from the dataset")
1653
1967
 
1654
1968
  # ── validate ──────────────────────────────────────────────────────────────
1655
1969
  validate_parser = subparsers.add_parser("validate",