@aws/ml-container-creator 0.13.4 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/README.md +23 -5
  2. package/config/parameter-schema-v2.json +32 -4
  3. package/infra/ci-harness/lib/ci-harness-stack.ts +13 -5
  4. package/infra/ci-harness/package-lock.json +122 -116
  5. package/infra/ci-harness/package.json +1 -1
  6. package/package.json +5 -3
  7. package/pyproject.toml +21 -0
  8. package/requirements.txt +19 -0
  9. package/servers/instance-sizer/index.js +72 -4
  10. package/servers/instance-sizer/lib/model-resolver.js +28 -2
  11. package/src/app.js +17 -0
  12. package/src/lib/bootstrap-command-handler.js +33 -23
  13. package/src/lib/config-loader.js +18 -0
  14. package/src/lib/config-manager.js +6 -1
  15. package/src/lib/dataset-slug.js +152 -0
  16. package/src/lib/generated/cli-options.js +9 -3
  17. package/src/lib/generated/parameter-matrix.js +14 -3
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +6 -0
  20. package/src/lib/prompt-runner.js +5 -0
  21. package/src/lib/prompts/feature-prompts.js +1 -1
  22. package/src/lib/template-manager.js +0 -7
  23. package/src/lib/template-variable-resolver.js +51 -1
  24. package/src/lib/tune-config-state.js +14 -1
  25. package/templates/do/.adapter_helper.py +451 -0
  26. package/templates/do/.benchmark_writer.py +22 -0
  27. package/templates/do/.register_helper.py +1163 -0
  28. package/templates/do/.stage_helper.py +419 -0
  29. package/templates/do/.tune_helper.py +379 -65
  30. package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
  31. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  32. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  33. package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
  34. package/templates/do/adapter +427 -27
  35. package/templates/do/add-ic +85 -3
  36. package/templates/do/benchmark +173 -15
  37. package/templates/do/config +24 -0
  38. package/templates/do/lib/inference-component.sh +56 -3
  39. package/templates/do/lib/profile.sh +5 -0
  40. package/templates/do/register +552 -6
  41. package/templates/do/stage +91 -272
  42. package/templates/do/test +12 -2
  43. package/templates/do/tune +264 -12
@@ -0,0 +1,1163 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Model Package Group helper for model registration.
6
+
7
+ Subcommands:
8
+ create-mpg - Create a Model Package Group (idempotent)
9
+ register-model - Register a model as a versioned Model Package
10
+ register-adapter - Register an adapter as a versioned Model Package linked to base model
11
+
12
+ Uses sagemaker-core ModelPackageGroup and ModelPackage resource APIs (SDK v3).
13
+ No boto3 sagemaker client per NFR-3.
14
+
15
+ All output is JSON on stdout for bash consumption.
16
+ Diagnostic messages go to stderr.
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import logging
22
+ import os
23
+ import sys
24
+ import warnings
25
+
26
+ # Suppress noisy dependency version warnings
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
28
+ warnings.filterwarnings("ignore", message=".*urllib3.*")
29
+
30
+ # Suppress sagemaker-core INFO/WARNING logging that pollutes stdout
31
+ logging.getLogger("sagemaker.config").setLevel(logging.ERROR)
32
+ logging.getLogger("sagemaker.core").setLevel(logging.ERROR)
33
+ logging.getLogger("sagemaker").setLevel(logging.ERROR)
34
+
35
+
36
+ # ── Constants ─────────────────────────────────────────────────────────────────
37
+
38
+ MAX_METADATA_VALUE_LEN = 256
39
+
40
+
41
+ # ── Utility functions ─────────────────────────────────────────────────────────
42
+
43
+
44
+ def _error_exit(message, code="REGISTRATION_ERROR", exit_code=1):
45
+ """Print error JSON to stdout, message to stderr, and exit."""
46
+ print(f"Error: {message}", file=sys.stderr)
47
+ print(json.dumps({"error": message, "code": code}))
48
+ sys.exit(exit_code)
49
+
50
+
51
+ def _output(data):
52
+ """Print JSON result to stdout and exit 0."""
53
+ print(json.dumps(data))
54
+ sys.exit(0)
55
+
56
+
57
+ def _warn(message):
58
+ """Print warning to stderr."""
59
+ print(f"⚠️ {message}", file=sys.stderr)
60
+
61
+
62
+ # ── Dependency check ──────────────────────────────────────────────────────────
63
+
64
+
65
+ def _check_sagemaker_core():
66
+ """Verify sagemaker-core is installed."""
67
+ try:
68
+ from sagemaker.core.resources import ModelPackageGroup, ModelPackage # noqa: F401
69
+ except ImportError:
70
+ _error_exit(
71
+ "sagemaker-core is not installed. "
72
+ "Please install: pip install 'sagemaker>=3.0.0' (includes sagemaker-core)",
73
+ code="MISSING_DEPENDENCY",
74
+ )
75
+
76
+
77
+ # ── Metadata helpers ──────────────────────────────────────────────────────────
78
+
79
+
80
+ def _truncate_metadata(props):
81
+ """Truncate metadata values exceeding 256 chars with '…' suffix and log warning.
82
+
83
+ Args:
84
+ props: dict of metadata key-value pairs
85
+
86
+ Returns:
87
+ dict with all values as strings, truncated if necessary
88
+ """
89
+ result = {}
90
+ for key, value in props.items():
91
+ str_val = str(value) if value is not None else ""
92
+ if len(str_val) > MAX_METADATA_VALUE_LEN:
93
+ _warn(f"Metadata '{key}' truncated ({len(str_val)} → {MAX_METADATA_VALUE_LEN} chars)")
94
+ str_val = str_val[: MAX_METADATA_VALUE_LEN - 1] + "…"
95
+ result[key] = str_val
96
+ return result
97
+
98
+
99
+ def _build_metadata(args):
100
+ """Build customer_metadata_properties dict from CLI args.
101
+
102
+ All values are converted to strings per SageMaker constraints (NFR-1).
103
+ Values exceeding 256 chars are truncated with '…' suffix (AC-1.8).
104
+ """
105
+ props = {
106
+ "deploymentConfig": args.deployment_config or "",
107
+ "architecture": args.architecture or "",
108
+ "backend": args.backend or "",
109
+ "instanceType": args.instance_type or "",
110
+ "modelName": args.model_name or "",
111
+ "baseImage": args.base_image or "",
112
+ "modelFormat": args.model_format or "",
113
+ "generatorVersion": args.generator_version or "",
114
+ "projectName": args.project_name or "",
115
+ }
116
+
117
+ # Add benchmark results if available
118
+ if getattr(args, "benchmark_results", None):
119
+ try:
120
+ bench = json.loads(args.benchmark_results) if isinstance(args.benchmark_results, str) else args.benchmark_results
121
+ if isinstance(bench, dict):
122
+ for bkey, bval in bench.items():
123
+ props[f"benchmark_{bkey}"] = str(bval)
124
+ except (json.JSONDecodeError, TypeError):
125
+ _warn("Could not parse benchmark results, skipping")
126
+
127
+ return _truncate_metadata(props)
128
+
129
+
130
+ def _build_adapter_metadata(args):
131
+ """Build customer_metadata_properties dict for adapter registration.
132
+
133
+ Includes all standard fields plus adapter-specific fields (AC-2.2):
134
+ isAdapter, parentModelVersionArn, tuneTechnique, datasetS3Uri.
135
+ """
136
+ props = {
137
+ "deploymentConfig": args.deployment_config or "",
138
+ "architecture": args.architecture or "",
139
+ "backend": args.backend or "",
140
+ "instanceType": args.instance_type or "",
141
+ "modelName": args.model_name or "",
142
+ "baseImage": args.base_image or "",
143
+ "modelFormat": args.model_format or "",
144
+ "generatorVersion": args.generator_version or "",
145
+ "projectName": args.project_name or "",
146
+ # Adapter-specific metadata (AC-2.2)
147
+ "isAdapter": "true",
148
+ "parentModelVersionArn": args.parent_version_arn or "",
149
+ "tuneTechnique": args.tune_technique or "",
150
+ "datasetS3Uri": args.dataset_s3_uri or "",
151
+ }
152
+
153
+ return _truncate_metadata(props)
154
+
155
+
156
+ # ── Subcommand: create-mpg ────────────────────────────────────────────────────
157
+
158
+
159
+ def cmd_create_mpg(args):
160
+ """Create a Model Package Group (idempotent — handles AlreadyExists).
161
+
162
+ Returns JSON: {"mpg_arn": str, "created": bool}
163
+ """
164
+ _check_sagemaker_core()
165
+
166
+ from sagemaker.core.resources import ModelPackageGroup
167
+
168
+ project_name = args.project_name
169
+ if not project_name:
170
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
171
+
172
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
173
+ os.environ["AWS_DEFAULT_REGION"] = region
174
+ os.environ.setdefault("AWS_REGION", region)
175
+
176
+ print(f"Creating Model Package Group: {project_name}", file=sys.stderr)
177
+
178
+ try:
179
+ mpg = ModelPackageGroup.create(
180
+ model_package_group_name=project_name,
181
+ model_package_group_description=f"Models for {project_name}",
182
+ )
183
+ mpg_arn = mpg.model_package_group_arn
184
+ _output({"mpg_arn": mpg_arn, "created": True})
185
+ except Exception as e:
186
+ error_msg = str(e).lower()
187
+ if "already exists" in error_msg or "alreadyexists" in error_msg or "resource in use" in error_msg:
188
+ # MPG already exists — retrieve its ARN
189
+ print(f"Model Package Group '{project_name}' already exists", file=sys.stderr)
190
+ try:
191
+ mpg = ModelPackageGroup.get(model_package_group_name=project_name)
192
+ mpg_arn = mpg.model_package_group_arn
193
+ _output({"mpg_arn": mpg_arn, "created": False})
194
+ except Exception as get_err:
195
+ # Construct the ARN from known pattern
196
+ account_id = _get_account_id()
197
+ mpg_arn = f"arn:aws:sagemaker:{region}:{account_id}:model-package-group/{project_name}"
198
+ _output({"mpg_arn": mpg_arn, "created": False})
199
+ else:
200
+ _error_exit(f"Failed to create Model Package Group: {e}", code="MPG_CREATE_FAILED")
201
+
202
+
203
+ def _get_account_id():
204
+ """Get AWS account ID from STS."""
205
+ try:
206
+ import boto3
207
+ sts = boto3.client("sts")
208
+ return sts.get_caller_identity()["Account"]
209
+ except Exception:
210
+ return "unknown"
211
+
212
+
213
+ # ── Subcommand: register-model ────────────────────────────────────────────────
214
+
215
+
216
+ def cmd_register_model(args):
217
+ """Register a model as a versioned Model Package in the project's MPG.
218
+
219
+ Creates the MPG if it doesn't exist (AC-1.1), then creates a new
220
+ ModelPackageVersion (AC-1.2, AC-1.7). Stores metadata in
221
+ customer_metadata_properties (AC-1.3, AC-1.8).
222
+
223
+ Returns JSON: {"mpg_arn": str, "model_package_arn": str, "version": int}
224
+ """
225
+ _check_sagemaker_core()
226
+
227
+ from sagemaker.core.resources import ModelPackageGroup, ModelPackage
228
+
229
+ project_name = args.project_name
230
+ if not project_name:
231
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
232
+
233
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
234
+ os.environ["AWS_DEFAULT_REGION"] = region
235
+ os.environ.setdefault("AWS_REGION", region)
236
+
237
+ # Step 1: Create MPG if it doesn't exist (AC-1.1)
238
+ mpg_arn = None
239
+ try:
240
+ mpg = ModelPackageGroup.create(
241
+ model_package_group_name=project_name,
242
+ model_package_group_description=f"Models for {project_name}",
243
+ )
244
+ mpg_arn = mpg.model_package_group_arn
245
+ print(f"Created Model Package Group: {project_name}", file=sys.stderr)
246
+ except Exception as e:
247
+ error_msg = str(e).lower()
248
+ if "already exists" in error_msg or "alreadyexists" in error_msg or "resource in use" in error_msg:
249
+ print(f"Model Package Group '{project_name}' already exists", file=sys.stderr)
250
+ try:
251
+ mpg = ModelPackageGroup.get(model_package_group_name=project_name)
252
+ mpg_arn = mpg.model_package_group_arn
253
+ except Exception:
254
+ # Construct ARN from known pattern
255
+ account_id = _get_account_id()
256
+ mpg_arn = f"arn:aws:sagemaker:{region}:{account_id}:model-package-group/{project_name}"
257
+ else:
258
+ _error_exit(f"Failed to create Model Package Group: {e}", code="MPG_CREATE_FAILED")
259
+
260
+ # Step 2: Build metadata (AC-1.3, AC-1.8)
261
+ metadata = _build_metadata(args)
262
+
263
+ # Step 3: Build inference specification
264
+ container_image = args.container_image or ""
265
+ model_data_url = args.model_data_url or ""
266
+
267
+ inference_spec = {
268
+ "Containers": [
269
+ {
270
+ "Image": container_image,
271
+ }
272
+ ],
273
+ "SupportedContentTypes": ["application/json"],
274
+ "SupportedResponseMIMETypes": ["application/json"],
275
+ }
276
+ # Only include ModelDataUrl if provided
277
+ if model_data_url:
278
+ inference_spec["Containers"][0]["ModelDataUrl"] = model_data_url
279
+
280
+ # Step 4: Create Model Package version (AC-1.2, AC-1.7)
281
+ description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
282
+
283
+ print(f"Registering model version in {project_name}...", file=sys.stderr)
284
+ try:
285
+ pkg = ModelPackage.create(
286
+ model_package_group_name=project_name,
287
+ model_package_description=description,
288
+ inference_specification=inference_spec,
289
+ customer_metadata_properties=metadata,
290
+ model_approval_status="Approved",
291
+ )
292
+
293
+ model_package_arn = pkg.model_package_arn
294
+ # Extract version number from ARN (format: .../project-name/version)
295
+ version = _extract_version_from_arn(model_package_arn)
296
+
297
+ print(f"Registered model version {version}: {model_package_arn}", file=sys.stderr)
298
+ _output({
299
+ "mpg_arn": mpg_arn,
300
+ "model_package_arn": model_package_arn,
301
+ "version": version,
302
+ })
303
+ except Exception as e:
304
+ _error_exit(f"Failed to register model package: {e}", code="MODEL_REGISTER_FAILED")
305
+
306
+
307
+ def _extract_version_from_arn(arn):
308
+ """Extract version number from a model package ARN.
309
+
310
+ ARN format: arn:aws:sagemaker:<region>:<account>:model-package/<group>/<version>
311
+ """
312
+ try:
313
+ parts = arn.split("/")
314
+ return int(parts[-1])
315
+ except (ValueError, IndexError):
316
+ return 0
317
+
318
+
319
+ # ── Subcommand: register-adapter ─────────────────────────────────────────────
320
+
321
+
322
+ def cmd_register_adapter(args):
323
+ """Register an adapter as a versioned Model Package linked to its base model.
324
+
325
+ Creates the MPG if it doesn't exist (reuses AC-1.1 logic), then creates a new
326
+ ModelPackageVersion with adapter-specific metadata (AC-2.1, AC-2.2):
327
+ - isAdapter=true
328
+ - parentModelVersionArn (links to base model version)
329
+ - tuneTechnique (sft/dpo/rlvr)
330
+ - datasetS3Uri (training dataset location)
331
+
332
+ Returns JSON: {"mpg_arn": str, "model_package_arn": str, "version": int, "parent_version_arn": str}
333
+ """
334
+ _check_sagemaker_core()
335
+
336
+ from sagemaker.core.resources import ModelPackageGroup, ModelPackage
337
+
338
+ project_name = args.project_name
339
+ if not project_name:
340
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
341
+
342
+ parent_version_arn = args.parent_version_arn
343
+ if not parent_version_arn:
344
+ _error_exit("--parent-version-arn is required", code="MISSING_ARGUMENT")
345
+
346
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
347
+ os.environ["AWS_DEFAULT_REGION"] = region
348
+ os.environ.setdefault("AWS_REGION", region)
349
+
350
+ # Step 1: Create MPG if it doesn't exist (reuses AC-1.1 logic)
351
+ mpg_arn = None
352
+ try:
353
+ mpg = ModelPackageGroup.create(
354
+ model_package_group_name=project_name,
355
+ model_package_group_description=f"Models for {project_name}",
356
+ )
357
+ mpg_arn = mpg.model_package_group_arn
358
+ print(f"Created Model Package Group: {project_name}", file=sys.stderr)
359
+ except Exception as e:
360
+ error_msg = str(e).lower()
361
+ if "already exists" in error_msg or "alreadyexists" in error_msg or "resource in use" in error_msg:
362
+ print(f"Model Package Group '{project_name}' already exists", file=sys.stderr)
363
+ try:
364
+ mpg = ModelPackageGroup.get(model_package_group_name=project_name)
365
+ mpg_arn = mpg.model_package_group_arn
366
+ except Exception:
367
+ account_id = _get_account_id()
368
+ mpg_arn = f"arn:aws:sagemaker:{region}:{account_id}:model-package-group/{project_name}"
369
+ else:
370
+ _error_exit(f"Failed to create Model Package Group: {e}", code="MPG_CREATE_FAILED")
371
+
372
+ # Step 2: Build adapter metadata (AC-2.2)
373
+ metadata = _build_adapter_metadata(args)
374
+
375
+ # Step 2.5: Check for existing adapter with same metadata (dedup, Backlog #024)
376
+ # SFTTrainer with model_package_group_name= auto-registers adapters on completion.
377
+ # If do/register also calls register-adapter, we get duplicate versions.
378
+ # Best-effort dedup: check if latest versions already have matching metadata.
379
+ try:
380
+ from sagemaker.core.resources import ModelPackage as _MP
381
+ packages = _MP.get_all(model_package_group_name=project_name)
382
+ for pkg in packages:
383
+ existing_meta = getattr(pkg, "customer_metadata_properties", None) or {}
384
+ if (existing_meta.get("isAdapter") == "true" and
385
+ existing_meta.get("parentModelVersionArn") == parent_version_arn and
386
+ existing_meta.get("tuneTechnique") == (args.tune_technique or "") and
387
+ existing_meta.get("datasetS3Uri") == (args.dataset_s3_uri or "")):
388
+ # Duplicate detected — SFTTrainer likely already registered this
389
+ existing_arn = pkg.model_package_arn
390
+ existing_version = _extract_version_from_arn(existing_arn)
391
+ print(f"Adapter already registered as version {existing_version} (likely by SFTTrainer)", file=sys.stderr)
392
+ print(f"Supplementing with deployment metadata...", file=sys.stderr)
393
+ # TODO: Update the existing version's metadata with deployment fields
394
+ # For now, output the existing version info instead of creating a duplicate
395
+ _output({
396
+ "mpg_arn": mpg_arn,
397
+ "model_package_arn": existing_arn,
398
+ "version": existing_version,
399
+ "parent_version_arn": parent_version_arn,
400
+ "deduplicated": True,
401
+ })
402
+ except Exception as dedup_err:
403
+ # Dedup check is best-effort — proceed with registration if it fails
404
+ print(f"Dedup check failed (non-fatal): {dedup_err}", file=sys.stderr)
405
+
406
+ # Step 3: Build inference specification
407
+ container_image = args.container_image or ""
408
+ model_data_url = args.model_data_url or ""
409
+
410
+ inference_spec = {
411
+ "Containers": [
412
+ {
413
+ "Image": container_image,
414
+ }
415
+ ],
416
+ "SupportedContentTypes": ["application/json"],
417
+ "SupportedResponseMIMETypes": ["application/json"],
418
+ }
419
+ if model_data_url:
420
+ inference_spec["Containers"][0]["ModelDataUrl"] = model_data_url
421
+
422
+ # Step 4: Create adapter Model Package version (AC-2.1)
423
+ technique = args.tune_technique or "unknown"
424
+ description = f"adapter ({technique}) on {args.instance_type or 'unknown'}, parent: {parent_version_arn}"
425
+
426
+ print(f"Registering adapter version in {project_name}...", file=sys.stderr)
427
+ try:
428
+ pkg = ModelPackage.create(
429
+ model_package_group_name=project_name,
430
+ model_package_description=description,
431
+ inference_specification=inference_spec,
432
+ customer_metadata_properties=metadata,
433
+ model_approval_status="Approved",
434
+ )
435
+
436
+ model_package_arn = pkg.model_package_arn
437
+ version = _extract_version_from_arn(model_package_arn)
438
+
439
+ print(f"Registered adapter version {version}: {model_package_arn}", file=sys.stderr)
440
+ _output({
441
+ "mpg_arn": mpg_arn,
442
+ "model_package_arn": model_package_arn,
443
+ "version": version,
444
+ "parent_version_arn": parent_version_arn,
445
+ })
446
+ except Exception as e:
447
+ _error_exit(f"Failed to register adapter package: {e}", code="ADAPTER_REGISTER_FAILED")
448
+
449
+
450
+ # ── AI Registry + Local Registry Helpers ──────────────────────────────────────
451
+ # Use sagemaker.ai_registry.dataset.DataSet API (SDK v3) when available.
452
+ # Fall back to local JSON-based registry (~/.ml-container-creator/datasets.json)
453
+ # if the import fails (older SDK, Backlog #023).
454
+ # Evaluator API does not exist yet — evaluators always use local JSON.
455
+ # TODO: Once an evaluator registry API is available, upgrade evaluators too.
456
+
457
+ _REGISTRY_DIR = os.path.join(os.path.expanduser("~"), ".ml-container-creator")
458
+ _DATASETS_REGISTRY = os.path.join(_REGISTRY_DIR, "datasets.json")
459
+ _EVALUATORS_REGISTRY = os.path.join(_REGISTRY_DIR, "evaluators.json")
460
+
461
+
462
+ def _check_ai_registry():
463
+ """Verify sagemaker.ai_registry.dataset is available."""
464
+ try:
465
+ from sagemaker.ai_registry.dataset import DataSet # noqa: F401
466
+ return True
467
+ except (ImportError, Exception):
468
+ # ImportError: module not installed
469
+ # Other exceptions: module exists but fails at import (e.g., NoRegionError
470
+ # from boto3 client created at class-definition time in AIRHub)
471
+ return False
472
+
473
+
474
+ def _ensure_registry_dir():
475
+ """Create the registry directory if it doesn't exist."""
476
+ os.makedirs(_REGISTRY_DIR, exist_ok=True)
477
+
478
+
479
+ def _load_registry(path):
480
+ """Load a registry JSON file. Returns list of entries."""
481
+ if not os.path.exists(path):
482
+ return []
483
+ try:
484
+ with open(path, "r") as f:
485
+ data = json.load(f)
486
+ return data if isinstance(data, list) else []
487
+ except (json.JSONDecodeError, IOError):
488
+ return []
489
+
490
+
491
+ def _save_registry(path, entries):
492
+ """Save entries to a registry JSON file."""
493
+ _ensure_registry_dir()
494
+ with open(path, "w") as f:
495
+ json.dump(entries, f, indent=2)
496
+
497
+
498
+ # ── Subcommand: register-dataset ─────────────────────────────────────────────
499
+
500
+
501
+ def cmd_register_dataset(args):
502
+ """Register a dataset into SageMaker AI Registry (preferred) or local registry (fallback).
503
+
504
+ Uses sagemaker.ai_registry.dataset.DataSet API (SDK v3) when available.
505
+ Falls back to local JSON registry if the API is not installed (Backlog #023).
506
+
507
+ Returns JSON: {"name": str, "s3_uri": str, "format": str, "technique": str, "arn": str|null, "registered": bool}
508
+ """
509
+ name = args.name
510
+ s3_uri = args.s3_uri
511
+ data_format = getattr(args, "format", "jsonl")
512
+ technique = args.technique
513
+ row_count = args.row_count
514
+ column_schema = args.column_schema
515
+ project_name = args.project_name or ""
516
+
517
+ # Set region before any sagemaker import (creates boto3 clients at import time)
518
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
519
+ if region:
520
+ os.environ['AWS_DEFAULT_REGION'] = region
521
+ os.environ.setdefault('AWS_REGION', region)
522
+
523
+ if not name:
524
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
525
+ if not s3_uri:
526
+ _error_exit("--s3-uri is required", code="MISSING_ARGUMENT")
527
+
528
+ # Validate column schema if provided
529
+ if column_schema:
530
+ try:
531
+ json.loads(column_schema)
532
+ except json.JSONDecodeError:
533
+ _error_exit("--column-schema must be valid JSON", code="INVALID_ARGUMENT")
534
+
535
+ # Try SageMaker AI Registry API first (Backlog #023)
536
+ if _check_ai_registry():
537
+ try:
538
+ from sagemaker.ai_registry.dataset import DataSet
539
+ from sagemaker.ai_registry.dataset import CustomizationTechnique
540
+
541
+ # Map technique string to enum
542
+ technique_enum = None
543
+ technique_map = {t.name.lower(): t for t in CustomizationTechnique}
544
+ if technique.lower() in technique_map:
545
+ technique_enum = technique_map[technique.lower()]
546
+
547
+ print(f"Registering dataset '{name}' via SageMaker AI Registry...", file=sys.stderr)
548
+ dataset = DataSet.create(
549
+ name=name,
550
+ source=s3_uri,
551
+ customization_technique=technique_enum,
552
+ )
553
+ dataset_arn = dataset.arn
554
+
555
+ # Also write to local registry for offline fallback
556
+ _write_dataset_to_local_registry(
557
+ name=name, s3_uri=s3_uri, data_format=data_format,
558
+ technique=technique, row_count=row_count,
559
+ column_schema=column_schema, project_name=project_name,
560
+ arn=dataset_arn,
561
+ )
562
+
563
+ print(f"Registered dataset '{name}' → {s3_uri} (ARN: {dataset_arn})", file=sys.stderr)
564
+ _output({
565
+ "name": name,
566
+ "s3_uri": s3_uri,
567
+ "format": data_format,
568
+ "technique": technique,
569
+ "arn": dataset_arn,
570
+ "registered": True,
571
+ })
572
+ except Exception as e:
573
+ _warn(f"AI Registry registration failed: {e}. Falling back to local registry.")
574
+ # Fall through to local registry below
575
+ else:
576
+ _warn(
577
+ "sagemaker.ai_registry.dataset.DataSet not available (older SDK). "
578
+ "Using local registry fallback."
579
+ )
580
+
581
+ # Fallback: local JSON registry
582
+ _write_dataset_to_local_registry(
583
+ name=name, s3_uri=s3_uri, data_format=data_format,
584
+ technique=technique, row_count=row_count,
585
+ column_schema=column_schema, project_name=project_name,
586
+ arn=None,
587
+ )
588
+
589
+ print(f"Registered dataset '{name}' → {s3_uri} (local registry)", file=sys.stderr)
590
+ _output({
591
+ "name": name,
592
+ "s3_uri": s3_uri,
593
+ "format": data_format,
594
+ "technique": technique,
595
+ "arn": None,
596
+ "registered": True,
597
+ })
598
+
599
+
600
+ def _write_dataset_to_local_registry(*, name, s3_uri, data_format, technique,
601
+ row_count, column_schema, project_name, arn):
602
+ """Write a dataset entry to the local JSON registry (for offline fallback)."""
603
+ import datetime
604
+
605
+ entries = _load_registry(_DATASETS_REGISTRY)
606
+
607
+ entry = {
608
+ "name": name,
609
+ "s3_uri": s3_uri,
610
+ "format": data_format,
611
+ "technique": technique,
612
+ "row_count": row_count,
613
+ "column_schema": column_schema,
614
+ "project_name": project_name,
615
+ "arn": arn,
616
+ "registered_at": datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z"),
617
+ }
618
+
619
+ # Upsert: replace existing entry with same name, or append
620
+ updated = False
621
+ for i, existing in enumerate(entries):
622
+ if existing.get("name") == name:
623
+ entries[i] = entry
624
+ updated = True
625
+ break
626
+ if not updated:
627
+ entries.append(entry)
628
+
629
+ _save_registry(_DATASETS_REGISTRY, entries)
630
+
631
+
632
+ # ── Subcommand: list-datasets ─────────────────────────────────────────────────
633
+
634
+
635
+ def cmd_list_datasets(args):
636
+ """List all registered datasets from the local registry.
637
+
638
+ Returns JSON: {"datasets": [...]}
639
+ """
640
+ entries = _load_registry(_DATASETS_REGISTRY)
641
+ # Filter by technique if provided
642
+ technique = getattr(args, 'technique', None)
643
+ if technique:
644
+ entries = [e for e in entries if e.get('technique') == technique]
645
+ _output({"datasets": entries})
646
+
647
+
648
+ # ── Subcommand: register-evaluator ───────────────────────────────────────────
649
+
650
+
651
+ def cmd_register_evaluator(args):
652
+ """Register an evaluator into the local registry.
653
+
654
+ Evaluators are Lambda ARN (RLVR) or preference model S3 URI (RLAIF).
655
+ NOTE: The evaluator registry API does not exist yet in the SDK.
656
+ Once an evaluator registry API is available, this should be upgraded
657
+ to use it (similar to how cmd_register_dataset uses DataSet API).
658
+ For now, evaluators always use local JSON.
659
+
660
+ Returns JSON: {"name": str, "type": str, "arn_or_uri": str, "technique": str, "registered": bool}
661
+ """
662
+ name = args.name
663
+ eval_type = args.eval_type
664
+ arn_or_uri = args.arn_or_uri
665
+ technique = args.technique
666
+ description = args.description or ""
667
+ project_name = args.project_name or ""
668
+
669
+ if not name:
670
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
671
+ if not arn_or_uri:
672
+ _error_exit("--arn-or-uri is required", code="MISSING_ARGUMENT")
673
+
674
+ # Load existing evaluators
675
+ entries = _load_registry(_EVALUATORS_REGISTRY)
676
+
677
+ # Build evaluator entry
678
+ import datetime
679
+ entry = {
680
+ "name": name,
681
+ "type": eval_type,
682
+ "arn_or_uri": arn_or_uri,
683
+ "technique": technique,
684
+ "description": description,
685
+ "project_name": project_name,
686
+ "registered_at": datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z"),
687
+ }
688
+
689
+ # Upsert: replace existing entry with same name, or append
690
+ updated = False
691
+ for i, existing in enumerate(entries):
692
+ if existing.get("name") == name:
693
+ entries[i] = entry
694
+ updated = True
695
+ break
696
+ if not updated:
697
+ entries.append(entry)
698
+
699
+ # Save
700
+ _save_registry(_EVALUATORS_REGISTRY, entries)
701
+
702
+ print(f"Registered evaluator '{name}' ({eval_type}) → {arn_or_uri}", file=sys.stderr)
703
+ _output({
704
+ "name": name,
705
+ "type": eval_type,
706
+ "arn_or_uri": arn_or_uri,
707
+ "technique": technique,
708
+ "registered": True,
709
+ })
710
+
711
+
712
+ # ── Subcommand: list-adapters ─────────────────────────────────────────────────
713
+
714
+
715
+ def cmd_list_adapters(args):
716
+ """List adapter versions from the project's Model Package Group.
717
+
718
+ Queries MPG for versions where customer_metadata_properties.isAdapter == "true".
719
+ Falls back to empty list if SageMaker API is unreachable (non-fatal).
720
+
721
+ Returns JSON: {"adapters": [{"arn": str, "version": int, "tuneTechnique": str,
722
+ "datasetS3Uri": str, "parentModelVersionArn": str,
723
+ "createdAt": str, "description": str, "modelDataUrl": str}]}
724
+ """
725
+ _check_sagemaker_core()
726
+
727
+ project_name = args.project_name
728
+ if not project_name:
729
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
730
+
731
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
732
+ os.environ["AWS_DEFAULT_REGION"] = region
733
+ os.environ.setdefault("AWS_REGION", region)
734
+
735
+ try:
736
+ from sagemaker.core.resources import ModelPackage
737
+
738
+ # List all model packages in the group
739
+ packages = ModelPackage.get_all(model_package_group_name=project_name)
740
+
741
+ adapters = []
742
+ for pkg in packages:
743
+ metadata = getattr(pkg, "customer_metadata_properties", None) or {}
744
+ if metadata.get("isAdapter") == "true":
745
+ # Extract version from ARN
746
+ arn = pkg.model_package_arn
747
+ version = _extract_version_from_arn(arn)
748
+
749
+ # Extract model data URL from inference spec if available
750
+ model_data_url = ""
751
+ inference_spec = getattr(pkg, "inference_specification", None)
752
+ if inference_spec and isinstance(inference_spec, dict):
753
+ containers = inference_spec.get("Containers") or inference_spec.get("containers") or []
754
+ if containers:
755
+ model_data_url = containers[0].get("ModelDataUrl", "") or containers[0].get("model_data_url", "")
756
+
757
+ # Get creation time
758
+ created_at = ""
759
+ if hasattr(pkg, "creation_time") and pkg.creation_time:
760
+ created_at = str(pkg.creation_time)
761
+
762
+ adapters.append({
763
+ "arn": arn,
764
+ "version": version,
765
+ "tuneTechnique": metadata.get("tuneTechnique", ""),
766
+ "datasetS3Uri": metadata.get("datasetS3Uri", ""),
767
+ "parentModelVersionArn": metadata.get("parentModelVersionArn", ""),
768
+ "createdAt": created_at,
769
+ "description": getattr(pkg, "model_package_description", "") or "",
770
+ "modelDataUrl": model_data_url,
771
+ })
772
+
773
+ _output({"adapters": adapters})
774
+
775
+ except Exception as e:
776
+ error_msg = str(e).lower()
777
+ # Non-fatal: return empty list on API failures
778
+ if "does not exist" in error_msg or "not found" in error_msg:
779
+ print(f"Model Package Group '{project_name}' not found — no registry adapters", file=sys.stderr)
780
+ else:
781
+ print(f"Warning: Could not query registry for adapters: {e}", file=sys.stderr)
782
+ _output({"adapters": []})
783
+
784
+
785
+ # ── Subcommand: list-models ────────────────────────────────────────────────────
786
+
787
+
788
+ def cmd_list_models(args):
789
+ """List base model versions (non-adapter) from the project's Model Package Group.
790
+
791
+ Queries MPG for versions where customer_metadata_properties.isAdapter != "true".
792
+ Falls back to empty list if SageMaker API is unreachable (non-fatal).
793
+
794
+ Returns JSON: {"models": [{"arn": str, "version": int, "deploymentConfig": str,
795
+ "modelName": str, "instanceType": str,
796
+ "modelDataUrl": str, "containerImage": str,
797
+ "createdAt": str, "description": str}]}
798
+ """
799
+ _check_sagemaker_core()
800
+
801
+ project_name = args.project_name
802
+ if not project_name:
803
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
804
+
805
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
806
+ os.environ["AWS_DEFAULT_REGION"] = region
807
+ os.environ.setdefault("AWS_REGION", region)
808
+
809
+ try:
810
+ from sagemaker.core.resources import ModelPackage
811
+
812
+ # List all model packages in the group
813
+ packages = ModelPackage.get_all(model_package_group_name=project_name)
814
+
815
+ models = []
816
+ for pkg in packages:
817
+ metadata = getattr(pkg, "customer_metadata_properties", None) or {}
818
+ # Filter out adapters — only include base models
819
+ if metadata.get("isAdapter") == "true":
820
+ continue
821
+
822
+ # Extract version from ARN
823
+ arn = pkg.model_package_arn
824
+ version = _extract_version_from_arn(arn)
825
+
826
+ # Extract model data URL and container image from inference spec
827
+ model_data_url = ""
828
+ container_image = ""
829
+ inference_spec = getattr(pkg, "inference_specification", None)
830
+ if inference_spec and isinstance(inference_spec, dict):
831
+ containers = inference_spec.get("Containers") or inference_spec.get("containers") or []
832
+ if containers:
833
+ model_data_url = containers[0].get("ModelDataUrl", "") or containers[0].get("model_data_url", "")
834
+ container_image = containers[0].get("Image", "") or containers[0].get("image", "")
835
+
836
+ # Get creation time
837
+ created_at = ""
838
+ if hasattr(pkg, "creation_time") and pkg.creation_time:
839
+ created_at = str(pkg.creation_time)
840
+
841
+ models.append({
842
+ "arn": arn,
843
+ "version": version,
844
+ "deploymentConfig": metadata.get("deploymentConfig", ""),
845
+ "modelName": metadata.get("modelName", ""),
846
+ "instanceType": metadata.get("instanceType", ""),
847
+ "modelDataUrl": model_data_url,
848
+ "containerImage": container_image,
849
+ "createdAt": created_at,
850
+ "description": getattr(pkg, "model_package_description", "") or "",
851
+ })
852
+
853
+ _output({"models": models})
854
+
855
+ except Exception as e:
856
+ error_msg = str(e).lower()
857
+ # Non-fatal: return empty list on API failures
858
+ if "does not exist" in error_msg or "not found" in error_msg:
859
+ print(f"Model Package Group '{project_name}' not found — no registry models", file=sys.stderr)
860
+ else:
861
+ print(f"Warning: Could not query registry for models: {e}", file=sys.stderr)
862
+ _output({"models": []})
863
+
864
+
865
+ # ── Subcommand: get-version ──────────────────────────────────────────────────
866
+
867
+
868
+ def cmd_get_version(args):
869
+ """Get details for a specific model package version by ARN.
870
+
871
+ Returns JSON with full version metadata including model data URL.
872
+
873
+ Returns JSON: {"arn": str, "version": int, "status": str, "description": str,
874
+ "modelDataUrl": str, "metadata": dict}
875
+ """
876
+ _check_sagemaker_core()
877
+
878
+ version_arn = args.arn
879
+ if not version_arn:
880
+ _error_exit("--arn is required", code="MISSING_ARGUMENT")
881
+
882
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
883
+ os.environ["AWS_DEFAULT_REGION"] = region
884
+ os.environ.setdefault("AWS_REGION", region)
885
+
886
+ try:
887
+ from sagemaker.core.resources import ModelPackage
888
+
889
+ pkg = ModelPackage.get(model_package_arn=version_arn)
890
+
891
+ # Extract model data URL from inference spec
892
+ model_data_url = ""
893
+ inference_spec = getattr(pkg, "inference_specification", None)
894
+ if inference_spec and isinstance(inference_spec, dict):
895
+ containers = inference_spec.get("Containers") or inference_spec.get("containers") or []
896
+ if containers:
897
+ model_data_url = containers[0].get("ModelDataUrl", "") or containers[0].get("model_data_url", "")
898
+
899
+ # Get metadata
900
+ metadata = getattr(pkg, "customer_metadata_properties", None) or {}
901
+
902
+ # Get status
903
+ status = getattr(pkg, "model_approval_status", "") or ""
904
+
905
+ # Get description
906
+ description = getattr(pkg, "model_package_description", "") or ""
907
+
908
+ # Get version from ARN
909
+ version = _extract_version_from_arn(version_arn)
910
+
911
+ _output({
912
+ "arn": version_arn,
913
+ "version": version,
914
+ "status": status,
915
+ "description": description,
916
+ "modelDataUrl": model_data_url,
917
+ "metadata": metadata,
918
+ })
919
+
920
+ except Exception as e:
921
+ _error_exit(f"Failed to get version details for {version_arn}: {e}", code="GET_VERSION_FAILED")
922
+
923
+
924
+ # ── Subcommand: resolve-dataset ──────────────────────────────────────────────
925
+
926
+
927
+ def cmd_resolve_dataset(args):
928
+ """Resolve a registered dataset by name.
929
+
930
+ Uses SageMaker AI Registry DataSet.get() when available, falls back to
931
+ local JSON registry. Includes ARN in output when available (Backlog #023).
932
+
933
+ Returns JSON: {"name": str, "s3_uri": str, "arn": str|null, "format": str, "technique": str, ...}
934
+ or error if not found.
935
+ """
936
+ name = args.name
937
+ if not name:
938
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
939
+
940
+ # Try SageMaker AI Registry API first
941
+ if _check_ai_registry():
942
+ try:
943
+ from sagemaker.ai_registry.dataset import DataSet
944
+
945
+ dataset = DataSet.get(name=name)
946
+ # Build response from AI Registry object
947
+ _output({
948
+ "name": dataset.name if hasattr(dataset, 'name') else name,
949
+ "s3_uri": dataset.source if hasattr(dataset, 'source') else "",
950
+ "arn": dataset.arn if hasattr(dataset, 'arn') else None,
951
+ "format": "jsonl", # AI Registry may not store format
952
+ "technique": getattr(dataset, 'customization_technique', '').lower() if hasattr(dataset, 'customization_technique') else "",
953
+ })
954
+ except Exception as e:
955
+ # AI Registry lookup failed — fall through to local registry
956
+ print(f"AI Registry lookup failed for '{name}': {e}. Trying local registry.", file=sys.stderr)
957
+
958
+ # Fallback: local registry
959
+ entries = _load_registry(_DATASETS_REGISTRY)
960
+ for entry in entries:
961
+ if entry.get("name") == name:
962
+ # Include arn field if present in local registry (Backlog #023)
963
+ output = dict(entry)
964
+ if "arn" not in output:
965
+ output["arn"] = None
966
+ _output(output)
967
+
968
+ _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
969
+
970
+
971
+ # ── Subcommand: resolve-evaluator ────────────────────────────────────────────
972
+
973
+
974
+ def cmd_resolve_evaluator(args):
975
+ """Resolve a registered evaluator by name.
976
+
977
+ Returns JSON: {"name": str, "type": str, "arn_or_uri": str, "technique": str, ...}
978
+ or error if not found.
979
+ """
980
+ name = args.name
981
+ if not name:
982
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
983
+
984
+ entries = _load_registry(_EVALUATORS_REGISTRY)
985
+ for entry in entries:
986
+ if entry.get("name") == name:
987
+ _output(entry)
988
+
989
+ _error_exit(f"Evaluator not found: {name}", code="EVALUATOR_NOT_FOUND")
990
+
991
+
992
+ # ── CLI argument parsing ──────────────────────────────────────────────────────
993
+
994
+
995
+ def main():
996
+ parser = argparse.ArgumentParser(
997
+ description="SageMaker Model Package Group helper for model registration",
998
+ prog=".register_helper.py",
999
+ )
1000
+ subparsers = parser.add_subparsers(dest="command", help="Subcommand")
1001
+
1002
+ # ── create-mpg ────────────────────────────────────────────────────────
1003
+ mpg_parser = subparsers.add_parser(
1004
+ "create-mpg",
1005
+ help="Create a Model Package Group (idempotent)",
1006
+ )
1007
+ mpg_parser.add_argument("--project-name", required=True, help="Project name (used as MPG name)")
1008
+ mpg_parser.add_argument("--region", default=None, help="AWS region")
1009
+
1010
+ # ── register-model ────────────────────────────────────────────────────
1011
+ reg_parser = subparsers.add_parser(
1012
+ "register-model",
1013
+ help="Register a model as a versioned Model Package",
1014
+ )
1015
+ reg_parser.add_argument("--project-name", required=True, help="Project name (used as MPG name)")
1016
+ reg_parser.add_argument("--deployment-config", default="", help="Deployment config (e.g., gpu-vllm)")
1017
+ reg_parser.add_argument("--container-image", default="", help="Container image URI")
1018
+ reg_parser.add_argument("--model-data-url", default="", help="Model data S3 URI")
1019
+ reg_parser.add_argument("--instance-type", default="", help="Instance type (e.g., ml.g5.2xlarge)")
1020
+ reg_parser.add_argument("--architecture", default="", help="Architecture (e.g., transformers)")
1021
+ reg_parser.add_argument("--backend", default="", help="Backend (e.g., vllm)")
1022
+ reg_parser.add_argument("--model-name", default="", help="Model name (e.g., meta-llama/Llama-3.1-8B)")
1023
+ reg_parser.add_argument("--base-image", default="", help="Base container image")
1024
+ reg_parser.add_argument("--model-format", default="", help="Model format (e.g., safetensors)")
1025
+ reg_parser.add_argument("--generator-version", default="", help="Generator version")
1026
+ reg_parser.add_argument("--region", default=None, help="AWS region")
1027
+ reg_parser.add_argument("--role-arn", default="", help="IAM execution role ARN")
1028
+ reg_parser.add_argument("--benchmark-results", default=None, help="Benchmark results JSON string")
1029
+
1030
+ # ── register-adapter ──────────────────────────────────────────────────
1031
+ adapter_parser = subparsers.add_parser(
1032
+ "register-adapter",
1033
+ help="Register an adapter as a versioned Model Package linked to base model",
1034
+ )
1035
+ adapter_parser.add_argument("--project-name", required=True, help="Project name (used as MPG name)")
1036
+ adapter_parser.add_argument("--parent-version-arn", required=True, help="Base model version ARN in the same MPG")
1037
+ adapter_parser.add_argument("--tune-technique", default="", help="Tune technique (sft/dpo/rlvr)")
1038
+ adapter_parser.add_argument("--dataset-s3-uri", default="", help="Training dataset S3 URI")
1039
+ adapter_parser.add_argument("--deployment-config", default="", help="Deployment config (e.g., gpu-vllm)")
1040
+ adapter_parser.add_argument("--container-image", default="", help="Container image URI")
1041
+ adapter_parser.add_argument("--model-data-url", default="", help="Model/adapter data S3 URI")
1042
+ adapter_parser.add_argument("--instance-type", default="", help="Instance type (e.g., ml.g5.2xlarge)")
1043
+ adapter_parser.add_argument("--architecture", default="", help="Architecture (e.g., transformers)")
1044
+ adapter_parser.add_argument("--backend", default="", help="Backend (e.g., vllm)")
1045
+ adapter_parser.add_argument("--model-name", default="", help="Model name (e.g., meta-llama/Llama-3.1-8B)")
1046
+ adapter_parser.add_argument("--base-image", default="", help="Base container image")
1047
+ adapter_parser.add_argument("--model-format", default="", help="Model format (e.g., safetensors)")
1048
+ adapter_parser.add_argument("--generator-version", default="", help="Generator version")
1049
+ adapter_parser.add_argument("--region", default=None, help="AWS region")
1050
+ adapter_parser.add_argument("--role-arn", default="", help="IAM execution role ARN")
1051
+
1052
+ # ── register-dataset ─────────────────────────────────────────────────
1053
+ dataset_parser = subparsers.add_parser(
1054
+ "register-dataset",
1055
+ help="Register a dataset into the local registry (AI Registry fallback)",
1056
+ )
1057
+ dataset_parser.add_argument("--name", required=True, help="Dataset name (unique identifier)")
1058
+ dataset_parser.add_argument("--s3-uri", required=True, help="S3 URI of the dataset")
1059
+ dataset_parser.add_argument("--format", default="jsonl", choices=["jsonl", "parquet", "csv"],
1060
+ help="Dataset format (jsonl/parquet/csv)")
1061
+ dataset_parser.add_argument("--technique", default="sft", choices=["sft", "dpo", "rlaif", "rlvr"],
1062
+ help="Associated tuning technique")
1063
+ dataset_parser.add_argument("--row-count", type=int, default=None, help="Number of rows in dataset")
1064
+ dataset_parser.add_argument("--column-schema", default=None,
1065
+ help="Column schema as JSON string")
1066
+ dataset_parser.add_argument("--project-name", default=None, help="Project name for context")
1067
+
1068
+ # ── list-datasets ─────────────────────────────────────────────────────────
1069
+ list_datasets_parser = subparsers.add_parser(
1070
+ "list-datasets",
1071
+ help="List all registered datasets from the local registry",
1072
+ )
1073
+ list_datasets_parser.add_argument("--technique", default=None, choices=["sft", "dpo", "rlaif", "rlvr"],
1074
+ help="Filter by tuning technique")
1075
+
1076
+ # ── register-evaluator ────────────────────────────────────────────────
1077
+ evaluator_parser = subparsers.add_parser(
1078
+ "register-evaluator",
1079
+ help="Register an evaluator (Lambda ARN or preference model) into the local registry",
1080
+ )
1081
+ evaluator_parser.add_argument("--name", required=True, help="Evaluator name (unique identifier)")
1082
+ evaluator_parser.add_argument("--type", required=True, choices=["lambda", "model"],
1083
+ help="Evaluator type (lambda/model)", dest="eval_type")
1084
+ evaluator_parser.add_argument("--arn-or-uri", required=True,
1085
+ help="Lambda ARN (RLVR) or model S3 URI (RLAIF)")
1086
+ evaluator_parser.add_argument("--technique", required=True, choices=["rlvr", "rlaif"],
1087
+ help="Associated technique (rlvr/rlaif)")
1088
+ evaluator_parser.add_argument("--description", default="", help="Evaluator description")
1089
+ evaluator_parser.add_argument("--project-name", default=None, help="Project name for context")
1090
+
1091
+ # ── list-adapters ─────────────────────────────────────────────────────
1092
+ list_adapters_parser = subparsers.add_parser(
1093
+ "list-adapters",
1094
+ help="List adapter versions from the project's Model Package Group",
1095
+ )
1096
+ list_adapters_parser.add_argument("--project-name", required=True, help="Project name (MPG name)")
1097
+ list_adapters_parser.add_argument("--region", default=None, help="AWS region")
1098
+
1099
+ # ── list-models ───────────────────────────────────────────────────────
1100
+ list_models_parser = subparsers.add_parser(
1101
+ "list-models",
1102
+ help="List base model versions (non-adapter) from the project's Model Package Group",
1103
+ )
1104
+ list_models_parser.add_argument("--project-name", required=True, help="Project name (MPG name)")
1105
+ list_models_parser.add_argument("--region", default=None, help="AWS region")
1106
+
1107
+ # ── get-version ───────────────────────────────────────────────────────
1108
+ get_version_parser = subparsers.add_parser(
1109
+ "get-version",
1110
+ help="Get details for a specific model package version by ARN",
1111
+ )
1112
+ get_version_parser.add_argument("--arn", required=True, help="Model package version ARN")
1113
+ get_version_parser.add_argument("--region", default=None, help="AWS region")
1114
+
1115
+ # ── resolve-dataset ───────────────────────────────────────────────────
1116
+ resolve_dataset_parser = subparsers.add_parser(
1117
+ "resolve-dataset",
1118
+ help="Resolve a registered dataset by name",
1119
+ )
1120
+ resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
1121
+
1122
+ # ── resolve-evaluator ─────────────────────────────────────────────────
1123
+ resolve_evaluator_parser = subparsers.add_parser(
1124
+ "resolve-evaluator",
1125
+ help="Resolve a registered evaluator by name",
1126
+ )
1127
+ resolve_evaluator_parser.add_argument("--name", required=True, help="Evaluator name to resolve")
1128
+
1129
+ # ── Parse and dispatch ────────────────────────────────────────────────
1130
+ args = parser.parse_args()
1131
+
1132
+ if not args.command:
1133
+ parser.print_help()
1134
+ sys.exit(1)
1135
+
1136
+ if args.command == "create-mpg":
1137
+ cmd_create_mpg(args)
1138
+ elif args.command == "register-model":
1139
+ cmd_register_model(args)
1140
+ elif args.command == "register-adapter":
1141
+ cmd_register_adapter(args)
1142
+ elif args.command == "register-dataset":
1143
+ cmd_register_dataset(args)
1144
+ elif args.command == "list-datasets":
1145
+ cmd_list_datasets(args)
1146
+ elif args.command == "register-evaluator":
1147
+ cmd_register_evaluator(args)
1148
+ elif args.command == "list-adapters":
1149
+ cmd_list_adapters(args)
1150
+ elif args.command == "list-models":
1151
+ cmd_list_models(args)
1152
+ elif args.command == "get-version":
1153
+ cmd_get_version(args)
1154
+ elif args.command == "resolve-dataset":
1155
+ cmd_resolve_dataset(args)
1156
+ elif args.command == "resolve-evaluator":
1157
+ cmd_resolve_evaluator(args)
1158
+ else:
1159
+ _error_exit(f"Unknown subcommand: {args.command}", code="UNKNOWN_COMMAND")
1160
+
1161
+
1162
+ if __name__ == "__main__":
1163
+ main()