@aws/ml-container-creator 1.0.3 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +2 -2
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Model Quality Evaluation Helper.
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
evaluate - Run evaluation against deployed endpoint, compute metrics
|
|
9
|
+
eval-write - Write evaluation results to S3/Athena (Parquet)
|
|
10
|
+
|
|
11
|
+
All output is JSON on stdout for bash consumption.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import math
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
23
|
+
|
|
24
|
+
def _error_exit(message):
|
|
25
|
+
"""Print JSON error to stdout and exit."""
|
|
26
|
+
print(json.dumps({"error": True, "message": message}))
|
|
27
|
+
sys.exit(1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _output(data):
|
|
31
|
+
"""Print JSON result to stdout."""
|
|
32
|
+
print(json.dumps(data))
|
|
33
|
+
sys.exit(0)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ── Endpoint invocation ───────────────────────────────────────────────────────
|
|
37
|
+
|
|
38
|
+
def _invoke_endpoint(endpoint_name, ic_name, region, payload):
|
|
39
|
+
"""Invoke SageMaker endpoint via boto3 runtime.
|
|
40
|
+
|
|
41
|
+
Uses InvokeEndpoint with InferenceComponentName header for IC routing.
|
|
42
|
+
Payload should be an OpenAI-compatible chat completion request.
|
|
43
|
+
|
|
44
|
+
Returns: parsed JSON response dict
|
|
45
|
+
"""
|
|
46
|
+
import boto3
|
|
47
|
+
|
|
48
|
+
client = boto3.client('sagemaker-runtime', region_name=region)
|
|
49
|
+
|
|
50
|
+
kwargs = {
|
|
51
|
+
'EndpointName': endpoint_name,
|
|
52
|
+
'ContentType': 'application/json',
|
|
53
|
+
'Body': json.dumps(payload),
|
|
54
|
+
}
|
|
55
|
+
if ic_name:
|
|
56
|
+
kwargs['InferenceComponentName'] = ic_name
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
response = client.invoke_endpoint(**kwargs)
|
|
60
|
+
body = response['Body'].read().decode('utf-8')
|
|
61
|
+
return json.loads(body)
|
|
62
|
+
except Exception as e:
|
|
63
|
+
return {"error": str(e)}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _score_text(endpoint_name, ic_name, region, prompt, completion):
|
|
67
|
+
"""Score a completion by getting its logprobs via the endpoint.
|
|
68
|
+
|
|
69
|
+
Sends prompt + completion and requests logprobs for the completion tokens.
|
|
70
|
+
Returns sum of token logprobs, or None if logprobs unavailable.
|
|
71
|
+
"""
|
|
72
|
+
messages = [
|
|
73
|
+
{"role": "user", "content": prompt},
|
|
74
|
+
{"role": "assistant", "content": completion},
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
payload = {
|
|
78
|
+
"messages": messages,
|
|
79
|
+
"max_tokens": 1,
|
|
80
|
+
"temperature": 0.0,
|
|
81
|
+
"logprobs": True,
|
|
82
|
+
"top_logprobs": 1,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
|
|
86
|
+
|
|
87
|
+
if "error" in response:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
# Extract logprobs from response
|
|
91
|
+
try:
|
|
92
|
+
choices = response.get("choices", [])
|
|
93
|
+
if not choices:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
# For scoring, we need the logprobs of the completion tokens
|
|
97
|
+
# The response format varies — try OpenAI-compatible format
|
|
98
|
+
logprobs_data = choices[0].get("logprobs")
|
|
99
|
+
if logprobs_data and "content" in logprobs_data:
|
|
100
|
+
token_logprobs = [t.get("logprob", 0.0) for t in logprobs_data["content"]]
|
|
101
|
+
return sum(token_logprobs) if token_logprobs else None
|
|
102
|
+
|
|
103
|
+
return None
|
|
104
|
+
except (KeyError, TypeError, IndexError):
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _generate_response(endpoint_name, ic_name, region, prompt, max_tokens=256):
|
|
109
|
+
"""Generate a response from the endpoint for generation-based metrics.
|
|
110
|
+
|
|
111
|
+
Returns: generated text string, or None on failure.
|
|
112
|
+
"""
|
|
113
|
+
payload = {
|
|
114
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
115
|
+
"max_tokens": max_tokens,
|
|
116
|
+
"temperature": 0.0,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
|
|
120
|
+
|
|
121
|
+
if "error" in response:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
choices = response.get("choices", [])
|
|
126
|
+
if choices:
|
|
127
|
+
return choices[0].get("message", {}).get("content", "")
|
|
128
|
+
return None
|
|
129
|
+
except (KeyError, TypeError, IndexError):
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# ── Metric computation ────────────────────────────────────────────────────────
|
|
134
|
+
|
|
135
|
+
def _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples):
|
|
136
|
+
"""Compute SFT evaluation metrics.
|
|
137
|
+
|
|
138
|
+
Metrics: perplexity (via logprobs), avg_response_length, format_compliance, exact_match
|
|
139
|
+
"""
|
|
140
|
+
metrics = {}
|
|
141
|
+
logprob_scores = []
|
|
142
|
+
response_lengths = []
|
|
143
|
+
exact_matches = 0
|
|
144
|
+
total = 0
|
|
145
|
+
|
|
146
|
+
for i, record in enumerate(dataset):
|
|
147
|
+
if samples and i >= samples:
|
|
148
|
+
break
|
|
149
|
+
|
|
150
|
+
prompt = record.get("prompt", "")
|
|
151
|
+
reference = record.get("reference", "")
|
|
152
|
+
|
|
153
|
+
if not prompt:
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
total += 1
|
|
157
|
+
|
|
158
|
+
# Score via logprobs (for perplexity)
|
|
159
|
+
if reference:
|
|
160
|
+
score = _score_text(endpoint_name, ic_name, region, prompt, reference)
|
|
161
|
+
if score is not None:
|
|
162
|
+
# Approximate per-token logprob
|
|
163
|
+
# score is sum of logprobs; we need per-token average
|
|
164
|
+
# Estimate token count from character length (rough: 4 chars/token)
|
|
165
|
+
est_tokens = max(1, len(reference) // 4)
|
|
166
|
+
logprob_scores.append(score / est_tokens)
|
|
167
|
+
|
|
168
|
+
# Generate response (for length and exact match)
|
|
169
|
+
generated = _generate_response(endpoint_name, ic_name, region, prompt)
|
|
170
|
+
if generated is not None:
|
|
171
|
+
response_lengths.append(len(generated.split()))
|
|
172
|
+
if reference and generated.strip() == reference.strip():
|
|
173
|
+
exact_matches += 1
|
|
174
|
+
|
|
175
|
+
# Compute aggregate metrics
|
|
176
|
+
if logprob_scores:
|
|
177
|
+
avg_logprob = sum(logprob_scores) / len(logprob_scores)
|
|
178
|
+
metrics["perplexity"] = round(math.exp(-avg_logprob), 4)
|
|
179
|
+
|
|
180
|
+
if response_lengths:
|
|
181
|
+
metrics["avg_response_length"] = round(sum(response_lengths) / len(response_lengths), 1)
|
|
182
|
+
|
|
183
|
+
if total > 0:
|
|
184
|
+
metrics["exact_match_accuracy"] = round(exact_matches / total, 4)
|
|
185
|
+
|
|
186
|
+
metrics["samples_scored"] = total
|
|
187
|
+
|
|
188
|
+
return metrics
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples):
|
|
192
|
+
"""Compute DPO evaluation metrics.
|
|
193
|
+
|
|
194
|
+
Metrics: reward_accuracy, avg_chosen_logprob, avg_rejected_logprob, reward_margin
|
|
195
|
+
"""
|
|
196
|
+
metrics = {}
|
|
197
|
+
chosen_scores = []
|
|
198
|
+
rejected_scores = []
|
|
199
|
+
reward_correct = 0
|
|
200
|
+
total = 0
|
|
201
|
+
|
|
202
|
+
for i, record in enumerate(dataset):
|
|
203
|
+
if samples and i >= samples:
|
|
204
|
+
break
|
|
205
|
+
|
|
206
|
+
prompt = record.get("prompt", "")
|
|
207
|
+
chosen = record.get("chosen", "")
|
|
208
|
+
rejected = record.get("rejected", "")
|
|
209
|
+
|
|
210
|
+
if not prompt or not chosen or not rejected:
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
total += 1
|
|
214
|
+
|
|
215
|
+
# Score chosen
|
|
216
|
+
chosen_score = _score_text(endpoint_name, ic_name, region, prompt, chosen)
|
|
217
|
+
# Score rejected
|
|
218
|
+
rejected_score = _score_text(endpoint_name, ic_name, region, prompt, rejected)
|
|
219
|
+
|
|
220
|
+
if chosen_score is not None and rejected_score is not None:
|
|
221
|
+
chosen_scores.append(chosen_score)
|
|
222
|
+
rejected_scores.append(rejected_score)
|
|
223
|
+
if chosen_score > rejected_score:
|
|
224
|
+
reward_correct += 1
|
|
225
|
+
|
|
226
|
+
# Compute aggregate metrics
|
|
227
|
+
scored = len(chosen_scores)
|
|
228
|
+
if scored > 0:
|
|
229
|
+
metrics["reward_accuracy"] = round(reward_correct / scored, 4)
|
|
230
|
+
metrics["avg_chosen_logprob"] = round(sum(chosen_scores) / scored, 4)
|
|
231
|
+
metrics["avg_rejected_logprob"] = round(sum(rejected_scores) / scored, 4)
|
|
232
|
+
metrics["reward_margin"] = round(
|
|
233
|
+
(sum(chosen_scores) - sum(rejected_scores)) / scored, 4
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
metrics["pairs_scored"] = scored
|
|
237
|
+
metrics["samples_evaluated"] = total
|
|
238
|
+
|
|
239
|
+
return metrics
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# ── Dataset loading ───────────────────────────────────────────────────────────
|
|
243
|
+
|
|
244
|
+
def _load_eval_dataset(eval_dataset_path):
|
|
245
|
+
"""Load evaluation dataset from local JSONL file or S3.
|
|
246
|
+
|
|
247
|
+
For this MVP, expects a local JSONL file path.
|
|
248
|
+
S3 and HF resolution is handled by the bash wrapper.
|
|
249
|
+
|
|
250
|
+
Returns: list of dicts
|
|
251
|
+
"""
|
|
252
|
+
records = []
|
|
253
|
+
|
|
254
|
+
if not eval_dataset_path:
|
|
255
|
+
_error_exit("No evaluation dataset specified. Use --eval-dataset <path>")
|
|
256
|
+
|
|
257
|
+
# Handle S3 paths by downloading
|
|
258
|
+
if eval_dataset_path.startswith("s3://"):
|
|
259
|
+
import boto3
|
|
260
|
+
import tempfile
|
|
261
|
+
s3 = boto3.client('s3')
|
|
262
|
+
bucket = eval_dataset_path.split('/')[2]
|
|
263
|
+
key = '/'.join(eval_dataset_path.split('/')[3:])
|
|
264
|
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl')
|
|
265
|
+
s3.download_file(bucket, key, tmp.name)
|
|
266
|
+
eval_dataset_path = tmp.name
|
|
267
|
+
|
|
268
|
+
# Load JSONL
|
|
269
|
+
try:
|
|
270
|
+
with open(eval_dataset_path, 'r') as f:
|
|
271
|
+
for line in f:
|
|
272
|
+
line = line.strip()
|
|
273
|
+
if line:
|
|
274
|
+
records.append(json.loads(line))
|
|
275
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
276
|
+
_error_exit(f"Failed to load eval dataset: {e}")
|
|
277
|
+
|
|
278
|
+
if not records:
|
|
279
|
+
_error_exit("Evaluation dataset is empty")
|
|
280
|
+
|
|
281
|
+
return records
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# ── cmd_evaluate ──────────────────────────────────────────────────────────────
|
|
285
|
+
|
|
286
|
+
def cmd_evaluate(args):
|
|
287
|
+
"""Run evaluation against deployed endpoint.
|
|
288
|
+
|
|
289
|
+
Returns JSON with metrics and metadata.
|
|
290
|
+
"""
|
|
291
|
+
endpoint_name = args.endpoint_name
|
|
292
|
+
ic_name = args.ic_name
|
|
293
|
+
region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
|
|
294
|
+
technique = args.technique or ''
|
|
295
|
+
samples = int(args.samples) if args.samples else None
|
|
296
|
+
|
|
297
|
+
# Load eval dataset
|
|
298
|
+
dataset = _load_eval_dataset(args.eval_dataset)
|
|
299
|
+
|
|
300
|
+
# Determine technique and compute metrics
|
|
301
|
+
if technique.lower() == 'dpo':
|
|
302
|
+
metrics = _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples)
|
|
303
|
+
else:
|
|
304
|
+
# Default to SFT metrics (works for any technique)
|
|
305
|
+
metrics = _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples)
|
|
306
|
+
|
|
307
|
+
# Build result
|
|
308
|
+
result = {
|
|
309
|
+
"adapter_name": args.ic_name,
|
|
310
|
+
"technique": technique or "sft",
|
|
311
|
+
"model": os.environ.get("MODEL_NAME", ""),
|
|
312
|
+
"eval_dataset": args.eval_dataset or "",
|
|
313
|
+
"samples_evaluated": metrics.get("samples_evaluated", metrics.get("samples_scored", 0)),
|
|
314
|
+
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
315
|
+
"metrics": metrics,
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
_output(result)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# ── cmd_eval_write ────────────────────────────────────────────────────────────
|
|
322
|
+
|
|
323
|
+
def cmd_eval_write(args):
|
|
324
|
+
"""Write evaluation results to S3 as Parquet for Athena.
|
|
325
|
+
|
|
326
|
+
Reads a results JSON file and converts to Parquet format.
|
|
327
|
+
"""
|
|
328
|
+
results_file = args.results_file
|
|
329
|
+
bucket = args.bucket
|
|
330
|
+
region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
|
|
331
|
+
|
|
332
|
+
# Read results
|
|
333
|
+
try:
|
|
334
|
+
with open(results_file, 'r') as f:
|
|
335
|
+
data = json.load(f)
|
|
336
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
337
|
+
_error_exit(f"Failed to read results file: {e}")
|
|
338
|
+
|
|
339
|
+
adapter_name = data.get("adapter_name", "unknown")
|
|
340
|
+
technique = data.get("technique", "unknown")
|
|
341
|
+
model = data.get("model", "unknown")
|
|
342
|
+
timestamp = data.get("timestamp", time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))
|
|
343
|
+
|
|
344
|
+
# Build Parquet record
|
|
345
|
+
record = {
|
|
346
|
+
"project_name": os.environ.get("PROJECT_NAME", ""),
|
|
347
|
+
"model_name": model,
|
|
348
|
+
"adapter_name": adapter_name,
|
|
349
|
+
"technique": technique,
|
|
350
|
+
"eval_dataset": data.get("eval_dataset", ""),
|
|
351
|
+
"samples_evaluated": data.get("samples_evaluated", 0),
|
|
352
|
+
"metrics": json.dumps(data.get("metrics", {})),
|
|
353
|
+
"timestamp": timestamp,
|
|
354
|
+
"region": region,
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
# Write as JSON lines (Athena can read JSON as well as Parquet)
|
|
358
|
+
# For MVP, write as JSON lines to S3. Parquet requires pyarrow dep.
|
|
359
|
+
s3_key = f"evaluations/model={model}/adapter={adapter_name}/{timestamp.replace(':', '-')}.json"
|
|
360
|
+
s3_uri = f"s3://{bucket}/{s3_key}"
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
import boto3
|
|
364
|
+
s3 = boto3.client('s3', region_name=region)
|
|
365
|
+
s3.put_object(
|
|
366
|
+
Bucket=bucket,
|
|
367
|
+
Key=s3_key,
|
|
368
|
+
Body=json.dumps(record),
|
|
369
|
+
ContentType='application/json',
|
|
370
|
+
)
|
|
371
|
+
_output({"written": True, "s3_uri": s3_uri})
|
|
372
|
+
except Exception as e:
|
|
373
|
+
_error_exit(f"Failed to write to S3: {e}")
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
|
377
|
+
|
|
378
|
+
def main():
|
|
379
|
+
parser = argparse.ArgumentParser(description='Model Quality Evaluation Helper')
|
|
380
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
381
|
+
|
|
382
|
+
# evaluate
|
|
383
|
+
eval_parser = subparsers.add_parser('evaluate', help='Run evaluation')
|
|
384
|
+
eval_parser.add_argument('--endpoint-name', required=True)
|
|
385
|
+
eval_parser.add_argument('--ic-name', required=True)
|
|
386
|
+
eval_parser.add_argument('--region')
|
|
387
|
+
eval_parser.add_argument('--technique', default='')
|
|
388
|
+
eval_parser.add_argument('--eval-dataset', default='')
|
|
389
|
+
eval_parser.add_argument('--samples', default='')
|
|
390
|
+
eval_parser.add_argument('--metrics', default='')
|
|
391
|
+
|
|
392
|
+
# eval-write
|
|
393
|
+
write_parser = subparsers.add_parser('eval-write', help='Write results to S3')
|
|
394
|
+
write_parser.add_argument('--results-file', required=True)
|
|
395
|
+
write_parser.add_argument('--bucket', required=True)
|
|
396
|
+
write_parser.add_argument('--region')
|
|
397
|
+
|
|
398
|
+
args = parser.parse_args()
|
|
399
|
+
|
|
400
|
+
if args.command == 'evaluate':
|
|
401
|
+
cmd_evaluate(args)
|
|
402
|
+
elif args.command == 'eval-write':
|
|
403
|
+
cmd_eval_write(args)
|
|
404
|
+
else:
|
|
405
|
+
_error_exit(f"Unknown command: {args.command}")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
if __name__ == '__main__':
|
|
409
|
+
main()
|
|
@@ -112,6 +112,74 @@ def _truncate_metadata(props):
|
|
|
112
112
|
return result
|
|
113
113
|
|
|
114
114
|
|
|
115
|
+
def _inject_eval_metrics(metadata, args):
|
|
116
|
+
"""Inject evaluation metrics from .mlcc/eval-results/ into metadata.
|
|
117
|
+
|
|
118
|
+
Looks for eval results matching the adapter name or project.
|
|
119
|
+
Adds metrics with 'eval_' prefix (G4 AC-3.1, AC-3.2).
|
|
120
|
+
Non-fatal: if no eval results exist, returns metadata unchanged.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
metadata: existing metadata dict (may be None)
|
|
124
|
+
args: parsed args with project_name, adapter name hints
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
metadata dict with eval metrics injected (or unchanged)
|
|
128
|
+
"""
|
|
129
|
+
if metadata is None:
|
|
130
|
+
metadata = {}
|
|
131
|
+
|
|
132
|
+
# Determine eval results directory (relative to script location)
|
|
133
|
+
# Convention: .mlcc/eval-results/<adapter-or-ic-name>.json
|
|
134
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
135
|
+
eval_results_dir = os.path.join(script_dir, "..", ".mlcc", "eval-results")
|
|
136
|
+
|
|
137
|
+
if not os.path.isdir(eval_results_dir):
|
|
138
|
+
return metadata
|
|
139
|
+
|
|
140
|
+
# Try to find eval results for this adapter
|
|
141
|
+
# Prioritize: adapter name from args > any available result
|
|
142
|
+
adapter_name = getattr(args, 'adapter_name', '') or ''
|
|
143
|
+
|
|
144
|
+
# Search for matching eval result file
|
|
145
|
+
eval_file = None
|
|
146
|
+
if adapter_name:
|
|
147
|
+
candidate = os.path.join(eval_results_dir, f"{adapter_name}.json")
|
|
148
|
+
if os.path.isfile(candidate):
|
|
149
|
+
eval_file = candidate
|
|
150
|
+
|
|
151
|
+
# If no specific adapter match, try to find any recent result
|
|
152
|
+
if not eval_file:
|
|
153
|
+
try:
|
|
154
|
+
json_files = [f for f in os.listdir(eval_results_dir) if f.endswith('.json')]
|
|
155
|
+
if json_files:
|
|
156
|
+
# Use most recently modified
|
|
157
|
+
json_files.sort(key=lambda f: os.path.getmtime(os.path.join(eval_results_dir, f)), reverse=True)
|
|
158
|
+
eval_file = os.path.join(eval_results_dir, json_files[0])
|
|
159
|
+
except OSError:
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
if not eval_file:
|
|
163
|
+
return metadata
|
|
164
|
+
|
|
165
|
+
# Load and inject metrics
|
|
166
|
+
try:
|
|
167
|
+
with open(eval_file, 'r') as f:
|
|
168
|
+
eval_data = json.load(f)
|
|
169
|
+
metrics = eval_data.get("metrics", {})
|
|
170
|
+
for metric_name, metric_value in metrics.items():
|
|
171
|
+
# Add with eval_ prefix, truncate to 256 chars
|
|
172
|
+
key = f"eval_{metric_name}"
|
|
173
|
+
str_val = str(metric_value)[:MAX_METADATA_VALUE_LEN]
|
|
174
|
+
metadata[key] = str_val
|
|
175
|
+
if metrics:
|
|
176
|
+
_warn(f"Injected {len(metrics)} eval metric(s) from {os.path.basename(eval_file)}")
|
|
177
|
+
except (IOError, json.JSONDecodeError, KeyError):
|
|
178
|
+
pass # Non-fatal — skip eval metrics if file is unreadable
|
|
179
|
+
|
|
180
|
+
return metadata
|
|
181
|
+
|
|
182
|
+
|
|
115
183
|
def _build_metadata(args):
|
|
116
184
|
"""Build customer_metadata_properties dict from CLI args.
|
|
117
185
|
|
|
@@ -283,7 +351,7 @@ def cmd_register_model(args):
|
|
|
283
351
|
|
|
284
352
|
# Step 3: Build inference specification
|
|
285
353
|
container_image = args.container_image or ""
|
|
286
|
-
model_data_url = args.model_data_url or ""
|
|
354
|
+
model_data_url = (args.model_data_url or "").rstrip("/")
|
|
287
355
|
|
|
288
356
|
# Step 4: Create Model Package version (AC-1.2, AC-1.7)
|
|
289
357
|
description = f"{args.deployment_config or 'model'} on {args.instance_type or 'unknown'}"
|
|
@@ -437,7 +505,7 @@ def cmd_register_adapter(args):
|
|
|
437
505
|
|
|
438
506
|
# Step 3: Build inference specification
|
|
439
507
|
container_image = args.container_image or ""
|
|
440
|
-
model_data_url = args.model_data_url or ""
|
|
508
|
+
model_data_url = (args.model_data_url or "").rstrip("/")
|
|
441
509
|
|
|
442
510
|
# Step 4: Create adapter Model Package version (AC-2.1)
|
|
443
511
|
technique = args.tune_technique or "unknown"
|
|
@@ -463,12 +531,21 @@ def cmd_register_adapter(args):
|
|
|
463
531
|
"SupportedContentTypes": ["application/json"],
|
|
464
532
|
"SupportedResponseMIMETypes": ["application/json"],
|
|
465
533
|
}
|
|
466
|
-
|
|
534
|
+
# ModelDataUrl in InferenceSpecification requires a tar.gz object —
|
|
535
|
+
# uncompressed S3 prefixes (adapter directories) are not supported.
|
|
536
|
+
# Store uncompressed paths in metadata instead.
|
|
537
|
+
if model_data_url and model_data_url.endswith(".tar.gz"):
|
|
467
538
|
create_params["InferenceSpecification"]["Containers"][0]["ModelDataUrl"] = model_data_url
|
|
468
|
-
|
|
539
|
+
|
|
540
|
+
# Always store model/adapter data URL in metadata for registry queries
|
|
541
|
+
if model_data_url:
|
|
469
542
|
if not metadata:
|
|
470
543
|
metadata = {}
|
|
471
544
|
metadata["modelDataUrl"] = model_data_url[:1024]
|
|
545
|
+
|
|
546
|
+
# Inject evaluation metrics if available (G4 AC-3.1, AC-3.2)
|
|
547
|
+
metadata = _inject_eval_metrics(metadata, args)
|
|
548
|
+
|
|
472
549
|
if metadata:
|
|
473
550
|
create_params["CustomerMetadataProperties"] = metadata
|
|
474
551
|
|
|
@@ -1366,9 +1443,24 @@ def cmd_get_version(args):
|
|
|
1366
1443
|
os.environ.setdefault("AWS_REGION", region)
|
|
1367
1444
|
|
|
1368
1445
|
try:
|
|
1369
|
-
|
|
1446
|
+
import boto3
|
|
1447
|
+
sm_client = boto3.client("sagemaker", region_name=region)
|
|
1370
1448
|
|
|
1371
|
-
|
|
1449
|
+
# Use boto3 directly — sagemaker-core v2.14 ModelPackage.get() requires
|
|
1450
|
+
# model_package_name (not ARN) and rejects model_package_arn as unexpected kwarg.
|
|
1451
|
+
pkg_response = sm_client.describe_model_package(ModelPackageName=version_arn)
|
|
1452
|
+
|
|
1453
|
+
# Wrap in a simple namespace for consistent access below
|
|
1454
|
+
class _Pkg:
|
|
1455
|
+
def __init__(self, data):
|
|
1456
|
+
self._data = data
|
|
1457
|
+
self.model_package_arn = data.get("ModelPackageArn", version_arn)
|
|
1458
|
+
self.inference_specification = data.get("InferenceSpecification")
|
|
1459
|
+
self.customer_metadata_properties = data.get("CustomerMetadataProperties", {})
|
|
1460
|
+
self.model_approval_status = data.get("ModelApprovalStatus", "")
|
|
1461
|
+
self.model_package_description = data.get("ModelPackageDescription", "")
|
|
1462
|
+
self.creation_time = data.get("CreationTime")
|
|
1463
|
+
pkg = _Pkg(pkg_response)
|
|
1372
1464
|
|
|
1373
1465
|
# Extract model data URL from inference spec
|
|
1374
1466
|
model_data_url = ""
|
|
@@ -1381,6 +1473,10 @@ def cmd_get_version(args):
|
|
|
1381
1473
|
# Get metadata
|
|
1382
1474
|
metadata = getattr(pkg, "customer_metadata_properties", None) or {}
|
|
1383
1475
|
|
|
1476
|
+
# Fallback: modelDataUrl stored in metadata when adapter is uncompressed S3 prefix
|
|
1477
|
+
if not model_data_url and metadata.get("modelDataUrl"):
|
|
1478
|
+
model_data_url = metadata["modelDataUrl"]
|
|
1479
|
+
|
|
1384
1480
|
# Get status
|
|
1385
1481
|
status = getattr(pkg, "model_approval_status", "") or ""
|
|
1386
1482
|
|
|
@@ -1414,6 +1510,7 @@ def cmd_resolve_dataset(args):
|
|
|
1414
1510
|
|
|
1415
1511
|
Version resolution (AC-2.1, AC-2.4):
|
|
1416
1512
|
- --version N: resolve the Nth version (ordinal, 1-based) for this name
|
|
1513
|
+
- --version X.Y.Z: resolve by semver string match
|
|
1417
1514
|
- No --version: resolve latest (existing behavior)
|
|
1418
1515
|
- If requested version doesn't exist: print available versions and exit 1 (AC-2.5)
|
|
1419
1516
|
|
|
@@ -1421,14 +1518,20 @@ def cmd_resolve_dataset(args):
|
|
|
1421
1518
|
or error if not found.
|
|
1422
1519
|
"""
|
|
1423
1520
|
name = args.name
|
|
1424
|
-
|
|
1521
|
+
version_spec = getattr(args, "version", None)
|
|
1425
1522
|
|
|
1426
1523
|
if not name:
|
|
1427
1524
|
_error_exit("--name is required", code="MISSING_ARGUMENT")
|
|
1428
1525
|
|
|
1429
1526
|
# If version is specified, use version-aware resolution
|
|
1430
|
-
if
|
|
1431
|
-
|
|
1527
|
+
if version_spec is not None:
|
|
1528
|
+
# Determine if it's an ordinal (pure integer) or semver string
|
|
1529
|
+
try:
|
|
1530
|
+
version_ordinal = int(version_spec)
|
|
1531
|
+
return _resolve_dataset_version(name, version_ordinal)
|
|
1532
|
+
except ValueError:
|
|
1533
|
+
# Not an integer — treat as semver string
|
|
1534
|
+
return _resolve_dataset_version_by_semver(name, version_spec)
|
|
1432
1535
|
|
|
1433
1536
|
# No version — resolve latest (existing behavior)
|
|
1434
1537
|
# Try SageMaker AI Registry API first
|
|
@@ -1545,6 +1648,77 @@ def _resolve_dataset_version(name, version_ordinal):
|
|
|
1545
1648
|
_error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
|
|
1546
1649
|
|
|
1547
1650
|
|
|
1651
|
+
def _resolve_dataset_version_by_semver(name, version_str):
|
|
1652
|
+
"""Resolve a specific version of a named dataset by semver string match.
|
|
1653
|
+
|
|
1654
|
+
Searches the versions[] array for an entry whose 'version' field matches
|
|
1655
|
+
the provided semver string (e.g., '1.0.0').
|
|
1656
|
+
|
|
1657
|
+
If the version doesn't exist, prints available versions and exits 1.
|
|
1658
|
+
|
|
1659
|
+
Args:
|
|
1660
|
+
name: Dataset name
|
|
1661
|
+
version_str: Semver string to match (e.g., '1.0.0', '2.1.0')
|
|
1662
|
+
"""
|
|
1663
|
+
# Load local registry
|
|
1664
|
+
entries = _load_registry(_DATASETS_REGISTRY)
|
|
1665
|
+
|
|
1666
|
+
for entry in entries:
|
|
1667
|
+
if entry.get("name") == name:
|
|
1668
|
+
versions = entry.get("versions", [])
|
|
1669
|
+
|
|
1670
|
+
if not versions:
|
|
1671
|
+
# Legacy entry without versions array — treat as having version "1.0.0"
|
|
1672
|
+
if version_str == "1.0.0":
|
|
1673
|
+
output = dict(entry)
|
|
1674
|
+
output["version"] = "1.0.0"
|
|
1675
|
+
output["ordinal"] = 1
|
|
1676
|
+
if "arn" not in output:
|
|
1677
|
+
output["arn"] = None
|
|
1678
|
+
_output(output)
|
|
1679
|
+
else:
|
|
1680
|
+
print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
|
|
1681
|
+
print(f"Available versions: 1.0.0", file=sys.stderr)
|
|
1682
|
+
print(json.dumps({
|
|
1683
|
+
"error": f"Version {version_str} not found for dataset '{name}'",
|
|
1684
|
+
"code": "VERSION_NOT_FOUND",
|
|
1685
|
+
"available_versions": [{"ordinal": 1, "version": "1.0.0"}],
|
|
1686
|
+
}))
|
|
1687
|
+
sys.exit(1)
|
|
1688
|
+
|
|
1689
|
+
# Search for matching version string
|
|
1690
|
+
for i, v in enumerate(versions, 1):
|
|
1691
|
+
ver = v.get("version", "")
|
|
1692
|
+
if ver == version_str:
|
|
1693
|
+
_output({
|
|
1694
|
+
"name": name,
|
|
1695
|
+
"s3_uri": v.get("s3_uri", entry.get("s3_uri", "")),
|
|
1696
|
+
"arn": entry.get("arn"),
|
|
1697
|
+
"format": v.get("format", entry.get("format", "jsonl")),
|
|
1698
|
+
"technique": v.get("technique", entry.get("technique", "")),
|
|
1699
|
+
"version": ver,
|
|
1700
|
+
"ordinal": i,
|
|
1701
|
+
"hash": v.get("hash"),
|
|
1702
|
+
})
|
|
1703
|
+
|
|
1704
|
+
# Version string not found — show available
|
|
1705
|
+
print(f"Error: Version {version_str} not found for dataset '{name}'", file=sys.stderr)
|
|
1706
|
+
available = []
|
|
1707
|
+
for i, v in enumerate(versions, 1):
|
|
1708
|
+
ver = v.get("version", f"{i}.0.0")
|
|
1709
|
+
available.append({"ordinal": i, "version": ver})
|
|
1710
|
+
print(f" v{i} ({ver})", file=sys.stderr)
|
|
1711
|
+
print(json.dumps({
|
|
1712
|
+
"error": f"Version {version_str} not found for dataset '{name}'",
|
|
1713
|
+
"code": "VERSION_NOT_FOUND",
|
|
1714
|
+
"available_versions": available,
|
|
1715
|
+
}))
|
|
1716
|
+
sys.exit(1)
|
|
1717
|
+
|
|
1718
|
+
# Dataset name not found at all
|
|
1719
|
+
_error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
|
|
1720
|
+
|
|
1721
|
+
|
|
1548
1722
|
# ── Subcommand: resolve-evaluator ────────────────────────────────────────────
|
|
1549
1723
|
|
|
1550
1724
|
|
|
@@ -1706,8 +1880,8 @@ def main():
|
|
|
1706
1880
|
help="Resolve a registered dataset by name",
|
|
1707
1881
|
)
|
|
1708
1882
|
resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
|
|
1709
|
-
resolve_dataset_parser.add_argument("--version", type=
|
|
1710
|
-
help="Version
|
|
1883
|
+
resolve_dataset_parser.add_argument("--version", type=str, default=None,
|
|
1884
|
+
help="Version to resolve: ordinal (e.g., 2) or semver (e.g., 1.0.0). Default: latest.")
|
|
1711
1885
|
|
|
1712
1886
|
# ── resolve-evaluator ─────────────────────────────────────────────────
|
|
1713
1887
|
resolve_evaluator_parser = subparsers.add_parser(
|