@aws/ml-container-creator 0.9.1 → 0.10.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. package/LICENSE-THIRD-PARTY +9304 -0
  2. package/bin/cli.js +2 -0
  3. package/config/bootstrap-e2e-stack.json +341 -0
  4. package/config/bootstrap-stack.json +40 -3
  5. package/config/parameter-schema-v2.json +2049 -0
  6. package/config/tune-catalog.json +1781 -0
  7. package/infra/ci-harness/buildspec.yml +1 -0
  8. package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
  9. package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
  10. package/infra/ci-harness/lib/ci-harness-stack.ts +837 -7
  11. package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
  12. package/package.json +53 -68
  13. package/servers/base-image-picker/index.js +121 -121
  14. package/servers/e2e-status/index.js +297 -0
  15. package/servers/e2e-status/manifest.json +14 -0
  16. package/servers/e2e-status/package.json +15 -0
  17. package/servers/endpoint-picker/LICENSE +202 -0
  18. package/servers/endpoint-picker/index.js +536 -0
  19. package/servers/endpoint-picker/manifest.json +14 -0
  20. package/servers/endpoint-picker/package.json +18 -0
  21. package/servers/hyperpod-cluster-picker/index.js +125 -125
  22. package/servers/instance-sizer/index.js +138 -138
  23. package/servers/instance-sizer/lib/instance-ranker.js +76 -76
  24. package/servers/instance-sizer/lib/model-resolver.js +61 -61
  25. package/servers/instance-sizer/lib/quota-resolver.js +113 -113
  26. package/servers/instance-sizer/lib/vram-estimator.js +31 -31
  27. package/servers/lib/bedrock-client.js +38 -38
  28. package/servers/lib/catalogs/jumpstart-public.json +101 -16
  29. package/servers/lib/catalogs/model-servers.json +201 -3
  30. package/servers/lib/catalogs/models.json +182 -26
  31. package/servers/lib/custom-validators.js +13 -13
  32. package/servers/lib/dynamic-resolver.js +4 -4
  33. package/servers/marketplace-picker/index.js +342 -0
  34. package/servers/marketplace-picker/manifest.json +14 -0
  35. package/servers/marketplace-picker/package.json +18 -0
  36. package/servers/model-picker/index.js +382 -382
  37. package/servers/region-picker/index.js +56 -56
  38. package/servers/workload-picker/LICENSE +202 -0
  39. package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
  40. package/servers/workload-picker/index.js +171 -0
  41. package/servers/workload-picker/manifest.json +16 -0
  42. package/servers/workload-picker/package.json +16 -0
  43. package/src/app.js +4 -390
  44. package/src/lib/bootstrap-command-handler.js +710 -1148
  45. package/src/lib/bootstrap-config.js +36 -0
  46. package/src/lib/bootstrap-profile-manager.js +641 -0
  47. package/src/lib/bootstrap-provisioners.js +421 -0
  48. package/src/lib/ci-register-helpers.js +74 -0
  49. package/src/lib/config-loader.js +408 -0
  50. package/src/lib/config-manager.js +66 -1685
  51. package/src/lib/config-mcp-client.js +118 -0
  52. package/src/lib/config-validator.js +634 -0
  53. package/src/lib/cuda-resolver.js +149 -0
  54. package/src/lib/e2e-catalog-validator.js +251 -3
  55. package/src/lib/e2e-ci-recorder.js +103 -0
  56. package/src/lib/generated/cli-options.js +315 -311
  57. package/src/lib/generated/parameter-matrix.js +671 -0
  58. package/src/lib/generated/validation-rules.js +71 -71
  59. package/src/lib/marketplace-flow.js +276 -0
  60. package/src/lib/mcp-query-runner.js +768 -0
  61. package/src/lib/parameter-schema-validator.js +62 -18
  62. package/src/lib/path-prover-brain.js +607 -0
  63. package/src/lib/prompt-runner.js +41 -1504
  64. package/src/lib/prompts/feature-prompts.js +172 -0
  65. package/src/lib/prompts/index.js +48 -0
  66. package/src/lib/prompts/infrastructure-prompts.js +690 -0
  67. package/src/lib/prompts/model-prompts.js +552 -0
  68. package/src/lib/prompts/project-prompts.js +82 -0
  69. package/src/lib/prompts.js +2 -1446
  70. package/src/lib/registry-command-handler.js +135 -3
  71. package/src/lib/secrets-prompt-runner.js +251 -0
  72. package/src/lib/template-variable-resolver.js +422 -0
  73. package/src/lib/tune-catalog-validator.js +37 -4
  74. package/templates/Dockerfile +9 -0
  75. package/templates/code/adapter_sidecar.py +444 -0
  76. package/templates/code/serve +6 -0
  77. package/templates/code/serve.d/vllm.ejs +1 -1
  78. package/templates/do/.benchmark_writer.py +1476 -0
  79. package/templates/do/.tune_helper.py +982 -57
  80. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  81. package/templates/do/adapter +149 -0
  82. package/templates/do/benchmark +639 -85
  83. package/templates/do/config +108 -5
  84. package/templates/do/deploy.d/managed-inference.ejs +192 -11
  85. package/templates/do/optimize +106 -37
  86. package/templates/do/register +89 -0
  87. package/templates/do/test +13 -0
  88. package/templates/do/tune +378 -59
  89. package/templates/do/validate +44 -4
  90. package/config/parameter-schema.json +0 -88
@@ -10,30 +10,44 @@ Subcommands:
10
10
  resolve - Resolve output artifact path from job
11
11
  stage-hf - Download HF dataset to S3
12
12
  validate - Validate dataset format against schema
13
+ discover - Discover tune-eligible models from JumpStart Hub
13
14
 
14
15
  All output is JSON on stdout for bash consumption.
15
16
  """
16
17
 
17
18
  import argparse
19
+ import fnmatch
18
20
  import json
19
21
  import os
22
+ import re
20
23
  import sys
21
24
  import time
25
+ import warnings
26
+
27
+ # Suppress noisy dependency version warnings from requests/urllib3
28
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
29
+ warnings.filterwarnings("ignore", message=".*urllib3.*")
30
+ warnings.filterwarnings("ignore", message=".*charset_normalizer.*")
22
31
 
23
32
  # ── Inline dependency check ───────────────────────────────────────────────────
24
- MIN_SAGEMAKER_VERSION = "2.232.0"
33
+ MIN_SAGEMAKER_VERSION = "3.0"
34
+
35
+ _GLOB_METACHAR_RE = re.compile(r'[*?\[]')
25
36
 
26
37
 
27
38
  def _check_sagemaker_sdk():
28
39
  """Verify sagemaker SDK is installed with minimum version."""
29
40
  try:
30
41
  import sagemaker # noqa: F401
42
+ # SDK v3 removed __version__; use importlib.metadata instead
43
+ from importlib.metadata import version as pkg_version
31
44
  from packaging.version import Version
32
- if Version(sagemaker.__version__) < Version(MIN_SAGEMAKER_VERSION):
45
+ installed = pkg_version("sagemaker")
46
+ if Version(installed) < Version(MIN_SAGEMAKER_VERSION):
33
47
  _error_exit(
34
- f"sagemaker SDK version {sagemaker.__version__} is below minimum "
48
+ f"sagemaker SDK version {installed} is below minimum "
35
49
  f"required version {MIN_SAGEMAKER_VERSION}. "
36
- f"Please upgrade: pip install 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
50
+ f"Please upgrade: pip install --upgrade 'sagemaker>={MIN_SAGEMAKER_VERSION}'"
37
51
  )
38
52
  except ImportError:
39
53
  _error_exit(
@@ -65,11 +79,37 @@ def cmd_submit(args):
65
79
 
66
80
  Returns: {"job_name": str, "job_arn": str, "mlflow_url": str|None}
67
81
  """
82
+ # Suppress SDK rich logging that pollutes stdout (we only want JSON output)
83
+ import logging
84
+ logging.disable(logging.CRITICAL)
85
+ os.environ["SAGEMAKER_LOG_LEVEL"] = "CRITICAL"
86
+
87
+ # Ensure region is set before ANY sagemaker import (v3 creates boto3 clients at import time)
88
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
89
+ if region:
90
+ os.environ["AWS_DEFAULT_REGION"] = region
91
+ os.environ.setdefault("AWS_REGION", region)
92
+
68
93
  _check_sagemaker_sdk()
69
94
 
70
- from sagemaker.modules.train.sft_trainer import SFTTrainer
71
- from sagemaker.modules.train.dpo_trainer import DPOTrainer
72
- from sagemaker.modules.train.common import TrainingType
95
+ # SDK v3 moved trainers from sagemaker.modules.train sagemaker.train
96
+ # Note: catch Exception (not just ImportError) because SDK v3 AIRHub
97
+ # creates boto3 clients at class-definition time, which can raise
98
+ # NoRegionError if AWS_DEFAULT_REGION is not set despite our best efforts.
99
+ try:
100
+ from sagemaker.train.sft_trainer import SFTTrainer
101
+ from sagemaker.train.dpo_trainer import DPOTrainer
102
+ from sagemaker.train.common import TrainingType
103
+ except Exception:
104
+ try:
105
+ from sagemaker.modules.train.sft_trainer import SFTTrainer
106
+ from sagemaker.modules.train.dpo_trainer import DPOTrainer
107
+ from sagemaker.modules.train.common import TrainingType
108
+ except Exception:
109
+ _error_exit(
110
+ "SFTTrainer not found. Requires sagemaker>=3.0. "
111
+ "Install: pip install --upgrade 'sagemaker>=3.0'"
112
+ )
73
113
 
74
114
  # Technique → Trainer class mapping
75
115
  TRAINER_MAP = {
@@ -88,63 +128,164 @@ def cmd_submit(args):
88
128
  # Resolve training type
89
129
  training_type_map = {
90
130
  "lora": TrainingType.LORA,
91
- "full-rank": TrainingType.FULL_RANK,
131
+ "full-rank": getattr(TrainingType, 'FULL_RANK', None) or getattr(TrainingType, 'FULL', None),
92
132
  }
93
133
  training_type = training_type_map.get(args.training_type)
94
134
  if not training_type:
95
135
  _error_exit(f"Unsupported training type: {args.training_type}")
96
136
 
97
137
  # Build hyperparameters dict from optional overrides
138
+ # Map CLI flag names to SDK v3 fine-tuning option names
98
139
  hyperparameters = {}
99
140
  if args.epochs is not None:
100
- hyperparameters["epochs"] = args.epochs
141
+ hyperparameters["max_epochs"] = args.epochs
101
142
  if args.learning_rate is not None:
102
143
  hyperparameters["learning_rate"] = args.learning_rate
103
144
  if args.max_seq_length is not None:
104
- hyperparameters["max_seq_length"] = args.max_seq_length
145
+ hyperparameters["dataset_max_len"] = args.max_seq_length
105
146
  if args.lora_rank is not None:
106
147
  hyperparameters["lora_rank"] = args.lora_rank
107
148
  if args.lora_alpha is not None:
108
149
  hyperparameters["lora_alpha"] = args.lora_alpha
109
150
  if args.batch_size is not None:
110
- hyperparameters["batch_size"] = args.batch_size
111
-
112
- # Build trainer kwargs
113
- trainer_kwargs = {
114
- "model_id": args.model_id,
115
- "training_type": training_type,
116
- "train_data_uri": args.dataset_s3_uri,
117
- "output_path": f"s3://{args.output_bucket}/{args.project_name}/tune/{technique}/",
118
- "role": args.role_arn,
119
- "job_name": args.job_name,
120
- }
151
+ hyperparameters["global_batch_size"] = args.batch_size
121
152
 
122
- # Add model package group for artifact registration
123
- if args.model_package_group:
124
- trainer_kwargs["model_package_group_name"] = args.model_package_group
153
+ # Build trainer kwargs API differs between SDK v2 and v3
154
+ output_path = f"s3://{args.output_bucket}/{args.project_name}/tune/{technique}/"
125
155
 
126
- # Add hyperparameters if any were specified
127
- if hyperparameters:
128
- trainer_kwargs["hyperparameters"] = hyperparameters
156
+ # Detect SDK version to use appropriate API
157
+ sdk_v3 = hasattr(trainer_cls, 'role') # v3 trainers have role as a settable attribute
129
158
 
130
- # Add evaluator config for RLVR/RLAIF techniques
131
- if technique in ("rlvr", "rlaif"):
132
- if args.reward_function:
133
- trainer_kwargs["evaluator_config"] = {
134
- "reward_function_arn": args.reward_function
159
+ try:
160
+ if sdk_v3:
161
+ # SDK v3 API: positional model, keyword training_dataset, s3_output_path
162
+ trainer_kwargs = {
163
+ "model": args.model_id,
164
+ "training_type": training_type,
165
+ "training_dataset": args.dataset_s3_uri,
166
+ "s3_output_path": output_path,
135
167
  }
136
- elif args.reward_prompt:
137
- trainer_kwargs["evaluator_config"] = {
138
- "reward_prompt_s3_uri": args.reward_prompt
168
+ # Accept EULA for gated models (e.g., Meta Llama)
169
+ # SDK v3.12+ accepts accept_eula as a constructor parameter
170
+ if args.accept_eula:
171
+ trainer_kwargs["accept_eula"] = True
172
+
173
+ # Resolve model package group — create if it doesn't exist
174
+ mpg_name = args.model_package_group or f"{args.project_name}-tune-models"
175
+ try:
176
+ import boto3 as _boto3
177
+ _sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
178
+ _sm.describe_model_package_group(ModelPackageGroupName=mpg_name)
179
+ except Exception as _mpg_err:
180
+ if "does not exist" in str(_mpg_err) or "ValidationException" in str(_mpg_err):
181
+ try:
182
+ _sm.create_model_package_group(
183
+ ModelPackageGroupName=mpg_name,
184
+ ModelPackageGroupDescription=f"Fine-tuned models for {args.project_name}",
185
+ )
186
+ except Exception:
187
+ pass # May already exist or lack permissions — let the trainer handle it
188
+ trainer_kwargs["model_package_group"] = mpg_name
189
+
190
+ trainer = trainer_cls(**trainer_kwargs)
191
+ trainer.role = args.role_arn
192
+ trainer.base_job_name = args.job_name
193
+ if hyperparameters:
194
+ # SDK v3 expects hyperparameters with a .to_dict() method
195
+ # Wrap our plain dict to satisfy the interface
196
+ hp_obj = trainer.hyperparameters
197
+ if hp_obj is not None and hasattr(hp_obj, '__dict__'):
198
+ for k, v in hyperparameters.items():
199
+ setattr(hp_obj, k, v)
200
+ else:
201
+ # Fallback: create a simple wrapper
202
+ class _HyperParams:
203
+ def __init__(self, d):
204
+ self._data = d
205
+ for k, v in d.items():
206
+ setattr(self, k, v)
207
+ def to_dict(self):
208
+ return {k: v for k, v in self._data.items() if v is not None}
209
+ trainer.hyperparameters = _HyperParams(hyperparameters)
210
+
211
+ # Use MLCC-owned MLflow app if available (avoids permission issues with Studio apps)
212
+ mlflow_arn = os.environ.get('MLFLOW_APP_ARN', '')
213
+ if mlflow_arn:
214
+ trainer.mlflow_resource_arn = mlflow_arn
215
+
216
+ # Suppress SDK print() output (e.g., "Training Job Name: ...")
217
+ # that pollutes stdout and breaks JSON parsing by the shell script
218
+ import io as _io
219
+ _orig_stdout = sys.stdout
220
+ sys.stdout = _io.StringIO()
221
+ try:
222
+ trainer.train(training_dataset=args.dataset_s3_uri, wait=False)
223
+ finally:
224
+ sys.stdout = _orig_stdout
225
+ else:
226
+ # SDK v2 API: model_id, train_data_uri, output_path, role, job_name
227
+ trainer_kwargs = {
228
+ "model_id": args.model_id,
229
+ "training_type": training_type,
230
+ "train_data_uri": args.dataset_s3_uri,
231
+ "output_path": output_path,
232
+ "role": args.role_arn,
233
+ "job_name": args.job_name,
139
234
  }
140
-
141
- try:
142
- trainer = trainer_cls(**trainer_kwargs)
143
- trainer.train(wait=False)
235
+ if args.model_package_group:
236
+ trainer_kwargs["model_package_group_name"] = args.model_package_group
237
+ if hyperparameters:
238
+ trainer_kwargs["hyperparameters"] = hyperparameters
239
+
240
+ # Add evaluator config for RLVR/RLAIF techniques
241
+ if technique in ("rlvr", "rlaif"):
242
+ if args.reward_function:
243
+ trainer_kwargs["evaluator_config"] = {"reward_function_arn": args.reward_function}
244
+ elif args.reward_prompt:
245
+ trainer_kwargs["evaluator_config"] = {"reward_prompt_s3_uri": args.reward_prompt}
246
+
247
+ # Accept EULA for gated models (e.g., Meta Llama)
248
+ if args.accept_eula:
249
+ trainer_kwargs["accept_eula"] = True
250
+
251
+ trainer = trainer_cls(**trainer_kwargs)
252
+ # Suppress SDK print() output that pollutes stdout
253
+ import io as _io
254
+ _orig_stdout = sys.stdout
255
+ sys.stdout = _io.StringIO()
256
+ try:
257
+ trainer.train(wait=False)
258
+ finally:
259
+ sys.stdout = _orig_stdout
144
260
 
145
261
  # Extract job info from the trainer
146
- job_name = trainer.training_job_name
262
+ job_name = getattr(trainer, 'training_job_name', None) or getattr(trainer, 'base_job_name', None)
147
263
  job_arn = getattr(trainer, "training_job_arn", None)
264
+ latest_job = getattr(trainer, 'latest_training_job', None)
265
+ if latest_job:
266
+ job_name = job_name or getattr(latest_job, 'name', None) or getattr(latest_job, 'job_name', None)
267
+ job_arn = job_arn or getattr(latest_job, 'arn', None)
268
+
269
+ # If we still don't have the actual job name (SDK appends suffix),
270
+ # query ListTrainingJobs to find it by our base_job_name prefix
271
+ if not job_name or job_name == args.job_name:
272
+ import boto3 as _boto3
273
+ _sm = _boto3.client("sagemaker", region_name=args.region or os.environ.get("AWS_REGION", "us-west-2"))
274
+ try:
275
+ # Brief delay to allow job to register
276
+ time.sleep(2)
277
+ list_resp = _sm.list_training_jobs(
278
+ NameContains=args.job_name,
279
+ SortBy="CreationTime",
280
+ SortOrder="Descending",
281
+ MaxResults=1,
282
+ )
283
+ summaries = list_resp.get("TrainingJobSummaries", [])
284
+ if summaries:
285
+ job_name = summaries[0]["TrainingJobName"]
286
+ job_arn = summaries[0].get("TrainingJobArn", job_arn)
287
+ except Exception:
288
+ pass # Fall back to whatever we have
148
289
 
149
290
  # Attempt to get MLflow URL if available
150
291
  mlflow_url = None
@@ -154,7 +295,7 @@ def cmd_submit(args):
154
295
  pass
155
296
 
156
297
  _output({
157
- "job_name": job_name,
298
+ "job_name": job_name or args.job_name,
158
299
  "job_arn": job_arn or "",
159
300
  "mlflow_url": mlflow_url,
160
301
  "model_package_group": args.model_package_group or "",
@@ -176,8 +317,15 @@ def cmd_submit(args):
176
317
  )
177
318
  elif "ValidationException" in error_msg and "license" in error_msg.lower():
178
319
  _error_exit(
179
- f"Model license not accepted. Accept the model license before "
180
- f"using this model for customization. Details: {error_msg}"
320
+ f"Model requires EULA acceptance. Re-run with --accept-eula flag: "
321
+ f"./do/tune --technique {technique} --accept-eula ... "
322
+ f"Details: {error_msg}"
323
+ )
324
+ elif "ValidationException" in error_msg and "eula" in error_msg.lower():
325
+ _error_exit(
326
+ f"Model requires EULA acceptance. Re-run with --accept-eula flag: "
327
+ f"./do/tune --technique {technique} --accept-eula ... "
328
+ f"Details: {error_msg}"
181
329
  )
182
330
  else:
183
331
  _error_exit(f"Failed to submit training job: {error_msg}")
@@ -189,6 +337,9 @@ def cmd_submit(args):
189
337
  def cmd_status(args):
190
338
  """Query job status via DescribeTrainingJob.
191
339
 
340
+ Falls back to ListTrainingJobs with name-contains if exact name not found
341
+ (SDK v3 appends a timestamp suffix to the base job name).
342
+
192
343
  Returns: {"status": str, "failure_reason": str|None,
193
344
  "metrics": dict|None, "elapsed_seconds": int}
194
345
  """
@@ -196,16 +347,36 @@ def cmd_status(args):
196
347
 
197
348
  client = boto3.client("sagemaker", region_name=args.region)
198
349
 
350
+ # Try exact name first
351
+ response = None
199
352
  try:
200
353
  response = client.describe_training_job(TrainingJobName=args.job_name)
201
354
  except client.exceptions.ClientError as e:
202
355
  error_code = e.response["Error"]["Code"]
203
- if error_code == "ValidationException":
204
- _error_exit(f"Training job not found: {args.job_name}")
205
- _error_exit(f"Failed to describe training job: {e}")
356
+ if error_code != "ValidationException":
357
+ _error_exit(f"Failed to describe training job: {e}")
358
+ # Job not found by exact name — try name-contains search
206
359
  except Exception as e:
207
360
  _error_exit(f"Failed to describe training job: {e}")
208
361
 
362
+ # Fallback: search by name prefix (SDK appends timestamp suffix)
363
+ if response is None:
364
+ try:
365
+ list_response = client.list_training_jobs(
366
+ NameContains=args.job_name,
367
+ SortBy="CreationTime",
368
+ SortOrder="Descending",
369
+ MaxResults=1,
370
+ )
371
+ summaries = list_response.get("TrainingJobSummaries", [])
372
+ if summaries:
373
+ actual_name = summaries[0]["TrainingJobName"]
374
+ response = client.describe_training_job(TrainingJobName=actual_name)
375
+ else:
376
+ _error_exit(f"Training job not found: {args.job_name}")
377
+ except Exception as e:
378
+ _error_exit(f"Failed to find training job: {e}")
379
+
209
380
  status = response.get("TrainingJobStatus", "Unknown")
210
381
  failure_reason = response.get("FailureReason")
211
382
 
@@ -278,6 +449,14 @@ def cmd_resolve(args):
278
449
  # Determine output type from training type
279
450
  output_type = "adapter" if args.training_type == "lora" else "full-model"
280
451
 
452
+ # For LoRA adapters, the actual adapter files are in checkpoints/hf/ subdirectory
453
+ # The S3ModelArtifacts path points to the top-level output directory
454
+ if output_type == "adapter":
455
+ # Ensure trailing slash for directory path
456
+ if not artifact_path.endswith("/"):
457
+ artifact_path += "/"
458
+ artifact_path += "checkpoints/hf/"
459
+
281
460
  # Try to find model package ARN if a model package group was used
282
461
  model_package_arn = None
283
462
  if args.model_package_group:
@@ -306,6 +485,332 @@ def cmd_resolve(args):
306
485
  # ── Subcommand: stage-hf ─────────────────────────────────────────────────────
307
486
 
308
487
 
488
+ def _get_required_columns(technique):
489
+ """Return the required column names for a given technique."""
490
+ schemas = {
491
+ "sft": ["prompt", "completion"],
492
+ "dpo": ["prompt", "chosen", "rejected"],
493
+ "rlaif": ["prompt"], # prompt is an array of messages
494
+ "rlvr": ["prompt"], # prompt is an array of messages
495
+ }
496
+ return schemas.get(technique, ["prompt", "completion"])
497
+
498
+
499
+ def _suggest_column_map(detected_columns, required_columns):
500
+ """Suggest a --column-map based on common column name patterns."""
501
+ # Common aliases for each required field
502
+ aliases = {
503
+ "prompt": ["question", "instruction", "input", "query", "text", "context", "user", "human"],
504
+ "completion": ["answer", "output", "response", "assistant", "target", "label", "reply"],
505
+ "chosen": ["chosen", "preferred", "good", "positive", "accepted"],
506
+ "rejected": ["rejected", "dispreferred", "bad", "negative", "refused"],
507
+ }
508
+
509
+ suggestions = {}
510
+ for req_col in required_columns:
511
+ if req_col in detected_columns:
512
+ continue # Already present
513
+ # Check aliases
514
+ for alias in aliases.get(req_col, []):
515
+ if alias in detected_columns:
516
+ suggestions[req_col] = alias
517
+ break
518
+
519
+ if not suggestions:
520
+ return None
521
+
522
+ # Format as --column-map string
523
+ mapping_str = ",".join(f"{k}={v}" for k, v in suggestions.items())
524
+ return mapping_str
525
+
526
+
527
+ def _parse_column_map(column_map_str):
528
+ """Parse a column map string like 'prompt=question,completion=answer' into a dict."""
529
+ if not column_map_str:
530
+ return {}
531
+ mapping = {}
532
+ for pair in column_map_str.split(","):
533
+ pair = pair.strip()
534
+ if "=" not in pair:
535
+ continue
536
+ target, source = pair.split("=", 1)
537
+ mapping[target.strip()] = source.strip()
538
+ return mapping
539
+
540
+
541
+ def _apply_column_map(record, column_map):
542
+ """Apply column mapping to a record: rename source columns to target names."""
543
+ if not column_map:
544
+ return record
545
+ mapped = dict(record)
546
+ for target, source in column_map.items():
547
+ if source in mapped and target not in mapped:
548
+ mapped[target] = mapped.pop(source)
549
+ return mapped
550
+
551
+
552
+ def _detect_chat_columns(record, required_columns, schema_types):
553
+ """Detect which required columns contain chat-format data.
554
+
555
+ Only inspects columns whose schema type is "string". Columns with
556
+ "array" type (RLAIF/RLVR) are excluded from detection entirely.
557
+
558
+ Args:
559
+ record: The first record (dict) after column mapping
560
+ required_columns: List of required column names for the technique
561
+ schema_types: Dict mapping column name -> expected type from schema
562
+
563
+ Returns:
564
+ dict: Maps column_name -> detection_result where detection_result is:
565
+ {"type": "single_dict"} or
566
+ {"type": "message_list", "strategy": "extract"|"same_role"|"multi_role", "count": int}
567
+ Only columns detected as chat-format are included.
568
+ """
569
+ results = {}
570
+ for column in required_columns:
571
+ # Only inspect columns whose schema type is "string"
572
+ if schema_types.get(column) != "string":
573
+ continue
574
+
575
+ # Skip if column is not present in the record
576
+ if column not in record:
577
+ continue
578
+
579
+ value = record[column]
580
+
581
+ # Check for Single_Message_Dict: dict with both "role" and "content" keys
582
+ if isinstance(value, dict) and "role" in value and "content" in value:
583
+ results[column] = {"type": "single_dict"}
584
+ continue
585
+
586
+ # Check for Message_List: non-empty list whose first element is a dict
587
+ # with both "role" and "content" keys
588
+ if isinstance(value, list) and len(value) > 0:
589
+ first_element = value[0]
590
+ if isinstance(first_element, dict) and "role" in first_element and "content" in first_element:
591
+ count = len(value)
592
+ if count == 1:
593
+ strategy = "extract"
594
+ elif all(
595
+ isinstance(elem, dict) and elem.get("role") == first_element["role"]
596
+ for elem in value
597
+ ):
598
+ strategy = "same_role"
599
+ else:
600
+ strategy = "multi_role"
601
+ results[column] = {"type": "message_list", "strategy": strategy, "count": count}
602
+ continue
603
+
604
+ return results
605
+
606
+
607
+ def _flatten_value(value, detection_result):
608
+ """Flatten a chat-format column value to a plain string.
609
+
610
+ Args:
611
+ value: The column value (dict, list, string, or other)
612
+ detection_result: The detection metadata for this column
613
+
614
+ Returns:
615
+ str: The flattened string value
616
+
617
+ Raises:
618
+ ValueError: If the value cannot be converted at all (str() also fails)
619
+ """
620
+ import json
621
+
622
+ # Edge case: string pass-through
623
+ if isinstance(value, str):
624
+ return value
625
+
626
+ # Edge case: None → ""
627
+ if value is None:
628
+ return ""
629
+
630
+ # Edge case: empty list → ""
631
+ if isinstance(value, list) and len(value) == 0:
632
+ return ""
633
+
634
+ det_type = detection_result.get("type")
635
+
636
+ if det_type == "single_dict":
637
+ if isinstance(value, dict):
638
+ role = value.get("role", "")
639
+ if "content" in value:
640
+ content = value["content"]
641
+ if isinstance(content, str):
642
+ return content
643
+ # Non-string content: format as "role: json_content"
644
+ return f"{role}: {json.dumps(content)}"
645
+ else:
646
+ # No content key: format as "role: remaining_values"
647
+ remaining = {k: v for k, v in value.items() if k != "role"}
648
+ return f"{role}: {json.dumps(remaining)}"
649
+
650
+ elif det_type == "message_list":
651
+ strategy = detection_result.get("strategy")
652
+
653
+ if isinstance(value, list) and len(value) > 0:
654
+ if strategy == "extract":
655
+ # Extract single element's content
656
+ elem = value[0]
657
+ if isinstance(elem, dict):
658
+ content = elem.get("content")
659
+ if content is None:
660
+ return ""
661
+ if isinstance(content, str):
662
+ return content
663
+ return f"{elem.get('role', '')}: {json.dumps(content)}"
664
+ return ""
665
+
666
+ elif strategy == "same_role":
667
+ # Join all content fields with newline
668
+ parts = []
669
+ for elem in value:
670
+ if isinstance(elem, dict):
671
+ content = elem.get("content")
672
+ if content is None or content == "":
673
+ parts.append("")
674
+ elif isinstance(content, str):
675
+ parts.append(content)
676
+ else:
677
+ parts.append(json.dumps(content))
678
+ else:
679
+ parts.append("")
680
+ return "\n".join(parts)
681
+
682
+ elif strategy == "multi_role":
683
+ # Format as "role: content" per line
684
+ lines = []
685
+ for elem in value:
686
+ if isinstance(elem, dict):
687
+ role = elem.get("role", "")
688
+ content = elem.get("content")
689
+ if content is None:
690
+ content = ""
691
+ elif not isinstance(content, str):
692
+ content = json.dumps(content)
693
+ lines.append(f"{role}: {content}")
694
+ else:
695
+ lines.append("")
696
+ return "\n".join(lines)
697
+
698
+ # Fallback for unexpected types: int/bool → str()
699
+ try:
700
+ return str(value)
701
+ except Exception as e:
702
+ raise ValueError(f"Cannot convert value to string: {e}")
703
+
704
+
705
+ def _flatten_record(record, chat_columns):
706
+ """Apply flattening to all chat-format columns in a record.
707
+
708
+ Args:
709
+ record: The mapped record dict
710
+ chat_columns: Detection results from _detect_chat_columns
711
+
712
+ Returns:
713
+ dict: The record with chat-format columns replaced by flat strings
714
+ """
715
+ flattened = dict(record)
716
+ for column_name, detection_result in chat_columns.items():
717
+ if column_name in flattened:
718
+ flattened[column_name] = _flatten_value(flattened[column_name], detection_result)
719
+ return flattened
720
+
721
+
722
+ def _log_flatten_info(chat_columns, no_transform):
723
+ """Log auto-flatten detection and strategy information.
724
+
725
+ Logs regardless of --no-transform state (per requirement 6.3/6.4).
726
+ When --no-transform is active, detection still runs for logging purposes.
727
+
728
+ All output goes to stderr to avoid polluting stdout JSON output.
729
+
730
+ Args:
731
+ chat_columns: Detection results dict (from _detect_chat_columns)
732
+ no_transform: Whether --no-transform flag is active
733
+ """
734
+ for column_name, detection_result in chat_columns.items():
735
+ print(f"\u2139\ufe0f Auto-converted column '{column_name}' from chat-format to string", file=sys.stderr)
736
+ det_type = detection_result.get("type")
737
+ if det_type == "single_dict":
738
+ print(" Format: extracted content field", file=sys.stderr)
739
+ elif det_type == "message_list":
740
+ strategy = detection_result.get("strategy")
741
+ count = detection_result.get("count", 0)
742
+ if strategy == "multi_role":
743
+ print(f" Format: role: content (multi-turn, {count} messages)", file=sys.stderr)
744
+ elif strategy == "same_role":
745
+ print(f" Format: newline-joined content ({count} messages, same role)", file=sys.stderr)
746
+ elif strategy == "extract":
747
+ print(" Format: extracted content field", file=sys.stderr)
748
+
749
+
750
+ def _get_schema_types(technique):
751
+ """Return a dict mapping column names to their expected types for a technique.
752
+
753
+ Args:
754
+ technique: One of 'sft', 'dpo', 'rlaif', 'rlvr'
755
+
756
+ Returns:
757
+ dict: Maps column_name -> expected type ("string" or "array")
758
+ """
759
+ schemas = {
760
+ "sft": {"prompt": "string", "completion": "string"},
761
+ "dpo": {"prompt": "string", "chosen": "string", "rejected": "string"},
762
+ "rlaif": {"prompt": "array"},
763
+ "rlvr": {"prompt": "array"},
764
+ }
765
+ return schemas.get(technique, {"prompt": "string", "completion": "string"})
766
+
767
+
768
+ def _validate_dataset_columns(first_record, technique, column_map_str, dataset_id):
769
+ """Validate that the first record has required columns after mapping.
770
+
771
+ Returns (mapped_record, column_map_dict) on success.
772
+ Calls _error_exit with helpful suggestions on failure.
773
+ """
774
+ column_map = _parse_column_map(column_map_str)
775
+ mapped = _apply_column_map(first_record, column_map)
776
+ required = _get_required_columns(technique)
777
+ detected = list(first_record.keys())
778
+
779
+ missing = [col for col in required if col not in mapped]
780
+ if not missing:
781
+ return mapped, column_map
782
+
783
+ # Build helpful error message
784
+ lines = [
785
+ f"Dataset columns don't match {technique.upper()} requirements.",
786
+ f"",
787
+ f" Required columns: {', '.join(required)}",
788
+ f" Detected columns: {', '.join(detected)}",
789
+ f" Missing: {', '.join(missing)}",
790
+ ]
791
+
792
+ # Suggest a column map
793
+ suggestion = _suggest_column_map(detected, required)
794
+ if suggestion:
795
+ lines.append(f"")
796
+ lines.append(f" 💡 Suggested fix:")
797
+ lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {suggestion}")
798
+ else:
799
+ lines.append(f"")
800
+ lines.append(f" 💡 Use --column-map to rename columns:")
801
+ example_map = ",".join(f"{r}=<your_column>" for r in missing)
802
+ lines.append(f" ./do/tune --technique {technique} --dataset hf://{dataset_id} --column-map {example_map}")
803
+
804
+ lines.append(f"")
805
+ lines.append(f" First record sample:")
806
+ # Show truncated first record
807
+ for k, v in list(first_record.items())[:5]:
808
+ val_str = str(v)[:80] + ("..." if len(str(v)) > 80 else "")
809
+ lines.append(f" {k}: {val_str}")
810
+
811
+ _error_exit("\n".join(lines))
812
+
813
+
309
814
  def cmd_stage_hf(args):
310
815
  """Download HF dataset to S3 using huggingface_hub.
311
816
 
@@ -313,6 +818,9 @@ def cmd_stage_hf(args):
313
818
 
314
819
  Returns: {"s3_uri": str, "num_records": int}
315
820
  """
821
+ # Suppress HF Hub progress bars — they pollute stdout which must be clean JSON
822
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
823
+
316
824
  try:
317
825
  from huggingface_hub import hf_hub_download, HfApi
318
826
  except ImportError:
@@ -352,12 +860,28 @@ def cmd_stage_hf(args):
352
860
  f"Available files: {', '.join(repo_files[:20])}"
353
861
  )
354
862
 
863
+ # Apply file filter if --hf-file is provided
864
+ hf_file_pattern = getattr(args, 'hf_file', None)
865
+ if hf_file_pattern:
866
+ data_files = _filter_data_files(data_files, hf_file_pattern)
867
+
355
868
  # Download and upload to S3
356
869
  s3_client = boto3.client("s3", region_name=args.region)
357
870
  s3_prefix = f"{args.project_name}/datasets/{org}/{name}/{split}"
358
871
  num_records = 0
359
872
 
360
873
  with tempfile.TemporaryDirectory() as tmpdir:
874
+ # Schema divergence check (skip for single file)
875
+ if len(data_files) > 1:
876
+ column_map = _parse_column_map(getattr(args, 'column_map', None))
877
+ technique = getattr(args, 'technique', 'sft')
878
+ no_transform = getattr(args, 'no_transform', False)
879
+ file_records = _inspect_file_schemas(
880
+ data_files, dataset_id, hf_token, tmpdir,
881
+ column_map, technique, no_transform
882
+ )
883
+ _check_schema_divergence(file_records, dataset_id, technique)
884
+
361
885
  for data_file in data_files:
362
886
  local_path = hf_hub_download(
363
887
  repo_id=dataset_id,
@@ -367,17 +891,162 @@ def cmd_stage_hf(args):
367
891
  local_dir=tmpdir,
368
892
  )
369
893
 
370
- # Count records (lines for JSONL)
371
- with open(local_path, "r") as f:
372
- for line in f:
373
- if line.strip():
374
- num_records += 1
375
-
376
- # Upload to S3
377
- s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
378
- s3_client.upload_file(local_path, args.output_bucket, s3_key)
379
-
380
- s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{os.path.basename(data_files[0])}"
894
+ # Handle Parquet files: convert to JSONL for SageMaker compatibility
895
+ if data_file.endswith(".parquet"):
896
+ try:
897
+ import pyarrow.parquet as pq
898
+ import json as json_mod
899
+
900
+ table = pq.read_table(local_path)
901
+ jsonl_filename = os.path.splitext(os.path.basename(data_file))[0] + ".jsonl"
902
+ jsonl_path = os.path.join(tmpdir, jsonl_filename)
903
+
904
+ # Parse column map and validate against first record
905
+ column_map = _parse_column_map(getattr(args, 'column_map', None))
906
+ technique = getattr(args, 'technique', 'sft')
907
+ no_transform = getattr(args, 'no_transform', False)
908
+ batches = table.to_batches(max_chunksize=1)
909
+ first_record = batches[0].to_pylist()[0] if batches else {}
910
+ _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
911
+
912
+ # Apply column map to first record for detection
913
+ mapped_first = _apply_column_map(first_record, column_map)
914
+ required_columns = _get_required_columns(technique)
915
+ schema_types = _get_schema_types(technique)
916
+
917
+ # Detect chat-format columns on first record
918
+ chat_columns = _detect_chat_columns(mapped_first, required_columns, schema_types)
919
+
920
+ # Log detection results if any chat columns found
921
+ if chat_columns:
922
+ _log_flatten_info(chat_columns, no_transform)
923
+
924
+ # If --no-transform is active and chat-format detected, halt with error
925
+ if no_transform and chat_columns:
926
+ col_name = next(iter(chat_columns))
927
+ det = chat_columns[col_name]
928
+ det_type = det.get("type")
929
+ strategy = det.get("strategy", "")
930
+ if det_type == "single_dict":
931
+ strategy_desc = "single message dict with role+content"
932
+ elif strategy == "extract":
933
+ strategy_desc = "message list (single element)"
934
+ elif strategy == "same_role":
935
+ strategy_desc = f"message list ({det.get('count', 0)} messages, same role)"
936
+ elif strategy == "multi_role":
937
+ strategy_desc = f"message list (multi-turn, {det.get('count', 0)} messages)"
938
+ else:
939
+ strategy_desc = det_type
940
+ _error_exit(
941
+ f"Column '{col_name}' contains chat-format data (detected: {det_type}) but --no-transform is active.\n\n"
942
+ f" Remove --no-transform to enable automatic conversion:\n"
943
+ f" ./do/tune --technique {technique} --dataset hf://{org}/{name} [--column-map ...]\n\n"
944
+ f" Detected format: {strategy_desc}"
945
+ )
946
+
947
+ with open(jsonl_path, "w", encoding="utf-8") as out_f:
948
+ for batch in table.to_batches():
949
+ for row in batch.to_pylist():
950
+ mapped_row = _apply_column_map(row, column_map)
951
+ if chat_columns and not no_transform:
952
+ mapped_row = _flatten_record(mapped_row, chat_columns)
953
+ out_f.write(json_mod.dumps(mapped_row, ensure_ascii=False) + "\n")
954
+ num_records += 1
955
+
956
+ # Upload converted JSONL
957
+ s3_key = f"{s3_prefix}/{jsonl_filename}"
958
+ s3_client.upload_file(jsonl_path, args.output_bucket, s3_key)
959
+
960
+ except ImportError:
961
+ _error_exit(
962
+ "Dataset is in Parquet format but pyarrow is not installed. "
963
+ "Please install: pip install pyarrow"
964
+ )
965
+ else:
966
+ # JSONL file — validate columns and apply mapping
967
+ import json as json_mod
968
+ column_map = _parse_column_map(getattr(args, 'column_map', None))
969
+ technique = getattr(args, 'technique', 'sft')
970
+ no_transform = getattr(args, 'no_transform', False)
971
+
972
+ # Read first line to validate
973
+ chat_columns = {}
974
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f:
975
+ first_line = f.readline().strip()
976
+ if first_line:
977
+ first_record = json_mod.loads(first_line)
978
+ _validate_dataset_columns(first_record, technique, getattr(args, 'column_map', None), f"{org}/{name}")
979
+
980
+ # Apply column map to first record for detection
981
+ mapped_first = _apply_column_map(first_record, column_map)
982
+ required_columns = _get_required_columns(technique)
983
+ schema_types = _get_schema_types(technique)
984
+
985
+ # Detect chat-format columns on first record
986
+ chat_columns = _detect_chat_columns(mapped_first, required_columns, schema_types)
987
+
988
+ # Log detection results if any chat columns found
989
+ if chat_columns:
990
+ _log_flatten_info(chat_columns, no_transform)
991
+
992
+ # If --no-transform is active and chat-format detected, halt with error
993
+ if no_transform and chat_columns:
994
+ col_name = next(iter(chat_columns))
995
+ det = chat_columns[col_name]
996
+ det_type = det.get("type")
997
+ strategy = det.get("strategy", "")
998
+ if det_type == "single_dict":
999
+ strategy_desc = "single message dict with role+content"
1000
+ elif strategy == "extract":
1001
+ strategy_desc = "message list (single element)"
1002
+ elif strategy == "same_role":
1003
+ strategy_desc = f"message list ({det.get('count', 0)} messages, same role)"
1004
+ elif strategy == "multi_role":
1005
+ strategy_desc = f"message list (multi-turn, {det.get('count', 0)} messages)"
1006
+ else:
1007
+ strategy_desc = det_type
1008
+ _error_exit(
1009
+ f"Column '{col_name}' contains chat-format data (detected: {det_type}) but --no-transform is active.\n\n"
1010
+ f" Remove --no-transform to enable automatic conversion:\n"
1011
+ f" ./do/tune --technique {technique} --dataset hf://{org}/{name} [--column-map ...]\n\n"
1012
+ f" Detected format: {strategy_desc}"
1013
+ )
1014
+
1015
+ # Rewrite the file with mapped (and optionally flattened) columns
1016
+ should_flatten = bool(chat_columns) and not no_transform
1017
+ if column_map or should_flatten:
1018
+ mapped_path = local_path + ".mapped"
1019
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f_in, \
1020
+ open(mapped_path, "w", encoding="utf-8") as f_out:
1021
+ for line in f_in:
1022
+ line = line.strip()
1023
+ if not line:
1024
+ continue
1025
+ record = json_mod.loads(line)
1026
+ mapped_record = _apply_column_map(record, column_map)
1027
+ if should_flatten:
1028
+ mapped_record = _flatten_record(mapped_record, chat_columns)
1029
+ f_out.write(json_mod.dumps(mapped_record, ensure_ascii=False) + "\n")
1030
+ num_records += 1
1031
+ local_path = mapped_path
1032
+ else:
1033
+ # Count records
1034
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f:
1035
+ for line in f:
1036
+ if line.strip():
1037
+ num_records += 1
1038
+
1039
+ # Upload to S3
1040
+ s3_key = f"{s3_prefix}/{os.path.basename(data_file)}"
1041
+ s3_client.upload_file(local_path, args.output_bucket, s3_key)
1042
+
1043
+ # Use the first file's name for the S3 URI (JSONL extension for Parquet conversions)
1044
+ first_file = data_files[0]
1045
+ if first_file.endswith(".parquet"):
1046
+ output_filename = os.path.splitext(os.path.basename(first_file))[0] + ".jsonl"
1047
+ else:
1048
+ output_filename = os.path.basename(first_file)
1049
+ s3_uri = f"s3://{args.output_bucket}/{s3_prefix}/{output_filename}"
381
1050
 
382
1051
  _output({
383
1052
  "s3_uri": s3_uri,
@@ -475,9 +1144,194 @@ def _find_data_files(repo_files, split):
475
1144
  if data_jsonl:
476
1145
  return sorted(data_jsonl)
477
1146
 
1147
+ # Final fallback: any JSONL/JSON file in the repo root (single-file datasets)
1148
+ root_data = [f for f in repo_files if "/" not in f and (f.endswith(".jsonl") or f.endswith(".json")) and not f.startswith(".")]
1149
+ if root_data:
1150
+ return sorted(root_data)
1151
+
478
1152
  return []
479
1153
 
480
1154
 
1155
+ def _is_glob_pattern(pattern):
1156
+ """Return True if pattern contains glob metacharacters (*, ?, [)."""
1157
+ return bool(_GLOB_METACHAR_RE.search(pattern))
1158
+
1159
+
1160
+ def _filter_data_files(data_files, pattern):
1161
+ """Filter data files by glob or substring pattern.
1162
+
1163
+ If the pattern is empty or None, returns all files (no-filter).
1164
+ If the pattern contains glob metacharacters (*, ?, [), uses fnmatch
1165
+ against the full relative path. Otherwise, performs substring match
1166
+ on the basename.
1167
+
1168
+ Args:
1169
+ data_files: List of file paths from _find_data_files
1170
+ pattern: The filter pattern string
1171
+
1172
+ Returns:
1173
+ list: Filtered file paths that match the pattern
1174
+
1175
+ Raises:
1176
+ SystemExit: via _error_exit if no files match (includes available files list)
1177
+ """
1178
+ if not pattern:
1179
+ return data_files
1180
+
1181
+ if _is_glob_pattern(pattern):
1182
+ matched = [f for f in data_files if fnmatch.fnmatch(f, pattern)]
1183
+ else:
1184
+ matched = [f for f in data_files if pattern in os.path.basename(f)]
1185
+
1186
+ if not matched:
1187
+ file_list = "\n".join(f" • {f}" for f in data_files)
1188
+ _error_exit(
1189
+ f"No files matched pattern '{pattern}'.\n\n"
1190
+ f"Available files:\n{file_list}"
1191
+ )
1192
+
1193
+ return matched
1194
+
1195
+
1196
+ def _inspect_file_schemas(data_files, dataset_id, hf_token, tmpdir,
1197
+ column_map, technique, no_transform):
1198
+ """Inspect first record of each file to extract effective column sets.
1199
+
1200
+ Downloads each file, reads its first record, applies column-map and
1201
+ flattening, then returns the resulting column names.
1202
+
1203
+ Args:
1204
+ data_files: List of file paths to inspect
1205
+ dataset_id: HF dataset identifier for downloads
1206
+ hf_token: Authentication token
1207
+ tmpdir: Temporary directory for downloads
1208
+ column_map: Parsed column mapping dict
1209
+ technique: Technique name for schema types
1210
+ no_transform: Whether --no-transform is active
1211
+
1212
+ Returns:
1213
+ list: [(filename, set_of_column_names), ...] for each file
1214
+ """
1215
+ from huggingface_hub import hf_hub_download
1216
+
1217
+ required_columns = _get_required_columns(technique)
1218
+ schema_types = _get_schema_types(technique)
1219
+ results = []
1220
+
1221
+ for data_file in data_files:
1222
+ local_path = hf_hub_download(
1223
+ repo_id=dataset_id,
1224
+ filename=data_file,
1225
+ repo_type="dataset",
1226
+ token=hf_token,
1227
+ local_dir=tmpdir,
1228
+ )
1229
+
1230
+ first_record = {}
1231
+
1232
+ if data_file.endswith(".parquet"):
1233
+ try:
1234
+ import pyarrow.parquet as pq
1235
+
1236
+ table = pq.read_table(local_path)
1237
+ batches = table.to_batches(max_chunksize=1)
1238
+ if batches:
1239
+ first_record = batches[0].to_pylist()[0]
1240
+ except ImportError:
1241
+ _error_exit(
1242
+ "Dataset is in Parquet format but pyarrow is not installed. "
1243
+ "Please install: pip install pyarrow"
1244
+ )
1245
+ else:
1246
+ import json as json_mod
1247
+
1248
+ with open(local_path, "r", encoding="utf-8", errors="replace") as f:
1249
+ first_line = f.readline().strip()
1250
+ if first_line:
1251
+ first_record = json_mod.loads(first_line)
1252
+
1253
+ # Apply column mapping
1254
+ mapped_record = _apply_column_map(first_record, column_map)
1255
+
1256
+ # Apply flattening if --no-transform is not active
1257
+ if not no_transform:
1258
+ chat_columns = _detect_chat_columns(mapped_record, required_columns, schema_types)
1259
+ if chat_columns:
1260
+ mapped_record = _flatten_record(mapped_record, chat_columns)
1261
+
1262
+ results.append((data_file, set(mapped_record.keys())))
1263
+
1264
+ return results
1265
+
1266
+
1267
+ def _check_schema_divergence(file_records, dataset_id, technique):
1268
+ """Check that all files have identical effective columns.
1269
+
1270
+ Args:
1271
+ file_records: List of (filename, first_record_columns) tuples where
1272
+ first_record_columns is the set of column names after
1273
+ column-map and flattening
1274
+ dataset_id: The dataset identifier (for error messages)
1275
+ technique: The technique name (for error messages)
1276
+
1277
+ Returns:
1278
+ None on success (all schemas match)
1279
+
1280
+ Raises:
1281
+ SystemExit: via _error_exit with per-file column listing and
1282
+ ?file= remediation suggestion if schemas differ
1283
+ """
1284
+ if not file_records:
1285
+ return None
1286
+
1287
+ # Compare all column sets to the first file's columns
1288
+ first_columns = file_records[0][1]
1289
+ all_identical = all(cols == first_columns for _, cols in file_records)
1290
+
1291
+ if all_identical:
1292
+ return None
1293
+
1294
+ # Build per-file column listing
1295
+ file_sections = []
1296
+ for filename, columns in file_records:
1297
+ sorted_cols = ", ".join(sorted(columns))
1298
+ file_sections.append(
1299
+ f" \U0001f4c4 {filename}\n"
1300
+ f" Columns: {sorted_cols}"
1301
+ )
1302
+
1303
+ # Derive remediation pattern from first file's basename
1304
+ first_file = file_records[0][0]
1305
+ basename = os.path.basename(first_file)
1306
+ # Strip extension and wrap with wildcards for a useful pattern
1307
+ name_without_ext = os.path.splitext(basename)[0]
1308
+ # Use a distinctive portion — take the first numeric segment if present
1309
+ import re as _re
1310
+ numeric_match = _re.search(r'\d+', name_without_ext)
1311
+ if numeric_match:
1312
+ pattern_suggestion = f"*{numeric_match.group()}*"
1313
+ else:
1314
+ pattern_suggestion = f"*{name_without_ext}*"
1315
+
1316
+ # Build available files list
1317
+ available_files = "\n".join(
1318
+ f" \u2022 {filename}" for filename, _ in file_records
1319
+ )
1320
+
1321
+ # Build the full error message
1322
+ file_listing = "\n\n".join(file_sections)
1323
+ message = (
1324
+ f"Schema divergence detected in dataset {dataset_id}.\n"
1325
+ f"Files have different columns after applying column-map and transforms:\n\n"
1326
+ f"{file_listing}\n\n"
1327
+ f"\U0001f4a1 Use ?file=<pattern> to select compatible files:\n"
1328
+ f" ./do/tune --technique {technique} --dataset hf://{dataset_id}?file={pattern_suggestion}\n\n"
1329
+ f" Available files:\n{available_files}"
1330
+ )
1331
+
1332
+ _error_exit(message)
1333
+
1334
+
481
1335
  # ── Subcommand: validate ──────────────────────────────────────────────────────
482
1336
 
483
1337
 
@@ -648,6 +1502,53 @@ def _build_expected_format(schema):
648
1502
  return "Each line must be a JSON object with: {" + ", ".join(fields) + "}"
649
1503
 
650
1504
 
1505
+ # ── Subcommand: discover ──────────────────────────────────────────────────────
1506
+
1507
+
1508
+ def cmd_discover(args):
1509
+ """Query JumpStart Hub for tune-eligible models matching a family.
1510
+
1511
+ Returns: {"models": [str], "count": int}
1512
+ """
1513
+ import boto3
1514
+
1515
+ region = args.region or os.environ.get('AWS_REGION', 'us-east-1')
1516
+
1517
+ family = args.family or ""
1518
+ # Map family names to Hub content name prefixes
1519
+ FAMILY_PREFIX_MAP = {
1520
+ "qwen-2.5": "huggingface-llm-qwen2-5",
1521
+ "qwen-3": "huggingface-reasoning-qwen3",
1522
+ "llama-3": "meta-textgeneration-llama-3",
1523
+ "deepseek-r1": "deepseek-llm-r1-distill",
1524
+ "gpt-oss": "openai-reasoning-gpt-oss",
1525
+ }
1526
+
1527
+ prefix = FAMILY_PREFIX_MAP.get(family, args.filter or "")
1528
+ if not prefix:
1529
+ _error_exit("No family or filter provided for discovery")
1530
+
1531
+ try:
1532
+ client = boto3.client("sagemaker", region_name=region)
1533
+ models = []
1534
+ paginator = client.get_paginator('list_hub_contents')
1535
+ pages = paginator.paginate(
1536
+ HubName="SageMakerPublicHub",
1537
+ HubContentType="Model",
1538
+ NameContains=prefix,
1539
+ MaxResults=20
1540
+ )
1541
+ for page in pages:
1542
+ for item in page.get('HubContentSummaries', []):
1543
+ if item.get('HubContentStatus') == 'Available':
1544
+ models.append(item['HubContentName'])
1545
+
1546
+ _output({"models": models[:5], "count": len(models)})
1547
+
1548
+ except Exception as e:
1549
+ _error_exit(f"Hub discovery failed: {e}")
1550
+
1551
+
651
1552
  # ── CLI argument parsing ──────────────────────────────────────────────────────
652
1553
 
653
1554
 
@@ -661,6 +1562,8 @@ def main():
661
1562
  # ── submit ────────────────────────────────────────────────────────────────
662
1563
  submit_parser = subparsers.add_parser("submit", help="Submit a customization job")
663
1564
  submit_parser.add_argument("--model-id", required=True, help="Model ID")
1565
+ submit_parser.add_argument("--region", default=None,
1566
+ help="AWS region (defaults to AWS_REGION env var)")
664
1567
  submit_parser.add_argument("--technique", required=True,
665
1568
  choices=["sft", "dpo", "rlaif", "rlvr"],
666
1569
  help="Customization technique")
@@ -695,6 +1598,8 @@ def main():
695
1598
  help="Lambda ARN for reward function (RLVR)")
696
1599
  submit_parser.add_argument("--reward-prompt", default=None,
697
1600
  help="S3 URI for reward prompt (RLAIF)")
1601
+ submit_parser.add_argument("--accept-eula", action="store_true", default=False,
1602
+ help="Accept model EULA for gated models (e.g., Llama)")
698
1603
 
699
1604
  # ── status ────────────────────────────────────────────────────────────────
700
1605
  status_parser = subparsers.add_parser("status", help="Get job status and metrics")
@@ -725,6 +1630,8 @@ def main():
725
1630
  help="Hugging Face dataset name")
726
1631
  stage_hf_parser.add_argument("--hf-split", default="train",
727
1632
  help="Dataset split (default: train)")
1633
+ stage_hf_parser.add_argument("--hf-file", default=None,
1634
+ help="File filter pattern (glob or substring)")
728
1635
  stage_hf_parser.add_argument("--output-bucket", required=True,
729
1636
  help="S3 bucket for staged dataset")
730
1637
  stage_hf_parser.add_argument("--project-name", required=True,
@@ -733,6 +1640,13 @@ def main():
733
1640
  help="AWS region")
734
1641
  stage_hf_parser.add_argument("--hf-secret-name", default=None,
735
1642
  help="Secrets Manager secret name for HF token")
1643
+ stage_hf_parser.add_argument("--column-map", default=None,
1644
+ help="Column mapping (e.g., prompt=question,completion=answer)")
1645
+ stage_hf_parser.add_argument("--technique", default="sft",
1646
+ choices=["sft", "dpo", "rlaif", "rlvr"],
1647
+ help="Customization technique (determines required columns)")
1648
+ stage_hf_parser.add_argument("--no-transform", action="store_true", default=False,
1649
+ help="Disable automatic chat-format flattening")
736
1650
 
737
1651
  # ── validate ──────────────────────────────────────────────────────────────
738
1652
  validate_parser = subparsers.add_parser("validate",
@@ -742,6 +1656,16 @@ def main():
742
1656
  validate_parser.add_argument("--file", default="-",
743
1657
  help="Path to dataset file (default: stdin)")
744
1658
 
1659
+ # ── discover ──────────────────────────────────────────────────────────────
1660
+ discover_parser = subparsers.add_parser("discover",
1661
+ help="Discover tune-eligible models from JumpStart Hub")
1662
+ discover_parser.add_argument("--family", default="",
1663
+ help="Model family name (e.g., qwen-3, llama-3, deepseek-r1)")
1664
+ discover_parser.add_argument("--filter", default="",
1665
+ help="Hub content name prefix filter (overrides family mapping)")
1666
+ discover_parser.add_argument("--region", default="",
1667
+ help="AWS region (default: AWS_REGION env or us-east-1)")
1668
+
745
1669
  # ── Parse and dispatch ────────────────────────────────────────────────────
746
1670
  args = parser.parse_args()
747
1671
 
@@ -755,6 +1679,7 @@ def main():
755
1679
  "resolve": cmd_resolve,
756
1680
  "stage-hf": cmd_stage_hf,
757
1681
  "validate": cmd_validate,
1682
+ "discover": cmd_discover,
758
1683
  }
759
1684
 
760
1685
  handler = command_map.get(args.command)