@aws/ml-container-creator 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +1 -1
  2. package/bin/cli.js +1 -1
  3. package/config/tune-catalog.json +303 -1
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +3 -2
  6. package/servers/base-image-picker/index.js +65 -18
  7. package/servers/instance-sizer/index.js +32 -0
  8. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  9. package/servers/lib/catalogs/model-arch-support.json +51 -0
  10. package/servers/lib/catalogs/model-servers.json +2842 -1516
  11. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  12. package/src/app.js +6 -4
  13. package/src/lib/bootstrap-command-handler.js +12 -2
  14. package/src/lib/bootstrap-profile-manager.js +16 -0
  15. package/src/lib/cross-cutting-checker.js +6 -1
  16. package/src/lib/generated/cli-options.js +1 -1
  17. package/src/lib/generated/parameter-matrix.js +1 -1
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +110 -3
  20. package/src/lib/prompt-runner.js +66 -22
  21. package/src/lib/template-variable-resolver.js +8 -0
  22. package/src/lib/train-config-builder.js +339 -0
  23. package/templates/do/.benchmark_writer.py +3 -0
  24. package/templates/do/.eval_helper.py +409 -0
  25. package/templates/do/.register_helper.py +185 -11
  26. package/templates/do/.train_build_request.py +102 -113
  27. package/templates/do/.train_helper.py +433 -0
  28. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  29. package/templates/do/adapter +157 -0
  30. package/templates/do/benchmark +60 -3
  31. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  32. package/templates/do/evaluate +272 -0
  33. package/templates/do/lib/resolve-instance.sh +155 -0
  34. package/templates/do/register +5 -0
  35. package/templates/do/test +1 -0
  36. package/templates/do/train +879 -126
  37. package/templates/do/training/config.yaml +83 -11
  38. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  39. package/templates/do/training/dpo/defaults.yaml +26 -0
  40. package/templates/do/training/dpo/prompts.json +8 -0
  41. package/templates/do/training/dpo/train.py +363 -0
  42. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  43. package/templates/do/training/sft/defaults.yaml +18 -0
  44. package/templates/do/training/sft/prompts.json +7 -0
  45. package/templates/do/training/sft/train.py +310 -0
  46. package/templates/do/tune +11 -2
  47. package/templates/do/.train_poll_parser.py +0 -135
  48. package/templates/do/.train_status_parser.py +0 -187
  49. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -0,0 +1,409 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Model Quality Evaluation Helper.
6
+
7
+ Subcommands:
8
+ evaluate - Run evaluation against deployed endpoint, compute metrics
9
+ eval-write - Write evaluation results to S3/Athena (Parquet)
10
+
11
+ All output is JSON on stdout for bash consumption.
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import math
17
+ import os
18
+ import sys
19
+ import time
20
+
21
+
22
+ # ── Utility functions ─────────────────────────────────────────────────────────
23
+
24
+ def _error_exit(message):
25
+ """Print JSON error to stdout and exit."""
26
+ print(json.dumps({"error": True, "message": message}))
27
+ sys.exit(1)
28
+
29
+
30
+ def _output(data):
31
+ """Print JSON result to stdout."""
32
+ print(json.dumps(data))
33
+ sys.exit(0)
34
+
35
+
36
+ # ── Endpoint invocation ───────────────────────────────────────────────────────
37
+
38
+ def _invoke_endpoint(endpoint_name, ic_name, region, payload):
39
+ """Invoke SageMaker endpoint via boto3 runtime.
40
+
41
+ Uses InvokeEndpoint with InferenceComponentName header for IC routing.
42
+ Payload should be an OpenAI-compatible chat completion request.
43
+
44
+ Returns: parsed JSON response dict
45
+ """
46
+ import boto3
47
+
48
+ client = boto3.client('sagemaker-runtime', region_name=region)
49
+
50
+ kwargs = {
51
+ 'EndpointName': endpoint_name,
52
+ 'ContentType': 'application/json',
53
+ 'Body': json.dumps(payload),
54
+ }
55
+ if ic_name:
56
+ kwargs['InferenceComponentName'] = ic_name
57
+
58
+ try:
59
+ response = client.invoke_endpoint(**kwargs)
60
+ body = response['Body'].read().decode('utf-8')
61
+ return json.loads(body)
62
+ except Exception as e:
63
+ return {"error": str(e)}
64
+
65
+
66
+ def _score_text(endpoint_name, ic_name, region, prompt, completion):
67
+ """Score a completion by getting its logprobs via the endpoint.
68
+
69
+ Sends prompt + completion and requests logprobs for the completion tokens.
70
+ Returns sum of token logprobs, or None if logprobs unavailable.
71
+ """
72
+ messages = [
73
+ {"role": "user", "content": prompt},
74
+ {"role": "assistant", "content": completion},
75
+ ]
76
+
77
+ payload = {
78
+ "messages": messages,
79
+ "max_tokens": 1,
80
+ "temperature": 0.0,
81
+ "logprobs": True,
82
+ "top_logprobs": 1,
83
+ }
84
+
85
+ response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
86
+
87
+ if "error" in response:
88
+ return None
89
+
90
+ # Extract logprobs from response
91
+ try:
92
+ choices = response.get("choices", [])
93
+ if not choices:
94
+ return None
95
+
96
+ # For scoring, we need the logprobs of the completion tokens
97
+ # The response format varies — try OpenAI-compatible format
98
+ logprobs_data = choices[0].get("logprobs")
99
+ if logprobs_data and "content" in logprobs_data:
100
+ token_logprobs = [t.get("logprob", 0.0) for t in logprobs_data["content"]]
101
+ return sum(token_logprobs) if token_logprobs else None
102
+
103
+ return None
104
+ except (KeyError, TypeError, IndexError):
105
+ return None
106
+
107
+
108
+ def _generate_response(endpoint_name, ic_name, region, prompt, max_tokens=256):
109
+ """Generate a response from the endpoint for generation-based metrics.
110
+
111
+ Returns: generated text string, or None on failure.
112
+ """
113
+ payload = {
114
+ "messages": [{"role": "user", "content": prompt}],
115
+ "max_tokens": max_tokens,
116
+ "temperature": 0.0,
117
+ }
118
+
119
+ response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
120
+
121
+ if "error" in response:
122
+ return None
123
+
124
+ try:
125
+ choices = response.get("choices", [])
126
+ if choices:
127
+ return choices[0].get("message", {}).get("content", "")
128
+ return None
129
+ except (KeyError, TypeError, IndexError):
130
+ return None
131
+
132
+
133
+ # ── Metric computation ────────────────────────────────────────────────────────
134
+
135
+ def _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples):
136
+ """Compute SFT evaluation metrics.
137
+
138
+ Metrics: perplexity (via logprobs), avg_response_length, format_compliance, exact_match
139
+ """
140
+ metrics = {}
141
+ logprob_scores = []
142
+ response_lengths = []
143
+ exact_matches = 0
144
+ total = 0
145
+
146
+ for i, record in enumerate(dataset):
147
+ if samples and i >= samples:
148
+ break
149
+
150
+ prompt = record.get("prompt", "")
151
+ reference = record.get("reference", "")
152
+
153
+ if not prompt:
154
+ continue
155
+
156
+ total += 1
157
+
158
+ # Score via logprobs (for perplexity)
159
+ if reference:
160
+ score = _score_text(endpoint_name, ic_name, region, prompt, reference)
161
+ if score is not None:
162
+ # Approximate per-token logprob
163
+ # score is sum of logprobs; we need per-token average
164
+ # Estimate token count from character length (rough: 4 chars/token)
165
+ est_tokens = max(1, len(reference) // 4)
166
+ logprob_scores.append(score / est_tokens)
167
+
168
+ # Generate response (for length and exact match)
169
+ generated = _generate_response(endpoint_name, ic_name, region, prompt)
170
+ if generated is not None:
171
+ response_lengths.append(len(generated.split()))
172
+ if reference and generated.strip() == reference.strip():
173
+ exact_matches += 1
174
+
175
+ # Compute aggregate metrics
176
+ if logprob_scores:
177
+ avg_logprob = sum(logprob_scores) / len(logprob_scores)
178
+ metrics["perplexity"] = round(math.exp(-avg_logprob), 4)
179
+
180
+ if response_lengths:
181
+ metrics["avg_response_length"] = round(sum(response_lengths) / len(response_lengths), 1)
182
+
183
+ if total > 0:
184
+ metrics["exact_match_accuracy"] = round(exact_matches / total, 4)
185
+
186
+ metrics["samples_scored"] = total
187
+
188
+ return metrics
189
+
190
+
191
+ def _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples):
192
+ """Compute DPO evaluation metrics.
193
+
194
+ Metrics: reward_accuracy, avg_chosen_logprob, avg_rejected_logprob, reward_margin
195
+ """
196
+ metrics = {}
197
+ chosen_scores = []
198
+ rejected_scores = []
199
+ reward_correct = 0
200
+ total = 0
201
+
202
+ for i, record in enumerate(dataset):
203
+ if samples and i >= samples:
204
+ break
205
+
206
+ prompt = record.get("prompt", "")
207
+ chosen = record.get("chosen", "")
208
+ rejected = record.get("rejected", "")
209
+
210
+ if not prompt or not chosen or not rejected:
211
+ continue
212
+
213
+ total += 1
214
+
215
+ # Score chosen
216
+ chosen_score = _score_text(endpoint_name, ic_name, region, prompt, chosen)
217
+ # Score rejected
218
+ rejected_score = _score_text(endpoint_name, ic_name, region, prompt, rejected)
219
+
220
+ if chosen_score is not None and rejected_score is not None:
221
+ chosen_scores.append(chosen_score)
222
+ rejected_scores.append(rejected_score)
223
+ if chosen_score > rejected_score:
224
+ reward_correct += 1
225
+
226
+ # Compute aggregate metrics
227
+ scored = len(chosen_scores)
228
+ if scored > 0:
229
+ metrics["reward_accuracy"] = round(reward_correct / scored, 4)
230
+ metrics["avg_chosen_logprob"] = round(sum(chosen_scores) / scored, 4)
231
+ metrics["avg_rejected_logprob"] = round(sum(rejected_scores) / scored, 4)
232
+ metrics["reward_margin"] = round(
233
+ (sum(chosen_scores) - sum(rejected_scores)) / scored, 4
234
+ )
235
+
236
+ metrics["pairs_scored"] = scored
237
+ metrics["samples_evaluated"] = total
238
+
239
+ return metrics
240
+
241
+
242
+ # ── Dataset loading ───────────────────────────────────────────────────────────
243
+
244
+ def _load_eval_dataset(eval_dataset_path):
245
+ """Load evaluation dataset from local JSONL file or S3.
246
+
247
+ For this MVP, expects a local JSONL file path.
248
+ S3 and HF resolution is handled by the bash wrapper.
249
+
250
+ Returns: list of dicts
251
+ """
252
+ records = []
253
+
254
+ if not eval_dataset_path:
255
+ _error_exit("No evaluation dataset specified. Use --eval-dataset <path>")
256
+
257
+ # Handle S3 paths by downloading
258
+ if eval_dataset_path.startswith("s3://"):
259
+ import boto3
260
+ import tempfile
261
+ s3 = boto3.client('s3')
262
+ bucket = eval_dataset_path.split('/')[2]
263
+ key = '/'.join(eval_dataset_path.split('/')[3:])
264
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl')
265
+ s3.download_file(bucket, key, tmp.name)
266
+ eval_dataset_path = tmp.name
267
+
268
+ # Load JSONL
269
+ try:
270
+ with open(eval_dataset_path, 'r') as f:
271
+ for line in f:
272
+ line = line.strip()
273
+ if line:
274
+ records.append(json.loads(line))
275
+ except (IOError, json.JSONDecodeError) as e:
276
+ _error_exit(f"Failed to load eval dataset: {e}")
277
+
278
+ if not records:
279
+ _error_exit("Evaluation dataset is empty")
280
+
281
+ return records
282
+
283
+
284
+ # ── cmd_evaluate ──────────────────────────────────────────────────────────────
285
+
286
+ def cmd_evaluate(args):
287
+ """Run evaluation against deployed endpoint.
288
+
289
+ Returns JSON with metrics and metadata.
290
+ """
291
+ endpoint_name = args.endpoint_name
292
+ ic_name = args.ic_name
293
+ region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
294
+ technique = args.technique or ''
295
+ samples = int(args.samples) if args.samples else None
296
+
297
+ # Load eval dataset
298
+ dataset = _load_eval_dataset(args.eval_dataset)
299
+
300
+ # Determine technique and compute metrics
301
+ if technique.lower() == 'dpo':
302
+ metrics = _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples)
303
+ else:
304
+ # Default to SFT metrics (works for any technique)
305
+ metrics = _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples)
306
+
307
+ # Build result
308
+ result = {
309
+ "adapter_name": args.ic_name,
310
+ "technique": technique or "sft",
311
+ "model": os.environ.get("MODEL_NAME", ""),
312
+ "eval_dataset": args.eval_dataset or "",
313
+ "samples_evaluated": metrics.get("samples_evaluated", metrics.get("samples_scored", 0)),
314
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
315
+ "metrics": metrics,
316
+ }
317
+
318
+ _output(result)
319
+
320
+
321
+ # ── cmd_eval_write ────────────────────────────────────────────────────────────
322
+
323
+ def cmd_eval_write(args):
324
+ """Write evaluation results to S3 as Parquet for Athena.
325
+
326
+ Reads a results JSON file and converts to Parquet format.
327
+ """
328
+ results_file = args.results_file
329
+ bucket = args.bucket
330
+ region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
331
+
332
+ # Read results
333
+ try:
334
+ with open(results_file, 'r') as f:
335
+ data = json.load(f)
336
+ except (IOError, json.JSONDecodeError) as e:
337
+ _error_exit(f"Failed to read results file: {e}")
338
+
339
+ adapter_name = data.get("adapter_name", "unknown")
340
+ technique = data.get("technique", "unknown")
341
+ model = data.get("model", "unknown")
342
+ timestamp = data.get("timestamp", time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))
343
+
344
+ # Build Parquet record
345
+ record = {
346
+ "project_name": os.environ.get("PROJECT_NAME", ""),
347
+ "model_name": model,
348
+ "adapter_name": adapter_name,
349
+ "technique": technique,
350
+ "eval_dataset": data.get("eval_dataset", ""),
351
+ "samples_evaluated": data.get("samples_evaluated", 0),
352
+ "metrics": json.dumps(data.get("metrics", {})),
353
+ "timestamp": timestamp,
354
+ "region": region,
355
+ }
356
+
357
+ # Write as JSON lines (Athena can read JSON as well as Parquet)
358
+ # For MVP, write as JSON lines to S3. Parquet requires pyarrow dep.
359
+ s3_key = f"evaluations/model={model}/adapter={adapter_name}/{timestamp.replace(':', '-')}.json"
360
+ s3_uri = f"s3://{bucket}/{s3_key}"
361
+
362
+ try:
363
+ import boto3
364
+ s3 = boto3.client('s3', region_name=region)
365
+ s3.put_object(
366
+ Bucket=bucket,
367
+ Key=s3_key,
368
+ Body=json.dumps(record),
369
+ ContentType='application/json',
370
+ )
371
+ _output({"written": True, "s3_uri": s3_uri})
372
+ except Exception as e:
373
+ _error_exit(f"Failed to write to S3: {e}")
374
+
375
+
376
+ # ── Main ──────────────────────────────────────────────────────────────────────
377
+
378
+ def main():
379
+ parser = argparse.ArgumentParser(description='Model Quality Evaluation Helper')
380
+ subparsers = parser.add_subparsers(dest='command', required=True)
381
+
382
+ # evaluate
383
+ eval_parser = subparsers.add_parser('evaluate', help='Run evaluation')
384
+ eval_parser.add_argument('--endpoint-name', required=True)
385
+ eval_parser.add_argument('--ic-name', required=True)
386
+ eval_parser.add_argument('--region')
387
+ eval_parser.add_argument('--technique', default='')
388
+ eval_parser.add_argument('--eval-dataset', default='')
389
+ eval_parser.add_argument('--samples', default='')
390
+ eval_parser.add_argument('--metrics', default='')
391
+
392
+ # eval-write
393
+ write_parser = subparsers.add_parser('eval-write', help='Write results to S3')
394
+ write_parser.add_argument('--results-file', required=True)
395
+ write_parser.add_argument('--bucket', required=True)
396
+ write_parser.add_argument('--region')
397
+
398
+ args = parser.parse_args()
399
+
400
+ if args.command == 'evaluate':
401
+ cmd_evaluate(args)
402
+ elif args.command == 'eval-write':
403
+ cmd_eval_write(args)
404
+ else:
405
+ _error_exit(f"Unknown command: {args.command}")
406
+
407
+
408
+ if __name__ == '__main__':
409
+ main()
@@ -112,6 +112,74 @@ def _truncate_metadata(props):
112
112
  return result
113
113
 
114
114
 
115
+ def _inject_eval_metrics(metadata, args):
116
+ """Inject evaluation metrics from .mlcc/eval-results/ into metadata.
117
+
118
+ Looks for eval results matching the adapter name or project.
119
+ Adds metrics with 'eval_' prefix (G4 AC-3.1, AC-3.2).
120
+ Non-fatal: if no eval results exist, returns metadata unchanged.
121
+
122
+ Args:
123
+ metadata: existing metadata dict (may be None)
124
+ args: parsed args with project_name, adapter name hints
125
+
126
+ Returns:
127
+ metadata dict with eval metrics injected (or unchanged)
128
+ """
129
+ if metadata is None:
130
+ metadata = {}
131
+
132
+ # Determine eval results directory (relative to script location)
133
+ # Convention: .mlcc/eval-results/<adapter-or-ic-name>.json
134
+ script_dir = os.path.dirname(os.path.abspath(__file__))
135
+ eval_results_dir = os.path.join(script_dir, "..", ".mlcc", "eval-results")
136
+
137
+ if not os.path.isdir(eval_results_dir):
138
+ return metadata
139
+
140
+ # Try to find eval results for this adapter
141
+ # Prioritize: adapter name from args > any available result
142
+ adapter_name = getattr(args, 'adapter_name', '') or ''
143
+
144
+ # Search for matching eval result file
145
+ eval_file = None
146
+ if adapter_name:
147
+ candidate = os.path.join(eval_results_dir, f"{adapter_name}.json")
148
+ if os.path.isfile(candidate):
149
+ eval_file = candidate
150
+
151
+ # If no specific adapter match, try to find any recent result
152
+ if not eval_file:
153
+ try:
154
+ json_files = [f for f in os.listdir(eval_results_dir) if f.endswith('.json')]
155
+ if json_files:
156
+ # Use most recently modified
157
+ json_files.sort(key=lambda f: os.path.getmtime(os.path.join(eval_results_dir, f)), reverse=True)
158
+ eval_file = os.path.join(eval_results_dir, json_files[0])
159
+ except OSError:
160
+ pass
161
+
162
+ if not eval_file:
163
+ return metadata
164
+
165
+ # Load and inject metrics
166
+ try:
167
+ with open(eval_file, 'r') as f:
168
+ eval_data = json.load(f)
169
+ metrics = eval_data.get("metrics", {})
170
+ for metric_name, metric_value in metrics.items():
171
+ # Add with eval_ prefix, truncate to 256 chars
172
+ key = f"eval_{metric_name}"
173
+ str_val = str(metric_value)[:MAX_METADATA_VALUE_LEN]
174
+ metadata[key] = str_val
175
+ if metrics:
176
+ _warn(f"Injected {len(metrics)} eval metric(s) from {os.path.basename(eval_file)}")
177
+ except (IOError, json.JSONDecodeError, KeyError):
178
+ pass # Non-fatal — skip eval metrics if file is unreadable
179
+
180
+ return metadata
181
+
182
+
115
183
  def _build_metadata(args):
116
184
  """Build customer_metadata_properties dict from CLI args.
117
185
 
@@ -283,7 +351,7 @@ def cmd_register_model(args):
283
351
 
284
352
  # Step 3: Build inference specification
285
353
  container_image = args.container_image or ""
286
- model_data_url = args.model_data_url or ""
354
+ model_data_url = (args.model_data_url or "").rstrip("/")
287
355
 
288
356
  # Step 4: Create Model Package version (AC-1.2, AC-1.7)
289
357
  description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
@@ -437,7 +505,7 @@ def cmd_register_adapter(args):
437
505
 
438
506
  # Step 3: Build inference specification
439
507
  container_image = args.container_image or ""
440
- model_data_url = args.model_data_url or ""
508
+ model_data_url = (args.model_data_url or "").rstrip("/")
441
509
 
442
510
  # Step 4: Create adapter Model Package version (AC-2.1)
443
511
  technique = args.tune_technique or "unknown"
@@ -463,12 +531,21 @@ def cmd_register_adapter(args):
463
531
  "SupportedContentTypes": ["application/json"],
464
532
  "SupportedResponseMIMETypes": ["application/json"],
465
533
  }
466
- if model_data_url:
534
+ # ModelDataUrl in InferenceSpecification requires a tar.gz object —
535
+ # uncompressed S3 prefixes (adapter directories) are not supported.
536
+ # Store uncompressed paths in metadata instead.
537
+ if model_data_url and model_data_url.endswith(".tar.gz"):
467
538
  create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
468
- elif model_data_url:
539
+
540
+ # Always store model/adapter data URL in metadata for registry queries
541
+ if model_data_url:
469
542
  if not metadata:
470
543
  metadata = {}
471
544
  metadata["modelDataUrl"] = model_data_url[:1024]
545
+
546
+ # Inject evaluation metrics if available (G4 AC-3.1, AC-3.2)
547
+ metadata = _inject_eval_metrics(metadata, args)
548
+
472
549
  if metadata:
473
550
  create_params["CustomerMetadataProperties"] = metadata
474
551
 
@@ -1366,9 +1443,24 @@ def cmd_get_version(args):
1366
1443
  os.environ.setdefault("AWS_REGION", region)
1367
1444
 
1368
1445
  try:
1369
- from sagemaker.core.resources import ModelPackage
1446
+ import boto3
1447
+ sm_client = boto3.client("sagemaker", region_name=region)
1370
1448
 
1371
- pkg = ModelPackage.get(model_package_arn=version_arn)
1449
+ # Use boto3 directly — sagemaker-core v2.14 ModelPackage.get() requires
1450
+ # model_package_name (not ARN) and rejects model_package_arn as unexpected kwarg.
1451
+ pkg_response = sm_client.describe_model_package(ModelPackageName=version_arn)
1452
+
1453
+ # Wrap in a simple namespace for consistent access below
1454
+ class _Pkg:
1455
+ def __init__(self, data):
1456
+ self._data = data
1457
+ self.model_package_arn = data.get("ModelPackageArn", version_arn)
1458
+ self.inference_specification = data.get("InferenceSpecification")
1459
+ self.customer_metadata_properties = data.get("CustomerMetadataProperties", {})
1460
+ self.model_approval_status = data.get("ModelApprovalStatus", "")
1461
+ self.model_package_description = data.get("ModelPackageDescription", "")
1462
+ self.creation_time = data.get("CreationTime")
1463
+ pkg = _Pkg(pkg_response)
1372
1464
 
1373
1465
  # Extract model data URL from inference spec
1374
1466
  model_data_url = ""
@@ -1381,6 +1473,10 @@ def cmd_get_version(args):
1381
1473
  # Get metadata
1382
1474
  metadata = getattr(pkg, "customer_metadata_properties", None) or {}
1383
1475
 
1476
+ # Fallback: modelDataUrl stored in metadata when adapter is uncompressed S3 prefix
1477
+ if not model_data_url and metadata.get("modelDataUrl"):
1478
+ model_data_url = metadata["modelDataUrl"]
1479
+
1384
1480
  # Get status
1385
1481
  status = getattr(pkg, "model_approval_status", "") or ""
1386
1482
 
@@ -1414,6 +1510,7 @@ def cmd_resolve_dataset(args):
1414
1510
 
1415
1511
  Version resolution (AC-2.1, AC-2.4):
1416
1512
  - --version N: resolve the Nth version (ordinal, 1-based) for this name
1513
+ - --version X.Y.Z: resolve by semver string match
1417
1514
  - No --version: resolve latest (existing behavior)
1418
1515
  - If requested version doesn't exist: print available versions and exit 1 (AC-2.5)
1419
1516
 
@@ -1421,14 +1518,20 @@ def cmd_resolve_dataset(args):
1421
1518
  or error if not found.
1422
1519
  """
1423
1520
  name = args.name
1424
- version_ordinal = getattr(args, "version", None)
1521
+ version_spec = getattr(args, "version", None)
1425
1522
 
1426
1523
  if not name:
1427
1524
  _error_exit("--name is required", code="MISSING_ARGUMENT")
1428
1525
 
1429
1526
  # If version is specified, use version-aware resolution
1430
- if version_ordinal is not None:
1431
- return _resolve_dataset_version(name, version_ordinal)
1527
+ if version_spec is not None:
1528
+ # Determine if it's an ordinal (pure integer) or semver string
1529
+ try:
1530
+ version_ordinal = int(version_spec)
1531
+ return _resolve_dataset_version(name, version_ordinal)
1532
+ except ValueError:
1533
+ # Not an integer — treat as semver string
1534
+ return _resolve_dataset_version_by_semver(name, version_spec)
1432
1535
 
1433
1536
  # No version — resolve latest (existing behavior)
1434
1537
  # Try SageMaker AI Registry API first
@@ -1545,6 +1648,77 @@ def _resolve_dataset_version(name, version_ordinal):
1545
1648
  _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
1546
1649
 
1547
1650
 
1651
+ def _resolve_dataset_version_by_semver(name, version_str):
1652
+ """Resolve a specific version of a named dataset by semver string match.
1653
+
1654
+ Searches the versions[] array for an entry whose 'version' field matches
1655
+ the provided semver string (e.g., '1.0.0').
1656
+
1657
+ If the version doesn't exist, prints available versions and exits 1.
1658
+
1659
+ Args:
1660
+ name: Dataset name
1661
+ version_str: Semver string to match (e.g., '1.0.0', '2.1.0')
1662
+ """
1663
+ # Load local registry
1664
+ entries = _load_registry(_DATASETS_REGISTRY)
1665
+
1666
+ for entry in entries:
1667
+ if entry.get("name") == name:
1668
+ versions = entry.get("versions", [])
1669
+
1670
+ if not versions:
1671
+ # Legacy entry without versions array — treat as having version "1.0.0"
1672
+ if version_str == "1.0.0":
1673
+ output = dict(entry)
1674
+ output["version"] = "1.0.0"
1675
+ output["ordinal"] = 1
1676
+ if "arn" not in output:
1677
+ output["arn"] = None
1678
+ _output(output)
1679
+ else:
1680
+ print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
1681
+ print(f"Available versions: 1.0.0", file=sys.stderr)
1682
+ print(json.dumps({
1683
+ "error": f"Version {version_str} not found for dataset '{name}'",
1684
+ "code": "VERSION_NOT_FOUND",
1685
+ "available_versions": [{"ordinal": 1, "version": "1.0.0"}],
1686
+ }))
1687
+ sys.exit(1)
1688
+
1689
+ # Search for matching version string
1690
+ for i, v in enumerate(versions, 1):
1691
+ ver = v.get("version", "")
1692
+ if ver == version_str:
1693
+ _output({
1694
+ "name": name,
1695
+ "s3_uri": v.get("s3_uri", entry.get("s3_uri", "")),
1696
+ "arn": entry.get("arn"),
1697
+ "format": v.get("format", entry.get("format", "jsonl")),
1698
+ "technique": v.get("technique", entry.get("technique", "")),
1699
+ "version": ver,
1700
+ "ordinal": i,
1701
+ "hash": v.get("hash"),
1702
+ })
1703
+
1704
+ # Version string not found — show available
1705
+ print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
1706
+ available = []
1707
+ for i, v in enumerate(versions, 1):
1708
+ ver = v.get("version", f"{i}.0.0")
1709
+ available.append({"ordinal": i, "version": ver})
1710
+ print(f" v{i} ({ver})", file=sys.stderr)
1711
+ print(json.dumps({
1712
+ "error": f"Version {version_str} not found for dataset '{name}'",
1713
+ "code": "VERSION_NOT_FOUND",
1714
+ "available_versions": available,
1715
+ }))
1716
+ sys.exit(1)
1717
+
1718
+ # Dataset name not found at all
1719
+ _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
1720
+
1721
+
1548
1722
  # ── Subcommand: resolve-evaluator ────────────────────────────────────────────
1549
1723
 
1550
1724
 
@@ -1706,8 +1880,8 @@ def main():
1706
1880
  help="Resolve a registered dataset by name",
1707
1881
  )
1708
1882
  resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
1709
- resolve_dataset_parser.add_argument("--version", type=int, default=None,
1710
- help="Version ordinal to resolve (e.g., 2 for the 2nd version). Default: latest.")
1883
+ resolve_dataset_parser.add_argument("--version", type=str, default=None,
1884
+ help="Version to resolve: ordinal (e.g., 2) or semver (e.g., 1.0.0). Default: latest.")
1711
1885
 
1712
1886
  # ── resolve-evaluator ─────────────────────────────────────────────────
1713
1887
  resolve_evaluator_parser = subparsers.add_parser(