@aws/ml-container-creator 0.13.5 → 0.15.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/config/parameter-schema-v2.json +33 -5
  2. package/infra/ci-harness/lib/ci-harness-stack.ts +13 -5
  3. package/infra/ci-harness/package-lock.json +121 -111
  4. package/infra/ci-harness/package.json +1 -1
  5. package/package.json +2 -2
  6. package/servers/endpoint-picker/index.js +23 -14
  7. package/servers/instance-sizer/index.js +72 -4
  8. package/servers/instance-sizer/lib/model-resolver.js +28 -2
  9. package/src/app.js +15 -0
  10. package/src/lib/config-loader.js +18 -0
  11. package/src/lib/config-manager.js +6 -1
  12. package/src/lib/dataset-slug.js +152 -0
  13. package/src/lib/generated/cli-options.js +9 -3
  14. package/src/lib/generated/parameter-matrix.js +15 -4
  15. package/src/lib/generated/validation-rules.js +1 -1
  16. package/src/lib/mcp-client.js +15 -1
  17. package/src/lib/mcp-query-runner.js +11 -1
  18. package/src/lib/prompt-runner.js +40 -20
  19. package/src/lib/prompts/feature-prompts.js +1 -1
  20. package/src/lib/template-manager.js +0 -7
  21. package/src/lib/template-variable-resolver.js +51 -1
  22. package/src/lib/tune-config-state.js +14 -1
  23. package/templates/do/.benchmark_writer.py +43 -0
  24. package/templates/do/.register_helper.py +1185 -0
  25. package/templates/do/.tune_helper.py +168 -2
  26. package/templates/do/__pycache__/.adapter_helper.cpython-312.pyc +0 -0
  27. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  28. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  29. package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
  30. package/templates/do/adapter +319 -27
  31. package/templates/do/add-ic +85 -3
  32. package/templates/do/benchmark +28 -8
  33. package/templates/do/config +20 -0
  34. package/templates/do/lib/inference-component.sh +56 -3
  35. package/templates/do/register +557 -6
  36. package/templates/do/test +12 -2
  37. package/templates/do/tune +219 -6
@@ -0,0 +1,1185 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Model Package Group helper for model registration.
6
+
7
+ Subcommands:
8
+ create-mpg - Create a Model Package Group (idempotent)
9
+ register-model - Register a model as a versioned Model Package
10
+ register-adapter - Register an adapter as a versioned Model Package linked to base model
11
+
12
+ Uses sagemaker-core ModelPackageGroup and ModelPackage resource APIs (SDK v3).
13
+ No boto3 sagemaker client per NFR-3.
14
+
15
+ All output is JSON on stdout for bash consumption.
16
+ Diagnostic messages go to stderr.
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import logging
22
+ import os
23
+ import sys
24
+ import warnings
25
+
26
+ # Suppress noisy dependency version warnings
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
28
+ warnings.filterwarnings("ignore", message=".*urllib3.*")
29
+
30
+ # Suppress sagemaker-core INFO/WARNING logging that pollutes stdout
31
+ logging.getLogger("sagemaker.config").setLevel(logging.ERROR)
32
+ logging.getLogger("sagemaker.core").setLevel(logging.ERROR)
33
+ logging.getLogger("sagemaker").setLevel(logging.ERROR)
34
+
35
+
36
+ # ── Constants ─────────────────────────────────────────────────────────────────
37
+
38
+ MAX_METADATA_VALUE_LEN = 256
39
+
40
+
41
+ # ── Utility functions ─────────────────────────────────────────────────────────
42
+
43
+
44
+ def _error_exit(message, code="REGISTRATION_ERROR", exit_code=1):
45
+ """Print error JSON to stdout, message to stderr, and exit."""
46
+ print(f"Error: {message}", file=sys.stderr)
47
+ print(json.dumps({"error": message, "code": code}))
48
+ sys.exit(exit_code)
49
+
50
+
51
+ def _output(data):
52
+ """Print JSON result to stdout and exit 0."""
53
+ print(json.dumps(data))
54
+ sys.exit(0)
55
+
56
+
57
+ def _warn(message):
58
+ """Print warning to stderr."""
59
+ print(f"⚠️ {message}", file=sys.stderr)
60
+
61
+
62
+ # ── Dependency check ──────────────────────────────────────────────────────────
63
+
64
+
65
+ def _check_sagemaker_core():
66
+ """Verify sagemaker-core is installed."""
67
+ try:
68
+ from sagemaker.core.resources import ModelPackageGroup, ModelPackage # noqa: F401
69
+ except ImportError:
70
+ _error_exit(
71
+ "sagemaker-core is not installed. "
72
+ "Please install: pip install 'sagemaker>=3.0.0' (includes sagemaker-core)",
73
+ code="MISSING_DEPENDENCY",
74
+ )
75
+
76
+
77
+ # ── Metadata helpers ──────────────────────────────────────────────────────────
78
+
79
+
80
+ def _truncate_metadata(props):
81
+ """Truncate metadata values exceeding 256 chars with '…' suffix and log warning.
82
+
83
+ Args:
84
+ props: dict of metadata key-value pairs
85
+
86
+ Returns:
87
+ dict with all values as strings, truncated if necessary
88
+ """
89
+ result = {}
90
+ for key, value in props.items():
91
+ str_val = str(value) if value is not None else ""
92
+ if not str_val:
93
+ continue # SageMaker requires min length 1 for metadata values — skip empty
94
+ if len(str_val) > MAX_METADATA_VALUE_LEN:
95
+ _warn(f"Metadata '{key}' truncated ({len(str_val)} → {MAX_METADATA_VALUE_LEN} chars)")
96
+ str_val = str_val[: MAX_METADATA_VALUE_LEN - 1] + "…"
97
+ result[key] = str_val
98
+ return result
99
+
100
+
101
+ def _build_metadata(args):
102
+ """Build customer_metadata_properties dict from CLI args.
103
+
104
+ All values are converted to strings per SageMaker constraints (NFR-1).
105
+ Values exceeding 256 chars are truncated with '…' suffix (AC-1.8).
106
+ """
107
+ props = {
108
+ "deploymentConfig": args.deployment_config or "",
109
+ "architecture": args.architecture or "",
110
+ "backend": args.backend or "",
111
+ "instanceType": args.instance_type or "",
112
+ "modelName": args.model_name or "",
113
+ "baseImage": args.base_image or "",
114
+ "modelFormat": args.model_format or "",
115
+ "generatorVersion": args.generator_version or "",
116
+ "projectName": args.project_name or "",
117
+ }
118
+
119
+ # Add benchmark results if available
120
+ if getattr(args, "benchmark_results", None):
121
+ try:
122
+ bench = json.loads(args.benchmark_results) if isinstance(args.benchmark_results, str) else args.benchmark_results
123
+ if isinstance(bench, dict):
124
+ for bkey, bval in bench.items():
125
+ props[f"benchmark_{bkey}"] = str(bval)
126
+ except (json.JSONDecodeError, TypeError):
127
+ _warn("Could not parse benchmark results, skipping")
128
+
129
+ return _truncate_metadata(props)
130
+
131
+
132
+ def _build_adapter_metadata(args):
133
+ """Build customer_metadata_properties dict for adapter registration.
134
+
135
+ Includes all standard fields plus adapter-specific fields (AC-2.2):
136
+ isAdapter, parentModelVersionArn, tuneTechnique, datasetS3Uri.
137
+ """
138
+ props = {
139
+ "deploymentConfig": args.deployment_config or "",
140
+ "architecture": args.architecture or "",
141
+ "backend": args.backend or "",
142
+ "instanceType": args.instance_type or "",
143
+ "modelName": args.model_name or "",
144
+ "baseImage": args.base_image or "",
145
+ "modelFormat": args.model_format or "",
146
+ "generatorVersion": args.generator_version or "",
147
+ "projectName": args.project_name or "",
148
+ # Adapter-specific metadata (AC-2.2)
149
+ "isAdapter": "true",
150
+ "parentModelVersionArn": args.parent_version_arn or "",
151
+ "tuneTechnique": args.tune_technique or "",
152
+ "datasetS3Uri": args.dataset_s3_uri or "",
153
+ }
154
+
155
+ return _truncate_metadata(props)
156
+
157
+
158
+ # ── Subcommand: create-mpg ────────────────────────────────────────────────────
159
+
160
+
161
+ def cmd_create_mpg(args):
162
+ """Create a Model Package Group (idempotent — handles AlreadyExists).
163
+
164
+ Returns JSON: {"mpg_arn": str, "created": bool}
165
+ """
166
+ _check_sagemaker_core()
167
+
168
+ from sagemaker.core.resources import ModelPackageGroup
169
+
170
+ project_name = args.project_name
171
+ if not project_name:
172
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
173
+
174
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
175
+ os.environ["AWS_DEFAULT_REGION"] = region
176
+ os.environ.setdefault("AWS_REGION", region)
177
+
178
+ print(f"Creating Model Package Group: {project_name}", file=sys.stderr)
179
+
180
+ try:
181
+ mpg = ModelPackageGroup.create(
182
+ model_package_group_name=project_name,
183
+ model_package_group_description=f"Models for {project_name}",
184
+ )
185
+ mpg_arn = mpg.model_package_group_arn
186
+ _output({"mpg_arn": mpg_arn, "created": True})
187
+ except Exception as e:
188
+ error_msg = str(e).lower()
189
+ if "already exists" in error_msg or "alreadyexists" in error_msg or "resource in use" in error_msg:
190
+ # MPG already exists — retrieve its ARN
191
+ print(f"Model Package Group '{project_name}' already exists", file=sys.stderr)
192
+ try:
193
+ mpg = ModelPackageGroup.get(model_package_group_name=project_name)
194
+ mpg_arn = mpg.model_package_group_arn
195
+ _output({"mpg_arn": mpg_arn, "created": False})
196
+ except Exception as get_err:
197
+ # Construct the ARN from known pattern
198
+ account_id = _get_account_id()
199
+ mpg_arn = f"arn:aws:sagemaker:{region}:{account_id}:model-package-group/{project_name}"
200
+ _output({"mpg_arn": mpg_arn, "created": False})
201
+ else:
202
+ _error_exit(f"Failed to create Model Package Group: {e}", code="MPG_CREATE_FAILED")
203
+
204
+
205
+ def _get_account_id():
206
+ """Get AWS account ID from STS."""
207
+ try:
208
+ import boto3
209
+ sts = boto3.client("sts")
210
+ return sts.get_caller_identity()["Account"]
211
+ except Exception:
212
+ return "unknown"
213
+
214
+
215
+ # ── Subcommand: register-model ────────────────────────────────────────────────
216
+
217
+
218
+ def cmd_register_model(args):
219
+ """Register a model as a versioned Model Package in the project's MPG.
220
+
221
+ Creates the MPG if it doesn't exist (AC-1.1), then creates a new
222
+ ModelPackageVersion (AC-1.2, AC-1.7). Stores metadata in
223
+ customer_metadata_properties (AC-1.3, AC-1.8).
224
+
225
+ Returns JSON: {"mpg_arn": str, "model_package_arn": str, "version": int}
226
+ """
227
+ _check_sagemaker_core()
228
+
229
+ from sagemaker.core.resources import ModelPackageGroup, ModelPackage
230
+
231
+ project_name = args.project_name
232
+ if not project_name:
233
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
234
+
235
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
236
+ os.environ["AWS_DEFAULT_REGION"] = region
237
+ os.environ.setdefault("AWS_REGION", region)
238
+
239
+ # Step 1: Create MPG if it doesn't exist (AC-1.1)
240
+ mpg_arn = None
241
+ try:
242
+ mpg = ModelPackageGroup.create(
243
+ model_package_group_name=project_name,
244
+ model_package_group_description=f"Models for {project_name}",
245
+ )
246
+ mpg_arn = mpg.model_package_group_arn
247
+ print(f"Created Model Package Group: {project_name}", file=sys.stderr)
248
+ except Exception as e:
249
+ error_msg = str(e).lower()
250
+ if "already exists" in error_msg or "alreadyexists" in error_msg or "resource in use" in error_msg:
251
+ print(f"Model Package Group '{project_name}' already exists", file=sys.stderr)
252
+ try:
253
+ mpg = ModelPackageGroup.get(model_package_group_name=project_name)
254
+ mpg_arn = mpg.model_package_group_arn
255
+ except Exception:
256
+ # Construct ARN from known pattern
257
+ account_id = _get_account_id()
258
+ mpg_arn = f"arn:aws:sagemaker:{region}:{account_id}:model-package-group/{project_name}"
259
+ else:
260
+ _error_exit(f"Failed to create Model Package Group: {e}", code="MPG_CREATE_FAILED")
261
+
262
+ # Step 2: Build metadata (AC-1.3, AC-1.8)
263
+ metadata = _build_metadata(args)
264
+
265
+ # Step 3: Build inference specification
266
+ container_image = args.container_image or ""
267
+ model_data_url = args.model_data_url or ""
268
+
269
+ # Step 4: Create Model Package version (AC-1.2, AC-1.7)
270
+ description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
271
+
272
+ print(f"Registering model version in {project_name}...", file=sys.stderr)
273
+ try:
274
+ # Use boto3 directly — sagemaker-core v2.14 has a KeyError bug in ModelPackage.create()
275
+ # where it tries to read response["ModelPackageName"] but the API returns "ModelPackageArn".
276
+ import boto3
277
+ sm_client = boto3.client("sagemaker", region_name=region)
278
+
279
+ create_params = {
280
+ "ModelPackageGroupName": project_name,
281
+ "ModelPackageDescription": description,
282
+ "ModelApprovalStatus": "Approved",
283
+ }
284
+ if container_image:
285
+ create_params["InferenceSpecification"] = {
286
+ "Containers": [{"Image": container_image}],
287
+ "SupportedContentTypes": ["application/json"],
288
+ "SupportedResponseMIMETypes": ["application/json"],
289
+ }
290
+ if model_data_url:
291
+ create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
292
+ if model_data_url:
293
+ if "InferenceSpecification" not in create_params:
294
+ # Store model data URL in metadata if no container image
295
+ if not metadata:
296
+ metadata = {}
297
+ metadata["modelDataUrl"] = model_data_url[:1024]
298
+ if metadata:
299
+ create_params["CustomerMetadataProperties"] = metadata
300
+
301
+ response = sm_client.create_model_package(**create_params)
302
+ model_package_arn = response["ModelPackageArn"]
303
+
304
+ # Extract version number from ARN (format: .../project-name/version)
305
+ version = _extract_version_from_arn(model_package_arn)
306
+
307
+ print(f"Registered model version {version}: {model_package_arn}", file=sys.stderr)
308
+ _output({
309
+ "mpg_arn": mpg_arn,
310
+ "model_package_arn": model_package_arn,
311
+ "version": version,
312
+ })
313
+ except Exception as e:
314
+ _error_exit(f"Failed to register model package: {e}", code="MODEL_REGISTER_FAILED")
315
+
316
+
317
+ def _extract_version_from_arn(arn):
318
+ """Extract version number from a model package ARN.
319
+
320
+ ARN format: arn:aws:sagemaker:<region>:<account>:model-package/<group>/<version>
321
+ """
322
+ try:
323
+ parts = arn.split("/")
324
+ return int(parts[-1])
325
+ except (ValueError, IndexError):
326
+ return 0
327
+
328
+
329
+ # ── Subcommand: register-adapter ─────────────────────────────────────────────
330
+
331
+
332
+ def cmd_register_adapter(args):
333
+ """Register an adapter as a versioned Model Package linked to its base model.
334
+
335
+ Creates the MPG if it doesn't exist (reuses AC-1.1 logic), then creates a new
336
+ ModelPackageVersion with adapter-specific metadata (AC-2.1, AC-2.2):
337
+ - isAdapter=true
338
+ - parentModelVersionArn (links to base model version)
339
+ - tuneTechnique (sft/dpo/rlvr)
340
+ - datasetS3Uri (training dataset location)
341
+
342
+ Returns JSON: {"mpg_arn": str, "model_package_arn": str, "version": int, "parent_version_arn": str}
343
+ """
344
+ _check_sagemaker_core()
345
+
346
+ from sagemaker.core.resources import ModelPackageGroup, ModelPackage
347
+
348
+ project_name = args.project_name
349
+ if not project_name:
350
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
351
+
352
+ parent_version_arn = args.parent_version_arn
353
+ if not parent_version_arn:
354
+ _error_exit("--parent-version-arn is required", code="MISSING_ARGUMENT")
355
+
356
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
357
+ os.environ["AWS_DEFAULT_REGION"] = region
358
+ os.environ.setdefault("AWS_REGION", region)
359
+
360
+ # Step 1: Create MPG if it doesn't exist (reuses AC-1.1 logic)
361
+ mpg_arn = None
362
+ try:
363
+ mpg = ModelPackageGroup.create(
364
+ model_package_group_name=project_name,
365
+ model_package_group_description=f"Models for {project_name}",
366
+ )
367
+ mpg_arn = mpg.model_package_group_arn
368
+ print(f"Created Model Package Group: {project_name}", file=sys.stderr)
369
+ except Exception as e:
370
+ error_msg = str(e).lower()
371
+ if "already exists" in error_msg or "alreadyexists" in error_msg or "resource in use" in error_msg:
372
+ print(f"Model Package Group '{project_name}' already exists", file=sys.stderr)
373
+ try:
374
+ mpg = ModelPackageGroup.get(model_package_group_name=project_name)
375
+ mpg_arn = mpg.model_package_group_arn
376
+ except Exception:
377
+ account_id = _get_account_id()
378
+ mpg_arn = f"arn:aws:sagemaker:{region}:{account_id}:model-package-group/{project_name}"
379
+ else:
380
+ _error_exit(f"Failed to create Model Package Group: {e}", code="MPG_CREATE_FAILED")
381
+
382
+ # Step 2: Build adapter metadata (AC-2.2)
383
+ metadata = _build_adapter_metadata(args)
384
+
385
+ # Step 2.5: Check for existing adapter with same metadata (dedup, Backlog #024)
386
+ # SFTTrainer with model_package_group_name= auto-registers adapters on completion.
387
+ # If do/register also calls register-adapter, we get duplicate versions.
388
+ # Best-effort dedup: check if latest versions already have matching metadata.
389
+ try:
390
+ from sagemaker.core.resources import ModelPackage as _MP
391
+ packages = _MP.get_all(model_package_group_name=project_name)
392
+ for pkg in packages:
393
+ existing_meta = getattr(pkg, "customer_metadata_properties", None) or {}
394
+ if (existing_meta.get("isAdapter") == "true" and
395
+ existing_meta.get("parentModelVersionArn") == parent_version_arn and
396
+ existing_meta.get("tuneTechnique") == (args.tune_technique or "") and
397
+ existing_meta.get("datasetS3Uri") == (args.dataset_s3_uri or "")):
398
+ # Duplicate detected — SFTTrainer likely already registered this
399
+ existing_arn = pkg.model_package_arn
400
+ existing_version = _extract_version_from_arn(existing_arn)
401
+ print(f"Adapter already registered as version {existing_version} (likely by SFTTrainer)", file=sys.stderr)
402
+ print(f"Supplementing with deployment metadata...", file=sys.stderr)
403
+ # TODO: Update the existing version's metadata with deployment fields
404
+ # For now, output the existing version info instead of creating a duplicate
405
+ _output({
406
+ "mpg_arn": mpg_arn,
407
+ "model_package_arn": existing_arn,
408
+ "version": existing_version,
409
+ "parent_version_arn": parent_version_arn,
410
+ "deduplicated": True,
411
+ })
412
+ except Exception as dedup_err:
413
+ # Dedup check is best-effort — proceed with registration if it fails
414
+ print(f"Dedup check failed (non-fatal): {dedup_err}", file=sys.stderr)
415
+
416
+ # Step 3: Build inference specification
417
+ container_image = args.container_image or ""
418
+ model_data_url = args.model_data_url or ""
419
+
420
+ # Step 4: Create adapter Model Package version (AC-2.1)
421
+ technique = args.tune_technique or "unknown"
422
+ description = f"adapter ({technique}) on {args.instance_type or 'unknown'}, parent: {parent_version_arn}"
423
+
424
+ print(f"Registering adapter version in {project_name}...", file=sys.stderr)
425
+ try:
426
+ # Use boto3 directly — sagemaker-core v2.14 has a KeyError bug in ModelPackage.create()
427
+ import boto3
428
+ sm_client = boto3.client("sagemaker", region_name=region)
429
+
430
+ create_params = {
431
+ "ModelPackageGroupName": project_name,
432
+ "ModelPackageDescription": description,
433
+ "ModelApprovalStatus": "Approved",
434
+ }
435
+ if container_image:
436
+ create_params["InferenceSpecification"] = {
437
+ "Containers": [{"Image": container_image}],
438
+ "SupportedContentTypes": ["application/json"],
439
+ "SupportedResponseMIMETypes": ["application/json"],
440
+ }
441
+ if model_data_url:
442
+ create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
443
+ elif model_data_url:
444
+ if not metadata:
445
+ metadata = {}
446
+ metadata["modelDataUrl"] = model_data_url[:1024]
447
+ if metadata:
448
+ create_params["CustomerMetadataProperties"] = metadata
449
+
450
+ response = sm_client.create_model_package(**create_params)
451
+ model_package_arn = response["ModelPackageArn"]
452
+
453
+ version = _extract_version_from_arn(model_package_arn)
454
+
455
+ print(f"Registered adapter version {version}: {model_package_arn}", file=sys.stderr)
456
+ _output({
457
+ "mpg_arn": mpg_arn,
458
+ "model_package_arn": model_package_arn,
459
+ "version": version,
460
+ "parent_version_arn": parent_version_arn,
461
+ })
462
+ except Exception as e:
463
+ _error_exit(f"Failed to register adapter package: {e}", code="ADAPTER_REGISTER_FAILED")
464
+
465
+
466
+ # ── AI Registry + Local Registry Helpers ──────────────────────────────────────
467
+ # Use sagemaker.ai_registry.dataset.DataSet API (SDK v3) when available.
468
+ # Fall back to local JSON-based registry (~/.ml-container-creator/datasets.json)
469
+ # if the import fails (older SDK, Backlog #023).
470
+ # Evaluator API does not exist yet — evaluators always use local JSON.
471
+ # TODO: Once an evaluator registry API is available, upgrade evaluators too.
472
+
473
+ _REGISTRY_DIR = os.path.join(os.path.expanduser("~"), ".ml-container-creator")
474
+ _DATASETS_REGISTRY = os.path.join(_REGISTRY_DIR, "datasets.json")
475
+ _EVALUATORS_REGISTRY = os.path.join(_REGISTRY_DIR, "evaluators.json")
476
+
477
+
478
+ def _check_ai_registry():
479
+ """Verify sagemaker.ai_registry.dataset is available."""
480
+ try:
481
+ from sagemaker.ai_registry.dataset import DataSet # noqa: F401
482
+ return True
483
+ except (ImportError, Exception):
484
+ # ImportError: module not installed
485
+ # Other exceptions: module exists but fails at import (e.g., NoRegionError
486
+ # from boto3 client created at class-definition time in AIRHub)
487
+ return False
488
+
489
+
490
+ def _ensure_registry_dir():
491
+ """Create the registry directory if it doesn't exist."""
492
+ os.makedirs(_REGISTRY_DIR, exist_ok=True)
493
+
494
+
495
+ def _load_registry(path):
496
+ """Load a registry JSON file. Returns list of entries."""
497
+ if not os.path.exists(path):
498
+ return []
499
+ try:
500
+ with open(path, "r") as f:
501
+ data = json.load(f)
502
+ return data if isinstance(data, list) else []
503
+ except (json.JSONDecodeError, IOError):
504
+ return []
505
+
506
+
507
+ def _save_registry(path, entries):
508
+ """Save entries to a registry JSON file."""
509
+ _ensure_registry_dir()
510
+ with open(path, "w") as f:
511
+ json.dump(entries, f, indent=2)
512
+
513
+
514
+ # ── Subcommand: register-dataset ─────────────────────────────────────────────
515
+
516
+
517
+ def cmd_register_dataset(args):
518
+ """Register a dataset into SageMaker AI Registry (preferred) or local registry (fallback).
519
+
520
+ Uses sagemaker.ai_registry.dataset.DataSet API (SDK v3) when available.
521
+ Falls back to local JSON registry if the API is not installed (Backlog #023).
522
+
523
+ Returns JSON: {"name": str, "s3_uri": str, "format": str, "technique": str, "arn": str|null, "registered": bool}
524
+ """
525
+ name = args.name
526
+ s3_uri = args.s3_uri
527
+ data_format = getattr(args, "format", "jsonl")
528
+ technique = args.technique
529
+ row_count = args.row_count
530
+ column_schema = args.column_schema
531
+ project_name = args.project_name or ""
532
+
533
+ # Set region before any sagemaker import (creates boto3 clients at import time)
534
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
535
+ if region:
536
+ os.environ['AWS_DEFAULT_REGION'] = region
537
+ os.environ.setdefault('AWS_REGION', region)
538
+
539
+ if not name:
540
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
541
+ if not s3_uri:
542
+ _error_exit("--s3-uri is required", code="MISSING_ARGUMENT")
543
+
544
+ # Validate column schema if provided
545
+ if column_schema:
546
+ try:
547
+ json.loads(column_schema)
548
+ except json.JSONDecodeError:
549
+ _error_exit("--column-schema must be valid JSON", code="INVALID_ARGUMENT")
550
+
551
+ # Try SageMaker AI Registry API first (Backlog #023)
552
+ if _check_ai_registry():
553
+ try:
554
+ from sagemaker.ai_registry.dataset import DataSet
555
+ from sagemaker.ai_registry.dataset import CustomizationTechnique
556
+
557
+ # Map technique string to enum
558
+ technique_enum = None
559
+ technique_map = {t.name.lower(): t for t in CustomizationTechnique}
560
+ if technique.lower() in technique_map:
561
+ technique_enum = technique_map[technique.lower()]
562
+
563
+ print(f"Registering dataset '{name}' via SageMaker AI Registry...", file=sys.stderr)
564
+ dataset = DataSet.create(
565
+ name=name,
566
+ source=s3_uri,
567
+ customization_technique=technique_enum,
568
+ )
569
+ dataset_arn = dataset.arn
570
+
571
+ # Also write to local registry for offline fallback
572
+ _write_dataset_to_local_registry(
573
+ name=name, s3_uri=s3_uri, data_format=data_format,
574
+ technique=technique, row_count=row_count,
575
+ column_schema=column_schema, project_name=project_name,
576
+ arn=dataset_arn,
577
+ )
578
+
579
+ print(f"Registered dataset '{name}' → {s3_uri} (ARN: {dataset_arn})", file=sys.stderr)
580
+ _output({
581
+ "name": name,
582
+ "s3_uri": s3_uri,
583
+ "format": data_format,
584
+ "technique": technique,
585
+ "arn": dataset_arn,
586
+ "registered": True,
587
+ })
588
+ except Exception as e:
589
+ _warn(f"AI Registry registration failed: {e}. Falling back to local registry.")
590
+ # Fall through to local registry below
591
+ else:
592
+ _warn(
593
+ "sagemaker.ai_registry.dataset.DataSet not available (older SDK). "
594
+ "Using local registry fallback."
595
+ )
596
+
597
+ # Fallback: local JSON registry
598
+ _write_dataset_to_local_registry(
599
+ name=name, s3_uri=s3_uri, data_format=data_format,
600
+ technique=technique, row_count=row_count,
601
+ column_schema=column_schema, project_name=project_name,
602
+ arn=None,
603
+ )
604
+
605
+ print(f"Registered dataset '{name}' → {s3_uri} (local registry)", file=sys.stderr)
606
+ _output({
607
+ "name": name,
608
+ "s3_uri": s3_uri,
609
+ "format": data_format,
610
+ "technique": technique,
611
+ "arn": None,
612
+ "registered": True,
613
+ })
614
+
615
+
616
+ def _write_dataset_to_local_registry(*, name, s3_uri, data_format, technique,
617
+ row_count, column_schema, project_name, arn):
618
+ """Write a dataset entry to the local JSON registry (for offline fallback)."""
619
+ import datetime
620
+
621
+ entries = _load_registry(_DATASETS_REGISTRY)
622
+
623
+ entry = {
624
+ "name": name,
625
+ "s3_uri": s3_uri,
626
+ "format": data_format,
627
+ "technique": technique,
628
+ "row_count": row_count,
629
+ "column_schema": column_schema,
630
+ "project_name": project_name,
631
+ "arn": arn,
632
+ "registered_at": datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z"),
633
+ }
634
+
635
+ # Upsert: replace existing entry with same name, or append
636
+ updated = False
637
+ for i, existing in enumerate(entries):
638
+ if existing.get("name") == name:
639
+ entries[i] = entry
640
+ updated = True
641
+ break
642
+ if not updated:
643
+ entries.append(entry)
644
+
645
+ _save_registry(_DATASETS_REGISTRY, entries)
646
+
647
+
648
+ # ── Subcommand: list-datasets ─────────────────────────────────────────────────
649
+
650
+
651
+ def cmd_list_datasets(args):
652
+ """List all registered datasets from the local registry.
653
+
654
+ Returns JSON: {"datasets": [...]}
655
+ """
656
+ entries = _load_registry(_DATASETS_REGISTRY)
657
+ # Filter by technique if provided
658
+ technique = getattr(args, 'technique', None)
659
+ if technique:
660
+ entries = [e for e in entries if e.get('technique') == technique]
661
+ _output({"datasets": entries})
662
+
663
+
664
+ # ── Subcommand: register-evaluator ───────────────────────────────────────────
665
+
666
+
667
+ def cmd_register_evaluator(args):
668
+ """Register an evaluator into the local registry.
669
+
670
+ Evaluators are Lambda ARN (RLVR) or preference model S3 URI (RLAIF).
671
+ NOTE: The evaluator registry API does not exist yet in the SDK.
672
+ Once an evaluator registry API is available, this should be upgraded
673
+ to use it (similar to how cmd_register_dataset uses DataSet API).
674
+ For now, evaluators always use local JSON.
675
+
676
+ Returns JSON: {"name": str, "type": str, "arn_or_uri": str, "technique": str, "registered": bool}
677
+ """
678
+ name = args.name
679
+ eval_type = args.eval_type
680
+ arn_or_uri = args.arn_or_uri
681
+ technique = args.technique
682
+ description = args.description or ""
683
+ project_name = args.project_name or ""
684
+
685
+ if not name:
686
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
687
+ if not arn_or_uri:
688
+ _error_exit("--arn-or-uri is required", code="MISSING_ARGUMENT")
689
+
690
+ # Load existing evaluators
691
+ entries = _load_registry(_EVALUATORS_REGISTRY)
692
+
693
+ # Build evaluator entry
694
+ import datetime
695
+ entry = {
696
+ "name": name,
697
+ "type": eval_type,
698
+ "arn_or_uri": arn_or_uri,
699
+ "technique": technique,
700
+ "description": description,
701
+ "project_name": project_name,
702
+ "registered_at": datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z"),
703
+ }
704
+
705
+ # Upsert: replace existing entry with same name, or append
706
+ updated = False
707
+ for i, existing in enumerate(entries):
708
+ if existing.get("name") == name:
709
+ entries[i] = entry
710
+ updated = True
711
+ break
712
+ if not updated:
713
+ entries.append(entry)
714
+
715
+ # Save
716
+ _save_registry(_EVALUATORS_REGISTRY, entries)
717
+
718
+ print(f"Registered evaluator '{name}' ({eval_type}) → {arn_or_uri}", file=sys.stderr)
719
+ _output({
720
+ "name": name,
721
+ "type": eval_type,
722
+ "arn_or_uri": arn_or_uri,
723
+ "technique": technique,
724
+ "registered": True,
725
+ })
726
+
727
+
728
+ # ── Subcommand: list-adapters ─────────────────────────────────────────────────
729
+
730
+
731
+ def cmd_list_adapters(args):
732
+ """List adapter versions from the project's Model Package Group.
733
+
734
+ Queries MPG for versions where customer_metadata_properties.isAdapter == "true".
735
+ Falls back to empty list if SageMaker API is unreachable (non-fatal).
736
+
737
+ Returns JSON: {"adapters": [{"arn": str, "version": int, "tuneTechnique": str,
738
+ "datasetS3Uri": str, "parentModelVersionArn": str,
739
+ "createdAt": str, "description": str, "modelDataUrl": str}]}
740
+ """
741
+ _check_sagemaker_core()
742
+
743
+ project_name = args.project_name
744
+ if not project_name:
745
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
746
+
747
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
748
+ os.environ["AWS_DEFAULT_REGION"] = region
749
+ os.environ.setdefault("AWS_REGION", region)
750
+
751
+ try:
752
+ from sagemaker.core.resources import ModelPackage
753
+
754
+ # List all model packages in the group
755
+ packages = ModelPackage.get_all(model_package_group_name=project_name)
756
+
757
+ adapters = []
758
+ for pkg in packages:
759
+ metadata = getattr(pkg, "customer_metadata_properties", None) or {}
760
+ if metadata.get("isAdapter") == "true":
761
+ # Extract version from ARN
762
+ arn = pkg.model_package_arn
763
+ version = _extract_version_from_arn(arn)
764
+
765
+ # Extract model data URL from inference spec if available
766
+ model_data_url = ""
767
+ inference_spec = getattr(pkg, "inference_specification", None)
768
+ if inference_spec and isinstance(inference_spec, dict):
769
+ containers = inference_spec.get("Containers") or inference_spec.get("containers") or []
770
+ if containers:
771
+ model_data_url = containers[0].get("ModelDataUrl", "") or containers[0].get("model_data_url", "")
772
+
773
+ # Get creation time
774
+ created_at = ""
775
+ if hasattr(pkg, "creation_time") and pkg.creation_time:
776
+ created_at = str(pkg.creation_time)
777
+
778
+ adapters.append({
779
+ "arn": arn,
780
+ "version": version,
781
+ "tuneTechnique": metadata.get("tuneTechnique", ""),
782
+ "datasetS3Uri": metadata.get("datasetS3Uri", ""),
783
+ "parentModelVersionArn": metadata.get("parentModelVersionArn", ""),
784
+ "createdAt": created_at,
785
+ "description": getattr(pkg, "model_package_description", "") or "",
786
+ "modelDataUrl": model_data_url,
787
+ })
788
+
789
+ _output({"adapters": adapters})
790
+
791
+ except Exception as e:
792
+ error_msg = str(e).lower()
793
+ # Non-fatal: return empty list on API failures
794
+ if "does not exist" in error_msg or "not found" in error_msg:
795
+ print(f"Model Package Group '{project_name}' not found — no registry adapters", file=sys.stderr)
796
+ else:
797
+ print(f"Warning: Could not query registry for adapters: {e}", file=sys.stderr)
798
+ _output({"adapters": []})
799
+
800
+
801
+ # ── Subcommand: list-models ────────────────────────────────────────────────────
802
+
803
+
804
+ def cmd_list_models(args):
805
+ """List base model versions (non-adapter) from the project's Model Package Group.
806
+
807
+ Queries MPG for versions where customer_metadata_properties.isAdapter != "true".
808
+ Falls back to empty list if SageMaker API is unreachable (non-fatal).
809
+
810
+ Returns JSON: {"models": [{"arn": str, "version": int, "deploymentConfig": str,
811
+ "modelName": str, "instanceType": str,
812
+ "modelDataUrl": str, "containerImage": str,
813
+ "createdAt": str, "description": str}]}
814
+ """
815
+ _check_sagemaker_core()
816
+
817
+ project_name = args.project_name
818
+ if not project_name:
819
+ _error_exit("--project-name is required", code="MISSING_ARGUMENT")
820
+
821
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
822
+ os.environ["AWS_DEFAULT_REGION"] = region
823
+ os.environ.setdefault("AWS_REGION", region)
824
+
825
+ try:
826
+ from sagemaker.core.resources import ModelPackage
827
+
828
+ # List all model packages in the group
829
+ packages = ModelPackage.get_all(model_package_group_name=project_name)
830
+
831
+ models = []
832
+ for pkg in packages:
833
+ metadata = getattr(pkg, "customer_metadata_properties", None) or {}
834
+ # Filter out adapters — only include base models
835
+ if metadata.get("isAdapter") == "true":
836
+ continue
837
+
838
+ # Extract version from ARN
839
+ arn = pkg.model_package_arn
840
+ version = _extract_version_from_arn(arn)
841
+
842
+ # Extract model data URL and container image from inference spec
843
+ model_data_url = ""
844
+ container_image = ""
845
+ inference_spec = getattr(pkg, "inference_specification", None)
846
+ if inference_spec and isinstance(inference_spec, dict):
847
+ containers = inference_spec.get("Containers") or inference_spec.get("containers") or []
848
+ if containers:
849
+ model_data_url = containers[0].get("ModelDataUrl", "") or containers[0].get("model_data_url", "")
850
+ container_image = containers[0].get("Image", "") or containers[0].get("image", "")
851
+
852
+ # Get creation time
853
+ created_at = ""
854
+ if hasattr(pkg, "creation_time") and pkg.creation_time:
855
+ created_at = str(pkg.creation_time)
856
+
857
+ models.append({
858
+ "arn": arn,
859
+ "version": version,
860
+ "deploymentConfig": metadata.get("deploymentConfig", ""),
861
+ "modelName": metadata.get("modelName", ""),
862
+ "instanceType": metadata.get("instanceType", ""),
863
+ "modelDataUrl": model_data_url,
864
+ "containerImage": container_image,
865
+ "createdAt": created_at,
866
+ "description": getattr(pkg, "model_package_description", "") or "",
867
+ })
868
+
869
+ _output({"models": models})
870
+
871
+ except Exception as e:
872
+ error_msg = str(e).lower()
873
+ # Non-fatal: return empty list on API failures
874
+ if "does not exist" in error_msg or "not found" in error_msg:
875
+ print(f"Model Package Group '{project_name}' not found — no registry models", file=sys.stderr)
876
+ else:
877
+ print(f"Warning: Could not query registry for models: {e}", file=sys.stderr)
878
+ _output({"models": []})
879
+
880
+
881
+ # ── Subcommand: get-version ──────────────────────────────────────────────────
882
+
883
+
884
+ def cmd_get_version(args):
885
+ """Get details for a specific model package version by ARN.
886
+
887
+ Returns JSON with full version metadata including model data URL.
888
+
889
+ Returns JSON: {"arn": str, "version": int, "status": str, "description": str,
890
+ "modelDataUrl": str, "metadata": dict}
891
+ """
892
+ _check_sagemaker_core()
893
+
894
+ version_arn = args.arn
895
+ if not version_arn:
896
+ _error_exit("--arn is required", code="MISSING_ARGUMENT")
897
+
898
+ region = args.region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION", "us-west-2")
899
+ os.environ["AWS_DEFAULT_REGION"] = region
900
+ os.environ.setdefault("AWS_REGION", region)
901
+
902
+ try:
903
+ from sagemaker.core.resources import ModelPackage
904
+
905
+ pkg = ModelPackage.get(model_package_arn=version_arn)
906
+
907
+ # Extract model data URL from inference spec
908
+ model_data_url = ""
909
+ inference_spec = getattr(pkg, "inference_specification", None)
910
+ if inference_spec and isinstance(inference_spec, dict):
911
+ containers = inference_spec.get("Containers") or inference_spec.get("containers") or []
912
+ if containers:
913
+ model_data_url = containers[0].get("ModelDataUrl", "") or containers[0].get("model_data_url", "")
914
+
915
+ # Get metadata
916
+ metadata = getattr(pkg, "customer_metadata_properties", None) or {}
917
+
918
+ # Get status
919
+ status = getattr(pkg, "model_approval_status", "") or ""
920
+
921
+ # Get description
922
+ description = getattr(pkg, "model_package_description", "") or ""
923
+
924
+ # Get version from ARN
925
+ version = _extract_version_from_arn(version_arn)
926
+
927
+ _output({
928
+ "arn": version_arn,
929
+ "version": version,
930
+ "status": status,
931
+ "description": description,
932
+ "modelDataUrl": model_data_url,
933
+ "metadata": metadata,
934
+ })
935
+
936
+ except Exception as e:
937
+ _error_exit(f"Failed to get version details for {version_arn}: {e}", code="GET_VERSION_FAILED")
938
+
939
+
940
+ # ── Subcommand: resolve-dataset ──────────────────────────────────────────────
941
+
942
+
943
+ def cmd_resolve_dataset(args):
944
+ """Resolve a registered dataset by name.
945
+
946
+ Uses SageMaker AI Registry DataSet.get() when available, falls back to
947
+ local JSON registry. Includes ARN in output when available (Backlog #023).
948
+
949
+ Returns JSON: {"name": str, "s3_uri": str, "arn": str|null, "format": str, "technique": str, ...}
950
+ or error if not found.
951
+ """
952
+ name = args.name
953
+ if not name:
954
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
955
+
956
+ # Try SageMaker AI Registry API first
957
+ if _check_ai_registry():
958
+ try:
959
+ from sagemaker.ai_registry.dataset import DataSet
960
+
961
+ dataset = DataSet.get(name=name)
962
+ # Build response from AI Registry object
963
+ _output({
964
+ "name": dataset.name if hasattr(dataset, 'name') else name,
965
+ "s3_uri": dataset.source if hasattr(dataset, 'source') else "",
966
+ "arn": dataset.arn if hasattr(dataset, 'arn') else None,
967
+ "format": "jsonl", # AI Registry may not store format
968
+ "technique": getattr(dataset, 'customization_technique', '').lower() if hasattr(dataset, 'customization_technique') else "",
969
+ })
970
+ except Exception as e:
971
+ # AI Registry lookup failed — fall through to local registry
972
+ print(f"AI Registry lookup failed for '{name}': {e}. Trying local registry.", file=sys.stderr)
973
+
974
+ # Fallback: local registry
975
+ entries = _load_registry(_DATASETS_REGISTRY)
976
+ for entry in entries:
977
+ if entry.get("name") == name:
978
+ # Include arn field if present in local registry (Backlog #023)
979
+ output = dict(entry)
980
+ if "arn" not in output:
981
+ output["arn"] = None
982
+ _output(output)
983
+
984
+ _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
985
+
986
+
987
+ # ── Subcommand: resolve-evaluator ────────────────────────────────────────────
988
+
989
+
990
+ def cmd_resolve_evaluator(args):
991
+ """Resolve a registered evaluator by name.
992
+
993
+ Returns JSON: {"name": str, "type": str, "arn_or_uri": str, "technique": str, ...}
994
+ or error if not found.
995
+ """
996
+ name = args.name
997
+ if not name:
998
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
999
+
1000
+ entries = _load_registry(_EVALUATORS_REGISTRY)
1001
+ for entry in entries:
1002
+ if entry.get("name") == name:
1003
+ _output(entry)
1004
+
1005
+ _error_exit(f"Evaluator not found: {name}", code="EVALUATOR_NOT_FOUND")
1006
+
1007
+
1008
+ # ── CLI argument parsing ──────────────────────────────────────────────────────
1009
+
1010
+
1011
+ def main():
1012
+ parser = argparse.ArgumentParser(
1013
+ description="SageMaker Model Package Group helper for model registration",
1014
+ prog=".register_helper.py",
1015
+ )
1016
+ subparsers = parser.add_subparsers(dest="command", help="Subcommand")
1017
+
1018
+ # ── create-mpg ────────────────────────────────────────────────────────
1019
+ mpg_parser = subparsers.add_parser(
1020
+ "create-mpg",
1021
+ help="Create a Model Package Group (idempotent)",
1022
+ )
1023
+ mpg_parser.add_argument("--project-name", required=True, help="Project name (used as MPG name)")
1024
+ mpg_parser.add_argument("--region", default=None, help="AWS region")
1025
+
1026
+ # ── register-model ────────────────────────────────────────────────────
1027
+ reg_parser = subparsers.add_parser(
1028
+ "register-model",
1029
+ help="Register a model as a versioned Model Package",
1030
+ )
1031
+ reg_parser.add_argument("--project-name", required=True, help="Project name (used as MPG name)")
1032
+ reg_parser.add_argument("--deployment-config", default="", help="Deployment config (e.g., gpu-vllm)")
1033
+ reg_parser.add_argument("--container-image", default="", help="Container image URI")
1034
+ reg_parser.add_argument("--model-data-url", default="", help="Model data S3 URI")
1035
+ reg_parser.add_argument("--instance-type", default="", help="Instance type (e.g., ml.g5.2xlarge)")
1036
+ reg_parser.add_argument("--architecture", default="", help="Architecture (e.g., transformers)")
1037
+ reg_parser.add_argument("--backend", default="", help="Backend (e.g., vllm)")
1038
+ reg_parser.add_argument("--model-name", default="", help="Model name (e.g., meta-llama/Llama-3.1-8B)")
1039
+ reg_parser.add_argument("--base-image", default="", help="Base container image")
1040
+ reg_parser.add_argument("--model-format", default="", help="Model format (e.g., safetensors)")
1041
+ reg_parser.add_argument("--generator-version", default="", help="Generator version")
1042
+ reg_parser.add_argument("--region", default=None, help="AWS region")
1043
+ reg_parser.add_argument("--role-arn", default="", help="IAM execution role ARN")
1044
+ reg_parser.add_argument("--benchmark-results", default=None, help="Benchmark results JSON string")
1045
+
1046
+ # ── register-adapter ──────────────────────────────────────────────────
1047
+ adapter_parser = subparsers.add_parser(
1048
+ "register-adapter",
1049
+ help="Register an adapter as a versioned Model Package linked to base model",
1050
+ )
1051
+ adapter_parser.add_argument("--project-name", required=True, help="Project name (used as MPG name)")
1052
+ adapter_parser.add_argument("--parent-version-arn", required=True, help="Base model version ARN in the same MPG")
1053
+ adapter_parser.add_argument("--tune-technique", default="", help="Tune technique (sft/dpo/rlvr)")
1054
+ adapter_parser.add_argument("--dataset-s3-uri", default="", help="Training dataset S3 URI")
1055
+ adapter_parser.add_argument("--deployment-config", default="", help="Deployment config (e.g., gpu-vllm)")
1056
+ adapter_parser.add_argument("--container-image", default="", help="Container image URI")
1057
+ adapter_parser.add_argument("--model-data-url", default="", help="Model/adapter data S3 URI")
1058
+ adapter_parser.add_argument("--instance-type", default="", help="Instance type (e.g., ml.g5.2xlarge)")
1059
+ adapter_parser.add_argument("--architecture", default="", help="Architecture (e.g., transformers)")
1060
+ adapter_parser.add_argument("--backend", default="", help="Backend (e.g., vllm)")
1061
+ adapter_parser.add_argument("--model-name", default="", help="Model name (e.g., meta-llama/Llama-3.1-8B)")
1062
+ adapter_parser.add_argument("--base-image", default="", help="Base container image")
1063
+ adapter_parser.add_argument("--model-format", default="", help="Model format (e.g., safetensors)")
1064
+ adapter_parser.add_argument("--generator-version", default="", help="Generator version")
1065
+ adapter_parser.add_argument("--region", default=None, help="AWS region")
1066
+ adapter_parser.add_argument("--role-arn", default="", help="IAM execution role ARN")
1067
+
1068
+ # ── register-dataset ─────────────────────────────────────────────────
1069
+ dataset_parser = subparsers.add_parser(
1070
+ "register-dataset",
1071
+ help="Register a dataset into the local registry (AI Registry fallback)",
1072
+ )
1073
+ dataset_parser.add_argument("--name", required=True, help="Dataset name (unique identifier)")
1074
+ dataset_parser.add_argument("--s3-uri", required=True, help="S3 URI of the dataset")
1075
+ dataset_parser.add_argument("--format", default="jsonl", choices=["jsonl", "parquet", "csv"],
1076
+ help="Dataset format (jsonl/parquet/csv)")
1077
+ dataset_parser.add_argument("--technique", default="sft", choices=["sft", "dpo", "rlaif", "rlvr"],
1078
+ help="Associated tuning technique")
1079
+ dataset_parser.add_argument("--row-count", type=int, default=None, help="Number of rows in dataset")
1080
+ dataset_parser.add_argument("--column-schema", default=None,
1081
+ help="Column schema as JSON string")
1082
+ dataset_parser.add_argument("--project-name", default=None, help="Project name for context")
1083
+
1084
+ # ── list-datasets ─────────────────────────────────────────────────────────
1085
+ list_datasets_parser = subparsers.add_parser(
1086
+ "list-datasets",
1087
+ help="List all registered datasets from the local registry",
1088
+ )
1089
+ list_datasets_parser.add_argument("--technique", default=None, choices=["sft", "dpo", "rlaif", "rlvr"],
1090
+ help="Filter by tuning technique")
1091
+
1092
+ # ── register-evaluator ────────────────────────────────────────────────
1093
+ evaluator_parser = subparsers.add_parser(
1094
+ "register-evaluator",
1095
+ help="Register an evaluator (Lambda ARN or preference model) into the local registry",
1096
+ )
1097
+ evaluator_parser.add_argument("--name", required=True, help="Evaluator name (unique identifier)")
1098
+ evaluator_parser.add_argument("--type", required=True, choices=["lambda", "model"],
1099
+ help="Evaluator type (lambda/model)", dest="eval_type")
1100
+ evaluator_parser.add_argument("--arn-or-uri", required=True,
1101
+ help="Lambda ARN (RLVR) or model S3 URI (RLAIF)")
1102
+ evaluator_parser.add_argument("--technique", required=True, choices=["rlvr", "rlaif"],
1103
+ help="Associated technique (rlvr/rlaif)")
1104
+ evaluator_parser.add_argument("--description", default="", help="Evaluator description")
1105
+ evaluator_parser.add_argument("--project-name", default=None, help="Project name for context")
1106
+
1107
+ # ── list-adapters ─────────────────────────────────────────────────────
1108
+ list_adapters_parser = subparsers.add_parser(
1109
+ "list-adapters",
1110
+ help="List adapter versions from the project's Model Package Group",
1111
+ )
1112
+ list_adapters_parser.add_argument("--project-name", required=True, help="Project name (MPG name)")
1113
+ list_adapters_parser.add_argument("--region", default=None, help="AWS region")
1114
+
1115
+ # ── list-models ───────────────────────────────────────────────────────
1116
+ list_models_parser = subparsers.add_parser(
1117
+ "list-models",
1118
+ help="List base model versions (non-adapter) from the project's Model Package Group",
1119
+ )
1120
+ list_models_parser.add_argument("--project-name", required=True, help="Project name (MPG name)")
1121
+ list_models_parser.add_argument("--region", default=None, help="AWS region")
1122
+
1123
+ # ── get-version ───────────────────────────────────────────────────────
1124
+ get_version_parser = subparsers.add_parser(
1125
+ "get-version",
1126
+ help="Get details for a specific model package version by ARN",
1127
+ )
1128
+ get_version_parser.add_argument("--arn", required=True, help="Model package version ARN")
1129
+ get_version_parser.add_argument("--region", default=None, help="AWS region")
1130
+
1131
+ # ── resolve-dataset ───────────────────────────────────────────────────
1132
+ resolve_dataset_parser = subparsers.add_parser(
1133
+ "resolve-dataset",
1134
+ help="Resolve a registered dataset by name",
1135
+ )
1136
+ resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
1137
+
1138
+ # ── resolve-evaluator ─────────────────────────────────────────────────
1139
+ resolve_evaluator_parser = subparsers.add_parser(
1140
+ "resolve-evaluator",
1141
+ help="Resolve a registered evaluator by name",
1142
+ )
1143
+ resolve_evaluator_parser.add_argument("--name", required=True, help="Evaluator name to resolve")
1144
+
1145
+ # ── Parse and dispatch ────────────────────────────────────────────────
1146
+ args = parser.parse_args()
1147
+
1148
+ if not args.command:
1149
+ parser.print_help()
1150
+ sys.exit(1)
1151
+
1152
+ # Set region before any sagemaker-core import (creates boto3 clients at import time)
1153
+ region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
1154
+ if region:
1155
+ os.environ['AWS_DEFAULT_REGION'] = region
1156
+ os.environ.setdefault('AWS_REGION', region)
1157
+
1158
+ if args.command == "create-mpg":
1159
+ cmd_create_mpg(args)
1160
+ elif args.command == "register-model":
1161
+ cmd_register_model(args)
1162
+ elif args.command == "register-adapter":
1163
+ cmd_register_adapter(args)
1164
+ elif args.command == "register-dataset":
1165
+ cmd_register_dataset(args)
1166
+ elif args.command == "list-datasets":
1167
+ cmd_list_datasets(args)
1168
+ elif args.command == "register-evaluator":
1169
+ cmd_register_evaluator(args)
1170
+ elif args.command == "list-adapters":
1171
+ cmd_list_adapters(args)
1172
+ elif args.command == "list-models":
1173
+ cmd_list_models(args)
1174
+ elif args.command == "get-version":
1175
+ cmd_get_version(args)
1176
+ elif args.command == "resolve-dataset":
1177
+ cmd_resolve_dataset(args)
1178
+ elif args.command == "resolve-evaluator":
1179
+ cmd_resolve_evaluator(args)
1180
+ else:
1181
+ _error_exit(f"Unknown subcommand: {args.command}", code="UNKNOWN_COMMAND")
1182
+
1183
+
1184
+ if __name__ == "__main__":
1185
+ main()