@aws/ml-container-creator 0.10.0 → 0.12.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +9304 -0
- package/bin/cli.js +2 -0
- package/config/bootstrap-e2e-stack.json +341 -0
- package/config/bootstrap-stack.json +40 -3
- package/config/parameter-schema-v2.json +33 -22
- package/config/tune-catalog.json +1781 -0
- package/infra/ci-harness/buildspec.yml +1 -0
- package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
- package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +851 -7
- package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
- package/package.json +53 -67
- package/servers/base-image-picker/index.js +121 -121
- package/servers/e2e-status/index.js +297 -0
- package/servers/e2e-status/manifest.json +14 -0
- package/servers/e2e-status/package.json +15 -0
- package/servers/endpoint-picker/LICENSE +202 -0
- package/servers/endpoint-picker/index.js +536 -0
- package/servers/endpoint-picker/manifest.json +14 -0
- package/servers/endpoint-picker/package.json +18 -0
- package/servers/hyperpod-cluster-picker/index.js +125 -125
- package/servers/instance-sizer/index.js +166 -153
- package/servers/instance-sizer/lib/instance-ranker.js +120 -76
- package/servers/instance-sizer/lib/model-resolver.js +61 -61
- package/servers/instance-sizer/lib/quota-resolver.js +113 -113
- package/servers/instance-sizer/lib/vram-estimator.js +31 -31
- package/servers/lib/bedrock-client.js +38 -38
- package/servers/lib/catalogs/instances.json +27 -0
- package/servers/lib/catalogs/model-servers.json +201 -3
- package/servers/lib/custom-validators.js +13 -13
- package/servers/lib/dynamic-resolver.js +4 -4
- package/servers/marketplace-picker/index.js +342 -0
- package/servers/marketplace-picker/manifest.json +14 -0
- package/servers/marketplace-picker/package.json +18 -0
- package/servers/model-picker/index.js +382 -382
- package/servers/region-picker/index.js +56 -56
- package/servers/workload-picker/LICENSE +202 -0
- package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
- package/servers/workload-picker/index.js +171 -0
- package/servers/workload-picker/manifest.json +16 -0
- package/servers/workload-picker/package.json +16 -0
- package/src/app.js +12 -3
- package/src/lib/bootstrap-command-handler.js +609 -15
- package/src/lib/bootstrap-config.js +36 -0
- package/src/lib/bootstrap-profile-manager.js +48 -41
- package/src/lib/ci-register-helpers.js +74 -0
- package/src/lib/config-loader.js +3 -0
- package/src/lib/config-manager.js +7 -0
- package/src/lib/config-validator.js +1 -1
- package/src/lib/cuda-resolver.js +17 -8
- package/src/lib/generated/cli-options.js +319 -314
- package/src/lib/generated/parameter-matrix.js +672 -661
- package/src/lib/generated/validation-rules.js +76 -72
- package/src/lib/path-prover-brain.js +664 -0
- package/src/lib/prompts/infrastructure-prompts.js +2 -2
- package/src/lib/prompts/model-prompts.js +6 -0
- package/src/lib/prompts/project-prompts.js +12 -0
- package/src/lib/secrets-prompt-runner.js +4 -0
- package/src/lib/template-manager.js +1 -1
- package/src/lib/template-variable-resolver.js +87 -1
- package/src/lib/tune-catalog-validator.js +37 -4
- package/templates/Dockerfile +9 -0
- package/templates/code/adapter_sidecar.py +444 -0
- package/templates/code/serve +6 -0
- package/templates/code/serve.d/vllm.ejs +1 -1
- package/templates/do/.benchmark_writer.py +1476 -0
- package/templates/do/.tune_helper.py +982 -57
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/adapter +154 -0
- package/templates/do/benchmark +639 -85
- package/templates/do/build +5 -0
- package/templates/do/clean.d/async-inference.ejs +5 -0
- package/templates/do/clean.d/batch-transform.ejs +5 -0
- package/templates/do/clean.d/hyperpod-eks.ejs +5 -0
- package/templates/do/clean.d/managed-inference.ejs +5 -0
- package/templates/do/config +115 -45
- package/templates/do/deploy.d/async-inference.ejs +30 -3
- package/templates/do/deploy.d/batch-transform.ejs +29 -3
- package/templates/do/deploy.d/hyperpod-eks.ejs +4 -0
- package/templates/do/deploy.d/managed-inference.ejs +216 -14
- package/templates/do/lib/endpoint-config.sh +1 -1
- package/templates/do/lib/profile.sh +44 -0
- package/templates/do/optimize +106 -37
- package/templates/do/push +5 -0
- package/templates/do/register +94 -0
- package/templates/do/stage +567 -0
- package/templates/do/submit +7 -0
- package/templates/do/test +14 -0
- package/templates/do/tune +382 -59
- package/templates/do/validate +44 -4
|
@@ -10,30 +10,44 @@ Subcommands:
|
|
|
10
10
|
resolve - Resolve output artifact path from job
|
|
11
11
|
stage-hf - Download HF dataset to S3
|
|
12
12
|
validate - Validate dataset format against schema
|
|
13
|
+
discover - Discover tune-eligible models from JumpStart Hub
|
|
13
14
|
|
|
14
15
|
All output is JSON on stdout for bash consumption.
|
|
15
16
|
"""
|
|
16
17
|
|
|
17
18
|
import argparse
|
|
19
|
+
import fnmatch
|
|
18
20
|
import json
|
|
19
21
|
import os
|
|
22
|
+
import re
|
|
20
23
|
import sys
|
|
21
24
|
import time
|
|
25
|
+
import warnings
|
|
26
|
+
|
|
27
|
+
# Suppress noisy dependency version warnings from requests/urllib3
|
|
28
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
29
|
+
warnings.filterwarnings("ignore", message=".*urllib3.*")
|
|
30
|
+
warnings.filterwarnings("ignore", message=".*charset_normalizer.*")
|
|
22
31
|
|
|
23
32
|
# ── Inline dependency check ───────────────────────────────────────────────────
|
|
24
|
-
MIN_SAGEMAKER_VERSION = "
|
|
33
|
+
MIN_SAGEMAKER_VERSION = "3.0"
|
|
34
|
+
|
|
35
|
+
_GLOB_METACHAR_RE = re.compile(r'[*?\[]')
|
|
25
36
|
|
|
26
37
|
|
|
27
38
|
def _check_sagemaker_sdk():
|
|
28
39
|
"""Verify sagemaker SDK is installed with minimum version."""
|
|
29
40
|
try:
|
|
30
41
|
import sagemaker # noqa: F401
|
|
42
|
+
# SDK v3 removed __version__; use importlib.metadata instead
|
|
43
|
+
from importlib.metadata import version as pkg_version
|
|
31
44
|
from packaging.version import Version
|
|
32
|
-
|
|
45
|
+
installed = pkg_version("sagemaker")
|
|
46
|
+
if Version(installed) < Version(MIN_SAGEMAKER_VERSION):
|
|
33
47
|
_error_exit(
|
|
34
|
-
f"sagemaker SDK version {
|
|
48
|
+
f"sagemaker SDK version {installed} is below minimum "
|
|
35
49
|
f"required version {MIN_SAGEMAKER_VERSION}. "
|
|
36
|
-
f"Please upgrade: pip install 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
|
|
50
|
+
f"Please upgrade: pip install --upgrade 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
|
|
37
51
|
)
|
|
38
52
|
except ImportError:
|
|
39
53
|
_error_exit(
|
|
@@ -65,11 +79,37 @@ def cmd_submit(args):
|
|
|
65
79
|
|
|
66
80
|
Returns: {"job_name": str, "job_arn": str, "mlflow_url": str|None}
|
|
67
81
|
"""
|
|
82
|
+
# Suppress SDK rich logging that pollutes stdout (we only want JSON output)
|
|
83
|
+
import logging
|
|
84
|
+
logging.disable(logging.CRITICAL)
|
|
85
|
+
os.environ["SAGEMAKER_LOG_LEVEL"] = "CRITICAL"
|
|
86
|
+
|
|
87
|
+
# Ensure region is set before ANY sagemaker import (v3 creates boto3 clients at import time)
|
|
88
|
+
region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
|
|
89
|
+
if region:
|
|
90
|
+
os.environ["AWS_DEFAULT_REGION"] = region
|
|
91
|
+
os.environ.setdefault("AWS_REGION", region)
|
|
92
|
+
|
|
68
93
|
_check_sagemaker_sdk()
|
|
69
94
|
|
|
70
|
-
from sagemaker.modules.train
|
|
71
|
-
|
|
72
|
-
|
|
95
|
+
# SDK v3 moved trainers from sagemaker.modules.train → sagemaker.train
|
|
96
|
+
# Note: catch Exception (not just ImportError) because SDK v3 AIRHub
|
|
97
|
+
# creates boto3 clients at class-definition time, which can raise
|
|
98
|
+
# NoRegionError if AWS_DEFAULT_REGION is not set despite our best efforts.
|
|
99
|
+
try:
|
|
100
|
+
from sagemaker.train.sft_trainer import SFTTrainer
|
|
101
|
+
from sagemaker.train.dpo_trainer import DPOTrainer
|
|
102
|
+
from sagemaker.train.common import TrainingType
|
|
103
|
+
except Exception:
|
|
104
|
+
try:
|
|
105
|
+
from sagemaker.modules.train.sft_trainer import SFTTrainer
|
|
106
|
+
from sagemaker.modules.train.dpo_trainer import DPOTrainer
|
|
107
|
+
from sagemaker.modules.train.common import TrainingType
|
|
108
|
+
except Exception:
|
|
109
|
+
_error_exit(
|
|
110
|
+
"SFTTrainer not found. Requires sagemaker>=3.0. "
|
|
111
|
+
"Install: pip install --upgrade 'sagemaker>=3.0'"
|
|
112
|
+
)
|
|
73
113
|
|
|
74
114
|
# Technique → Trainer class mapping
|
|
75
115
|
TRAINER_MAP = {
|
|
@@ -88,63 +128,164 @@ def cmd_submit(args):
|
|
|
88
128
|
# Resolve training type
|
|
89
129
|
training_type_map = {
|
|
90
130
|
"lora": TrainingType.LORA,
|
|
91
|
-
"full-rank": TrainingType
|
|
131
|
+
"full-rank": getattr(TrainingType, 'FULL_RANK', None) or getattr(TrainingType, 'FULL', None),
|
|
92
132
|
}
|
|
93
133
|
training_type = training_type_map.get(args.training_type)
|
|
94
134
|
if not training_type:
|
|
95
135
|
_error_exit(f"Unsupported training type: {args.training_type}")
|
|
96
136
|
|
|
97
137
|
# Build hyperparameters dict from optional overrides
|
|
138
|
+
# Map CLI flag names to SDK v3 fine-tuning option names
|
|
98
139
|
hyperparameters = {}
|
|
99
140
|
if args.epochs is not None:
|
|
100
|
-
hyperparameters["
|
|
141
|
+
hyperparameters["max_epochs"] = args.epochs
|
|
101
142
|
if args.learning_rate is not None:
|
|
102
143
|
hyperparameters["learning_rate"] = args.learning_rate
|
|
103
144
|
if args.max_seq_length is not None:
|
|
104
|
-
hyperparameters["
|
|
145
|
+
hyperparameters["dataset_max_len"] = args.max_seq_length
|
|
105
146
|
if args.lora_rank is not None:
|
|
106
147
|
hyperparameters["lora_rank"] = args.lora_rank
|
|
107
148
|
if args.lora_alpha is not None:
|
|
108
149
|
hyperparameters["lora_alpha"] = args.lora_alpha
|
|
109
150
|
if args.batch_size is not None:
|
|
110
|
-
hyperparameters["
|
|
111
|
-
|
|
112
|
-
# Build trainer kwargs
|
|
113
|
-
trainer_kwargs = {
|
|
114
|
-
"model_id": args.model_id,
|
|
115
|
-
"training_type": training_type,
|
|
116
|
-
"train_data_uri": args.dataset_s3_uri,
|
|
117
|
-
"output_path": f"s3://{args.output_bucket}/{args.project_name}/tune/{technique}/",
|
|
118
|
-
"role": args.role_arn,
|
|
119
|
-
"job_name": args.job_name,
|
|
120
|
-
}
|
|
151
|
+
hyperparameters["global_batch_size"] = args.batch_size
|
|
121
152
|
|
|
122
|
-
#
|
|
123
|
-
|
|
124
|
-
trainer_kwargs["model_package_group_name"] = args.model_package_group
|
|
153
|
+
# Build trainer kwargs — API differs between SDK v2 and v3
|
|
154
|
+
output_path = f"s3://{args.output_bucket}/{args.project_name}/tune/{technique}/"
|
|
125
155
|
|
|
126
|
-
#
|
|
127
|
-
|
|
128
|
-
trainer_kwargs["hyperparameters"] = hyperparameters
|
|
156
|
+
# Detect SDK version to use appropriate API
|
|
157
|
+
sdk_v3 = hasattr(trainer_cls, 'role') # v3 trainers have role as a settable attribute
|
|
129
158
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
trainer_kwargs
|
|
134
|
-
"
|
|
159
|
+
try:
|
|
160
|
+
if sdk_v3:
|
|
161
|
+
# SDK v3 API: positional model, keyword training_dataset, s3_output_path
|
|
162
|
+
trainer_kwargs = {
|
|
163
|
+
"model": args.model_id,
|
|
164
|
+
"training_type": training_type,
|
|
165
|
+
"training_dataset": args.dataset_s3_uri,
|
|
166
|
+
"s3_output_path": output_path,
|
|
135
167
|
}
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
168
|
+
# Accept EULA for gated models (e.g., Meta Llama)
|
|
169
|
+
# SDK v3.12+ accepts accept_eula as a constructor parameter
|
|
170
|
+
if args.accept_eula:
|
|
171
|
+
trainer_kwargs["accept_eula"] = True
|
|
172
|
+
|
|
173
|
+
# Resolve model package group — create if it doesn't exist
|
|
174
|
+
mpg_name = args.model_package_group or f"{args.project_name}-tune-models"
|
|
175
|
+
try:
|
|
176
|
+
import boto3 as _boto3
|
|
177
|
+
_sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
|
|
178
|
+
_sm.describe_model_package_group(ModelPackageGroupName=mpg_name)
|
|
179
|
+
except Exception as _mpg_err:
|
|
180
|
+
if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
|
|
181
|
+
try:
|
|
182
|
+
_sm.create_model_package_group(
|
|
183
|
+
ModelPackageGroupName=mpg_name,
|
|
184
|
+
ModelPackageGroupDescription=f"Fine-tuned models for {args.project_name}",
|
|
185
|
+
)
|
|
186
|
+
except Exception:
|
|
187
|
+
pass # May already exist or lack permissions — let the trainer handle it
|
|
188
|
+
trainer_kwargs["model_package_group"] = mpg_name
|
|
189
|
+
|
|
190
|
+
trainer = trainer_cls(**trainer_kwargs)
|
|
191
|
+
trainer.role = args.role_arn
|
|
192
|
+
trainer.base_job_name = args.job_name
|
|
193
|
+
if hyperparameters:
|
|
194
|
+
# SDK v3 expects hyperparameters with a .to_dict() method
|
|
195
|
+
# Wrap our plain dict to satisfy the interface
|
|
196
|
+
hp_obj = trainer.hyperparameters
|
|
197
|
+
if hp_obj is not None and hasattr(hp_obj, '__dict__'):
|
|
198
|
+
for k, v in hyperparameters.items():
|
|
199
|
+
setattr(hp_obj, k, v)
|
|
200
|
+
else:
|
|
201
|
+
# Fallback: create a simple wrapper
|
|
202
|
+
class _HyperParams:
|
|
203
|
+
def __init__(self, d):
|
|
204
|
+
self._data = d
|
|
205
|
+
for k, v in d.items():
|
|
206
|
+
setattr(self, k, v)
|
|
207
|
+
def to_dict(self):
|
|
208
|
+
return {k: v for k, v in self._data.items() if v is not None}
|
|
209
|
+
trainer.hyperparameters = _HyperParams(hyperparameters)
|
|
210
|
+
|
|
211
|
+
# Use MLCC-owned MLflow app if available (avoids permission issues with Studio apps)
|
|
212
|
+
mlflow_arn = os.environ.get('MLFLOW_APP_ARN', '')
|
|
213
|
+
if mlflow_arn:
|
|
214
|
+
trainer.mlflow_resource_arn = mlflow_arn
|
|
215
|
+
|
|
216
|
+
# Suppress SDK print() output (e.g., "Training Job Name: ...")
|
|
217
|
+
# that pollutes stdout and breaks JSON parsing by the shell script
|
|
218
|
+
import io as _io
|
|
219
|
+
_orig_stdout = sys.stdout
|
|
220
|
+
sys.stdout = _io.StringIO()
|
|
221
|
+
try:
|
|
222
|
+
trainer.train(training_dataset=args.dataset_s3_uri, wait=False)
|
|
223
|
+
finally:
|
|
224
|
+
sys.stdout = _orig_stdout
|
|
225
|
+
else:
|
|
226
|
+
# SDK v2 API: model_id, train_data_uri, output_path, role, job_name
|
|
227
|
+
trainer_kwargs = {
|
|
228
|
+
"model_id": args.model_id,
|
|
229
|
+
"training_type": training_type,
|
|
230
|
+
"train_data_uri": args.dataset_s3_uri,
|
|
231
|
+
"output_path": output_path,
|
|
232
|
+
"role": args.role_arn,
|
|
233
|
+
"job_name": args.job_name,
|
|
139
234
|
}
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
235
|
+
if args.model_package_group:
|
|
236
|
+
trainer_kwargs["model_package_group_name"] = args.model_package_group
|
|
237
|
+
if hyperparameters:
|
|
238
|
+
trainer_kwargs["hyperparameters"] = hyperparameters
|
|
239
|
+
|
|
240
|
+
# Add evaluator config for RLVR/RLAIF techniques
|
|
241
|
+
if technique in ("rlvr", "rlaif"):
|
|
242
|
+
if args.reward_function:
|
|
243
|
+
trainer_kwargs["evaluator_config"] = {"reward_function_arn": args.reward_function}
|
|
244
|
+
elif args.reward_prompt:
|
|
245
|
+
trainer_kwargs["evaluator_config"] = {"reward_prompt_s3_uri": args.reward_prompt}
|
|
246
|
+
|
|
247
|
+
# Accept EULA for gated models (e.g., Meta Llama)
|
|
248
|
+
if args.accept_eula:
|
|
249
|
+
trainer_kwargs["accept_eula"] = True
|
|
250
|
+
|
|
251
|
+
trainer = trainer_cls(**trainer_kwargs)
|
|
252
|
+
# Suppress SDK print() output that pollutes stdout
|
|
253
|
+
import io as _io
|
|
254
|
+
_orig_stdout = sys.stdout
|
|
255
|
+
sys.stdout = _io.StringIO()
|
|
256
|
+
try:
|
|
257
|
+
trainer.train(wait=False)
|
|
258
|
+
finally:
|
|
259
|
+
sys.stdout = _orig_stdout
|
|
144
260
|
|
|
145
261
|
# Extract job info from the trainer
|
|
146
|
-
job_name = trainer
|
|
262
|
+
job_name = getattr(trainer, 'training_job_name', None) or getattr(trainer, 'base_job_name', None)
|
|
147
263
|
job_arn = getattr(trainer, "training_job_arn", None)
|
|
264
|
+
latest_job = getattr(trainer, 'latest_training_job', None)
|
|
265
|
+
if latest_job:
|
|
266
|
+
job_name = job_name or getattr(latest_job, 'name', None) or getattr(latest_job, 'job_name', None)
|
|
267
|
+
job_arn = job_arn or getattr(latest_job, 'arn', None)
|
|
268
|
+
|
|
269
|
+
# If we still don't have the actual job name (SDK appends suffix),
|
|
270
|
+
# query ListTrainingJobs to find it by our base_job_name prefix
|
|
271
|
+
if not job_name or job_name == args.job_name:
|
|
272
|
+
import boto3 as _boto3
|
|
273
|
+
_sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
|
|
274
|
+
try:
|
|
275
|
+
# Brief delay to allow job to register
|
|
276
|
+
time.sleep(2)
|
|
277
|
+
list_resp = _sm.list_training_jobs(
|
|
278
|
+
NameContains=args.job_name,
|
|
279
|
+
SortBy="CreationTime",
|
|
280
|
+
SortOrder="Descending",
|
|
281
|
+
MaxResults=1,
|
|
282
|
+
)
|
|
283
|
+
summaries = list_resp.get("TrainingJobSummaries", [])
|
|
284
|
+
if summaries:
|
|
285
|
+
job_name = summaries[0]["TrainingJobName"]
|
|
286
|
+
job_arn = summaries[0].get("TrainingJobArn", job_arn)
|
|
287
|
+
except Exception:
|
|
288
|
+
pass # Fall back to whatever we have
|
|
148
289
|
|
|
149
290
|
# Attempt to get MLflow URL if available
|
|
150
291
|
mlflow_url = None
|
|
@@ -154,7 +295,7 @@ def cmd_submit(args):
|
|
|
154
295
|
pass
|
|
155
296
|
|
|
156
297
|
_output({
|
|
157
|
-
"job_name": job_name,
|
|
298
|
+
"job_name": job_name or args.job_name,
|
|
158
299
|
"job_arn": job_arn or "",
|
|
159
300
|
"mlflow_url": mlflow_url,
|
|
160
301
|
"model_package_group": args.model_package_group or "",
|
|
@@ -176,8 +317,15 @@ def cmd_submit(args):
|
|
|
176
317
|
)
|
|
177
318
|
elif "ValidationException" in error_msg and "license" in error_msg.lower():
|
|
178
319
|
_error_exit(
|
|
179
|
-
f"Model
|
|
180
|
-
f"
|
|
320
|
+
f"Model requires EULA acceptance. Re-run with --accept-eula flag: "
|
|
321
|
+
f"./do/tune --technique {technique} --accept-eula ... "
|
|
322
|
+
f"Details: {error_msg}"
|
|
323
|
+
)
|
|
324
|
+
elif "ValidationException" in error_msg and "eula" in error_msg.lower():
|
|
325
|
+
_error_exit(
|
|
326
|
+
f"Model requires EULA acceptance. Re-run with --accept-eula flag: "
|
|
327
|
+
f"./do/tune --technique {technique} --accept-eula ... "
|
|
328
|
+
f"Details: {error_msg}"
|
|
181
329
|
)
|
|
182
330
|
else:
|
|
183
331
|
_error_exit(f"Failed to submit training job: {error_msg}")
|
|
@@ -189,6 +337,9 @@ def cmd_submit(args):
|
|
|
189
337
|
def cmd_status(args):
|
|
190
338
|
"""Query job status via DescribeTrainingJob.
|
|
191
339
|
|
|
340
|
+
Falls back to ListTrainingJobs with name-contains if exact name not found
|
|
341
|
+
(SDK v3 appends a timestamp suffix to the base job name).
|
|
342
|
+
|
|
192
343
|
Returns: {"status": str, "failure_reason": str|None,
|
|
193
344
|
"metrics": dict|None, "elapsed_seconds": int}
|
|
194
345
|
"""
|
|
@@ -196,16 +347,36 @@ def cmd_status(args):
|
|
|
196
347
|
|
|
197
348
|
client = boto3.client("sagemaker", region_name=args.region)
|
|
198
349
|
|
|
350
|
+
# Try exact name first
|
|
351
|
+
response = None
|
|
199
352
|
try:
|
|
200
353
|
response = client.describe_training_job(TrainingJobName=args.job_name)
|
|
201
354
|
except client.exceptions.ClientError as e:
|
|
202
355
|
error_code = e.response["Error"]["Code"]
|
|
203
|
-
if error_code
|
|
204
|
-
_error_exit(f"
|
|
205
|
-
|
|
356
|
+
if error_code != "ValidationException":
|
|
357
|
+
_error_exit(f"Failed to describe training job: {e}")
|
|
358
|
+
# Job not found by exact name — try name-contains search
|
|
206
359
|
except Exception as e:
|
|
207
360
|
_error_exit(f"Failed to describe training job: {e}")
|
|
208
361
|
|
|
362
|
+
# Fallback: search by name prefix (SDK appends timestamp suffix)
|
|
363
|
+
if response is None:
|
|
364
|
+
try:
|
|
365
|
+
list_response = client.list_training_jobs(
|
|
366
|
+
NameContains=args.job_name,
|
|
367
|
+
SortBy="CreationTime",
|
|
368
|
+
SortOrder="Descending",
|
|
369
|
+
MaxResults=1,
|
|
370
|
+
)
|
|
371
|
+
summaries = list_response.get("TrainingJobSummaries", [])
|
|
372
|
+
if summaries:
|
|
373
|
+
actual_name = summaries[0]["TrainingJobName"]
|
|
374
|
+
response = client.describe_training_job(TrainingJobName=actual_name)
|
|
375
|
+
else:
|
|
376
|
+
_error_exit(f"Training job not found: {args.job_name}")
|
|
377
|
+
except Exception as e:
|
|
378
|
+
_error_exit(f"Failed to find training job: {e}")
|
|
379
|
+
|
|
209
380
|
status = response.get("TrainingJobStatus", "Unknown")
|
|
210
381
|
failure_reason = response.get("FailureReason")
|
|
211
382
|
|
|
@@ -278,6 +449,14 @@ def cmd_resolve(args):
|
|
|
278
449
|
# Determine output type from training type
|
|
279
450
|
output_type = "adapter" if args.training_type == "lora" else "full-model"
|
|
280
451
|
|
|
452
|
+
# For LoRA adapters, the actual adapter files are in checkpoints/hf/ subdirectory
|
|
453
|
+
# The S3ModelArtifacts path points to the top-level output directory
|
|
454
|
+
if output_type == "adapter":
|
|
455
|
+
# Ensure trailing slash for directory path
|
|
456
|
+
if not artifact_path.endswith("/"):
|
|
457
|
+
artifact_path += "/"
|
|
458
|
+
artifact_path += "checkpoints/hf/"
|
|
459
|
+
|
|
281
460
|
# Try to find model package ARN if a model package group was used
|
|
282
461
|
model_package_arn = None
|
|
283
462
|
if args.model_package_group:
|
|
@@ -306,6 +485,332 @@ def cmd_resolve(args):
|
|
|
306
485
|
# ── Subcommand: stage-hf ─────────────────────────────────────────────────────
|
|
307
486
|
|
|
308
487
|
|
|
488
|
+
def _get_required_columns(technique):
|
|
489
|
+
"""Return the required column names for a given technique."""
|
|
490
|
+
schemas = {
|
|
491
|
+
"sft": ["prompt", "completion"],
|
|
492
|
+
"dpo": ["prompt", "chosen", "rejected"],
|
|
493
|
+
"rlaif": ["prompt"], # prompt is an array of messages
|
|
494
|
+
"rlvr": ["prompt"], # prompt is an array of messages
|
|
495
|
+
}
|
|
496
|
+
return schemas.get(technique, ["prompt", "completion"])
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def _suggest_column_map(detected_columns, required_columns):
|
|
500
|
+
"""Suggest a --column-map based on common column name patterns."""
|
|
501
|
+
# Common aliases for each required field
|
|
502
|
+
aliases = {
|
|
503
|
+
"prompt": ["question", "instruction", "input", "query", "text", "context", "user", "human"],
|
|
504
|
+
"completion": ["answer", "output", "response", "assistant", "target", "label", "reply"],
|
|
505
|
+
"chosen": ["chosen", "preferred", "good", "positive", "accepted"],
|
|
506
|
+
"rejected": ["rejected", "dispreferred", "bad", "negative", "refused"],
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
suggestions = {}
|
|
510
|
+
for req_col in required_columns:
|
|
511
|
+
if req_col in detected_columns:
|
|
512
|
+
continue # Already present
|
|
513
|
+
# Check aliases
|
|
514
|
+
for alias in aliases.get(req_col, []):
|
|
515
|
+
if alias in detected_columns:
|
|
516
|
+
suggestions[req_col] = alias
|
|
517
|
+
break
|
|
518
|
+
|
|
519
|
+
if not suggestions:
|
|
520
|
+
return None
|
|
521
|
+
|
|
522
|
+
# Format as --column-map string
|
|
523
|
+
mapping_str = ",".join(f"{k}={v}" for k, v in suggestions.items())
|
|
524
|
+
return mapping_str
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _parse_column_map(column_map_str):
|
|
528
|
+
"""Parse a column map string like 'prompt=question,completion=answer' into a dict."""
|
|
529
|
+
if not column_map_str:
|
|
530
|
+
return {}
|
|
531
|
+
mapping = {}
|
|
532
|
+
for pair in column_map_str.split(","):
|
|
533
|
+
pair = pair.strip()
|
|
534
|
+
if "=" not in pair:
|
|
535
|
+
continue
|
|
536
|
+
target, source = pair.split("=", 1)
|
|
537
|
+
mapping[target.strip()] = source.strip()
|
|
538
|
+
return mapping
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def _apply_column_map(record, column_map):
|
|
542
|
+
"""Apply column mapping to a record: rename source columns to target names."""
|
|
543
|
+
if not column_map:
|
|
544
|
+
return record
|
|
545
|
+
mapped = dict(record)
|
|
546
|
+
for target, source in column_map.items():
|
|
547
|
+
if source in mapped and target not in mapped:
|
|
548
|
+
mapped[target] = mapped.pop(source)
|
|
549
|
+
return mapped
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def _detect_chat_columns(record, required_columns, schema_types):
|
|
553
|
+
"""Detect which required columns contain chat-format data.
|
|
554
|
+
|
|
555
|
+
Only inspects columns whose schema type is "string". Columns with
|
|
556
|
+
"array" type (RLAIF/RLVR) are excluded from detection entirely.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
record: The first record (dict) after column mapping
|
|
560
|
+
required_columns: List of required column names for the technique
|
|
561
|
+
schema_types: Dict mapping column name -> expected type from schema
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
dict: Maps column_name -> detection_result where detection_result is:
|
|
565
|
+
{"type": "single_dict"} or
|
|
566
|
+
{"type": "message_list", "strategy": "extract"|"same_role"|"multi_role", "count": int}
|
|
567
|
+
Only columns detected as chat-format are included.
|
|
568
|
+
"""
|
|
569
|
+
results = {}
|
|
570
|
+
for column in required_columns:
|
|
571
|
+
# Only inspect columns whose schema type is "string"
|
|
572
|
+
if schema_types.get(column) != "string":
|
|
573
|
+
continue
|
|
574
|
+
|
|
575
|
+
# Skip if column is not present in the record
|
|
576
|
+
if column not in record:
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
value = record[column]
|
|
580
|
+
|
|
581
|
+
# Check for Single_Message_Dict: dict with both "role" and "content" keys
|
|
582
|
+
if isinstance(value, dict) and "role" in value and "content" in value:
|
|
583
|
+
results[column] = {"type": "single_dict"}
|
|
584
|
+
continue
|
|
585
|
+
|
|
586
|
+
# Check for Message_List: non-empty list whose first element is a dict
|
|
587
|
+
# with both "role" and "content" keys
|
|
588
|
+
if isinstance(value, list) and len(value) > 0:
|
|
589
|
+
first_element = value[0]
|
|
590
|
+
if isinstance(first_element, dict) and "role" in first_element and "content" in first_element:
|
|
591
|
+
count = len(value)
|
|
592
|
+
if count == 1:
|
|
593
|
+
strategy = "extract"
|
|
594
|
+
elif all(
|
|
595
|
+
isinstance(elem, dict) and elem.get("role") == first_element["role"]
|
|
596
|
+
for elem in value
|
|
597
|
+
):
|
|
598
|
+
strategy = "same_role"
|
|
599
|
+
else:
|
|
600
|
+
strategy = "multi_role"
|
|
601
|
+
results[column] = {"type": "message_list", "strategy": strategy, "count": count}
|
|
602
|
+
continue
|
|
603
|
+
|
|
604
|
+
return results
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def _flatten_value(value, detection_result):
|
|
608
|
+
"""Flatten a chat-format column value to a plain string.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
value: The column value (dict, list, string, or other)
|
|
612
|
+
detection_result: The detection metadata for this column
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
str: The flattened string value
|
|
616
|
+
|
|
617
|
+
Raises:
|
|
618
|
+
ValueError: If the value cannot be converted at all (str() also fails)
|
|
619
|
+
"""
|
|
620
|
+
import json
|
|
621
|
+
|
|
622
|
+
# Edge case: string pass-through
|
|
623
|
+
if isinstance(value, str):
|
|
624
|
+
return value
|
|
625
|
+
|
|
626
|
+
# Edge case: None → ""
|
|
627
|
+
if value is None:
|
|
628
|
+
return ""
|
|
629
|
+
|
|
630
|
+
# Edge case: empty list → ""
|
|
631
|
+
if isinstance(value, list) and len(value) == 0:
|
|
632
|
+
return ""
|
|
633
|
+
|
|
634
|
+
det_type = detection_result.get("type")
|
|
635
|
+
|
|
636
|
+
if det_type == "single_dict":
|
|
637
|
+
if isinstance(value, dict):
|
|
638
|
+
role = value.get("role", "")
|
|
639
|
+
if "content" in value:
|
|
640
|
+
content = value["content"]
|
|
641
|
+
if isinstance(content, str):
|
|
642
|
+
return content
|
|
643
|
+
# Non-string content: format as "role: json_content"
|
|
644
|
+
return f"{role}: {json.dumps(content)}"
|
|
645
|
+
else:
|
|
646
|
+
# No content key: format as "role: remaining_values"
|
|
647
|
+
remaining = {k: v for k, v in value.items() if k != "role"}
|
|
648
|
+
return f"{role}: {json.dumps(remaining)}"
|
|
649
|
+
|
|
650
|
+
elif det_type == "message_list":
|
|
651
|
+
strategy = detection_result.get("strategy")
|
|
652
|
+
|
|
653
|
+
if isinstance(value, list) and len(value) > 0:
|
|
654
|
+
if strategy == "extract":
|
|
655
|
+
# Extract single element's content
|
|
656
|
+
elem = value[0]
|
|
657
|
+
if isinstance(elem, dict):
|
|
658
|
+
content = elem.get("content")
|
|
659
|
+
if content is None:
|
|
660
|
+
return ""
|
|
661
|
+
if isinstance(content, str):
|
|
662
|
+
return content
|
|
663
|
+
return f"{elem.get('role', '')}: {json.dumps(content)}"
|
|
664
|
+
return ""
|
|
665
|
+
|
|
666
|
+
elif strategy == "same_role":
|
|
667
|
+
# Join all content fields with newline
|
|
668
|
+
parts = []
|
|
669
|
+
for elem in value:
|
|
670
|
+
if isinstance(elem, dict):
|
|
671
|
+
content = elem.get("content")
|
|
672
|
+
if content is None or content == "":
|
|
673
|
+
parts.append("")
|
|
674
|
+
elif isinstance(content, str):
|
|
675
|
+
parts.append(content)
|
|
676
|
+
else:
|
|
677
|
+
parts.append(json.dumps(content))
|
|
678
|
+
else:
|
|
679
|
+
parts.append("")
|
|
680
|
+
return "\n".join(parts)
|
|
681
|
+
|
|
682
|
+
elif strategy == "multi_role":
|
|
683
|
+
# Format as "role: content" per line
|
|
684
|
+
lines = []
|
|
685
|
+
for elem in value:
|
|
686
|
+
if isinstance(elem, dict):
|
|
687
|
+
role = elem.get("role", "")
|
|
688
|
+
content = elem.get("content")
|
|
689
|
+
if content is None:
|
|
690
|
+
content = ""
|
|
691
|
+
elif not isinstance(content, str):
|
|
692
|
+
content = json.dumps(content)
|
|
693
|
+
lines.append(f"{role}: {content}")
|
|
694
|
+
else:
|
|
695
|
+
lines.append("")
|
|
696
|
+
return "\n".join(lines)
|
|
697
|
+
|
|
698
|
+
# Fallback for unexpected types: int/bool → str()
|
|
699
|
+
try:
|
|
700
|
+
return str(value)
|
|
701
|
+
except Exception as e:
|
|
702
|
+
raise ValueError(f"Cannot convert value to string: {e}")
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def _flatten_record(record, chat_columns):
|
|
706
|
+
"""Apply flattening to all chat-format columns in a record.
|
|
707
|
+
|
|
708
|
+
Args:
|
|
709
|
+
record: The mapped record dict
|
|
710
|
+
chat_columns: Detection results from _detect_chat_columns
|
|
711
|
+
|
|
712
|
+
Returns:
|
|
713
|
+
dict: The record with chat-format columns replaced by flat strings
|
|
714
|
+
"""
|
|
715
|
+
flattened = dict(record)
|
|
716
|
+
for column_name, detection_result in chat_columns.items():
|
|
717
|
+
if column_name in flattened:
|
|
718
|
+
flattened[column_name] = _flatten_value(flattened[column_name], detection_result)
|
|
719
|
+
return flattened
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _log_flatten_info(chat_columns, no_transform):
|
|
723
|
+
"""Log auto-flatten detection and strategy information.
|
|
724
|
+
|
|
725
|
+
Logs regardless of --no-transform state (per requirement 6.3/6.4).
|
|
726
|
+
When --no-transform is active, detection still runs for logging purposes.
|
|
727
|
+
|
|
728
|
+
All output goes to stderr to avoid polluting stdout JSON output.
|
|
729
|
+
|
|
730
|
+
Args:
|
|
731
|
+
chat_columns: Detection results dict (from _detect_chat_columns)
|
|
732
|
+
no_transform: Whether --no-transform flag is active
|
|
733
|
+
"""
|
|
734
|
+
for column_name, detection_result in chat_columns.items():
|
|
735
|
+
print(f"\u2139\ufe0f Auto-converted column '{column_name}' from chat-format to string", file=sys.stderr)
|
|
736
|
+
det_type = detection_result.get("type")
|
|
737
|
+
if det_type == "single_dict":
|
|
738
|
+
print(" Format: extracted content field", file=sys.stderr)
|
|
739
|
+
elif det_type == "message_list":
|
|
740
|
+
strategy = detection_result.get("strategy")
|
|
741
|
+
count = detection_result.get("count", 0)
|
|
742
|
+
if strategy == "multi_role":
|
|
743
|
+
print(f" Format: role: content (multi-turn, {count} messages)", file=sys.stderr)
|
|
744
|
+
elif strategy == "same_role":
|
|
745
|
+
print(f" Format: newline-joined content ({count} messages, same role)", file=sys.stderr)
|
|
746
|
+
elif strategy == "extract":
|
|
747
|
+
print(" Format: extracted content field", file=sys.stderr)
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def _get_schema_types(technique):
|
|
751
|
+
"""Return a dict mapping column names to their expected types for a technique.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
technique: One of 'sft', 'dpo', 'rlaif', 'rlvr'
|
|
755
|
+
|
|
756
|
+
Returns:
|
|
757
|
+
dict: Maps column_name -> expected type ("string" or "array")
|
|
758
|
+
"""
|
|
759
|
+
schemas = {
|
|
760
|
+
"sft": {"prompt": "string", "completion": "string"},
|
|
761
|
+
"dpo": {"prompt": "string", "chosen": "string", "rejected": "string"},
|
|
762
|
+
"rlaif": {"prompt": "array"},
|
|
763
|
+
"rlvr": {"prompt": "array"},
|
|
764
|
+
}
|
|
765
|
+
return schemas.get(technique, {"prompt": "string", "completion": "string"})
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id):
|
|
769
|
+
"""Validate that the first record has required columns after mapping.
|
|
770
|
+
|
|
771
|
+
Returns (mapped_record, column_map_dict) on success.
|
|
772
|
+
Calls _error_exit with helpful suggestions on failure.
|
|
773
|
+
"""
|
|
774
|
+
column_map = _parse_column_map(column_map_str)
|
|
775
|
+
mapped = _apply_column_map(first_record, column_map)
|
|
776
|
+
required = _get_required_columns(technique)
|
|
777
|
+
detected = list(first_record.keys())
|
|
778
|
+
|
|
779
|
+
missing = [col for col in required if col not in mapped]
|
|
780
|
+
if not missing:
|
|
781
|
+
return mapped, column_map
|
|
782
|
+
|
|
783
|
+
# Build helpful error message
|
|
784
|
+
lines = [
|
|
785
|
+
f"Dataset columns don't match {technique.upper()} requirements.",
|
|
786
|
+
f"",
|
|
787
|
+
f" Required columns: {', '.join(required)}",
|
|
788
|
+
f" Detected columns: {', '.join(detected)}",
|
|
789
|
+
f" Missing: {', '.join(missing)}",
|
|
790
|
+
]
|
|
791
|
+
|
|
792
|
+
# Suggest a column map
|
|
793
|
+
suggestion = _suggest_column_map(detected, required)
|
|
794
|
+
if suggestion:
|
|
795
|
+
lines.append(f"")
|
|
796
|
+
lines.append(f" 💡 Suggested fix:")
|
|
797
|
+
lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}")
|
|
798
|
+
else:
|
|
799
|
+
lines.append(f"")
|
|
800
|
+
lines.append(f" 💡 Use --column-map to rename columns:")
|
|
801
|
+
example_map = ",".join(f"{r}=<your_column>" for r in missing)
|
|
802
|
+
lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}")
|
|
803
|
+
|
|
804
|
+
lines.append(f"")
|
|
805
|
+
lines.append(f" First record sample:")
|
|
806
|
+
# Show truncated first record
|
|
807
|
+
for k, v in list(first_record.items())[:5]:
|
|
808
|
+
val_str = str(v)[:80] + ("..." if len(str(v)) > 80 else "")
|
|
809
|
+
lines.append(f" {k}: {val_str}")
|
|
810
|
+
|
|
811
|
+
_error_exit("\n".join(lines))
|
|
812
|
+
|
|
813
|
+
|
|
309
814
|
def cmd_stage_hf(args):
|
|
310
815
|
"""Download HF dataset to S3 using huggingface_hub.
|
|
311
816
|
|
|
@@ -313,6 +818,9 @@ def cmd_stage_hf(args):
|
|
|
313
818
|
|
|
314
819
|
Returns: {"s3_uri": str, "num_records": int}
|
|
315
820
|
"""
|
|
821
|
+
# Suppress HF Hub progress bars — they pollute stdout which must be clean JSON
|
|
822
|
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
823
|
+
|
|
316
824
|
try:
|
|
317
825
|
from huggingface_hub import hf_hub_download, HfApi
|
|
318
826
|
except ImportError:
|
|
@@ -352,12 +860,28 @@ def cmd_stage_hf(args):
|
|
|
352
860
|
f"Available files: {', '.join(repo_files[:20])}"
|
|
353
861
|
)
|
|
354
862
|
|
|
863
|
+
# Apply file filter if --hf-file is provided
|
|
864
|
+
hf_file_pattern = getattr(args, 'hf_file', None)
|
|
865
|
+
if hf_file_pattern:
|
|
866
|
+
data_files = _filter_data_files(data_files, hf_file_pattern)
|
|
867
|
+
|
|
355
868
|
# Download and upload to S3
|
|
356
869
|
s3_client = boto3.client("s3", region_name=args.region)
|
|
357
870
|
s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
|
|
358
871
|
num_records = 0
|
|
359
872
|
|
|
360
873
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
874
|
+
# Schema divergence check (skip for single file)
|
|
875
|
+
if len(data_files) > 1:
|
|
876
|
+
column_map = _parse_column_map(getattr(args, 'column_map', None))
|
|
877
|
+
technique = getattr(args, 'technique', 'sft')
|
|
878
|
+
no_transform = getattr(args, 'no_transform', False)
|
|
879
|
+
file_records = _inspect_file_schemas(
|
|
880
|
+
data_files, dataset_id, hf_token, tmpdir,
|
|
881
|
+
column_map, technique, no_transform
|
|
882
|
+
)
|
|
883
|
+
_check_schema_divergence(file_records, dataset_id, technique)
|
|
884
|
+
|
|
361
885
|
for data_file in data_files:
|
|
362
886
|
local_path = hf_hub_download(
|
|
363
887
|
repo_id=dataset_id,
|
|
@@ -367,17 +891,162 @@ def cmd_stage_hf(args):
|
|
|
367
891
|
local_dir=tmpdir,
|
|
368
892
|
)
|
|
369
893
|
|
|
370
|
-
#
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
894
|
+
# Handle Parquet files: convert to JSONL for SageMaker compatibility
|
|
895
|
+
if data_file.endswith(".parquet"):
|
|
896
|
+
try:
|
|
897
|
+
import pyarrow.parquet as pq
|
|
898
|
+
import json as json_mod
|
|
899
|
+
|
|
900
|
+
table = pq.read_table(local_path)
|
|
901
|
+
jsonl_filename = os.path.splitext(os.path.basename(data_file))[0] + ".jsonl"
|
|
902
|
+
jsonl_path = os.path.join(tmpdir, jsonl_filename)
|
|
903
|
+
|
|
904
|
+
# Parse column map and validate against first record
|
|
905
|
+
column_map = _parse_column_map(getattr(args, 'column_map', None))
|
|
906
|
+
technique = getattr(args, 'technique', 'sft')
|
|
907
|
+
no_transform = getattr(args, 'no_transform', False)
|
|
908
|
+
batches = table.to_batches(max_chunksize=1)
|
|
909
|
+
first_record = batches[0].to_pylist()[0] if batches else {}
|
|
910
|
+
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
|
|
911
|
+
|
|
912
|
+
# Apply column map to first record for detection
|
|
913
|
+
mapped_first = _apply_column_map(first_record, column_map)
|
|
914
|
+
required_columns = _get_required_columns(technique)
|
|
915
|
+
schema_types = _get_schema_types(technique)
|
|
916
|
+
|
|
917
|
+
# Detect chat-format columns on first record
|
|
918
|
+
chat_columns = _detect_chat_columns(mapped_first, required_columns, schema_types)
|
|
919
|
+
|
|
920
|
+
# Log detection results if any chat columns found
|
|
921
|
+
if chat_columns:
|
|
922
|
+
_log_flatten_info(chat_columns, no_transform)
|
|
923
|
+
|
|
924
|
+
# If --no-transform is active and chat-format detected, halt with error
|
|
925
|
+
if no_transform and chat_columns:
|
|
926
|
+
col_name = next(iter(chat_columns))
|
|
927
|
+
det = chat_columns[col_name]
|
|
928
|
+
det_type = det.get("type")
|
|
929
|
+
strategy = det.get("strategy", "")
|
|
930
|
+
if det_type == "single_dict":
|
|
931
|
+
strategy_desc = "single message dict with role+content"
|
|
932
|
+
elif strategy == "extract":
|
|
933
|
+
strategy_desc = "message list (single element)"
|
|
934
|
+
elif strategy == "same_role":
|
|
935
|
+
strategy_desc = f"message list ({det.get('count', 0)} messages, same role)"
|
|
936
|
+
elif strategy == "multi_role":
|
|
937
|
+
strategy_desc = f"message list (multi-turn, {det.get('count', 0)} messages)"
|
|
938
|
+
else:
|
|
939
|
+
strategy_desc = det_type
|
|
940
|
+
_error_exit(
|
|
941
|
+
f"Column '{col_name}' contains chat-format data (detected: {det_type}) but --no-transform is active.\n\n"
|
|
942
|
+
f" Remove --no-transform to enable automatic conversion:\n"
|
|
943
|
+
f" ./do/tune --technique {technique} --dataset hf://{org}/{name} [--column-map ...]\n\n"
|
|
944
|
+
f" Detected format: {strategy_desc}"
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
with open(jsonl_path, "w", encoding="utf-8") as out_f:
|
|
948
|
+
for batch in table.to_batches():
|
|
949
|
+
for row in batch.to_pylist():
|
|
950
|
+
mapped_row = _apply_column_map(row, column_map)
|
|
951
|
+
if chat_columns and not no_transform:
|
|
952
|
+
mapped_row = _flatten_record(mapped_row, chat_columns)
|
|
953
|
+
out_f.write(json_mod.dumps(mapped_row, ensure_ascii=False) + "\n")
|
|
954
|
+
num_records += 1
|
|
955
|
+
|
|
956
|
+
# Upload converted JSONL
|
|
957
|
+
s3_key = f"{s3_prefix}/{jsonl_filename}"
|
|
958
|
+
s3_client.upload_file(jsonl_path, args.output_bucket, s3_key)
|
|
959
|
+
|
|
960
|
+
except ImportError:
|
|
961
|
+
_error_exit(
|
|
962
|
+
"Dataset is in Parquet format but pyarrow is not installed. "
|
|
963
|
+
"Please install: pip install pyarrow"
|
|
964
|
+
)
|
|
965
|
+
else:
|
|
966
|
+
# JSONL file — validate columns and apply mapping
|
|
967
|
+
import json as json_mod
|
|
968
|
+
column_map = _parse_column_map(getattr(args, 'column_map', None))
|
|
969
|
+
technique = getattr(args, 'technique', 'sft')
|
|
970
|
+
no_transform = getattr(args, 'no_transform', False)
|
|
971
|
+
|
|
972
|
+
# Read first line to validate
|
|
973
|
+
chat_columns = {}
|
|
974
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
|
|
975
|
+
first_line = f.readline().strip()
|
|
976
|
+
if first_line:
|
|
977
|
+
first_record = json_mod.loads(first_line)
|
|
978
|
+
_validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
|
|
979
|
+
|
|
980
|
+
# Apply column map to first record for detection
|
|
981
|
+
mapped_first = _apply_column_map(first_record, column_map)
|
|
982
|
+
required_columns = _get_required_columns(technique)
|
|
983
|
+
schema_types = _get_schema_types(technique)
|
|
984
|
+
|
|
985
|
+
# Detect chat-format columns on first record
|
|
986
|
+
chat_columns = _detect_chat_columns(mapped_first, required_columns, schema_types)
|
|
987
|
+
|
|
988
|
+
# Log detection results if any chat columns found
|
|
989
|
+
if chat_columns:
|
|
990
|
+
_log_flatten_info(chat_columns, no_transform)
|
|
991
|
+
|
|
992
|
+
# If --no-transform is active and chat-format detected, halt with error
|
|
993
|
+
if no_transform and chat_columns:
|
|
994
|
+
col_name = next(iter(chat_columns))
|
|
995
|
+
det = chat_columns[col_name]
|
|
996
|
+
det_type = det.get("type")
|
|
997
|
+
strategy = det.get("strategy", "")
|
|
998
|
+
if det_type == "single_dict":
|
|
999
|
+
strategy_desc = "single message dict with role+content"
|
|
1000
|
+
elif strategy == "extract":
|
|
1001
|
+
strategy_desc = "message list (single element)"
|
|
1002
|
+
elif strategy == "same_role":
|
|
1003
|
+
strategy_desc = f"message list ({det.get('count', 0)} messages, same role)"
|
|
1004
|
+
elif strategy == "multi_role":
|
|
1005
|
+
strategy_desc = f"message list (multi-turn, {det.get('count', 0)} messages)"
|
|
1006
|
+
else:
|
|
1007
|
+
strategy_desc = det_type
|
|
1008
|
+
_error_exit(
|
|
1009
|
+
f"Column '{col_name}' contains chat-format data (detected: {det_type}) but --no-transform is active.\n\n"
|
|
1010
|
+
f" Remove --no-transform to enable automatic conversion:\n"
|
|
1011
|
+
f" ./do/tune --technique {technique} --dataset hf://{org}/{name} [--column-map ...]\n\n"
|
|
1012
|
+
f" Detected format: {strategy_desc}"
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
# Rewrite the file with mapped (and optionally flattened) columns
|
|
1016
|
+
should_flatten = bool(chat_columns) and not no_transform
|
|
1017
|
+
if column_map or should_flatten:
|
|
1018
|
+
mapped_path = local_path + ".mapped"
|
|
1019
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
|
|
1020
|
+
open(mapped_path, "w", encoding="utf-8") as f_out:
|
|
1021
|
+
for line in f_in:
|
|
1022
|
+
line = line.strip()
|
|
1023
|
+
if not line:
|
|
1024
|
+
continue
|
|
1025
|
+
record = json_mod.loads(line)
|
|
1026
|
+
mapped_record = _apply_column_map(record, column_map)
|
|
1027
|
+
if should_flatten:
|
|
1028
|
+
mapped_record = _flatten_record(mapped_record, chat_columns)
|
|
1029
|
+
f_out.write(json_mod.dumps(mapped_record, ensure_ascii=False) + "\n")
|
|
1030
|
+
num_records += 1
|
|
1031
|
+
local_path = mapped_path
|
|
1032
|
+
else:
|
|
1033
|
+
# Count records
|
|
1034
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
|
|
1035
|
+
for line in f:
|
|
1036
|
+
if line.strip():
|
|
1037
|
+
num_records += 1
|
|
1038
|
+
|
|
1039
|
+
# Upload to S3
|
|
1040
|
+
s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
|
|
1041
|
+
s3_client.upload_file(local_path, args.output_bucket, s3_key)
|
|
1042
|
+
|
|
1043
|
+
# Use the first file's name for the S3 URI (JSONL extension for Parquet conversions)
|
|
1044
|
+
first_file = data_files[0]
|
|
1045
|
+
if first_file.endswith(".parquet"):
|
|
1046
|
+
output_filename = os.path.splitext(os.path.basename(first_file))[0] + ".jsonl"
|
|
1047
|
+
else:
|
|
1048
|
+
output_filename = os.path.basename(first_file)
|
|
1049
|
+
s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{output_filename}"
|
|
381
1050
|
|
|
382
1051
|
_output({
|
|
383
1052
|
"s3_uri": s3_uri,
|
|
@@ -475,9 +1144,194 @@ def _find_data_files(repo_files, split):
|
|
|
475
1144
|
if data_jsonl:
|
|
476
1145
|
return sorted(data_jsonl)
|
|
477
1146
|
|
|
1147
|
+
# Final fallback: any JSONL/JSON file in the repo root (single-file datasets)
|
|
1148
|
+
root_data = [f for f in repo_files if "/" not in f and (f.endswith(".jsonl") or f.endswith(".json")) and not f.startswith(".")]
|
|
1149
|
+
if root_data:
|
|
1150
|
+
return sorted(root_data)
|
|
1151
|
+
|
|
478
1152
|
return []
|
|
479
1153
|
|
|
480
1154
|
|
|
1155
|
+
def _is_glob_pattern(pattern):
|
|
1156
|
+
"""Return True if pattern contains glob metacharacters (*, ?, [)."""
|
|
1157
|
+
return bool(_GLOB_METACHAR_RE.search(pattern))
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
def _filter_data_files(data_files, pattern):
|
|
1161
|
+
"""Filter data files by glob or substring pattern.
|
|
1162
|
+
|
|
1163
|
+
If the pattern is empty or None, returns all files (no-filter).
|
|
1164
|
+
If the pattern contains glob metacharacters (*, ?, [), uses fnmatch
|
|
1165
|
+
against the full relative path. Otherwise, performs substring match
|
|
1166
|
+
on the basename.
|
|
1167
|
+
|
|
1168
|
+
Args:
|
|
1169
|
+
data_files: List of file paths from _find_data_files
|
|
1170
|
+
pattern: The filter pattern string
|
|
1171
|
+
|
|
1172
|
+
Returns:
|
|
1173
|
+
list: Filtered file paths that match the pattern
|
|
1174
|
+
|
|
1175
|
+
Raises:
|
|
1176
|
+
SystemExit: via _error_exit if no files match (includes available files list)
|
|
1177
|
+
"""
|
|
1178
|
+
if not pattern:
|
|
1179
|
+
return data_files
|
|
1180
|
+
|
|
1181
|
+
if _is_glob_pattern(pattern):
|
|
1182
|
+
matched = [f for f in data_files if fnmatch.fnmatch(f, pattern)]
|
|
1183
|
+
else:
|
|
1184
|
+
matched = [f for f in data_files if pattern in os.path.basename(f)]
|
|
1185
|
+
|
|
1186
|
+
if not matched:
|
|
1187
|
+
file_list = "\n".join(f" • {f}" for f in data_files)
|
|
1188
|
+
_error_exit(
|
|
1189
|
+
f"No files matched pattern '{pattern}'.\n\n"
|
|
1190
|
+
f"Available files:\n{file_list}"
|
|
1191
|
+
)
|
|
1192
|
+
|
|
1193
|
+
return matched
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
def _inspect_file_schemas(data_files, dataset_id, hf_token, tmpdir,
|
|
1197
|
+
column_map, technique, no_transform):
|
|
1198
|
+
"""Inspect first record of each file to extract effective column sets.
|
|
1199
|
+
|
|
1200
|
+
Downloads each file, reads its first record, applies column-map and
|
|
1201
|
+
flattening, then returns the resulting column names.
|
|
1202
|
+
|
|
1203
|
+
Args:
|
|
1204
|
+
data_files: List of file paths to inspect
|
|
1205
|
+
dataset_id: HF dataset identifier for downloads
|
|
1206
|
+
hf_token: Authentication token
|
|
1207
|
+
tmpdir: Temporary directory for downloads
|
|
1208
|
+
column_map: Parsed column mapping dict
|
|
1209
|
+
technique: Technique name for schema types
|
|
1210
|
+
no_transform: Whether --no-transform is active
|
|
1211
|
+
|
|
1212
|
+
Returns:
|
|
1213
|
+
list: [(filename, set_of_column_names), ...] for each file
|
|
1214
|
+
"""
|
|
1215
|
+
from huggingface_hub import hf_hub_download
|
|
1216
|
+
|
|
1217
|
+
required_columns = _get_required_columns(technique)
|
|
1218
|
+
schema_types = _get_schema_types(technique)
|
|
1219
|
+
results = []
|
|
1220
|
+
|
|
1221
|
+
for data_file in data_files:
|
|
1222
|
+
local_path = hf_hub_download(
|
|
1223
|
+
repo_id=dataset_id,
|
|
1224
|
+
filename=data_file,
|
|
1225
|
+
repo_type="dataset",
|
|
1226
|
+
token=hf_token,
|
|
1227
|
+
local_dir=tmpdir,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
first_record = {}
|
|
1231
|
+
|
|
1232
|
+
if data_file.endswith(".parquet"):
|
|
1233
|
+
try:
|
|
1234
|
+
import pyarrow.parquet as pq
|
|
1235
|
+
|
|
1236
|
+
table = pq.read_table(local_path)
|
|
1237
|
+
batches = table.to_batches(max_chunksize=1)
|
|
1238
|
+
if batches:
|
|
1239
|
+
first_record = batches[0].to_pylist()[0]
|
|
1240
|
+
except ImportError:
|
|
1241
|
+
_error_exit(
|
|
1242
|
+
"Dataset is in Parquet format but pyarrow is not installed. "
|
|
1243
|
+
"Please install: pip install pyarrow"
|
|
1244
|
+
)
|
|
1245
|
+
else:
|
|
1246
|
+
import json as json_mod
|
|
1247
|
+
|
|
1248
|
+
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
|
|
1249
|
+
first_line = f.readline().strip()
|
|
1250
|
+
if first_line:
|
|
1251
|
+
first_record = json_mod.loads(first_line)
|
|
1252
|
+
|
|
1253
|
+
# Apply column mapping
|
|
1254
|
+
mapped_record = _apply_column_map(first_record, column_map)
|
|
1255
|
+
|
|
1256
|
+
# Apply flattening if --no-transform is not active
|
|
1257
|
+
if not no_transform:
|
|
1258
|
+
chat_columns = _detect_chat_columns(mapped_record, required_columns, schema_types)
|
|
1259
|
+
if chat_columns:
|
|
1260
|
+
mapped_record = _flatten_record(mapped_record, chat_columns)
|
|
1261
|
+
|
|
1262
|
+
results.append((data_file, set(mapped_record.keys())))
|
|
1263
|
+
|
|
1264
|
+
return results
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def _check_schema_divergence(file_records, dataset_id, technique):
|
|
1268
|
+
"""Check that all files have identical effective columns.
|
|
1269
|
+
|
|
1270
|
+
Args:
|
|
1271
|
+
file_records: List of (filename, first_record_columns) tuples where
|
|
1272
|
+
first_record_columns is the set of column names after
|
|
1273
|
+
column-map and flattening
|
|
1274
|
+
dataset_id: The dataset identifier (for error messages)
|
|
1275
|
+
technique: The technique name (for error messages)
|
|
1276
|
+
|
|
1277
|
+
Returns:
|
|
1278
|
+
None on success (all schemas match)
|
|
1279
|
+
|
|
1280
|
+
Raises:
|
|
1281
|
+
SystemExit: via _error_exit with per-file column listing and
|
|
1282
|
+
?file= remediation suggestion if schemas differ
|
|
1283
|
+
"""
|
|
1284
|
+
if not file_records:
|
|
1285
|
+
return None
|
|
1286
|
+
|
|
1287
|
+
# Compare all column sets to the first file's columns
|
|
1288
|
+
first_columns = file_records[0][1]
|
|
1289
|
+
all_identical = all(cols == first_columns for _, cols in file_records)
|
|
1290
|
+
|
|
1291
|
+
if all_identical:
|
|
1292
|
+
return None
|
|
1293
|
+
|
|
1294
|
+
# Build per-file column listing
|
|
1295
|
+
file_sections = []
|
|
1296
|
+
for filename, columns in file_records:
|
|
1297
|
+
sorted_cols = ", ".join(sorted(columns))
|
|
1298
|
+
file_sections.append(
|
|
1299
|
+
f" \U0001f4c4 {filename}\n"
|
|
1300
|
+
f" Columns: {sorted_cols}"
|
|
1301
|
+
)
|
|
1302
|
+
|
|
1303
|
+
# Derive remediation pattern from first file's basename
|
|
1304
|
+
first_file = file_records[0][0]
|
|
1305
|
+
basename = os.path.basename(first_file)
|
|
1306
|
+
# Strip extension and wrap with wildcards for a useful pattern
|
|
1307
|
+
name_without_ext = os.path.splitext(basename)[0]
|
|
1308
|
+
# Use a distinctive portion — take the first numeric segment if present
|
|
1309
|
+
import re as _re
|
|
1310
|
+
numeric_match = _re.search(r'\d+', name_without_ext)
|
|
1311
|
+
if numeric_match:
|
|
1312
|
+
pattern_suggestion = f"*{numeric_match.group()}*"
|
|
1313
|
+
else:
|
|
1314
|
+
pattern_suggestion = f"*{name_without_ext}*"
|
|
1315
|
+
|
|
1316
|
+
# Build available files list
|
|
1317
|
+
available_files = "\n".join(
|
|
1318
|
+
f" \u2022 {filename}" for filename, _ in file_records
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# Build the full error message
|
|
1322
|
+
file_listing = "\n\n".join(file_sections)
|
|
1323
|
+
message = (
|
|
1324
|
+
f"Schema divergence detected in dataset {dataset_id}.\n"
|
|
1325
|
+
f"Files have different columns after applying column-map and transforms:\n\n"
|
|
1326
|
+
f"{file_listing}\n\n"
|
|
1327
|
+
f"\U0001f4a1 Use ?file=<pattern> to select compatible files:\n"
|
|
1328
|
+
f" ./do/tune --technique {technique} --dataset hf://{dataset_id}?file={pattern_suggestion}\n\n"
|
|
1329
|
+
f" Available files:\n{available_files}"
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
_error_exit(message)
|
|
1333
|
+
|
|
1334
|
+
|
|
481
1335
|
# ── Subcommand: validate ──────────────────────────────────────────────────────
|
|
482
1336
|
|
|
483
1337
|
|
|
@@ -648,6 +1502,53 @@ def _build_expected_format(schema):
|
|
|
648
1502
|
return "Each line must be a JSON object with: {" + ", ".join(fields) + "}"
|
|
649
1503
|
|
|
650
1504
|
|
|
1505
|
+
# ── Subcommand: discover ──────────────────────────────────────────────────────
|
|
1506
|
+
|
|
1507
|
+
|
|
1508
|
+
def cmd_discover(args):
|
|
1509
|
+
"""Query JumpStart Hub for tune-eligible models matching a family.
|
|
1510
|
+
|
|
1511
|
+
Returns: {"models": [str], "count": int}
|
|
1512
|
+
"""
|
|
1513
|
+
import boto3
|
|
1514
|
+
|
|
1515
|
+
region = args.region or os.environ.get('AWS_REGION', 'us-east-1')
|
|
1516
|
+
|
|
1517
|
+
family = args.family or ""
|
|
1518
|
+
# Map family names to Hub content name prefixes
|
|
1519
|
+
FAMILY_PREFIX_MAP = {
|
|
1520
|
+
"qwen-2.5": "huggingface-llm-qwen2-5",
|
|
1521
|
+
"qwen-3": "huggingface-reasoning-qwen3",
|
|
1522
|
+
"llama-3": "meta-textgeneration-llama-3",
|
|
1523
|
+
"deepseek-r1": "deepseek-llm-r1-distill",
|
|
1524
|
+
"gpt-oss": "openai-reasoning-gpt-oss",
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
prefix = FAMILY_PREFIX_MAP.get(family, args.filter or "")
|
|
1528
|
+
if not prefix:
|
|
1529
|
+
_error_exit("No family or filter provided for discovery")
|
|
1530
|
+
|
|
1531
|
+
try:
|
|
1532
|
+
client = boto3.client("sagemaker", region_name=region)
|
|
1533
|
+
models = []
|
|
1534
|
+
paginator = client.get_paginator('list_hub_contents')
|
|
1535
|
+
pages = paginator.paginate(
|
|
1536
|
+
HubName="SageMakerPublicHub",
|
|
1537
|
+
HubContentType="Model",
|
|
1538
|
+
NameContains=prefix,
|
|
1539
|
+
MaxResults=20
|
|
1540
|
+
)
|
|
1541
|
+
for page in pages:
|
|
1542
|
+
for item in page.get('HubContentSummaries', []):
|
|
1543
|
+
if item.get('HubContentStatus') == 'Available':
|
|
1544
|
+
models.append(item['HubContentName'])
|
|
1545
|
+
|
|
1546
|
+
_output({"models": models[:5], "count": len(models)})
|
|
1547
|
+
|
|
1548
|
+
except Exception as e:
|
|
1549
|
+
_error_exit(f"Hub discovery failed: {e}")
|
|
1550
|
+
|
|
1551
|
+
|
|
651
1552
|
# ── CLI argument parsing ──────────────────────────────────────────────────────
|
|
652
1553
|
|
|
653
1554
|
|
|
@@ -661,6 +1562,8 @@ def main():
|
|
|
661
1562
|
# ── submit ────────────────────────────────────────────────────────────────
|
|
662
1563
|
submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
|
|
663
1564
|
submit_parser.add_argument("--model-id", required=True, help="Model ID")
|
|
1565
|
+
submit_parser.add_argument("--region", default=None,
|
|
1566
|
+
help="AWS region (defaults to AWS_REGION env var)")
|
|
664
1567
|
submit_parser.add_argument("--technique", required=True,
|
|
665
1568
|
choices=["sft", "dpo", "rlaif", "rlvr"],
|
|
666
1569
|
help="Customization technique")
|
|
@@ -695,6 +1598,8 @@ def main():
|
|
|
695
1598
|
help="Lambda ARN for reward function (RLVR)")
|
|
696
1599
|
submit_parser.add_argument("--reward-prompt", default=None,
|
|
697
1600
|
help="S3 URI for reward prompt (RLAIF)")
|
|
1601
|
+
submit_parser.add_argument("--accept-eula", action="store_true", default=False,
|
|
1602
|
+
help="Accept model EULA for gated models (e.g., Llama)")
|
|
698
1603
|
|
|
699
1604
|
# ── status ────────────────────────────────────────────────────────────────
|
|
700
1605
|
status_parser = subparsers.add_parser("status", help="Get job status and metrics")
|
|
@@ -725,6 +1630,8 @@ def main():
|
|
|
725
1630
|
help="Hugging Face dataset name")
|
|
726
1631
|
stage_hf_parser.add_argument("--hf-split", default="train",
|
|
727
1632
|
help="Dataset split (default: train)")
|
|
1633
|
+
stage_hf_parser.add_argument("--hf-file", default=None,
|
|
1634
|
+
help="File filter pattern (glob or substring)")
|
|
728
1635
|
stage_hf_parser.add_argument("--output-bucket", required=True,
|
|
729
1636
|
help="S3 bucket for staged dataset")
|
|
730
1637
|
stage_hf_parser.add_argument("--project-name", required=True,
|
|
@@ -733,6 +1640,13 @@ def main():
|
|
|
733
1640
|
help="AWS region")
|
|
734
1641
|
stage_hf_parser.add_argument("--hf-secret-name", default=None,
|
|
735
1642
|
help="Secrets Manager secret name for HF token")
|
|
1643
|
+
stage_hf_parser.add_argument("--column-map", default=None,
|
|
1644
|
+
help="Column mapping (e.g., prompt=question,completion=answer)")
|
|
1645
|
+
stage_hf_parser.add_argument("--technique", default="sft",
|
|
1646
|
+
choices=["sft", "dpo", "rlaif", "rlvr"],
|
|
1647
|
+
help="Customization technique (determines required columns)")
|
|
1648
|
+
stage_hf_parser.add_argument("--no-transform", action="store_true", default=False,
|
|
1649
|
+
help="Disable automatic chat-format flattening")
|
|
736
1650
|
|
|
737
1651
|
# ── validate ──────────────────────────────────────────────────────────────
|
|
738
1652
|
validate_parser = subparsers.add_parser("validate",
|
|
@@ -742,6 +1656,16 @@ def main():
|
|
|
742
1656
|
validate_parser.add_argument("--file", default="-",
|
|
743
1657
|
help="Path to dataset file (default: stdin)")
|
|
744
1658
|
|
|
1659
|
+
# ── discover ──────────────────────────────────────────────────────────────
|
|
1660
|
+
discover_parser = subparsers.add_parser("discover",
|
|
1661
|
+
help="Discover tune-eligible models from JumpStart Hub")
|
|
1662
|
+
discover_parser.add_argument("--family", default="",
|
|
1663
|
+
help="Model family name (e.g., qwen-3, llama-3, deepseek-r1)")
|
|
1664
|
+
discover_parser.add_argument("--filter", default="",
|
|
1665
|
+
help="Hub content name prefix filter (overrides family mapping)")
|
|
1666
|
+
discover_parser.add_argument("--region", default="",
|
|
1667
|
+
help="AWS region (default: AWS_REGION env or us-east-1)")
|
|
1668
|
+
|
|
745
1669
|
# ── Parse and dispatch ────────────────────────────────────────────────────
|
|
746
1670
|
args = parser.parse_args()
|
|
747
1671
|
|
|
@@ -755,6 +1679,7 @@ def main():
|
|
|
755
1679
|
"resolve": cmd_resolve,
|
|
756
1680
|
"stage-hf": cmd_stage_hf,
|
|
757
1681
|
"validate": cmd_validate,
|
|
1682
|
+
"discover": cmd_discover,
|
|
758
1683
|
}
|
|
759
1684
|
|
|
760
1685
|
handler = command_map.get(args.command)
|