@aws/ml-container-creator 0.13.4 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -5
- package/config/parameter-schema-v2.json +32 -4
- package/infra/ci-harness/lib/ci-harness-stack.ts +13 -5
- package/infra/ci-harness/package-lock.json +122 -116
- package/infra/ci-harness/package.json +1 -1
- package/package.json +5 -3
- package/pyproject.toml +21 -0
- package/requirements.txt +19 -0
- package/servers/instance-sizer/index.js +72 -4
- package/servers/instance-sizer/lib/model-resolver.js +28 -2
- package/src/app.js +17 -0
- package/src/lib/bootstrap-command-handler.js +33 -23
- package/src/lib/config-loader.js +18 -0
- package/src/lib/config-manager.js +6 -1
- package/src/lib/dataset-slug.js +152 -0
- package/src/lib/generated/cli-options.js +9 -3
- package/src/lib/generated/parameter-matrix.js +14 -3
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +6 -0
- package/src/lib/prompt-runner.js +5 -0
- package/src/lib/prompts/feature-prompts.js +1 -1
- package/src/lib/template-manager.js +0 -7
- package/src/lib/template-variable-resolver.js +51 -1
- package/src/lib/tune-config-state.js +14 -1
- package/templates/do/.adapter_helper.py +451 -0
- package/templates/do/.benchmark_writer.py +22 -0
- package/templates/do/.register_helper.py +1163 -0
- package/templates/do/.stage_helper.py +419 -0
- package/templates/do/.tune_helper.py +379 -65
- package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +427 -27
- package/templates/do/add-ic +85 -3
- package/templates/do/benchmark +173 -15
- package/templates/do/config +24 -0
- package/templates/do/lib/inference-component.sh +56 -3
- package/templates/do/lib/profile.sh +5 -0
- package/templates/do/register +552 -6
- package/templates/do/stage +91 -272
- package/templates/do/test +12 -2
- package/templates/do/tune +264 -12
|
@@ -29,6 +29,12 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
29
29
|
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
30
30
|
warnings.filterwarnings("ignore", message=".*charset_normalizer.*")
|
|
31
31
|
|
|
32
|
+
# Suppress ALL logging to prevent sagemaker-core/rich from writing to stdout.
|
|
33
|
+
# This script outputs JSON on stdout — any other stdout output corrupts parsing.
|
|
34
|
+
import logging as _logging
|
|
35
|
+
_logging.disable(_logging.CRITICAL)
|
|
36
|
+
os.environ.setdefault("SAGEMAKER_LOG_LEVEL", "CRITICAL")
|
|
37
|
+
|
|
32
38
|
# ── Inline dependency check ───────────────────────────────────────────────────
|
|
33
39
|
MIN_SAGEMAKER_VERSION = "3.0"
|
|
34
40
|
|
|
@@ -71,6 +77,164 @@ def _output(data):
|
|
|
71
77
|
sys.exit(0)
|
|
72
78
|
|
|
73
79
|
|
|
80
|
+
def _sanitize_for_json(value):
|
|
81
|
+
"""Convert sagemaker-core Unassigned sentinel values to None for JSON serialization.
|
|
82
|
+
|
|
83
|
+
sagemaker-core uses an 'Unassigned' type instead of None for unset fields.
|
|
84
|
+
This function converts any non-standard types to JSON-safe values.
|
|
85
|
+
"""
|
|
86
|
+
if value is None:
|
|
87
|
+
return None
|
|
88
|
+
# Check for Unassigned type from sagemaker-core
|
|
89
|
+
type_name = type(value).__name__
|
|
90
|
+
if type_name == "Unassigned" or type_name == "UnassignedValue":
|
|
91
|
+
return None
|
|
92
|
+
if isinstance(value, (str, int, float, bool)):
|
|
93
|
+
return value
|
|
94
|
+
if isinstance(value, dict):
|
|
95
|
+
return {k: _sanitize_for_json(v) for k, v in value.items()}
|
|
96
|
+
if isinstance(value, list):
|
|
97
|
+
return [_sanitize_for_json(v) for v in value]
|
|
98
|
+
# For other types, try str conversion as fallback
|
|
99
|
+
try:
|
|
100
|
+
# Check if it's JSON serializable as-is
|
|
101
|
+
import json as _json
|
|
102
|
+
_json.dumps(value)
|
|
103
|
+
return value
|
|
104
|
+
except (TypeError, ValueError):
|
|
105
|
+
return str(value) if value else None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ── Registry resolution helpers ───────────────────────────────────────────────
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _resolve_dataset_name(dataset_name):
|
|
112
|
+
"""Resolve a registered dataset name to S3 URI (or ARN) via .register_helper.py.
|
|
113
|
+
|
|
114
|
+
Calls the resolve-dataset subcommand of .register_helper.py and returns
|
|
115
|
+
the resolved value. If the response contains an 'arn' field (Backlog #023,
|
|
116
|
+
AI Registry mode), returns the ARN for use with SFTTrainer(training_dataset=arn).
|
|
117
|
+
Otherwise returns the S3 URI for backward compatibility.
|
|
118
|
+
"""
|
|
119
|
+
import subprocess
|
|
120
|
+
|
|
121
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
122
|
+
helper_path = os.path.join(script_dir, ".register_helper.py")
|
|
123
|
+
|
|
124
|
+
if not os.path.exists(helper_path):
|
|
125
|
+
_error_exit(
|
|
126
|
+
f"Cannot resolve dataset '{dataset_name}': .register_helper.py not found. "
|
|
127
|
+
f"Register datasets first with: ./do/register --dataset"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
result = subprocess.run(
|
|
132
|
+
["python3", helper_path, "resolve-dataset", "--name", dataset_name],
|
|
133
|
+
capture_output=True, text=True, timeout=30
|
|
134
|
+
)
|
|
135
|
+
except subprocess.TimeoutExpired:
|
|
136
|
+
_error_exit(f"Timeout resolving dataset '{dataset_name}' from registry")
|
|
137
|
+
except Exception as e:
|
|
138
|
+
_error_exit(f"Failed to resolve dataset '{dataset_name}': {e}")
|
|
139
|
+
|
|
140
|
+
if result.returncode != 0:
|
|
141
|
+
_error_exit(
|
|
142
|
+
f"Dataset '{dataset_name}' not found in registry. "
|
|
143
|
+
f"Register it first: ./do/register --dataset --dataset-name {dataset_name} --dataset-s3-uri s3://..."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Parse JSON output from resolve-dataset
|
|
147
|
+
try:
|
|
148
|
+
output = json.loads(result.stdout.strip())
|
|
149
|
+
except (json.JSONDecodeError, ValueError):
|
|
150
|
+
_error_exit(
|
|
151
|
+
f"Failed to parse registry response for dataset '{dataset_name}'. "
|
|
152
|
+
f"Raw output: {result.stdout[:200]}"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if "error" in output:
|
|
156
|
+
_error_exit(
|
|
157
|
+
f"Dataset '{dataset_name}' not found in registry: {output['error']}. "
|
|
158
|
+
f"Register it first: ./do/register --dataset --dataset-name {dataset_name} --dataset-s3-uri s3://..."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Prefer ARN if available (Backlog #023 — AI Registry mode)
|
|
162
|
+
# When arn is present, use it directly with SFTTrainer(training_dataset=arn)
|
|
163
|
+
arn = output.get("arn")
|
|
164
|
+
if arn:
|
|
165
|
+
return arn
|
|
166
|
+
|
|
167
|
+
# Fallback: use S3 URI
|
|
168
|
+
s3_uri = output.get("s3_uri", "")
|
|
169
|
+
if not s3_uri:
|
|
170
|
+
_error_exit(
|
|
171
|
+
f"Dataset '{dataset_name}' resolved but has no S3 URI or ARN. "
|
|
172
|
+
f"Re-register with: ./do/register --dataset --dataset-name {dataset_name} --dataset-s3-uri s3://..."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return s3_uri
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _resolve_evaluator_name(evaluator_name):
|
|
179
|
+
"""Resolve a registered evaluator name to type and ARN/URI via .register_helper.py.
|
|
180
|
+
|
|
181
|
+
Returns (evaluator_type, arn_or_uri) tuple.
|
|
182
|
+
evaluator_type is "lambda" for RLVR or "model" for RLAIF.
|
|
183
|
+
"""
|
|
184
|
+
import subprocess
|
|
185
|
+
|
|
186
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
187
|
+
helper_path = os.path.join(script_dir, ".register_helper.py")
|
|
188
|
+
|
|
189
|
+
if not os.path.exists(helper_path):
|
|
190
|
+
_error_exit(
|
|
191
|
+
f"Cannot resolve evaluator '{evaluator_name}': .register_helper.py not found. "
|
|
192
|
+
f"Register evaluators first with: ./do/register --evaluator"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
result = subprocess.run(
|
|
197
|
+
["python3", helper_path, "resolve-evaluator", "--name", evaluator_name],
|
|
198
|
+
capture_output=True, text=True, timeout=30
|
|
199
|
+
)
|
|
200
|
+
except subprocess.TimeoutExpired:
|
|
201
|
+
_error_exit(f"Timeout resolving evaluator '{evaluator_name}' from registry")
|
|
202
|
+
except Exception as e:
|
|
203
|
+
_error_exit(f"Failed to resolve evaluator '{evaluator_name}': {e}")
|
|
204
|
+
|
|
205
|
+
if result.returncode != 0:
|
|
206
|
+
_error_exit(
|
|
207
|
+
f"Evaluator '{evaluator_name}' not found in registry. "
|
|
208
|
+
f"Register it first: ./do/register --evaluator --evaluator-name {evaluator_name} ..."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Parse JSON output from resolve-evaluator
|
|
212
|
+
try:
|
|
213
|
+
output = json.loads(result.stdout.strip())
|
|
214
|
+
except (json.JSONDecodeError, ValueError):
|
|
215
|
+
_error_exit(
|
|
216
|
+
f"Failed to parse registry response for evaluator '{evaluator_name}'. "
|
|
217
|
+
f"Raw output: {result.stdout[:200]}"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if "error" in output:
|
|
221
|
+
_error_exit(
|
|
222
|
+
f"Evaluator '{evaluator_name}' not found in registry: {output['error']}. "
|
|
223
|
+
f"Register it first: ./do/register --evaluator --evaluator-name {evaluator_name} ..."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
ev_type = output.get("type", "")
|
|
227
|
+
arn_or_uri = output.get("arn_or_uri", "")
|
|
228
|
+
|
|
229
|
+
if not arn_or_uri:
|
|
230
|
+
_error_exit(
|
|
231
|
+
f"Evaluator '{evaluator_name}' resolved but has no ARN/URI. "
|
|
232
|
+
f"Re-register with: ./do/register --evaluator --evaluator-name {evaluator_name} ..."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return ev_type, arn_or_uri
|
|
236
|
+
|
|
237
|
+
|
|
74
238
|
# ── Subcommand: submit ────────────────────────────────────────────────────────
|
|
75
239
|
|
|
76
240
|
|
|
@@ -90,6 +254,26 @@ def cmd_submit(args):
|
|
|
90
254
|
os.environ["AWS_DEFAULT_REGION"] = region
|
|
91
255
|
os.environ.setdefault("AWS_REGION", region)
|
|
92
256
|
|
|
257
|
+
# ── Resolve --dataset-name from registry (AC-2b.4) ────────────────────────
|
|
258
|
+
# --dataset-s3-uri wins if both are provided (backward compatible override)
|
|
259
|
+
if not args.dataset_s3_uri and args.dataset_name:
|
|
260
|
+
resolved_uri = _resolve_dataset_name(args.dataset_name)
|
|
261
|
+
args.dataset_s3_uri = resolved_uri
|
|
262
|
+
elif not args.dataset_s3_uri and not args.dataset_name:
|
|
263
|
+
_error_exit(
|
|
264
|
+
"Either --dataset-s3-uri or --dataset-name is required. "
|
|
265
|
+
"Provide an S3 URI directly or a registered dataset name."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# ── Resolve --evaluator-name from registry (AC-2c.3, AC-2c.4) ────────────
|
|
269
|
+
# --reward-function / --reward-prompt win if provided (backward compatible override)
|
|
270
|
+
if args.evaluator_name and not args.reward_function and not args.reward_prompt:
|
|
271
|
+
ev_type, ev_arn_or_uri = _resolve_evaluator_name(args.evaluator_name)
|
|
272
|
+
if ev_type == "lambda":
|
|
273
|
+
args.reward_function = ev_arn_or_uri
|
|
274
|
+
else:
|
|
275
|
+
args.reward_prompt = ev_arn_or_uri
|
|
276
|
+
|
|
93
277
|
_check_sagemaker_sdk()
|
|
94
278
|
|
|
95
279
|
# SDK v3 moved trainers from sagemaker.modules.train → sagemaker.train
|
|
@@ -171,20 +355,25 @@ def cmd_submit(args):
|
|
|
171
355
|
trainer_kwargs["accept_eula"] = True
|
|
172
356
|
|
|
173
357
|
# Resolve model package group — create if it doesn't exist
|
|
358
|
+
# Using sagemaker-core ModelPackageGroup.create() per SDK v3 policy
|
|
174
359
|
mpg_name = args.model_package_group or f"{args.project_name}-tune-models"
|
|
175
360
|
try:
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
361
|
+
from sagemaker.core.resources import ModelPackageGroup
|
|
362
|
+
from botocore.exceptions import ClientError as _ClientError
|
|
363
|
+
try:
|
|
364
|
+
ModelPackageGroup.get(model_package_group_name=mpg_name)
|
|
365
|
+
except (_ClientError, Exception) as _mpg_err:
|
|
366
|
+
if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
|
|
367
|
+
try:
|
|
368
|
+
ModelPackageGroup.create(
|
|
369
|
+
model_package_group_name=mpg_name,
|
|
370
|
+
model_package_group_description=f"Fine-tuned models for {args.project_name}",
|
|
371
|
+
)
|
|
372
|
+
except Exception:
|
|
373
|
+
pass # May already exist or lack permissions — let the trainer handle it
|
|
374
|
+
except ImportError:
|
|
375
|
+
# sagemaker-core not available — skip MPG creation, let trainer handle it
|
|
376
|
+
pass
|
|
188
377
|
trainer_kwargs["model_package_group"] = mpg_name
|
|
189
378
|
|
|
190
379
|
trainer = trainer_cls(**trainer_kwargs)
|
|
@@ -267,7 +456,9 @@ def cmd_submit(args):
|
|
|
267
456
|
job_arn = job_arn or getattr(latest_job, 'arn', None)
|
|
268
457
|
|
|
269
458
|
# If we still don't have the actual job name (SDK appends suffix),
|
|
270
|
-
# query ListTrainingJobs to find it by our base_job_name prefix
|
|
459
|
+
# query ListTrainingJobs to find it by our base_job_name prefix.
|
|
460
|
+
# Note: list_training_jobs with NameContains filter is not available
|
|
461
|
+
# via sagemaker-core resource API, so boto3 is retained here.
|
|
271
462
|
if not job_name or job_name == args.job_name:
|
|
272
463
|
import boto3 as _boto3
|
|
273
464
|
_sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
|
|
@@ -335,23 +526,28 @@ def cmd_submit(args):
|
|
|
335
526
|
|
|
336
527
|
|
|
337
528
|
def cmd_status(args):
|
|
338
|
-
"""Query job status via
|
|
529
|
+
"""Query job status via sagemaker-core TrainingJob.get().
|
|
339
530
|
|
|
340
|
-
Falls back to ListTrainingJobs with name-contains if exact name not found
|
|
531
|
+
Falls back to boto3 ListTrainingJobs with name-contains if exact name not found
|
|
341
532
|
(SDK v3 appends a timestamp suffix to the base job name).
|
|
342
533
|
|
|
343
534
|
Returns: {"status": str, "failure_reason": str|None,
|
|
344
535
|
"metrics": dict|None, "elapsed_seconds": int}
|
|
345
536
|
"""
|
|
346
|
-
import boto3
|
|
537
|
+
# Set region before any sagemaker import (creates boto3 clients at import time)
|
|
538
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
539
|
+
if region:
|
|
540
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
541
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
347
542
|
|
|
348
|
-
|
|
543
|
+
from sagemaker.core.resources import TrainingJob
|
|
544
|
+
from botocore.exceptions import ClientError
|
|
349
545
|
|
|
350
|
-
# Try exact name first
|
|
351
|
-
|
|
546
|
+
# Try exact name first via sagemaker-core
|
|
547
|
+
job = None
|
|
352
548
|
try:
|
|
353
|
-
|
|
354
|
-
except
|
|
549
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
550
|
+
except ClientError as e:
|
|
355
551
|
error_code = e.response["Error"]["Code"]
|
|
356
552
|
if error_code != "ValidationException":
|
|
357
553
|
_error_exit(f"Failed to describe training job: {e}")
|
|
@@ -360,8 +556,13 @@ def cmd_status(args):
|
|
|
360
556
|
_error_exit(f"Failed to describe training job: {e}")
|
|
361
557
|
|
|
362
558
|
# Fallback: search by name prefix (SDK appends timestamp suffix)
|
|
363
|
-
|
|
559
|
+
# Note: TrainingJob.get_all() with name_contains is not available in
|
|
560
|
+
# sagemaker-core for list operations, so we use boto3 list_training_jobs
|
|
561
|
+
# to find the actual name, then call TrainingJob.get() with it.
|
|
562
|
+
if job is None:
|
|
364
563
|
try:
|
|
564
|
+
import boto3
|
|
565
|
+
client = boto3.client("sagemaker", region_name=args.region)
|
|
365
566
|
list_response = client.list_training_jobs(
|
|
366
567
|
NameContains=args.job_name,
|
|
367
568
|
SortBy="CreationTime",
|
|
@@ -371,18 +572,26 @@ def cmd_status(args):
|
|
|
371
572
|
summaries = list_response.get("TrainingJobSummaries", [])
|
|
372
573
|
if summaries:
|
|
373
574
|
actual_name = summaries[0]["TrainingJobName"]
|
|
374
|
-
|
|
575
|
+
job = TrainingJob.get(training_job_name=actual_name)
|
|
375
576
|
else:
|
|
376
577
|
_error_exit(f"Training job not found: {args.job_name}")
|
|
377
578
|
except Exception as e:
|
|
378
579
|
_error_exit(f"Failed to find training job: {e}")
|
|
379
580
|
|
|
380
|
-
status
|
|
381
|
-
|
|
581
|
+
# Read status attributes directly from the TrainingJob resource object.
|
|
582
|
+
# sagemaker-core returns status values in the same casing as the API
|
|
583
|
+
# (e.g., "InProgress", "Completed", "Failed", "Stopped").
|
|
584
|
+
status = getattr(job, "training_job_status", "Unknown") or "Unknown"
|
|
585
|
+
failure_reason = getattr(job, "failure_reason", None)
|
|
382
586
|
|
|
383
587
|
# Calculate elapsed time
|
|
384
|
-
start_time =
|
|
385
|
-
end_time =
|
|
588
|
+
start_time = getattr(job, "training_start_time", None)
|
|
589
|
+
end_time = getattr(job, "training_end_time", None)
|
|
590
|
+
# Convert Unassigned sentinel to None
|
|
591
|
+
if start_time and type(start_time).__name__ in ("Unassigned", "UnassignedValue"):
|
|
592
|
+
start_time = None
|
|
593
|
+
if end_time and type(end_time).__name__ in ("Unassigned", "UnassignedValue"):
|
|
594
|
+
end_time = None
|
|
386
595
|
elapsed_seconds = 0
|
|
387
596
|
|
|
388
597
|
if start_time:
|
|
@@ -393,24 +602,30 @@ def cmd_status(args):
|
|
|
393
602
|
|
|
394
603
|
# Extract final metrics if available
|
|
395
604
|
metrics = None
|
|
396
|
-
final_metrics =
|
|
605
|
+
final_metrics = getattr(job, "final_metric_data_list", None)
|
|
606
|
+
if final_metrics and type(final_metrics).__name__ in ("Unassigned", "UnassignedValue"):
|
|
607
|
+
final_metrics = None
|
|
397
608
|
if final_metrics:
|
|
398
609
|
metrics = {}
|
|
399
610
|
for metric in final_metrics:
|
|
400
|
-
metrics
|
|
611
|
+
# sagemaker-core returns metrics as objects with snake_case attributes
|
|
612
|
+
metric_name = getattr(metric, "metric_name", None) or metric.get("MetricName", "")
|
|
613
|
+
metric_value = getattr(metric, "value", None) or metric.get("Value", 0)
|
|
614
|
+
metrics[metric_name] = metric_value
|
|
401
615
|
|
|
402
616
|
# Get output path if completed
|
|
403
617
|
output_path = None
|
|
404
618
|
if status == "Completed":
|
|
405
|
-
model_artifacts =
|
|
406
|
-
|
|
619
|
+
model_artifacts = getattr(job, "model_artifacts", None)
|
|
620
|
+
if model_artifacts:
|
|
621
|
+
output_path = getattr(model_artifacts, "s3_model_artifacts", None)
|
|
407
622
|
|
|
408
623
|
_output({
|
|
409
|
-
"status": status,
|
|
410
|
-
"failure_reason": failure_reason,
|
|
411
|
-
"metrics": metrics,
|
|
624
|
+
"status": _sanitize_for_json(status),
|
|
625
|
+
"failure_reason": _sanitize_for_json(failure_reason),
|
|
626
|
+
"metrics": _sanitize_for_json(metrics),
|
|
412
627
|
"elapsed_seconds": elapsed_seconds,
|
|
413
|
-
"output_path": output_path,
|
|
628
|
+
"output_path": _sanitize_for_json(output_path),
|
|
414
629
|
})
|
|
415
630
|
|
|
416
631
|
|
|
@@ -420,28 +635,37 @@ def cmd_status(args):
|
|
|
420
635
|
def cmd_resolve(args):
|
|
421
636
|
"""Resolve artifact path within S3 output directory.
|
|
422
637
|
|
|
638
|
+
Uses sagemaker-core TrainingJob.get() to read model_artifacts and
|
|
639
|
+
output_data_config. Uses ModelPackage for model package lookup.
|
|
640
|
+
|
|
423
641
|
Returns: {"artifact_path": str, "model_package_arn": str|None,
|
|
424
642
|
"output_type": str}
|
|
425
643
|
"""
|
|
426
|
-
import boto3
|
|
644
|
+
# Set region before any sagemaker import (creates boto3 clients at import time)
|
|
645
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
646
|
+
if region:
|
|
647
|
+
os.environ['AWS_DEFAULT_REGION'] = region
|
|
648
|
+
os.environ.setdefault('AWS_REGION', region)
|
|
427
649
|
|
|
428
|
-
|
|
650
|
+
from sagemaker.core.resources import TrainingJob
|
|
429
651
|
|
|
430
652
|
try:
|
|
431
|
-
|
|
653
|
+
job = TrainingJob.get(training_job_name=args.job_name)
|
|
432
654
|
except Exception as e:
|
|
433
655
|
_error_exit(f"Failed to describe training job: {e}")
|
|
434
656
|
|
|
435
|
-
status =
|
|
657
|
+
status = getattr(job, "training_job_status", None)
|
|
436
658
|
if status != "Completed":
|
|
437
659
|
_error_exit(
|
|
438
660
|
f"Cannot resolve artifacts for job in status: {status}. "
|
|
439
661
|
f"Job must be Completed."
|
|
440
662
|
)
|
|
441
663
|
|
|
442
|
-
# Get the S3 model artifacts path
|
|
443
|
-
model_artifacts =
|
|
444
|
-
artifact_path =
|
|
664
|
+
# Get the S3 model artifacts path from TrainingJob resource
|
|
665
|
+
model_artifacts = getattr(job, "model_artifacts", None)
|
|
666
|
+
artifact_path = ""
|
|
667
|
+
if model_artifacts:
|
|
668
|
+
artifact_path = getattr(model_artifacts, "s3_model_artifacts", "") or ""
|
|
445
669
|
|
|
446
670
|
if not artifact_path:
|
|
447
671
|
_error_exit("No model artifacts found in training job output.")
|
|
@@ -461,6 +685,9 @@ def cmd_resolve(args):
|
|
|
461
685
|
model_package_arn = None
|
|
462
686
|
if args.model_package_group:
|
|
463
687
|
try:
|
|
688
|
+
# Use boto3 for list_model_packages since sagemaker-core ModelPackage
|
|
689
|
+
# doesn't have a direct list-by-group method with sort/limit
|
|
690
|
+
import boto3
|
|
464
691
|
mp_client = boto3.client("sagemaker", region_name=args.region)
|
|
465
692
|
packages = mp_client.list_model_packages(
|
|
466
693
|
ModelPackageGroupName=args.model_package_group,
|
|
@@ -476,8 +703,8 @@ def cmd_resolve(args):
|
|
|
476
703
|
pass
|
|
477
704
|
|
|
478
705
|
_output({
|
|
479
|
-
"artifact_path": artifact_path,
|
|
480
|
-
"model_package_arn": model_package_arn,
|
|
706
|
+
"artifact_path": _sanitize_for_json(artifact_path),
|
|
707
|
+
"model_package_arn": _sanitize_for_json(model_package_arn),
|
|
481
708
|
"output_type": output_type,
|
|
482
709
|
})
|
|
483
710
|
|
|
@@ -765,11 +992,12 @@ def _get_schema_types(technique):
|
|
|
765
992
|
return schemas.get(technique, {"prompt": "string", "completion": "string"})
|
|
766
993
|
|
|
767
994
|
|
|
768
|
-
def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id):
|
|
995
|
+
def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id, take=None):
|
|
769
996
|
"""Validate that the first record has required columns after mapping.
|
|
770
997
|
|
|
771
998
|
Returns (mapped_record, column_map_dict) on success.
|
|
772
999
|
Calls _error_exit with helpful suggestions on failure.
|
|
1000
|
+
If take is provided, includes --take N in the suggested command.
|
|
773
1001
|
"""
|
|
774
1002
|
column_map = _parse_column_map(column_map_str)
|
|
775
1003
|
mapped = _apply_column_map(first_record, column_map)
|
|
@@ -794,12 +1022,14 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
|
|
|
794
1022
|
if suggestion:
|
|
795
1023
|
lines.append(f"")
|
|
796
1024
|
lines.append(f" 💡 Suggested fix:")
|
|
797
|
-
|
|
1025
|
+
take_suffix = f" --take {take}" if take else ""
|
|
1026
|
+
lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}{take_suffix}")
|
|
798
1027
|
else:
|
|
799
1028
|
lines.append(f"")
|
|
800
1029
|
lines.append(f" 💡 Use --column-map to rename columns:")
|
|
801
1030
|
example_map = ",".join(f"{r}=<your_column>" for r in missing)
|
|
802
|
-
|
|
1031
|
+
take_suffix = f" --take {take}" if take else ""
|
|
1032
|
+
lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}{take_suffix}")
|
|
803
1033
|
|
|
804
1034
|
lines.append(f"")
|
|
805
1035
|
lines.append(f" First record sample:")
|
|
@@ -811,6 +1041,16 @@ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_i
|
|
|
811
1041
|
_error_exit("\n".join(lines))
|
|
812
1042
|
|
|
813
1043
|
|
|
1044
|
+
def _check_empty_fields(record, required_columns):
|
|
1045
|
+
"""Return list of required column names that are empty/blank in this record."""
|
|
1046
|
+
empty = []
|
|
1047
|
+
for col in required_columns:
|
|
1048
|
+
value = record.get(col, "")
|
|
1049
|
+
if value is None or (isinstance(value, str) and not value.strip()):
|
|
1050
|
+
empty.append(col)
|
|
1051
|
+
return empty
|
|
1052
|
+
|
|
1053
|
+
|
|
814
1054
|
def cmd_stage_hf(args):
|
|
815
1055
|
"""Download HF dataset to S3 using huggingface_hub.
|
|
816
1056
|
|
|
@@ -854,21 +1094,35 @@ def cmd_stage_hf(args):
|
|
|
854
1094
|
|
|
855
1095
|
# Find the appropriate data file for the split
|
|
856
1096
|
data_files = _find_data_files(repo_files, split)
|
|
1097
|
+
|
|
1098
|
+
# Apply file filter if --hf-file is provided
|
|
1099
|
+
hf_file_pattern = getattr(args, 'hf_file', None)
|
|
1100
|
+
|
|
1101
|
+
if not data_files and hf_file_pattern:
|
|
1102
|
+
# Split-based lookup found nothing, but user specified a file filter.
|
|
1103
|
+
# Fall back to filtering directly from all data files in the repo.
|
|
1104
|
+
all_data_files = [
|
|
1105
|
+
f for f in repo_files
|
|
1106
|
+
if f.endswith(('.parquet', '.jsonl', '.json'))
|
|
1107
|
+
and not f.startswith('.')
|
|
1108
|
+
]
|
|
1109
|
+
if all_data_files:
|
|
1110
|
+
data_files = _filter_data_files(all_data_files, hf_file_pattern)
|
|
1111
|
+
elif hf_file_pattern and data_files:
|
|
1112
|
+
# Normal case: apply file filter to split-matched results
|
|
1113
|
+
data_files = _filter_data_files(data_files, hf_file_pattern)
|
|
1114
|
+
|
|
857
1115
|
if not data_files:
|
|
858
1116
|
_error_exit(
|
|
859
1117
|
f"No data files found for split '{split}' in dataset {dataset_id}. "
|
|
860
1118
|
f"Available files: {', '.join(repo_files[:20])}"
|
|
861
1119
|
)
|
|
862
1120
|
|
|
863
|
-
# Apply file filter if --hf-file is provided
|
|
864
|
-
hf_file_pattern = getattr(args, 'hf_file', None)
|
|
865
|
-
if hf_file_pattern:
|
|
866
|
-
data_files = _filter_data_files(data_files, hf_file_pattern)
|
|
867
|
-
|
|
868
1121
|
# Download and upload to S3
|
|
869
1122
|
s3_client = boto3.client("s3", region_name=args.region)
|
|
870
1123
|
s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
|
|
871
1124
|
num_records = 0
|
|
1125
|
+
empty_field_counts = {} # Track empty required fields: {field_name: count}
|
|
872
1126
|
|
|
873
1127
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
874
1128
|
# Schema divergence check (skip for single file)
|
|
@@ -907,7 +1161,7 @@ def cmd_stage_hf(args):
|
|
|
907
1161
|
no_transform = getattr(args, 'no_transform', False)
|
|
908
1162
|
batches = table.to_batches(max_chunksize=1)
|
|
909
1163
|
first_record = batches[0].to_pylist()[0] if batches else {}
|
|
910
|
-
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
|
|
1164
|
+
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
|
|
911
1165
|
|
|
912
1166
|
# Apply column map to first record for detection
|
|
913
1167
|
mapped_first = _apply_column_map(first_record, column_map)
|
|
@@ -944,16 +1198,31 @@ def cmd_stage_hf(args):
|
|
|
944
1198
|
f" Detected format: {strategy_desc}"
|
|
945
1199
|
)
|
|
946
1200
|
|
|
1201
|
+
take_limit = getattr(args, 'take', None)
|
|
947
1202
|
with open(jsonl_path, "w", encoding="utf-8") as out_f:
|
|
948
1203
|
for batch in table.to_batches():
|
|
949
1204
|
for row in batch.to_pylist():
|
|
1205
|
+
if take_limit and num_records >= take_limit:
|
|
1206
|
+
break
|
|
950
1207
|
mapped_row = _apply_column_map(row, column_map)
|
|
951
1208
|
if chat_columns and not no_transform:
|
|
952
1209
|
mapped_row = _flatten_record(mapped_row, chat_columns)
|
|
1210
|
+
# Track empty required fields
|
|
1211
|
+
for col in _check_empty_fields(mapped_row, required_columns):
|
|
1212
|
+
empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
|
|
953
1213
|
out_f.write(json_mod.dumps(mapped_row, ensure_ascii=False) + "\n")
|
|
954
1214
|
num_records += 1
|
|
1215
|
+
if take_limit and num_records >= take_limit:
|
|
1216
|
+
break
|
|
955
1217
|
|
|
956
1218
|
# Upload converted JSONL
|
|
1219
|
+
# Verify file has content before uploading
|
|
1220
|
+
file_size = os.path.getsize(jsonl_path)
|
|
1221
|
+
if file_size == 0:
|
|
1222
|
+
_error_exit(
|
|
1223
|
+
f"Converted JSONL file is empty (0 bytes) after processing "
|
|
1224
|
+
f"{num_records} records. This is a bug — please report it."
|
|
1225
|
+
)
|
|
957
1226
|
s3_key = f"{s3_prefix}/{jsonl_filename}"
|
|
958
1227
|
s3_client.upload_file(jsonl_path, args.output_bucket, s3_key)
|
|
959
1228
|
|
|
@@ -975,7 +1244,7 @@ def cmd_stage_hf(args):
|
|
|
975
1244
|
first_line = f.readline().strip()
|
|
976
1245
|
if first_line:
|
|
977
1246
|
first_record = json_mod.loads(first_line)
|
|
978
|
-
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
|
|
1247
|
+
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}", take=getattr(args, 'take', None))
|
|
979
1248
|
|
|
980
1249
|
# Apply column map to first record for detection
|
|
981
1250
|
mapped_first = _apply_column_map(first_record, column_map)
|
|
@@ -1014,11 +1283,14 @@ def cmd_stage_hf(args):
|
|
|
1014
1283
|
|
|
1015
1284
|
# Rewrite the file with mapped (and optionally flattened) columns
|
|
1016
1285
|
should_flatten = bool(chat_columns) and not no_transform
|
|
1017
|
-
|
|
1286
|
+
take_limit = getattr(args, 'take', None)
|
|
1287
|
+
if column_map or should_flatten or take_limit:
|
|
1018
1288
|
mapped_path = local_path + ".mapped"
|
|
1019
1289
|
with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
|
|
1020
1290
|
open(mapped_path, "w", encoding="utf-8") as f_out:
|
|
1021
1291
|
for line in f_in:
|
|
1292
|
+
if take_limit and num_records >= take_limit:
|
|
1293
|
+
break
|
|
1022
1294
|
line = line.strip()
|
|
1023
1295
|
if not line:
|
|
1024
1296
|
continue
|
|
@@ -1026,15 +1298,32 @@ def cmd_stage_hf(args):
|
|
|
1026
1298
|
mapped_record = _apply_column_map(record, column_map)
|
|
1027
1299
|
if should_flatten:
|
|
1028
1300
|
mapped_record = _flatten_record(mapped_record, chat_columns)
|
|
1301
|
+
# Track empty required fields
|
|
1302
|
+
for col in _check_empty_fields(mapped_record, _get_required_columns(technique)):
|
|
1303
|
+
empty_field_counts[col] = empty_field_counts.get(col, 0) + 1
|
|
1029
1304
|
f_out.write(json_mod.dumps(mapped_record, ensure_ascii=False) + "\n")
|
|
1030
1305
|
num_records += 1
|
|
1031
1306
|
local_path = mapped_path
|
|
1032
1307
|
else:
|
|
1033
|
-
# Count records
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1308
|
+
# Count records (and truncate if --take specified)
|
|
1309
|
+
take_limit = getattr(args, 'take', None)
|
|
1310
|
+
if take_limit:
|
|
1311
|
+
# Need to rewrite the file truncated
|
|
1312
|
+
mapped_path = local_path + ".mapped"
|
|
1313
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
|
|
1314
|
+
open(mapped_path, "w", encoding="utf-8") as f_out:
|
|
1315
|
+
for line in f_in:
|
|
1316
|
+
if num_records >= take_limit:
|
|
1317
|
+
break
|
|
1318
|
+
if line.strip():
|
|
1319
|
+
f_out.write(line)
|
|
1320
|
+
num_records += 1
|
|
1321
|
+
local_path = mapped_path
|
|
1322
|
+
else:
|
|
1323
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
|
|
1324
|
+
for line in f:
|
|
1325
|
+
if line.strip():
|
|
1326
|
+
num_records += 1
|
|
1038
1327
|
|
|
1039
1328
|
# Upload to S3
|
|
1040
1329
|
s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
|
|
@@ -1048,6 +1337,19 @@ def cmd_stage_hf(args):
|
|
|
1048
1337
|
output_filename = os.path.basename(first_file)
|
|
1049
1338
|
s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{output_filename}"
|
|
1050
1339
|
|
|
1340
|
+
# Warn if required columns have many empty values
|
|
1341
|
+
if num_records > 0 and empty_field_counts:
|
|
1342
|
+
for field, count in empty_field_counts.items():
|
|
1343
|
+
pct = (count / num_records) * 100
|
|
1344
|
+
if pct > 30:
|
|
1345
|
+
print(
|
|
1346
|
+
f"\u26a0\ufe0f Warning: {pct:.0f}% of records ({count}/{num_records}) "
|
|
1347
|
+
f"have empty '{field}' after column mapping.\n"
|
|
1348
|
+
f" SageMaker may reject these as invalid samples.\n"
|
|
1349
|
+
f" Consider using a different --column-map or dataset.",
|
|
1350
|
+
file=sys.stderr,
|
|
1351
|
+
)
|
|
1352
|
+
|
|
1051
1353
|
_output({
|
|
1052
1354
|
"s3_uri": s3_uri,
|
|
1053
1355
|
"num_records": num_records,
|
|
@@ -1124,12 +1426,12 @@ def _find_data_files(repo_files, split):
|
|
|
1124
1426
|
if pattern in repo_files:
|
|
1125
1427
|
return [pattern]
|
|
1126
1428
|
|
|
1127
|
-
# Prefix match for sharded files
|
|
1128
|
-
matches =
|
|
1429
|
+
# Prefix match for sharded files (deduplicate via set)
|
|
1430
|
+
matches = set()
|
|
1129
1431
|
for f in repo_files:
|
|
1130
1432
|
for pattern in patterns[4:]:
|
|
1131
1433
|
if pattern in f:
|
|
1132
|
-
matches.
|
|
1434
|
+
matches.add(f)
|
|
1133
1435
|
|
|
1134
1436
|
if matches:
|
|
1135
1437
|
return sorted(matches)
|
|
@@ -1508,6 +1810,10 @@ def _build_expected_format(schema):
|
|
|
1508
1810
|
def cmd_discover(args):
|
|
1509
1811
|
"""Query JumpStart Hub for tune-eligible models matching a family.
|
|
1510
1812
|
|
|
1813
|
+
NOTE: This subcommand intentionally stays on boto3.client('sagemaker')
|
|
1814
|
+
because list_hub_contents / Hub API is NOT available in sagemaker-core.
|
|
1815
|
+
This is a documented exception per the SDK v3 migration policy.
|
|
1816
|
+
|
|
1511
1817
|
Returns: {"models": [str], "count": int}
|
|
1512
1818
|
"""
|
|
1513
1819
|
region = args.region or os.environ.get('AWS_REGION', 'us-east-1')
|
|
@@ -1532,6 +1838,8 @@ def cmd_discover(args):
|
|
|
1532
1838
|
_error_exit("Hub discovery failed: boto3 is not installed. Install with: pip install boto3")
|
|
1533
1839
|
|
|
1534
1840
|
try:
|
|
1841
|
+
# Documented exception: Hub API (list_hub_contents) is not available in
|
|
1842
|
+
# sagemaker-core, so we retain boto3.client('sagemaker') here.
|
|
1535
1843
|
client = boto3.client("sagemaker", region_name=region)
|
|
1536
1844
|
models = []
|
|
1537
1845
|
paginator = client.get_paginator('list_hub_contents')
|
|
@@ -1573,8 +1881,10 @@ def main():
|
|
|
1573
1881
|
submit_parser.add_argument("--training-type", required=True,
|
|
1574
1882
|
choices=["lora", "full-rank"],
|
|
1575
1883
|
help="Training type (lora or full-rank)")
|
|
1576
|
-
submit_parser.add_argument("--dataset-s3-uri", required=
|
|
1577
|
-
help="S3 URI of the training dataset")
|
|
1884
|
+
submit_parser.add_argument("--dataset-s3-uri", required=False, default=None,
|
|
1885
|
+
help="S3 URI of the training dataset (direct override)")
|
|
1886
|
+
submit_parser.add_argument("--dataset-name", default=None,
|
|
1887
|
+
help="Registered dataset name to resolve from registry")
|
|
1578
1888
|
submit_parser.add_argument("--output-bucket", required=True,
|
|
1579
1889
|
help="S3 bucket for output artifacts")
|
|
1580
1890
|
submit_parser.add_argument("--role-arn", required=True,
|
|
@@ -1601,6 +1911,8 @@ def main():
|
|
|
1601
1911
|
help="Lambda ARN for reward function (RLVR)")
|
|
1602
1912
|
submit_parser.add_argument("--reward-prompt", default=None,
|
|
1603
1913
|
help="S3 URI for reward prompt (RLAIF)")
|
|
1914
|
+
submit_parser.add_argument("--evaluator-name", default=None,
|
|
1915
|
+
help="Registered evaluator name to resolve from registry")
|
|
1604
1916
|
submit_parser.add_argument("--accept-eula", action="store_true", default=False,
|
|
1605
1917
|
help="Accept model EULA for gated models (e.g., Llama)")
|
|
1606
1918
|
|
|
@@ -1650,6 +1962,8 @@ def main():
|
|
|
1650
1962
|
help="Customization technique (determines required columns)")
|
|
1651
1963
|
stage_hf_parser.add_argument("--no-transform", action="store_true", default=False,
|
|
1652
1964
|
help="Disable automatic chat-format flattening")
|
|
1965
|
+
stage_hf_parser.add_argument("--take", type=int, default=None,
|
|
1966
|
+
help="Take only the first N records from the dataset")
|
|
1653
1967
|
|
|
1654
1968
|
# ── validate ──────────────────────────────────────────────────────────────
|
|
1655
1969
|
validate_parser = subparsers.add_parser("validate",
|