@aws/ml-container-creator 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -1
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +5 -2
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/config +6 -1
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -2,65 +2,91 @@
|
|
|
2
2
|
// SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
|
|
4
4
|
/**
|
|
5
|
-
* Tune Config State
|
|
5
|
+
* Tune Config State
|
|
6
6
|
*
|
|
7
|
-
*
|
|
8
|
-
*
|
|
9
|
-
*
|
|
7
|
+
* Manages bash-style config files (do/config) that contain lines like:
|
|
8
|
+
* export VAR_NAME="value"
|
|
9
|
+
*
|
|
10
|
+
* Provides read/write access for tuning job state variables.
|
|
10
11
|
*/
|
|
11
12
|
|
|
12
13
|
import { readFileSync, writeFileSync } from 'node:fs';
|
|
13
14
|
|
|
14
15
|
/**
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
* - If the variable exists (line starts with `export VAR_NAME=`), replace it
|
|
18
|
-
* - Otherwise, append a new line
|
|
16
|
+
* Read a variable value from a bash config file.
|
|
17
|
+
* Looks for lines matching: export VAR_NAME="value", export VAR_NAME='value', or export VAR_NAME=value
|
|
19
18
|
*
|
|
20
19
|
* @param {string} configPath - Path to the config file
|
|
21
|
-
* @param {string} varName - Variable name
|
|
22
|
-
* @
|
|
20
|
+
* @param {string} varName - Variable name to read
|
|
21
|
+
* @returns {string|null} The unquoted value, or null if not found
|
|
23
22
|
*/
|
|
24
|
-
export function
|
|
25
|
-
|
|
26
|
-
const
|
|
23
|
+
export function readConfigVar(configPath, varName) {
|
|
24
|
+
const content = readFileSync(configPath, 'utf8');
|
|
25
|
+
const lines = content.split('\n');
|
|
27
26
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
if (
|
|
32
|
-
|
|
27
|
+
for (const line of lines) {
|
|
28
|
+
const trimmed = line.trim();
|
|
29
|
+
const prefix = `export ${varName}=`;
|
|
30
|
+
if (trimmed.startsWith(prefix)) {
|
|
31
|
+
let value = trimmed.slice(prefix.length);
|
|
32
|
+
// Strip surrounding quotes (double or single)
|
|
33
|
+
if ((value.startsWith('"') && value.endsWith('"')) ||
|
|
34
|
+
(value.startsWith('\'') && value.endsWith('\''))) {
|
|
35
|
+
value = value.slice(1, -1);
|
|
36
|
+
}
|
|
37
|
+
return value;
|
|
33
38
|
}
|
|
34
|
-
content += `export ${varName}="${varValue}"\n`;
|
|
35
39
|
}
|
|
36
40
|
|
|
37
|
-
|
|
41
|
+
return null;
|
|
38
42
|
}
|
|
39
43
|
|
|
40
44
|
/**
|
|
41
|
-
*
|
|
45
|
+
* Write or update a variable in a bash config file.
|
|
46
|
+
* If the variable already exists, replaces that line.
|
|
47
|
+
* If not, appends the new export line.
|
|
42
48
|
*
|
|
43
49
|
* @param {string} configPath - Path to the config file
|
|
44
|
-
* @param {string} varName - Variable name to
|
|
45
|
-
* @
|
|
50
|
+
* @param {string} varName - Variable name to set
|
|
51
|
+
* @param {string} value - Value to assign
|
|
46
52
|
*/
|
|
47
|
-
export function
|
|
53
|
+
export function updateConfigVar(configPath, varName, value) {
|
|
48
54
|
const content = readFileSync(configPath, 'utf8');
|
|
49
|
-
const
|
|
50
|
-
const
|
|
51
|
-
|
|
55
|
+
const lines = content.split('\n');
|
|
56
|
+
const prefix = `export ${varName}=`;
|
|
57
|
+
const newLine = `export ${varName}="${value}"`;
|
|
58
|
+
|
|
59
|
+
let found = false;
|
|
60
|
+
for (let i = 0; i < lines.length; i++) {
|
|
61
|
+
if (lines[i].trim().startsWith(prefix)) {
|
|
62
|
+
lines[i] = newLine;
|
|
63
|
+
found = true;
|
|
64
|
+
break;
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
if (found) {
|
|
69
|
+
writeFileSync(configPath, lines.join('\n'), 'utf8');
|
|
70
|
+
} else {
|
|
71
|
+
// Append to end of file
|
|
72
|
+
let appendContent = content;
|
|
73
|
+
if (appendContent.length > 0 && !appendContent.endsWith('\n')) {
|
|
74
|
+
appendContent += '\n';
|
|
75
|
+
}
|
|
76
|
+
appendContent += `${newLine }\n`;
|
|
77
|
+
writeFileSync(configPath, appendContent, 'utf8');
|
|
78
|
+
}
|
|
52
79
|
}
|
|
53
80
|
|
|
54
81
|
/**
|
|
55
|
-
*
|
|
56
|
-
* This mirrors the behavior in do/tune's _submit_job() function.
|
|
82
|
+
* Write tuning job submission state to config.
|
|
57
83
|
*
|
|
58
84
|
* @param {string} configPath - Path to the config file
|
|
59
|
-
* @param {object}
|
|
60
|
-
* @param {string}
|
|
61
|
-
* @param {string}
|
|
62
|
-
* @param {string}
|
|
63
|
-
* @param {string}
|
|
85
|
+
* @param {object} state - Submission state
|
|
86
|
+
* @param {string} state.technique - Tuning technique (e.g., 'sft', 'dpo')
|
|
87
|
+
* @param {string} state.trainingType - Training type (e.g., 'lora', 'full-rank')
|
|
88
|
+
* @param {string} state.datasetPath - Dataset path (S3 or HF URI)
|
|
89
|
+
* @param {string} state.jobName - Generated job name
|
|
64
90
|
*/
|
|
65
91
|
export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
|
|
66
92
|
const techniqueUpper = technique.toUpperCase();
|
|
@@ -71,59 +97,54 @@ export function persistSubmissionState(configPath, { technique, trainingType, da
|
|
|
71
97
|
}
|
|
72
98
|
|
|
73
99
|
/**
|
|
74
|
-
*
|
|
75
|
-
* This mirrors the behavior in do/tune's _handle_completion() function.
|
|
76
|
-
*
|
|
77
|
-
* Writes three levels of tracking (AC-4.1, AC-4.2):
|
|
78
|
-
* - Level 1: TUNE_OUTPUT_PATH_LATEST (always the last run, any technique)
|
|
79
|
-
* - Level 2: TUNE_ADAPTER_PATH_<TECHNIQUE> (last run per technique)
|
|
80
|
-
* - Level 3: TUNE_ADAPTER_PATH_<TECHNIQUE>_<SLUG> (per technique + dataset slug)
|
|
100
|
+
* Write tuning job completion state to config.
|
|
81
101
|
*
|
|
82
102
|
* @param {string} configPath - Path to the config file
|
|
83
|
-
* @param {object}
|
|
84
|
-
* @param {string}
|
|
85
|
-
* @param {string}
|
|
86
|
-
* @param {string}
|
|
87
|
-
* @param {string}
|
|
88
|
-
* @param {string} [
|
|
103
|
+
* @param {object} state - Completion state
|
|
104
|
+
* @param {string} state.technique - Tuning technique
|
|
105
|
+
* @param {string} state.trainingType - Training type
|
|
106
|
+
* @param {string} state.artifactPath - Output artifact path (S3 URI)
|
|
107
|
+
* @param {string} state.outputType - Output type ('adapter' or 'model')
|
|
108
|
+
* @param {string} [state.datasetSlug] - Dataset slug for named paths
|
|
89
109
|
*/
|
|
90
|
-
export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType, datasetSlug }) {
|
|
110
|
+
export function persistCompletionState(configPath, { technique, trainingType: _trainingType, artifactPath, outputType, datasetSlug }) {
|
|
91
111
|
const techniqueUpper = technique.toUpperCase();
|
|
92
112
|
|
|
93
|
-
|
|
94
|
-
|
|
113
|
+
updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
|
|
114
|
+
updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
|
|
115
|
+
|
|
116
|
+
if (outputType === 'adapter') {
|
|
95
117
|
updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
|
|
96
|
-
// Level 3: per-technique + per-dataset (if slug available)
|
|
97
118
|
if (datasetSlug) {
|
|
98
119
|
const slugUpper = datasetSlug.toUpperCase().replace(/-/g, '_');
|
|
99
120
|
updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}_${slugUpper}`, artifactPath);
|
|
100
121
|
}
|
|
101
|
-
} else
|
|
122
|
+
} else {
|
|
102
123
|
updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
|
|
103
124
|
}
|
|
104
|
-
|
|
105
|
-
// Level 1: latest
|
|
106
|
-
updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
|
|
107
|
-
updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
|
|
108
125
|
}
|
|
109
126
|
|
|
110
127
|
/**
|
|
111
|
-
* Generate a job name
|
|
112
|
-
*
|
|
128
|
+
* Generate a job name matching pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
|
|
129
|
+
* Uses local time for the timestamp.
|
|
113
130
|
*
|
|
114
131
|
* @param {string} projectName - Project name
|
|
115
|
-
* @param {string} technique -
|
|
116
|
-
* @param {Date} [timestamp] - Optional timestamp (defaults to
|
|
117
|
-
* @returns {string}
|
|
132
|
+
* @param {string} technique - Tuning technique
|
|
133
|
+
* @param {Date} [timestamp] - Optional timestamp (defaults to new Date())
|
|
134
|
+
* @returns {string} Formatted job name
|
|
118
135
|
*/
|
|
119
|
-
export function generateJobName(projectName, technique, timestamp
|
|
120
|
-
const
|
|
121
|
-
|
|
122
|
-
const
|
|
123
|
-
const
|
|
124
|
-
const
|
|
125
|
-
const
|
|
136
|
+
export function generateJobName(projectName, technique, timestamp) {
|
|
137
|
+
const ts = timestamp || new Date();
|
|
138
|
+
|
|
139
|
+
const year = ts.getFullYear().toString();
|
|
140
|
+
const month = (ts.getMonth() + 1).toString().padStart(2, '0');
|
|
141
|
+
const day = ts.getDate().toString().padStart(2, '0');
|
|
142
|
+
const hours = ts.getHours().toString().padStart(2, '0');
|
|
143
|
+
const minutes = ts.getMinutes().toString().padStart(2, '0');
|
|
144
|
+
const seconds = ts.getSeconds().toString().padStart(2, '0');
|
|
145
|
+
|
|
126
146
|
const dateStr = `${year}${month}${day}`;
|
|
127
147
|
const timeStr = `${hours}${minutes}${seconds}`;
|
|
148
|
+
|
|
128
149
|
return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
|
|
129
150
|
}
|
|
@@ -1478,6 +1478,7 @@ def _load_config_file(config_path):
|
|
|
1478
1478
|
'HF_MODEL_ID': 'hf_model_id',
|
|
1479
1479
|
'INSTANCE_TYPE': 'instance_type',
|
|
1480
1480
|
'INSTANCE_POOLS': 'instance_pools',
|
|
1481
|
+
'DEPLOYED_INSTANCE_TYPE': 'deployed_instance_type',
|
|
1481
1482
|
'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
|
|
1482
1483
|
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1483
1484
|
'DEPLOYMENT_TARGET': 'deployment_target',
|
|
@@ -1521,6 +1522,8 @@ def _load_config_file(config_path):
|
|
|
1521
1522
|
# BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
|
|
1522
1523
|
if context.get('benchmark_instance_type'):
|
|
1523
1524
|
context['instance_type'] = context.pop('benchmark_instance_type')
|
|
1525
|
+
elif context.get('deployed_instance_type'):
|
|
1526
|
+
context['instance_type'] = context.pop('deployed_instance_type')
|
|
1524
1527
|
# Fall back to INSTANCE_POOLS when neither is set.
|
|
1525
1528
|
# Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
|
|
1526
1529
|
# but always define INSTANCE_POOLS as a JSON array with Priority fields.
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Model Quality Evaluation Helper.
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
evaluate - Run evaluation against deployed endpoint, compute metrics
|
|
9
|
+
eval-write - Write evaluation results to S3/Athena (Parquet)
|
|
10
|
+
|
|
11
|
+
All output is JSON on stdout for bash consumption.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import math
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
23
|
+
|
|
24
|
+
def _error_exit(message):
|
|
25
|
+
"""Print JSON error to stdout and exit."""
|
|
26
|
+
print(json.dumps({"error": True, "message": message}))
|
|
27
|
+
sys.exit(1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _output(data):
|
|
31
|
+
"""Print JSON result to stdout."""
|
|
32
|
+
print(json.dumps(data))
|
|
33
|
+
sys.exit(0)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ── Endpoint invocation ───────────────────────────────────────────────────────
|
|
37
|
+
|
|
38
|
+
def _invoke_endpoint(endpoint_name, ic_name, region, payload):
|
|
39
|
+
"""Invoke SageMaker endpoint via boto3 runtime.
|
|
40
|
+
|
|
41
|
+
Uses InvokeEndpoint with InferenceComponentName header for IC routing.
|
|
42
|
+
Payload should be an OpenAI-compatible chat completion request.
|
|
43
|
+
|
|
44
|
+
Returns: parsed JSON response dict
|
|
45
|
+
"""
|
|
46
|
+
import boto3
|
|
47
|
+
|
|
48
|
+
client = boto3.client('sagemaker-runtime', region_name=region)
|
|
49
|
+
|
|
50
|
+
kwargs = {
|
|
51
|
+
'EndpointName': endpoint_name,
|
|
52
|
+
'ContentType': 'application/json',
|
|
53
|
+
'Body': json.dumps(payload),
|
|
54
|
+
}
|
|
55
|
+
if ic_name:
|
|
56
|
+
kwargs['InferenceComponentName'] = ic_name
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
response = client.invoke_endpoint(**kwargs)
|
|
60
|
+
body = response['Body'].read().decode('utf-8')
|
|
61
|
+
return json.loads(body)
|
|
62
|
+
except Exception as e:
|
|
63
|
+
return {"error": str(e)}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _score_text(endpoint_name, ic_name, region, prompt, completion):
|
|
67
|
+
"""Score a completion by getting its logprobs via the endpoint.
|
|
68
|
+
|
|
69
|
+
Sends prompt + completion and requests logprobs for the completion tokens.
|
|
70
|
+
Returns sum of token logprobs, or None if logprobs unavailable.
|
|
71
|
+
"""
|
|
72
|
+
messages = [
|
|
73
|
+
{"role": "user", "content": prompt},
|
|
74
|
+
{"role": "assistant", "content": completion},
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
payload = {
|
|
78
|
+
"messages": messages,
|
|
79
|
+
"max_tokens": 1,
|
|
80
|
+
"temperature": 0.0,
|
|
81
|
+
"logprobs": True,
|
|
82
|
+
"top_logprobs": 1,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
|
|
86
|
+
|
|
87
|
+
if "error" in response:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
# Extract logprobs from response
|
|
91
|
+
try:
|
|
92
|
+
choices = response.get("choices", [])
|
|
93
|
+
if not choices:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
# For scoring, we need the logprobs of the completion tokens
|
|
97
|
+
# The response format varies — try OpenAI-compatible format
|
|
98
|
+
logprobs_data = choices[0].get("logprobs")
|
|
99
|
+
if logprobs_data and "content" in logprobs_data:
|
|
100
|
+
token_logprobs = [t.get("logprob", 0.0) for t in logprobs_data["content"]]
|
|
101
|
+
return sum(token_logprobs) if token_logprobs else None
|
|
102
|
+
|
|
103
|
+
return None
|
|
104
|
+
except (KeyError, TypeError, IndexError):
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _generate_response(endpoint_name, ic_name, region, prompt, max_tokens=256):
|
|
109
|
+
"""Generate a response from the endpoint for generation-based metrics.
|
|
110
|
+
|
|
111
|
+
Returns: generated text string, or None on failure.
|
|
112
|
+
"""
|
|
113
|
+
payload = {
|
|
114
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
115
|
+
"max_tokens": max_tokens,
|
|
116
|
+
"temperature": 0.0,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
|
|
120
|
+
|
|
121
|
+
if "error" in response:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
choices = response.get("choices", [])
|
|
126
|
+
if choices:
|
|
127
|
+
return choices[0].get("message", {}).get("content", "")
|
|
128
|
+
return None
|
|
129
|
+
except (KeyError, TypeError, IndexError):
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# ── Metric computation ────────────────────────────────────────────────────────
|
|
134
|
+
|
|
135
|
+
def _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples):
|
|
136
|
+
"""Compute SFT evaluation metrics.
|
|
137
|
+
|
|
138
|
+
Metrics: perplexity (via logprobs), avg_response_length, format_compliance, exact_match
|
|
139
|
+
"""
|
|
140
|
+
metrics = {}
|
|
141
|
+
logprob_scores = []
|
|
142
|
+
response_lengths = []
|
|
143
|
+
exact_matches = 0
|
|
144
|
+
total = 0
|
|
145
|
+
|
|
146
|
+
for i, record in enumerate(dataset):
|
|
147
|
+
if samples and i >= samples:
|
|
148
|
+
break
|
|
149
|
+
|
|
150
|
+
prompt = record.get("prompt", "")
|
|
151
|
+
reference = record.get("reference", "")
|
|
152
|
+
|
|
153
|
+
if not prompt:
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
total += 1
|
|
157
|
+
|
|
158
|
+
# Score via logprobs (for perplexity)
|
|
159
|
+
if reference:
|
|
160
|
+
score = _score_text(endpoint_name, ic_name, region, prompt, reference)
|
|
161
|
+
if score is not None:
|
|
162
|
+
# Approximate per-token logprob
|
|
163
|
+
# score is sum of logprobs; we need per-token average
|
|
164
|
+
# Estimate token count from character length (rough: 4 chars/token)
|
|
165
|
+
est_tokens = max(1, len(reference) // 4)
|
|
166
|
+
logprob_scores.append(score / est_tokens)
|
|
167
|
+
|
|
168
|
+
# Generate response (for length and exact match)
|
|
169
|
+
generated = _generate_response(endpoint_name, ic_name, region, prompt)
|
|
170
|
+
if generated is not None:
|
|
171
|
+
response_lengths.append(len(generated.split()))
|
|
172
|
+
if reference and generated.strip() == reference.strip():
|
|
173
|
+
exact_matches += 1
|
|
174
|
+
|
|
175
|
+
# Compute aggregate metrics
|
|
176
|
+
if logprob_scores:
|
|
177
|
+
avg_logprob = sum(logprob_scores) / len(logprob_scores)
|
|
178
|
+
metrics["perplexity"] = round(math.exp(-avg_logprob), 4)
|
|
179
|
+
|
|
180
|
+
if response_lengths:
|
|
181
|
+
metrics["avg_response_length"] = round(sum(response_lengths) / len(response_lengths), 1)
|
|
182
|
+
|
|
183
|
+
if total > 0:
|
|
184
|
+
metrics["exact_match_accuracy"] = round(exact_matches / total, 4)
|
|
185
|
+
|
|
186
|
+
metrics["samples_scored"] = total
|
|
187
|
+
|
|
188
|
+
return metrics
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples):
|
|
192
|
+
"""Compute DPO evaluation metrics.
|
|
193
|
+
|
|
194
|
+
Metrics: reward_accuracy, avg_chosen_logprob, avg_rejected_logprob, reward_margin
|
|
195
|
+
"""
|
|
196
|
+
metrics = {}
|
|
197
|
+
chosen_scores = []
|
|
198
|
+
rejected_scores = []
|
|
199
|
+
reward_correct = 0
|
|
200
|
+
total = 0
|
|
201
|
+
|
|
202
|
+
for i, record in enumerate(dataset):
|
|
203
|
+
if samples and i >= samples:
|
|
204
|
+
break
|
|
205
|
+
|
|
206
|
+
prompt = record.get("prompt", "")
|
|
207
|
+
chosen = record.get("chosen", "")
|
|
208
|
+
rejected = record.get("rejected", "")
|
|
209
|
+
|
|
210
|
+
if not prompt or not chosen or not rejected:
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
total += 1
|
|
214
|
+
|
|
215
|
+
# Score chosen
|
|
216
|
+
chosen_score = _score_text(endpoint_name, ic_name, region, prompt, chosen)
|
|
217
|
+
# Score rejected
|
|
218
|
+
rejected_score = _score_text(endpoint_name, ic_name, region, prompt, rejected)
|
|
219
|
+
|
|
220
|
+
if chosen_score is not None and rejected_score is not None:
|
|
221
|
+
chosen_scores.append(chosen_score)
|
|
222
|
+
rejected_scores.append(rejected_score)
|
|
223
|
+
if chosen_score > rejected_score:
|
|
224
|
+
reward_correct += 1
|
|
225
|
+
|
|
226
|
+
# Compute aggregate metrics
|
|
227
|
+
scored = len(chosen_scores)
|
|
228
|
+
if scored > 0:
|
|
229
|
+
metrics["reward_accuracy"] = round(reward_correct / scored, 4)
|
|
230
|
+
metrics["avg_chosen_logprob"] = round(sum(chosen_scores) / scored, 4)
|
|
231
|
+
metrics["avg_rejected_logprob"] = round(sum(rejected_scores) / scored, 4)
|
|
232
|
+
metrics["reward_margin"] = round(
|
|
233
|
+
(sum(chosen_scores) - sum(rejected_scores)) / scored, 4
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
metrics["pairs_scored"] = scored
|
|
237
|
+
metrics["samples_evaluated"] = total
|
|
238
|
+
|
|
239
|
+
return metrics
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# ── Dataset loading ───────────────────────────────────────────────────────────
|
|
243
|
+
|
|
244
|
+
def _load_eval_dataset(eval_dataset_path):
|
|
245
|
+
"""Load evaluation dataset from local JSONL file or S3.
|
|
246
|
+
|
|
247
|
+
For this MVP, expects a local JSONL file path.
|
|
248
|
+
S3 and HF resolution is handled by the bash wrapper.
|
|
249
|
+
|
|
250
|
+
Returns: list of dicts
|
|
251
|
+
"""
|
|
252
|
+
records = []
|
|
253
|
+
|
|
254
|
+
if not eval_dataset_path:
|
|
255
|
+
_error_exit("No evaluation dataset specified. Use --eval-dataset <path>")
|
|
256
|
+
|
|
257
|
+
# Handle S3 paths by downloading
|
|
258
|
+
if eval_dataset_path.startswith("s3://"):
|
|
259
|
+
import boto3
|
|
260
|
+
import tempfile
|
|
261
|
+
s3 = boto3.client('s3')
|
|
262
|
+
bucket = eval_dataset_path.split('/')[2]
|
|
263
|
+
key = '/'.join(eval_dataset_path.split('/')[3:])
|
|
264
|
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl')
|
|
265
|
+
s3.download_file(bucket, key, tmp.name)
|
|
266
|
+
eval_dataset_path = tmp.name
|
|
267
|
+
|
|
268
|
+
# Load JSONL
|
|
269
|
+
try:
|
|
270
|
+
with open(eval_dataset_path, 'r') as f:
|
|
271
|
+
for line in f:
|
|
272
|
+
line = line.strip()
|
|
273
|
+
if line:
|
|
274
|
+
records.append(json.loads(line))
|
|
275
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
276
|
+
_error_exit(f"Failed to load eval dataset: {e}")
|
|
277
|
+
|
|
278
|
+
if not records:
|
|
279
|
+
_error_exit("Evaluation dataset is empty")
|
|
280
|
+
|
|
281
|
+
return records
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# ── cmd_evaluate ──────────────────────────────────────────────────────────────
|
|
285
|
+
|
|
286
|
+
def cmd_evaluate(args):
|
|
287
|
+
"""Run evaluation against deployed endpoint.
|
|
288
|
+
|
|
289
|
+
Returns JSON with metrics and metadata.
|
|
290
|
+
"""
|
|
291
|
+
endpoint_name = args.endpoint_name
|
|
292
|
+
ic_name = args.ic_name
|
|
293
|
+
region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
|
|
294
|
+
technique = args.technique or ''
|
|
295
|
+
samples = int(args.samples) if args.samples else None
|
|
296
|
+
|
|
297
|
+
# Load eval dataset
|
|
298
|
+
dataset = _load_eval_dataset(args.eval_dataset)
|
|
299
|
+
|
|
300
|
+
# Determine technique and compute metrics
|
|
301
|
+
if technique.lower() == 'dpo':
|
|
302
|
+
metrics = _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples)
|
|
303
|
+
else:
|
|
304
|
+
# Default to SFT metrics (works for any technique)
|
|
305
|
+
metrics = _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples)
|
|
306
|
+
|
|
307
|
+
# Build result
|
|
308
|
+
result = {
|
|
309
|
+
"adapter_name": args.ic_name,
|
|
310
|
+
"technique": technique or "sft",
|
|
311
|
+
"model": os.environ.get("MODEL_NAME", ""),
|
|
312
|
+
"eval_dataset": args.eval_dataset or "",
|
|
313
|
+
"samples_evaluated": metrics.get("samples_evaluated", metrics.get("samples_scored", 0)),
|
|
314
|
+
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
315
|
+
"metrics": metrics,
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
_output(result)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# ── cmd_eval_write ────────────────────────────────────────────────────────────
|
|
322
|
+
|
|
323
|
+
def cmd_eval_write(args):
|
|
324
|
+
"""Write evaluation results to S3 as Parquet for Athena.
|
|
325
|
+
|
|
326
|
+
Reads a results JSON file and converts to Parquet format.
|
|
327
|
+
"""
|
|
328
|
+
results_file = args.results_file
|
|
329
|
+
bucket = args.bucket
|
|
330
|
+
region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
|
|
331
|
+
|
|
332
|
+
# Read results
|
|
333
|
+
try:
|
|
334
|
+
with open(results_file, 'r') as f:
|
|
335
|
+
data = json.load(f)
|
|
336
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
337
|
+
_error_exit(f"Failed to read results file: {e}")
|
|
338
|
+
|
|
339
|
+
adapter_name = data.get("adapter_name", "unknown")
|
|
340
|
+
technique = data.get("technique", "unknown")
|
|
341
|
+
model = data.get("model", "unknown")
|
|
342
|
+
timestamp = data.get("timestamp", time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))
|
|
343
|
+
|
|
344
|
+
# Build Parquet record
|
|
345
|
+
record = {
|
|
346
|
+
"project_name": os.environ.get("PROJECT_NAME", ""),
|
|
347
|
+
"model_name": model,
|
|
348
|
+
"adapter_name": adapter_name,
|
|
349
|
+
"technique": technique,
|
|
350
|
+
"eval_dataset": data.get("eval_dataset", ""),
|
|
351
|
+
"samples_evaluated": data.get("samples_evaluated", 0),
|
|
352
|
+
"metrics": json.dumps(data.get("metrics", {})),
|
|
353
|
+
"timestamp": timestamp,
|
|
354
|
+
"region": region,
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
# Write as JSON lines (Athena can read JSON as well as Parquet)
|
|
358
|
+
# For MVP, write as JSON lines to S3. Parquet requires pyarrow dep.
|
|
359
|
+
s3_key = f"evaluations/model={model}/adapter={adapter_name}/{timestamp.replace(':', '-')}.json"
|
|
360
|
+
s3_uri = f"s3://{bucket}/{s3_key}"
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
import boto3
|
|
364
|
+
s3 = boto3.client('s3', region_name=region)
|
|
365
|
+
s3.put_object(
|
|
366
|
+
Bucket=bucket,
|
|
367
|
+
Key=s3_key,
|
|
368
|
+
Body=json.dumps(record),
|
|
369
|
+
ContentType='application/json',
|
|
370
|
+
)
|
|
371
|
+
_output({"written": True, "s3_uri": s3_uri})
|
|
372
|
+
except Exception as e:
|
|
373
|
+
_error_exit(f"Failed to write to S3: {e}")
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
|
377
|
+
|
|
378
|
+
def main():
|
|
379
|
+
parser = argparse.ArgumentParser(description='Model Quality Evaluation Helper')
|
|
380
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
381
|
+
|
|
382
|
+
# evaluate
|
|
383
|
+
eval_parser = subparsers.add_parser('evaluate', help='Run evaluation')
|
|
384
|
+
eval_parser.add_argument('--endpoint-name', required=True)
|
|
385
|
+
eval_parser.add_argument('--ic-name', required=True)
|
|
386
|
+
eval_parser.add_argument('--region')
|
|
387
|
+
eval_parser.add_argument('--technique', default='')
|
|
388
|
+
eval_parser.add_argument('--eval-dataset', default='')
|
|
389
|
+
eval_parser.add_argument('--samples', default='')
|
|
390
|
+
eval_parser.add_argument('--metrics', default='')
|
|
391
|
+
|
|
392
|
+
# eval-write
|
|
393
|
+
write_parser = subparsers.add_parser('eval-write', help='Write results to S3')
|
|
394
|
+
write_parser.add_argument('--results-file', required=True)
|
|
395
|
+
write_parser.add_argument('--bucket', required=True)
|
|
396
|
+
write_parser.add_argument('--region')
|
|
397
|
+
|
|
398
|
+
args = parser.parse_args()
|
|
399
|
+
|
|
400
|
+
if args.command == 'evaluate':
|
|
401
|
+
cmd_evaluate(args)
|
|
402
|
+
elif args.command == 'eval-write':
|
|
403
|
+
cmd_eval_write(args)
|
|
404
|
+
else:
|
|
405
|
+
_error_exit(f"Unknown command: {args.command}")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
if __name__ == '__main__':
|
|
409
|
+
main()
|