@aws/ml-container-creator 0.13.4 → 0.13.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,6 +29,12 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
29
29
  warnings.filterwarnings("ignore", message=".*urllib3.*")
30
30
  warnings.filterwarnings("ignore", message=".*charset_normalizer.*")
31
31
 
32
+ # Suppress ALL logging to prevent sagemaker-core/rich from writing to stdout.
33
+ # This script outputs JSON on stdout — any other stdout output corrupts parsing.
34
+ import logging as _logging
35
+ _logging.disable(_logging.CRITICAL)
36
+ os.environ.setdefault("SAGEMAKER_LOG_LEVEL", "CRITICAL")
37
+
32
38
  # ── Inline dependency check ───────────────────────────────────────────────────
33
39
  MIN_SAGEMAKER_VERSION = "3.0"
34
40
 
@@ -71,6 +77,34 @@ def _output(data):
71
77
  sys.exit(0)
72
78
 
73
79
 
80
+ def _sanitize_for_json(value):
81
+ """Convert sagemaker-core Unassigned sentinel values to None for JSON serialization.
82
+
83
+ sagemaker-core uses an 'Unassigned' type instead of None for unset fields.
84
+ This function converts any non-standard types to JSON-safe values.
85
+ """
86
+ if value is None:
87
+ return None
88
+ # Check for Unassigned type from sagemaker-core
89
+ type_name = type(value).__name__
90
+ if type_name == "Unassigned" or type_name == "UnassignedValue":
91
+ return None
92
+ if isinstance(value, (str, int, float, bool)):
93
+ return value
94
+ if isinstance(value, dict):
95
+ return {k: _sanitize_for_json(v) for k, v in value.items()}
96
+ if isinstance(value, list):
97
+ return [_sanitize_for_json(v) for v in value]
98
+ # For other types, try str conversion as fallback
99
+ try:
100
+ # Check if it's JSON serializable as-is
101
+ import json as _json
102
+ _json.dumps(value)
103
+ return value
104
+ except (TypeError, ValueError):
105
+ return str(value) if value else None
106
+
107
+
74
108
  # ── Subcommand: submit ────────────────────────────────────────────────────────
75
109
 
76
110
 
@@ -171,20 +205,25 @@ def cmd_submit(args):
171
205
  trainer_kwargs["accept_eula"] = True
172
206
 
173
207
  # Resolve model package group — create if it doesn't exist
208
+ # Using sagemaker-core ModelPackageGroup.create() per SDK v3 policy
174
209
  mpg_name = args.model_package_group or f"{args.project_name}-tune-models"
175
210
  try:
176
- import boto3 as _boto3
177
- _sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
178
- _sm.describe_model_package_group(ModelPackageGroupName=mpg_name)
179
- except Exception as _mpg_err:
180
- if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
181
- try:
182
- _sm.create_model_package_group(
183
- ModelPackageGroupName=mpg_name,
184
- ModelPackageGroupDescription=f"Fine-tuned models for {args.project_name}",
185
- )
186
- except Exception:
187
- pass # May already exist or lack permissions — let the trainer handle it
211
+ from sagemaker.core.resources import ModelPackageGroup
212
+ from botocore.exceptions import ClientError as _ClientError
213
+ try:
214
+ ModelPackageGroup.get(model_package_group_name=mpg_name)
215
+ except (_ClientError, Exception) as _mpg_err:
216
+ if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
217
+ try:
218
+ ModelPackageGroup.create(
219
+ model_package_group_name=mpg_name,
220
+ model_package_group_description=f"Fine-tuned models for {args.project_name}",
221
+ )
222
+ except Exception:
223
+ pass # May already exist or lack permissions — let the trainer handle it
224
+ except ImportError:
225
+ # sagemaker-core not available — skip MPG creation, let trainer handle it
226
+ pass
188
227
  trainer_kwargs["model_package_group"] = mpg_name
189
228
 
190
229
  trainer = trainer_cls(**trainer_kwargs)
@@ -267,7 +306,9 @@ def cmd_submit(args):
267
306
  job_arn = job_arn or getattr(latest_job, 'arn', None)
268
307
 
269
308
  # If we still don't have the actual job name (SDK appends suffix),
270
- # query ListTrainingJobs to find it by our base_job_name prefix
309
+ # query ListTrainingJobs to find it by our base_job_name prefix.
310
+ # Note: list_training_jobs with NameContains filter is not available
311
+ # via sagemaker-core resource API, so boto3 is retained here.
271
312
  if not job_name or job_name == args.job_name:
272
313
  import boto3 as _boto3
273
314
  _sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
@@ -335,23 +376,22 @@ def cmd_submit(args):
335
376
 
336
377
 
337
378
  def cmd_status(args):
338
- """Query job status via DescribeTrainingJob.
379
+ """Query job status via sagemaker-core TrainingJob.get().
339
380
 
340
- Falls back to ListTrainingJobs with name-contains if exact name not found
381
+ Falls back to boto3 ListTrainingJobs with name-contains if exact name not found
341
382
  (SDK v3 appends a timestamp suffix to the base job name).
342
383
 
343
384
  Returns: {"status": str, "failure_reason": str|None,
344
385
  "metrics": dict|None, "elapsed_seconds": int}
345
386
  """
346
- import boto3
347
-
348
- client = boto3.client("sagemaker", region_name=args.region)
387
+ from sagemaker.core.resources import TrainingJob
388
+ from botocore.exceptions import ClientError
349
389
 
350
- # Try exact name first
351
- response = None
390
+ # Try exact name first via sagemaker-core
391
+ job = None
352
392
  try:
353
- response = client.describe_training_job(TrainingJobName=args.job_name)
354
- except client.exceptions.ClientError as e:
393
+ job = TrainingJob.get(training_job_name=args.job_name)
394
+ except ClientError as e:
355
395
  error_code = e.response["Error"]["Code"]
356
396
  if error_code != "ValidationException":
357
397
  _error_exit(f"Failed to describe training job: {e}")
@@ -360,8 +400,13 @@ def cmd_status(args):
360
400
  _error_exit(f"Failed to describe training job: {e}")
361
401
 
362
402
  # Fallback: search by name prefix (SDK appends timestamp suffix)
363
- if response is None:
403
+ # Note: TrainingJob.get_all() with name_contains is not available in
404
+ # sagemaker-core for list operations, so we use boto3 list_training_jobs
405
+ # to find the actual name, then call TrainingJob.get() with it.
406
+ if job is None:
364
407
  try:
408
+ import boto3
409
+ client = boto3.client("sagemaker", region_name=args.region)
365
410
  list_response = client.list_training_jobs(
366
411
  NameContains=args.job_name,
367
412
  SortBy="CreationTime",
@@ -371,18 +416,26 @@ def cmd_status(args):
371
416
  summaries = list_response.get("TrainingJobSummaries", [])
372
417
  if summaries:
373
418
  actual_name = summaries[0]["TrainingJobName"]
374
- response = client.describe_training_job(TrainingJobName=actual_name)
419
+ job = TrainingJob.get(training_job_name=actual_name)
375
420
  else:
376
421
  _error_exit(f"Training job not found: {args.job_name}")
377
422
  except Exception as e:
378
423
  _error_exit(f"Failed to find training job: {e}")
379
424
 
380
- status = response.get("TrainingJobStatus", "Unknown")
381
- failure_reason = response.get("FailureReason")
425
+ # Read status attributes directly from the TrainingJob resource object.
426
+ # sagemaker-core returns status values in the same casing as the API
427
+ # (e.g., "InProgress", "Completed", "Failed", "Stopped").
428
+ status = getattr(job, "training_job_status", "Unknown") or "Unknown"
429
+ failure_reason = getattr(job, "failure_reason", None)
382
430
 
383
431
  # Calculate elapsed time
384
- start_time = response.get("TrainingStartTime")
385
- end_time = response.get("TrainingEndTime")
432
+ start_time = getattr(job, "training_start_time", None)
433
+ end_time = getattr(job, "training_end_time", None)
434
+ # Convert Unassigned sentinel to None
435
+ if start_time and type(start_time).__name__ in ("Unassigned", "UnassignedValue"):
436
+ start_time = None
437
+ if end_time and type(end_time).__name__ in ("Unassigned", "UnassignedValue"):
438
+ end_time = None
386
439
  elapsed_seconds = 0
387
440
 
388
441
  if start_time:
@@ -393,24 +446,30 @@ def cmd_status(args):
393
446
 
394
447
  # Extract final metrics if available
395
448
  metrics = None
396
- final_metrics = response.get("FinalMetricDataList")
449
+ final_metrics = getattr(job, "final_metric_data_list", None)
450
+ if final_metrics and type(final_metrics).__name__ in ("Unassigned", "UnassignedValue"):
451
+ final_metrics = None
397
452
  if final_metrics:
398
453
  metrics = {}
399
454
  for metric in final_metrics:
400
- metrics[metric["MetricName"]] = metric["Value"]
455
+ # sagemaker-core returns metrics as objects with snake_case attributes
456
+ metric_name = getattr(metric, "metric_name", None) or metric.get("MetricName", "")
457
+ metric_value = getattr(metric, "value", None) or metric.get("Value", 0)
458
+ metrics[metric_name] = metric_value
401
459
 
402
460
  # Get output path if completed
403
461
  output_path = None
404
462
  if status == "Completed":
405
- model_artifacts = response.get("ModelArtifacts", {})
406
- output_path = model_artifacts.get("S3ModelArtifacts")
463
+ model_artifacts = getattr(job, "model_artifacts", None)
464
+ if model_artifacts:
465
+ output_path = getattr(model_artifacts, "s3_model_artifacts", None)
407
466
 
408
467
  _output({
409
- "status": status,
410
- "failure_reason": failure_reason,
411
- "metrics": metrics,
468
+ "status": _sanitize_for_json(status),
469
+ "failure_reason": _sanitize_for_json(failure_reason),
470
+ "metrics": _sanitize_for_json(metrics),
412
471
  "elapsed_seconds": elapsed_seconds,
413
- "output_path": output_path,
472
+ "output_path": _sanitize_for_json(output_path),
414
473
  })
415
474
 
416
475
 
@@ -420,28 +479,31 @@ def cmd_status(args):
420
479
  def cmd_resolve(args):
421
480
  """Resolve artifact path within S3 output directory.
422
481
 
482
+ Uses sagemaker-core TrainingJob.get() to read model_artifacts and
483
+ output_data_config. Uses ModelPackage for model package lookup.
484
+
423
485
  Returns: {"artifact_path": str, "model_package_arn": str|None,
424
486
  "output_type": str}
425
487
  """
426
- import boto3
427
-
428
- client = boto3.client("sagemaker", region_name=args.region)
488
+ from sagemaker.core.resources import TrainingJob
429
489
 
430
490
  try:
431
- response = client.describe_training_job(TrainingJobName=args.job_name)
491
+ job = TrainingJob.get(training_job_name=args.job_name)
432
492
  except Exception as e:
433
493
  _error_exit(f"Failed to describe training job: {e}")
434
494
 
435
- status = response.get("TrainingJobStatus")
495
+ status = getattr(job, "training_job_status", None)
436
496
  if status != "Completed":
437
497
  _error_exit(
438
498
  f"Cannot resolve artifacts for job in status: {status}. "
439
499
  f"Job must be Completed."
440
500
  )
441
501
 
442
- # Get the S3 model artifacts path
443
- model_artifacts = response.get("ModelArtifacts", {})
444
- artifact_path = model_artifacts.get("S3ModelArtifacts", "")
502
+ # Get the S3 model artifacts path from TrainingJob resource
503
+ model_artifacts = getattr(job, "model_artifacts", None)
504
+ artifact_path = ""
505
+ if model_artifacts:
506
+ artifact_path = getattr(model_artifacts, "s3_model_artifacts", "") or ""
445
507
 
446
508
  if not artifact_path:
447
509
  _error_exit("No model artifacts found in training job output.")
@@ -461,6 +523,9 @@ def cmd_resolve(args):
461
523
  model_package_arn = None
462
524
  if args.model_package_group:
463
525
  try:
526
+ # Use boto3 for list_model_packages since sagemaker-core ModelPackage
527
+ # doesn't have a direct list-by-group method with sort/limit
528
+ import boto3
464
529
  mp_client = boto3.client("sagemaker", region_name=args.region)
465
530
  packages = mp_client.list_model_packages(
466
531
  ModelPackageGroupName=args.model_package_group,
@@ -476,8 +541,8 @@ def cmd_resolve(args):
476
541
  pass
477
542
 
478
543
  _output({
479
- "artifact_path": artifact_path,
480
- "model_package_arn": model_package_arn,
544
+ "artifact_path": _sanitize_for_json(artifact_path),
545
+ "model_package_arn": _sanitize_for_json(model_package_arn),
481
546
  "output_type": output_type,
482
547
  })
483
548
 
@@ -765,11 +830,12 @@ def _get_schema_types(technique):
765
830
  return schemas.get(technique, {"prompt": "string", "completion": "string"})
766
831
 
767
832
 
768
- def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id):
833
+ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id, take=None):
769
834
  """Validate that the first record has required columns after mapping.
770
835
 
771
836
  Returns (mapped_record, column_map_dict) on success.
772
837
  Calls _error_exit with helpful suggestions on failure.
838
+ If take is provided, includes --take N in the suggested command.
773
839
  """
774
840
  column_map = _parse_column_map(column_map_str)
775
841
  mapped = _apply_column_map(first_record, column_map)
@@ -794,12 +860,14 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
794
860
  if suggestion:
795
861
  lines.append(f"")
796
862
  lines.append(f" 💡 Suggested fix:")
797
- lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}")
863
+ take_suffix = f" --take {take}" if take else ""
864
+ lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}{take_suffix}")
798
865
  else:
799
866
  lines.append(f"")
800
867
  lines.append(f" 💡 Use --column-map to rename columns:")
801
868
  example_map = ",".join(f"{r}=<your_column>" for r in missing)
802
- lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}")
869
+ take_suffix = f" --take {take}" if take else ""
870
+ lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}{take_suffix}")
803
871
 
804
872
  lines.append(f"")
805
873
  lines.append(f" First record sample:")
@@ -811,6 +879,16 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
811
879
  _error_exit("\n".join(lines))
812
880
 
813
881
 
882
+ def _check_empty_fields(record, required_columns):
883
+ """Return list of required column names that are empty/blank in this record."""
884
+ empty = []
885
+ for col in required_columns:
886
+ value = record.get(col, "")
887
+ if value is None or (isinstance(value, str) and not value.strip()):
888
+ empty.append(col)
889
+ return empty
890
+
891
+
814
892
  def cmd_stage_hf(args):
815
893
  """Download HF dataset to S3 using huggingface_hub.
816
894
 
@@ -854,21 +932,35 @@ def cmd_stage_hf(args):
854
932
 
855
933
  # Find the appropriate data file for the split
856
934
  data_files = _find_data_files(repo_files, split)
935
+
936
+ # Apply file filter if --hf-file is provided
937
+ hf_file_pattern = getattr(args, 'hf_file', None)
938
+
939
+ if not data_files and hf_file_pattern:
940
+ # Split-based lookup found nothing, but user specified a file filter.
941
+ # Fall back to filtering directly from all data files in the repo.
942
+ all_data_files = [
943
+ f for f in repo_files
944
+ if f.endswith(('.parquet', '.jsonl', '.json'))
945
+ and not f.startswith('.')
946
+ ]
947
+ if all_data_files:
948
+ data_files = _filter_data_files(all_data_files, hf_file_pattern)
949
+ elif hf_file_pattern and data_files:
950
+ # Normal case: apply file filter to split-matched results
951
+ data_files = _filter_data_files(data_files, hf_file_pattern)
952
+
857
953
  if not data_files:
858
954
  _error_exit(
859
955
  f"No data files found for split '{split}' in dataset {dataset_id}. "
860
956
  f"Available files: {', '.join(repo_files[:20])}"
861
957
  )
862
958
 
863
- # Apply file filter if --hf-file is provided
864
- hf_file_pattern = getattr(args, 'hf_file', None)
865
- if hf_file_pattern:
866
- data_files = _filter_data_files(data_files, hf_file_pattern)
867
-
868
959
  # Download and upload to S3
869
960
  s3_client = boto3.client("s3", region_name=args.region)
870
961
  s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
871
962
  num_records = 0
963
+ empty_field_counts = {} # Track empty required fields: {field_name: count}
872
964
 
873
965
  with tempfile.TemporaryDirectory() as tmpdir:
874
966
  # Schema divergence check (skip for single file)
@@ -907,7 +999,7 @@ def cmd_stage_hf(args):
907
999
  no_transform = getattr(args, 'no_transform', False)
908
1000
  batches = table.to_batches(max_chunksize=1)
909
1001
  first_record = batches[0].to_pylist()[0] if batches else {}
910
- _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
1002
+ _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
911
1003
 
912
1004
  # Apply column map to first record for detection
913
1005
  mapped_first = _apply_column_map(first_record, column_map)
@@ -944,16 +1036,31 @@ def cmd_stage_hf(args):
944
1036
  f" Detected format: {strategy_desc}"
945
1037
  )
946
1038
 
1039
+ take_limit = getattr(args, 'take', None)
947
1040
  with open(jsonl_path, "w", encoding="utf-8") as out_f:
948
1041
  for batch in table.to_batches():
949
1042
  for row in batch.to_pylist():
1043
+ if take_limit and num_records >= take_limit:
1044
+ break
950
1045
  mapped_row = _apply_column_map(row, column_map)
951
1046
  if chat_columns and not no_transform:
952
1047
  mapped_row = _flatten_record(mapped_row, chat_columns)
1048
+ # Track empty required fields
1049
+ for col in _check_empty_fields(mapped_row, required_columns):
1050
+ empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
953
1051
  out_f.write(json_mod.dumps(mapped_row, ensure_ascii=False) + "\n")
954
1052
  num_records += 1
1053
+ if take_limit and num_records >= take_limit:
1054
+ break
955
1055
 
956
1056
  # Upload converted JSONL
1057
+ # Verify file has content before uploading
1058
+ file_size = os.path.getsize(jsonl_path)
1059
+ if file_size == 0:
1060
+ _error_exit(
1061
+ f"Converted JSONL file is empty (0 bytes) after processing "
1062
+ f"{num_records} records. This is a bug — please report it."
1063
+ )
957
1064
  s3_key = f"{s3_prefix}/{jsonl_filename}"
958
1065
  s3_client.upload_file(jsonl_path, args.output_bucket, s3_key)
959
1066
 
@@ -975,7 +1082,7 @@ def cmd_stage_hf(args):
975
1082
  first_line = f.readline().strip()
976
1083
  if first_line:
977
1084
  first_record = json_mod.loads(first_line)
978
- _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
1085
+ _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
979
1086
 
980
1087
  # Apply column map to first record for detection
981
1088
  mapped_first = _apply_column_map(first_record, column_map)
@@ -1014,11 +1121,14 @@ def cmd_stage_hf(args):
1014
1121
 
1015
1122
  # Rewrite the file with mapped (and optionally flattened) columns
1016
1123
  should_flatten = bool(chat_columns) and not no_transform
1017
- if column_map or should_flatten:
1124
+ take_limit = getattr(args, 'take', None)
1125
+ if column_map or should_flatten or take_limit:
1018
1126
  mapped_path = local_path + ".mapped"
1019
1127
  with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
1020
1128
  open(mapped_path, "w", encoding="utf-8") as f_out:
1021
1129
  for line in f_in:
1130
+ if take_limit and num_records >= take_limit:
1131
+ break
1022
1132
  line = line.strip()
1023
1133
  if not line:
1024
1134
  continue
@@ -1026,15 +1136,32 @@ def cmd_stage_hf(args):
1026
1136
  mapped_record = _apply_column_map(record, column_map)
1027
1137
  if should_flatten:
1028
1138
  mapped_record = _flatten_record(mapped_record, chat_columns)
1139
+ # Track empty required fields
1140
+ for col in _check_empty_fields(mapped_record, _get_required_columns(technique)):
1141
+ empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
1029
1142
  f_out.write(json_mod.dumps(mapped_record, ensure_ascii=False) + "\n")
1030
1143
  num_records += 1
1031
1144
  local_path = mapped_path
1032
1145
  else:
1033
- # Count records
1034
- with open(local_path, "r", encoding="utf-8", errors="replace") as f:
1035
- for line in f:
1036
- if line.strip():
1037
- num_records += 1
1146
+ # Count records (and truncate if --take specified)
1147
+ take_limit = getattr(args, 'take', None)
1148
+ if take_limit:
1149
+ # Need to rewrite the file truncated
1150
+ mapped_path = local_path + ".mapped"
1151
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
1152
+ open(mapped_path, "w", encoding="utf-8") as f_out:
1153
+ for line in f_in:
1154
+ if num_records >= take_limit:
1155
+ break
1156
+ if line.strip():
1157
+ f_out.write(line)
1158
+ num_records += 1
1159
+ local_path = mapped_path
1160
+ else:
1161
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f:
1162
+ for line in f:
1163
+ if line.strip():
1164
+ num_records += 1
1038
1165
 
1039
1166
  # Upload to S3
1040
1167
  s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
@@ -1048,6 +1175,19 @@ def cmd_stage_hf(args):
1048
1175
  output_filename = os.path.basename(first_file)
1049
1176
  s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{output_filename}"
1050
1177
 
1178
+ # Warn if required columns have many empty values
1179
+ if num_records > 0 and empty_field_counts:
1180
+ for field, count in empty_field_counts.items():
1181
+ pct = (count / num_records) * 100
1182
+ if pct > 30:
1183
+ print(
1184
+ f"\u26a0\ufe0f Warning: {pct:.0f}% of records ({count}/{num_records}) "
1185
+ f"have empty '{field}' after column mapping.\n"
1186
+ f" SageMaker may reject these as invalid samples.\n"
1187
+ f" Consider using a different --column-map or dataset.",
1188
+ file=sys.stderr,
1189
+ )
1190
+
1051
1191
  _output({
1052
1192
  "s3_uri": s3_uri,
1053
1193
  "num_records": num_records,
@@ -1124,12 +1264,12 @@ def _find_data_files(repo_files, split):
1124
1264
  if pattern in repo_files:
1125
1265
  return [pattern]
1126
1266
 
1127
- # Prefix match for sharded files
1128
- matches = []
1267
+ # Prefix match for sharded files (deduplicate via set)
1268
+ matches = set()
1129
1269
  for f in repo_files:
1130
1270
  for pattern in patterns[4:]:
1131
1271
  if pattern in f:
1132
- matches.append(f)
1272
+ matches.add(f)
1133
1273
 
1134
1274
  if matches:
1135
1275
  return sorted(matches)
@@ -1508,6 +1648,10 @@ def _build_expected_format(schema):
1508
1648
  def cmd_discover(args):
1509
1649
  """Query JumpStart Hub for tune-eligible models matching a family.
1510
1650
 
1651
+ NOTE: This subcommand intentionally stays on boto3.client('sagemaker')
1652
+ because list_hub_contents / Hub API is NOT available in sagemaker-core.
1653
+ This is a documented exception per the SDK v3 migration policy.
1654
+
1511
1655
  Returns: {"models": [str], "count": int}
1512
1656
  """
1513
1657
  region = args.region or os.environ.get('AWS_REGION', 'us-east-1')
@@ -1532,6 +1676,8 @@ def cmd_discover(args):
1532
1676
  _error_exit("Hub discovery failed: boto3 is not installed. Install with: pip install boto3")
1533
1677
 
1534
1678
  try:
1679
+ # Documented exception: Hub API (list_hub_contents) is not available in
1680
+ # sagemaker-core, so we retain boto3.client('sagemaker') here.
1535
1681
  client = boto3.client("sagemaker", region_name=region)
1536
1682
  models = []
1537
1683
  paginator = client.get_paginator('list_hub_contents')
@@ -1650,6 +1796,8 @@ def main():
1650
1796
  help="Customization technique (determines required columns)")
1651
1797
  stage_hf_parser.add_argument("--no-transform", action="store_true", default=False,
1652
1798
  help="Disable automatic chat-format flattening")
1799
+ stage_hf_parser.add_argument("--take", type=int, default=None,
1800
+ help="Take only the first N records from the dataset")
1653
1801
 
1654
1802
  # ── validate ──────────────────────────────────────────────────────────────
1655
1803
  validate_parser = subparsers.add_parser("validate",