@aws/ml-container-creator 0.13.4 → 0.13.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -5
- package/infra/ci-harness/package-lock.json +1 -5
- package/package.json +4 -2
- package/pyproject.toml +21 -0
- package/requirements.txt +19 -0
- package/src/app.js +2 -0
- package/src/lib/bootstrap-command-handler.js +33 -23
- package/templates/do/.adapter_helper.py +451 -0
- package/templates/do/.benchmark_writer.py +13 -0
- package/templates/do/.stage_helper.py +419 -0
- package/templates/do/.tune_helper.py +213 -65
- package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +108 -0
- package/templates/do/benchmark +150 -12
- package/templates/do/config +4 -0
- package/templates/do/lib/profile.sh +5 -0
- package/templates/do/stage +91 -272
- package/templates/do/tune +63 -6
|
@@ -29,6 +29,12 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
29
29
|
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
30
30
|
warnings.filterwarnings("ignore", message=".*charset_normalizer.*")
|
|
31
31
|
|
|
32
|
+
# Suppress ALL logging to prevent sagemaker-core/rich from writing to stdout.
|
|
33
|
+
# This script outputs JSON on stdout — any other stdout output corrupts parsing.
|
|
34
|
+
import logging as _logging
|
|
35
|
+
_logging.disable(_logging.CRITICAL)
|
|
36
|
+
os.environ.setdefault("SAGEMAKER_LOG_LEVEL", "CRITICAL")
|
|
37
|
+
|
|
32
38
|
# ── Inline dependency check ───────────────────────────────────────────────────
|
|
33
39
|
MIN_SAGEMAKER_VERSION = "3.0"
|
|
34
40
|
|
|
@@ -71,6 +77,34 @@ def _output(data):
|
|
|
71
77
|
sys.exit(0)
|
|
72
78
|
|
|
73
79
|
|
|
80
|
+
def _sanitize_for_json(value):
|
|
81
|
+
"""Convert sagemaker-core Unassigned sentinel values to None for JSON serialization.
|
|
82
|
+
|
|
83
|
+
sagemaker-core uses an 'Unassigned' type instead of None for unset fields.
|
|
84
|
+
This function converts any non-standard types to JSON-safe values.
|
|
85
|
+
"""
|
|
86
|
+
if value is None:
|
|
87
|
+
return None
|
|
88
|
+
# Check for Unassigned type from sagemaker-core
|
|
89
|
+
type_name = type(value).__name__
|
|
90
|
+
if type_name == "Unassigned" or type_name == "UnassignedValue":
|
|
91
|
+
return None
|
|
92
|
+
if isinstance(value, (str, int, float, bool)):
|
|
93
|
+
return value
|
|
94
|
+
if isinstance(value, dict):
|
|
95
|
+
return {k: _sanitize_for_json(v) for k, v in value.items()}
|
|
96
|
+
if isinstance(value, list):
|
|
97
|
+
return [_sanitize_for_json(v) for v in value]
|
|
98
|
+
# For other types, try str conversion as fallback
|
|
99
|
+
try:
|
|
100
|
+
# Check if it's JSON serializable as-is
|
|
101
|
+
import json as _json
|
|
102
|
+
_json.dumps(value)
|
|
103
|
+
return value
|
|
104
|
+
except (TypeError, ValueError):
|
|
105
|
+
return str(value) if value else None
|
|
106
|
+
|
|
107
|
+
|
|
74
108
|
# ── Subcommand: submit ────────────────────────────────────────────────────────
|
|
75
109
|
|
|
76
110
|
|
|
@@ -171,20 +205,25 @@ def cmd_submit(args):
|
|
|
171
205
|
trainer_kwargs["accept_eula"] = True
|
|
172
206
|
|
|
173
207
|
# Resolve model package group — create if it doesn't exist
|
|
208
|
+
# Using sagemaker-core ModelPackageGroup.create() per SDK v3 policy
|
|
174
209
|
mpg_name = args.model_package_group or f"{args.project_name}-tune-models"
|
|
175
210
|
try:
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
211
|
+
from sagemaker.core.resources import ModelPackageGroup
|
|
212
|
+
from botocore.exceptions import ClientError as _ClientError
|
|
213
|
+
try:
|
|
214
|
+
ModelPackageGroup.get(model_package_group_name=mpg_name)
|
|
215
|
+
except (_ClientError, Exception) as _mpg_err:
|
|
216
|
+
if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
|
|
217
|
+
try:
|
|
218
|
+
ModelPackageGroup.create(
|
|
219
|
+
model_package_group_name=mpg_name,
|
|
220
|
+
model_package_group_description=f"Fine-tuned models for {args.project_name}",
|
|
221
|
+
)
|
|
222
|
+
except Exception:
|
|
223
|
+
pass # May already exist or lack permissions — let the trainer handle it
|
|
224
|
+
except ImportError:
|
|
225
|
+
# sagemaker-core not available — skip MPG creation, let trainer handle it
|
|
226
|
+
pass
|
|
188
227
|
trainer_kwargs["model_package_group"] = mpg_name
|
|
189
228
|
|
|
190
229
|
trainer = trainer_cls(**trainer_kwargs)
|
|
@@ -267,7 +306,9 @@ def cmd_submit(args):
|
|
|
267
306
|
job_arn = job_arn or getattr(latest_job, 'arn', None)
|
|
268
307
|
|
|
269
308
|
# If we still don't have the actual job name (SDK appends suffix),
|
|
270
|
-
# query ListTrainingJobs to find it by our base_job_name prefix
|
|
309
|
+
# query ListTrainingJobs to find it by our base_job_name prefix.
|
|
310
|
+
# Note: list_training_jobs with NameContains filter is not available
|
|
311
|
+
# via sagemaker-core resource API, so boto3 is retained here.
|
|
271
312
|
if not job_name or job_name == args.job_name:
|
|
272
313
|
import boto3 as _boto3
|
|
273
314
|
_sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
|
|
@@ -335,23 +376,22 @@ def cmd_submit(args):
|
|
|
335
376
|
|
|
336
377
|
|
|
337
378
|
def cmd_status(args):
|
|
338
|
-
"""Query job status via
|
|
379
|
+
"""Query job status via sagemaker-core TrainingJob.get().
|
|
339
380
|
|
|
340
|
-
Falls back to ListTrainingJobs with name-contains if exact name not found
|
|
381
|
+
Falls back to boto3 ListTrainingJobs with name-contains if exact name not found
|
|
341
382
|
(SDK v3 appends a timestamp suffix to the base job name).
|
|
342
383
|
|
|
343
384
|
Returns: {"status": str, "failure_reason": str|None,
|
|
344
385
|
"metrics": dict|None, "elapsed_seconds": int}
|
|
345
386
|
"""
|
|
346
|
-
import
|
|
347
|
-
|
|
348
|
-
client = boto3.client("sagemaker", region_name=args.region)
|
|
387
|
+
from sagemaker.core.resources import TrainingJob
|
|
388
|
+
from botocore.exceptions import ClientError
|
|
349
389
|
|
|
350
|
-
# Try exact name first
|
|
351
|
-
|
|
390
|
+
# Try exact name first via sagemaker-core
|
|
391
|
+
job = None
|
|
352
392
|
try:
|
|
353
|
-
|
|
354
|
-
except
|
|
393
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
394
|
+
except ClientError as e:
|
|
355
395
|
error_code = e.response["Error"]["Code"]
|
|
356
396
|
if error_code != "ValidationException":
|
|
357
397
|
_error_exit(f"Failed to describe training job: {e}")
|
|
@@ -360,8 +400,13 @@ def cmd_status(args):
|
|
|
360
400
|
_error_exit(f"Failed to describe training job: {e}")
|
|
361
401
|
|
|
362
402
|
# Fallback: search by name prefix (SDK appends timestamp suffix)
|
|
363
|
-
|
|
403
|
+
# Note: TrainingJob.get_all() with name_contains is not available in
|
|
404
|
+
# sagemaker-core for list operations, so we use boto3 list_training_jobs
|
|
405
|
+
# to find the actual name, then call TrainingJob.get() with it.
|
|
406
|
+
if job is None:
|
|
364
407
|
try:
|
|
408
|
+
import boto3
|
|
409
|
+
client = boto3.client("sagemaker", region_name=args.region)
|
|
365
410
|
list_response = client.list_training_jobs(
|
|
366
411
|
NameContains=args.job_name,
|
|
367
412
|
SortBy="CreationTime",
|
|
@@ -371,18 +416,26 @@ def cmd_status(args):
|
|
|
371
416
|
summaries = list_response.get("TrainingJobSummaries", [])
|
|
372
417
|
if summaries:
|
|
373
418
|
actual_name = summaries[0]["TrainingJobName"]
|
|
374
|
-
|
|
419
|
+
job = TrainingJob.get(training_job_name=actual_name)
|
|
375
420
|
else:
|
|
376
421
|
_error_exit(f"Training job not found: {args.job_name}")
|
|
377
422
|
except Exception as e:
|
|
378
423
|
_error_exit(f"Failed to find training job: {e}")
|
|
379
424
|
|
|
380
|
-
status
|
|
381
|
-
|
|
425
|
+
# Read status attributes directly from the TrainingJob resource object.
|
|
426
|
+
# sagemaker-core returns status values in the same casing as the API
|
|
427
|
+
# (e.g., "InProgress", "Completed", "Failed", "Stopped").
|
|
428
|
+
status = getattr(job, "training_job_status", "Unknown") or "Unknown"
|
|
429
|
+
failure_reason = getattr(job, "failure_reason", None)
|
|
382
430
|
|
|
383
431
|
# Calculate elapsed time
|
|
384
|
-
start_time =
|
|
385
|
-
end_time =
|
|
432
|
+
start_time = getattr(job, "training_start_time", None)
|
|
433
|
+
end_time = getattr(job, "training_end_time", None)
|
|
434
|
+
# Convert Unassigned sentinel to None
|
|
435
|
+
if start_time and type(start_time).__name__ in ("Unassigned", "UnassignedValue"):
|
|
436
|
+
start_time = None
|
|
437
|
+
if end_time and type(end_time).__name__ in ("Unassigned", "UnassignedValue"):
|
|
438
|
+
end_time = None
|
|
386
439
|
elapsed_seconds = 0
|
|
387
440
|
|
|
388
441
|
if start_time:
|
|
@@ -393,24 +446,30 @@ def cmd_status(args):
|
|
|
393
446
|
|
|
394
447
|
# Extract final metrics if available
|
|
395
448
|
metrics = None
|
|
396
|
-
final_metrics =
|
|
449
|
+
final_metrics = getattr(job, "final_metric_data_list", None)
|
|
450
|
+
if final_metrics and type(final_metrics).__name__ in ("Unassigned", "UnassignedValue"):
|
|
451
|
+
final_metrics = None
|
|
397
452
|
if final_metrics:
|
|
398
453
|
metrics = {}
|
|
399
454
|
for metric in final_metrics:
|
|
400
|
-
metrics
|
|
455
|
+
# sagemaker-core returns metrics as objects with snake_case attributes
|
|
456
|
+
metric_name = getattr(metric, "metric_name", None) or metric.get("MetricName", "")
|
|
457
|
+
metric_value = getattr(metric, "value", None) or metric.get("Value", 0)
|
|
458
|
+
metrics[metric_name] = metric_value
|
|
401
459
|
|
|
402
460
|
# Get output path if completed
|
|
403
461
|
output_path = None
|
|
404
462
|
if status == "Completed":
|
|
405
|
-
model_artifacts =
|
|
406
|
-
|
|
463
|
+
model_artifacts = getattr(job, "model_artifacts", None)
|
|
464
|
+
if model_artifacts:
|
|
465
|
+
output_path = getattr(model_artifacts, "s3_model_artifacts", None)
|
|
407
466
|
|
|
408
467
|
_output({
|
|
409
|
-
"status": status,
|
|
410
|
-
"failure_reason": failure_reason,
|
|
411
|
-
"metrics": metrics,
|
|
468
|
+
"status": _sanitize_for_json(status),
|
|
469
|
+
"failure_reason": _sanitize_for_json(failure_reason),
|
|
470
|
+
"metrics": _sanitize_for_json(metrics),
|
|
412
471
|
"elapsed_seconds": elapsed_seconds,
|
|
413
|
-
"output_path": output_path,
|
|
472
|
+
"output_path": _sanitize_for_json(output_path),
|
|
414
473
|
})
|
|
415
474
|
|
|
416
475
|
|
|
@@ -420,28 +479,31 @@ def cmd_status(args):
|
|
|
420
479
|
def cmd_resolve(args):
|
|
421
480
|
"""Resolve artifact path within S3 output directory.
|
|
422
481
|
|
|
482
|
+
Uses sagemaker-core TrainingJob.get() to read model_artifacts and
|
|
483
|
+
output_data_config. Uses ModelPackage for model package lookup.
|
|
484
|
+
|
|
423
485
|
Returns: {"artifact_path": str, "model_package_arn": str|None,
|
|
424
486
|
"output_type": str}
|
|
425
487
|
"""
|
|
426
|
-
import
|
|
427
|
-
|
|
428
|
-
client = boto3.client("sagemaker", region_name=args.region)
|
|
488
|
+
from sagemaker.core.resources import TrainingJob
|
|
429
489
|
|
|
430
490
|
try:
|
|
431
|
-
|
|
491
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
432
492
|
except Exception as e:
|
|
433
493
|
_error_exit(f"Failed to describe training job: {e}")
|
|
434
494
|
|
|
435
|
-
status =
|
|
495
|
+
status = getattr(job, "training_job_status", None)
|
|
436
496
|
if status != "Completed":
|
|
437
497
|
_error_exit(
|
|
438
498
|
f"Cannot resolve artifacts for job in status: {status}. "
|
|
439
499
|
f"Job must be Completed."
|
|
440
500
|
)
|
|
441
501
|
|
|
442
|
-
# Get the S3 model artifacts path
|
|
443
|
-
model_artifacts =
|
|
444
|
-
artifact_path =
|
|
502
|
+
# Get the S3 model artifacts path from TrainingJob resource
|
|
503
|
+
model_artifacts = getattr(job, "model_artifacts", None)
|
|
504
|
+
artifact_path = ""
|
|
505
|
+
if model_artifacts:
|
|
506
|
+
artifact_path = getattr(model_artifacts, "s3_model_artifacts", "") or ""
|
|
445
507
|
|
|
446
508
|
if not artifact_path:
|
|
447
509
|
_error_exit("No model artifacts found in training job output.")
|
|
@@ -461,6 +523,9 @@ def cmd_resolve(args):
|
|
|
461
523
|
model_package_arn = None
|
|
462
524
|
if args.model_package_group:
|
|
463
525
|
try:
|
|
526
|
+
# Use boto3 for list_model_packages since sagemaker-core ModelPackage
|
|
527
|
+
# doesn't have a direct list-by-group method with sort/limit
|
|
528
|
+
import boto3
|
|
464
529
|
mp_client = boto3.client("sagemaker", region_name=args.region)
|
|
465
530
|
packages = mp_client.list_model_packages(
|
|
466
531
|
ModelPackageGroupName=args.model_package_group,
|
|
@@ -476,8 +541,8 @@ def cmd_resolve(args):
|
|
|
476
541
|
pass
|
|
477
542
|
|
|
478
543
|
_output({
|
|
479
|
-
"artifact_path": artifact_path,
|
|
480
|
-
"model_package_arn": model_package_arn,
|
|
544
|
+
"artifact_path": _sanitize_for_json(artifact_path),
|
|
545
|
+
"model_package_arn": _sanitize_for_json(model_package_arn),
|
|
481
546
|
"output_type": output_type,
|
|
482
547
|
})
|
|
483
548
|
|
|
@@ -765,11 +830,12 @@ def _get_schema_types(technique):
|
|
|
765
830
|
return schemas.get(technique, {"prompt": "string", "completion": "string"})
|
|
766
831
|
|
|
767
832
|
|
|
768
|
-
def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id):
|
|
833
|
+
def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id, take=None):
|
|
769
834
|
"""Validate that the first record has required columns after mapping.
|
|
770
835
|
|
|
771
836
|
Returns (mapped_record, column_map_dict) on success.
|
|
772
837
|
Calls _error_exit with helpful suggestions on failure.
|
|
838
|
+
If take is provided, includes --take N in the suggested command.
|
|
773
839
|
"""
|
|
774
840
|
column_map = _parse_column_map(column_map_str)
|
|
775
841
|
mapped = _apply_column_map(first_record, column_map)
|
|
@@ -794,12 +860,14 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
|
|
|
794
860
|
if suggestion:
|
|
795
861
|
lines.append(f"")
|
|
796
862
|
lines.append(f" 💡 Suggested fix:")
|
|
797
|
-
|
|
863
|
+
take_suffix = f" --take {take}" if take else ""
|
|
864
|
+
lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}{take_suffix}")
|
|
798
865
|
else:
|
|
799
866
|
lines.append(f"")
|
|
800
867
|
lines.append(f" 💡 Use --column-map to rename columns:")
|
|
801
868
|
example_map = ",".join(f"{r}=<your_column>" for r in missing)
|
|
802
|
-
|
|
869
|
+
take_suffix = f" --take {take}" if take else ""
|
|
870
|
+
lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}{take_suffix}")
|
|
803
871
|
|
|
804
872
|
lines.append(f"")
|
|
805
873
|
lines.append(f" First record sample:")
|
|
@@ -811,6 +879,16 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
|
|
|
811
879
|
_error_exit("\n".join(lines))
|
|
812
880
|
|
|
813
881
|
|
|
882
|
+
def _check_empty_fields(record, required_columns):
|
|
883
|
+
"""Return list of required column names that are empty/blank in this record."""
|
|
884
|
+
empty = []
|
|
885
|
+
for col in required_columns:
|
|
886
|
+
value = record.get(col, "")
|
|
887
|
+
if value is None or (isinstance(value, str) and not value.strip()):
|
|
888
|
+
empty.append(col)
|
|
889
|
+
return empty
|
|
890
|
+
|
|
891
|
+
|
|
814
892
|
def cmd_stage_hf(args):
|
|
815
893
|
"""Download HF dataset to S3 using huggingface_hub.
|
|
816
894
|
|
|
@@ -854,21 +932,35 @@ def cmd_stage_hf(args):
|
|
|
854
932
|
|
|
855
933
|
# Find the appropriate data file for the split
|
|
856
934
|
data_files = _find_data_files(repo_files, split)
|
|
935
|
+
|
|
936
|
+
# Apply file filter if --hf-file is provided
|
|
937
|
+
hf_file_pattern = getattr(args, 'hf_file', None)
|
|
938
|
+
|
|
939
|
+
if not data_files and hf_file_pattern:
|
|
940
|
+
# Split-based lookup found nothing, but user specified a file filter.
|
|
941
|
+
# Fall back to filtering directly from all data files in the repo.
|
|
942
|
+
all_data_files = [
|
|
943
|
+
f for f in repo_files
|
|
944
|
+
if f.endswith(('.parquet', '.jsonl', '.json'))
|
|
945
|
+
and not f.startswith('.')
|
|
946
|
+
]
|
|
947
|
+
if all_data_files:
|
|
948
|
+
data_files = _filter_data_files(all_data_files, hf_file_pattern)
|
|
949
|
+
elif hf_file_pattern and data_files:
|
|
950
|
+
# Normal case: apply file filter to split-matched results
|
|
951
|
+
data_files = _filter_data_files(data_files, hf_file_pattern)
|
|
952
|
+
|
|
857
953
|
if not data_files:
|
|
858
954
|
_error_exit(
|
|
859
955
|
f"No data files found for split '{split}' in dataset {dataset_id}. "
|
|
860
956
|
f"Available files: {', '.join(repo_files[:20])}"
|
|
861
957
|
)
|
|
862
958
|
|
|
863
|
-
# Apply file filter if --hf-file is provided
|
|
864
|
-
hf_file_pattern = getattr(args, 'hf_file', None)
|
|
865
|
-
if hf_file_pattern:
|
|
866
|
-
data_files = _filter_data_files(data_files, hf_file_pattern)
|
|
867
|
-
|
|
868
959
|
# Download and upload to S3
|
|
869
960
|
s3_client = boto3.client("s3", region_name=args.region)
|
|
870
961
|
s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
|
|
871
962
|
num_records = 0
|
|
963
|
+
empty_field_counts = {} # Track empty required fields: {field_name: count}
|
|
872
964
|
|
|
873
965
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
874
966
|
# Schema divergence check (skip for single file)
|
|
@@ -907,7 +999,7 @@ def cmd_stage_hf(args):
|
|
|
907
999
|
no_transform = getattr(args, 'no_transform', False)
|
|
908
1000
|
batches = table.to_batches(max_chunksize=1)
|
|
909
1001
|
first_record = batches[0].to_pylist()[0] if batches else {}
|
|
910
|
-
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
|
|
1002
|
+
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
|
|
911
1003
|
|
|
912
1004
|
# Apply column map to first record for detection
|
|
913
1005
|
mapped_first = _apply_column_map(first_record, column_map)
|
|
@@ -944,16 +1036,31 @@ def cmd_stage_hf(args):
|
|
|
944
1036
|
f" Detected format: {strategy_desc}"
|
|
945
1037
|
)
|
|
946
1038
|
|
|
1039
|
+
take_limit = getattr(args, 'take', None)
|
|
947
1040
|
with open(jsonl_path, "w", encoding="utf-8") as out_f:
|
|
948
1041
|
for batch in table.to_batches():
|
|
949
1042
|
for row in batch.to_pylist():
|
|
1043
|
+
if take_limit and num_records >= take_limit:
|
|
1044
|
+
break
|
|
950
1045
|
mapped_row = _apply_column_map(row, column_map)
|
|
951
1046
|
if chat_columns and not no_transform:
|
|
952
1047
|
mapped_row = _flatten_record(mapped_row, chat_columns)
|
|
1048
|
+
# Track empty required fields
|
|
1049
|
+
for col in _check_empty_fields(mapped_row, required_columns):
|
|
1050
|
+
empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
|
|
953
1051
|
out_f.write(json_mod.dumps(mapped_row, ensure_ascii=False) + "\n")
|
|
954
1052
|
num_records += 1
|
|
1053
|
+
if take_limit and num_records >= take_limit:
|
|
1054
|
+
break
|
|
955
1055
|
|
|
956
1056
|
# Upload converted JSONL
|
|
1057
|
+
# Verify file has content before uploading
|
|
1058
|
+
file_size = os.path.getsize(jsonl_path)
|
|
1059
|
+
if file_size == 0:
|
|
1060
|
+
_error_exit(
|
|
1061
|
+
f"Converted JSONL file is empty (0 bytes) after processing "
|
|
1062
|
+
f"{num_records} records. This is a bug — please report it."
|
|
1063
|
+
)
|
|
957
1064
|
s3_key = f"{s3_prefix}/{jsonl_filename}"
|
|
958
1065
|
s3_client.upload_file(jsonl_path, args.output_bucket, s3_key)
|
|
959
1066
|
|
|
@@ -975,7 +1082,7 @@ def cmd_stage_hf(args):
|
|
|
975
1082
|
first_line = f.readline().strip()
|
|
976
1083
|
if first_line:
|
|
977
1084
|
first_record = json_mod.loads(first_line)
|
|
978
|
-
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
|
|
1085
|
+
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
|
|
979
1086
|
|
|
980
1087
|
# Apply column map to first record for detection
|
|
981
1088
|
mapped_first = _apply_column_map(first_record, column_map)
|
|
@@ -1014,11 +1121,14 @@ def cmd_stage_hf(args):
|
|
|
1014
1121
|
|
|
1015
1122
|
# Rewrite the file with mapped (and optionally flattened) columns
|
|
1016
1123
|
should_flatten = bool(chat_columns) and not no_transform
|
|
1017
|
-
|
|
1124
|
+
take_limit = getattr(args, 'take', None)
|
|
1125
|
+
if column_map or should_flatten or take_limit:
|
|
1018
1126
|
mapped_path = local_path + ".mapped"
|
|
1019
1127
|
with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
|
|
1020
1128
|
open(mapped_path, "w", encoding="utf-8") as f_out:
|
|
1021
1129
|
for line in f_in:
|
|
1130
|
+
if take_limit and num_records >= take_limit:
|
|
1131
|
+
break
|
|
1022
1132
|
line = line.strip()
|
|
1023
1133
|
if not line:
|
|
1024
1134
|
continue
|
|
@@ -1026,15 +1136,32 @@ def cmd_stage_hf(args):
|
|
|
1026
1136
|
mapped_record = _apply_column_map(record, column_map)
|
|
1027
1137
|
if should_flatten:
|
|
1028
1138
|
mapped_record = _flatten_record(mapped_record, chat_columns)
|
|
1139
|
+
# Track empty required fields
|
|
1140
|
+
for col in _check_empty_fields(mapped_record, _get_required_columns(technique)):
|
|
1141
|
+
empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
|
|
1029
1142
|
f_out.write(json_mod.dumps(mapped_record, ensure_ascii=False) + "\n")
|
|
1030
1143
|
num_records += 1
|
|
1031
1144
|
local_path = mapped_path
|
|
1032
1145
|
else:
|
|
1033
|
-
# Count records
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1146
|
+
# Count records (and truncate if --take specified)
|
|
1147
|
+
take_limit = getattr(args, 'take', None)
|
|
1148
|
+
if take_limit:
|
|
1149
|
+
# Need to rewrite the file truncated
|
|
1150
|
+
mapped_path = local_path + ".mapped"
|
|
1151
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
|
|
1152
|
+
open(mapped_path, "w", encoding="utf-8") as f_out:
|
|
1153
|
+
for line in f_in:
|
|
1154
|
+
if num_records >= take_limit:
|
|
1155
|
+
break
|
|
1156
|
+
if line.strip():
|
|
1157
|
+
f_out.write(line)
|
|
1158
|
+
num_records += 1
|
|
1159
|
+
local_path = mapped_path
|
|
1160
|
+
else:
|
|
1161
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
|
|
1162
|
+
for line in f:
|
|
1163
|
+
if line.strip():
|
|
1164
|
+
num_records += 1
|
|
1038
1165
|
|
|
1039
1166
|
# Upload to S3
|
|
1040
1167
|
s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
|
|
@@ -1048,6 +1175,19 @@ def cmd_stage_hf(args):
|
|
|
1048
1175
|
output_filename = os.path.basename(first_file)
|
|
1049
1176
|
s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{output_filename}"
|
|
1050
1177
|
|
|
1178
|
+
# Warn if required columns have many empty values
|
|
1179
|
+
if num_records > 0 and empty_field_counts:
|
|
1180
|
+
for field, count in empty_field_counts.items():
|
|
1181
|
+
pct = (count / num_records) * 100
|
|
1182
|
+
if pct > 30:
|
|
1183
|
+
print(
|
|
1184
|
+
f"\u26a0\ufe0f Warning: {pct:.0f}% of records ({count}/{num_records}) "
|
|
1185
|
+
f"have empty '{field}' after column mapping.\n"
|
|
1186
|
+
f" SageMaker may reject these as invalid samples.\n"
|
|
1187
|
+
f" Consider using a different --column-map or dataset.",
|
|
1188
|
+
file=sys.stderr,
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1051
1191
|
_output({
|
|
1052
1192
|
"s3_uri": s3_uri,
|
|
1053
1193
|
"num_records": num_records,
|
|
@@ -1124,12 +1264,12 @@ def _find_data_files(repo_files, split):
|
|
|
1124
1264
|
if pattern in repo_files:
|
|
1125
1265
|
return [pattern]
|
|
1126
1266
|
|
|
1127
|
-
# Prefix match for sharded files
|
|
1128
|
-
matches =
|
|
1267
|
+
# Prefix match for sharded files (deduplicate via set)
|
|
1268
|
+
matches = set()
|
|
1129
1269
|
for f in repo_files:
|
|
1130
1270
|
for pattern in patterns[4:]:
|
|
1131
1271
|
if pattern in f:
|
|
1132
|
-
matches.
|
|
1272
|
+
matches.add(f)
|
|
1133
1273
|
|
|
1134
1274
|
if matches:
|
|
1135
1275
|
return sorted(matches)
|
|
@@ -1508,6 +1648,10 @@ def _build_expected_format(schema):
|
|
|
1508
1648
|
def cmd_discover(args):
|
|
1509
1649
|
"""Query JumpStart Hub for tune-eligible models matching a family.
|
|
1510
1650
|
|
|
1651
|
+
NOTE: This subcommand intentionally stays on boto3.client('sagemaker')
|
|
1652
|
+
because list_hub_contents / Hub API is NOT available in sagemaker-core.
|
|
1653
|
+
This is a documented exception per the SDK v3 migration policy.
|
|
1654
|
+
|
|
1511
1655
|
Returns: {"models": [str], "count": int}
|
|
1512
1656
|
"""
|
|
1513
1657
|
region = args.region or os.environ.get('AWS_REGION', 'us-east-1')
|
|
@@ -1532,6 +1676,8 @@ def cmd_discover(args):
|
|
|
1532
1676
|
_error_exit("Hub discovery failed: boto3 is not installed. Install with: pip install boto3")
|
|
1533
1677
|
|
|
1534
1678
|
try:
|
|
1679
|
+
# Documented exception: Hub API (list_hub_contents) is not available in
|
|
1680
|
+
# sagemaker-core, so we retain boto3.client('sagemaker') here.
|
|
1535
1681
|
client = boto3.client("sagemaker", region_name=region)
|
|
1536
1682
|
models = []
|
|
1537
1683
|
paginator = client.get_paginator('list_hub_contents')
|
|
@@ -1650,6 +1796,8 @@ def main():
|
|
|
1650
1796
|
help="Customization technique (determines required columns)")
|
|
1651
1797
|
stage_hf_parser.add_argument("--no-transform", action="store_true", default=False,
|
|
1652
1798
|
help="Disable automatic chat-format flattening")
|
|
1799
|
+
stage_hf_parser.add_argument("--take", type=int, default=None,
|
|
1800
|
+
help="Take only the first N records from the dataset")
|
|
1653
1801
|
|
|
1654
1802
|
# ── validate ──────────────────────────────────────────────────────────────
|
|
1655
1803
|
validate_parser = subparsers.add_parser("validate",
|
|
Binary file
|
|
Binary file
|
|
Binary file
|