@aws/ml-container-creator 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -2,65 +2,91 @@
2
2
  // SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  /**
5
- * Tune Config State Manager
5
+ * Tune Config State
6
6
  *
7
- * JavaScript module that mimics the bash _update_config_var() behavior
8
- * from do/tune for testing purposes. Manages config variables written
9
- * after job submission.
7
+ * Manages bash-style config files (do/config) that contain lines like:
8
+ * export VAR_NAME="value"
9
+ *
10
+ * Provides read/write access for tuning job state variables.
10
11
  */
11
12
 
12
13
  import { readFileSync, writeFileSync } from 'node:fs';
13
14
 
14
15
  /**
15
- * Update or add a config variable in a do/config-style file.
16
- * Mimics the bash _update_config_var() function:
17
- * - If the variable exists (line starts with `export VAR_NAME=`), replace it
18
- * - Otherwise, append a new line
16
+ * Read a variable value from a bash config file.
17
+ * Looks for lines matching: export VAR_NAME="value", export VAR_NAME='value', or export VAR_NAME=value
19
18
  *
20
19
  * @param {string} configPath - Path to the config file
21
- * @param {string} varName - Variable name (e.g., TUNE_JOB_NAME_SFT)
22
- * @param {string} varValue - Variable value
20
+ * @param {string} varName - Variable name to read
21
+ * @returns {string|null} The unquoted value, or null if not found
23
22
  */
24
- export function updateConfigVar(configPath, varName, varValue) {
25
- let content = readFileSync(configPath, 'utf8');
26
- const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
23
+ export function readConfigVar(configPath, varName) {
24
+ const content = readFileSync(configPath, 'utf8');
25
+ const lines = content.split('\n');
27
26
 
28
- if (pattern.test(content)) {
29
- content = content.replace(pattern, `export ${varName}="${varValue}"`);
30
- } else {
31
- if (content.length > 0 && !content.endsWith('\n')) {
32
- content += '\n';
27
+ for (const line of lines) {
28
+ const trimmed = line.trim();
29
+ const prefix = `export ${varName}=`;
30
+ if (trimmed.startsWith(prefix)) {
31
+ let value = trimmed.slice(prefix.length);
32
+ // Strip surrounding quotes (double or single)
33
+ if ((value.startsWith('"') && value.endsWith('"')) ||
34
+ (value.startsWith('\'') && value.endsWith('\''))) {
35
+ value = value.slice(1, -1);
36
+ }
37
+ return value;
33
38
  }
34
- content += `export ${varName}="${varValue}"\n`;
35
39
  }
36
40
 
37
- writeFileSync(configPath, content, 'utf8');
41
+ return null;
38
42
  }
39
43
 
40
44
  /**
41
- * Read a config variable from a do/config-style file.
45
+ * Write or update a variable in a bash config file.
46
+ * If the variable already exists, replaces that line.
47
+ * If not, appends the new export line.
42
48
  *
43
49
  * @param {string} configPath - Path to the config file
44
- * @param {string} varName - Variable name to read
45
- * @returns {string|null} The variable value, or null if not found
50
+ * @param {string} varName - Variable name to set
51
+ * @param {string} value - Value to assign
46
52
  */
47
- export function readConfigVar(configPath, varName) {
53
+ export function updateConfigVar(configPath, varName, value) {
48
54
  const content = readFileSync(configPath, 'utf8');
49
- const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
50
- const match = content.match(pattern);
51
- return match ? match[1] : null;
55
+ const lines = content.split('\n');
56
+ const prefix = `export ${varName}=`;
57
+ const newLine = `export ${varName}="${value}"`;
58
+
59
+ let found = false;
60
+ for (let i = 0; i < lines.length; i++) {
61
+ if (lines[i].trim().startsWith(prefix)) {
62
+ lines[i] = newLine;
63
+ found = true;
64
+ break;
65
+ }
66
+ }
67
+
68
+ if (found) {
69
+ writeFileSync(configPath, lines.join('\n'), 'utf8');
70
+ } else {
71
+ // Append to end of file
72
+ let appendContent = content;
73
+ if (appendContent.length > 0 && !appendContent.endsWith('\n')) {
74
+ appendContent += '\n';
75
+ }
76
+ appendContent += `${newLine }\n`;
77
+ writeFileSync(configPath, appendContent, 'utf8');
78
+ }
52
79
  }
53
80
 
54
81
  /**
55
- * Simulate the config writes that happen after a successful job submission.
56
- * This mirrors the behavior in do/tune's _submit_job() function.
82
+ * Write tuning job submission state to config.
57
83
  *
58
84
  * @param {string} configPath - Path to the config file
59
- * @param {object} params - Submission parameters
60
- * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
61
- * @param {string} params.trainingType - Training type (lora, full-rank)
62
- * @param {string} params.datasetPath - Dataset path (s3://... or hf://...)
63
- * @param {string} params.jobName - Generated job name
85
+ * @param {object} state - Submission state
86
+ * @param {string} state.technique - Tuning technique (e.g., 'sft', 'dpo')
87
+ * @param {string} state.trainingType - Training type (e.g., 'lora', 'full-rank')
88
+ * @param {string} state.datasetPath - Dataset path (S3 or HF URI)
89
+ * @param {string} state.jobName - Generated job name
64
90
  */
65
91
  export function persistSubmissionState(configPath, { technique, trainingType, datasetPath, jobName }) {
66
92
  const techniqueUpper = technique.toUpperCase();
@@ -71,59 +97,54 @@ export function persistSubmissionState(configPath, { technique, trainingType, da
71
97
  }
72
98
 
73
99
  /**
74
- * Simulate the config writes that happen after a job completes successfully.
75
- * This mirrors the behavior in do/tune's _handle_completion() function.
76
- *
77
- * Writes three levels of tracking (AC-4.1, AC-4.2):
78
- * - Level 1: TUNE_OUTPUT_PATH_LATEST (always the last run, any technique)
79
- * - Level 2: TUNE_ADAPTER_PATH_<TECHNIQUE> (last run per technique)
80
- * - Level 3: TUNE_ADAPTER_PATH_<TECHNIQUE>_<SLUG> (per technique + dataset slug)
100
+ * Write tuning job completion state to config.
81
101
  *
82
102
  * @param {string} configPath - Path to the config file
83
- * @param {object} params - Completion parameters
84
- * @param {string} params.technique - Technique (sft, dpo, rlaif, rlvr)
85
- * @param {string} params.trainingType - Training type (lora, full-rank)
86
- * @param {string} params.artifactPath - S3 path to the output artifact
87
- * @param {string} params.outputType - Output type (adapter, full-model)
88
- * @param {string} [params.datasetSlug] - Optional dataset slug for per-technique-per-dataset tracking
103
+ * @param {object} state - Completion state
104
+ * @param {string} state.technique - Tuning technique
105
+ * @param {string} state.trainingType - Training type
106
+ * @param {string} state.artifactPath - Output artifact path (S3 URI)
107
+ * @param {string} state.outputType - Output type ('adapter' or 'model')
108
+ * @param {string} [state.datasetSlug] - Dataset slug for named paths
89
109
  */
90
- export function persistCompletionState(configPath, { technique, trainingType, artifactPath, outputType, datasetSlug }) {
110
+ export function persistCompletionState(configPath, { technique, trainingType: _trainingType, artifactPath, outputType, datasetSlug }) {
91
111
  const techniqueUpper = technique.toUpperCase();
92
112
 
93
- if (trainingType === 'lora') {
94
- // Level 2: per-technique
113
+ updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
114
+ updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
115
+
116
+ if (outputType === 'adapter') {
95
117
  updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}`, artifactPath);
96
- // Level 3: per-technique + per-dataset (if slug available)
97
118
  if (datasetSlug) {
98
119
  const slugUpper = datasetSlug.toUpperCase().replace(/-/g, '_');
99
120
  updateConfigVar(configPath, `TUNE_ADAPTER_PATH_${techniqueUpper}_${slugUpper}`, artifactPath);
100
121
  }
101
- } else if (trainingType === 'full-rank') {
122
+ } else {
102
123
  updateConfigVar(configPath, `TUNE_MODEL_PATH_${techniqueUpper}`, artifactPath);
103
124
  }
104
-
105
- // Level 1: latest
106
- updateConfigVar(configPath, 'TUNE_OUTPUT_PATH_LATEST', artifactPath);
107
- updateConfigVar(configPath, 'TUNE_OUTPUT_TYPE_LATEST', outputType);
108
125
  }
109
126
 
110
127
  /**
111
- * Generate a job name following the pattern used by do/tune.
112
- * Pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
128
+ * Generate a job name matching pattern: ${projectName}-tune-${technique}-YYYYMMDD-HHMMSS
129
+ * Uses local time for the timestamp.
113
130
  *
114
131
  * @param {string} projectName - Project name
115
- * @param {string} technique - Technique (sft, dpo, rlaif, rlvr)
116
- * @param {Date} [timestamp] - Optional timestamp (defaults to now)
117
- * @returns {string} Generated job name
132
+ * @param {string} technique - Tuning technique
133
+ * @param {Date} [timestamp] - Optional timestamp (defaults to new Date())
134
+ * @returns {string} Formatted job name
118
135
  */
119
- export function generateJobName(projectName, technique, timestamp = new Date()) {
120
- const year = timestamp.getFullYear().toString();
121
- const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
122
- const day = timestamp.getDate().toString().padStart(2, '0');
123
- const hours = timestamp.getHours().toString().padStart(2, '0');
124
- const minutes = timestamp.getMinutes().toString().padStart(2, '0');
125
- const seconds = timestamp.getSeconds().toString().padStart(2, '0');
136
+ export function generateJobName(projectName, technique, timestamp) {
137
+ const ts = timestamp || new Date();
138
+
139
+ const year = ts.getFullYear().toString();
140
+ const month = (ts.getMonth() + 1).toString().padStart(2, '0');
141
+ const day = ts.getDate().toString().padStart(2, '0');
142
+ const hours = ts.getHours().toString().padStart(2, '0');
143
+ const minutes = ts.getMinutes().toString().padStart(2, '0');
144
+ const seconds = ts.getSeconds().toString().padStart(2, '0');
145
+
126
146
  const dateStr = `${year}${month}${day}`;
127
147
  const timeStr = `${hours}${minutes}${seconds}`;
148
+
128
149
  return `${projectName}-tune-${technique}-${dateStr}-${timeStr}`;
129
150
  }
@@ -1478,6 +1478,7 @@ def _load_config_file(config_path):
1478
1478
  'HF_MODEL_ID': 'hf_model_id',
1479
1479
  'INSTANCE_TYPE': 'instance_type',
1480
1480
  'INSTANCE_POOLS': 'instance_pools',
1481
+ 'DEPLOYED_INSTANCE_TYPE': 'deployed_instance_type',
1481
1482
  'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
1482
1483
  'DEPLOYMENT_CONFIG': 'deployment_config',
1483
1484
  'DEPLOYMENT_TARGET': 'deployment_target',
@@ -1521,6 +1522,8 @@ def _load_config_file(config_path):
1521
1522
  # BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
1522
1523
  if context.get('benchmark_instance_type'):
1523
1524
  context['instance_type'] = context.pop('benchmark_instance_type')
1525
+ elif context.get('deployed_instance_type'):
1526
+ context['instance_type'] = context.pop('deployed_instance_type')
1524
1527
  # Fall back to INSTANCE_POOLS when neither is set.
1525
1528
  # Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
1526
1529
  # but always define INSTANCE_POOLS as a JSON array with Priority fields.
@@ -0,0 +1,409 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Model Quality Evaluation Helper.
6
+
7
+ Subcommands:
8
+ evaluate - Run evaluation against deployed endpoint, compute metrics
9
+ eval-write - Write evaluation results to S3/Athena (Parquet)
10
+
11
+ All output is JSON on stdout for bash consumption.
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import math
17
+ import os
18
+ import sys
19
+ import time
20
+
21
+
22
+ # ── Utility functions ─────────────────────────────────────────────────────────
23
+
24
+ def _error_exit(message):
25
+ """Print JSON error to stdout and exit."""
26
+ print(json.dumps({"error": True, "message": message}))
27
+ sys.exit(1)
28
+
29
+
30
+ def _output(data):
31
+ """Print JSON result to stdout."""
32
+ print(json.dumps(data))
33
+ sys.exit(0)
34
+
35
+
36
+ # ── Endpoint invocation ───────────────────────────────────────────────────────
37
+
38
+ def _invoke_endpoint(endpoint_name, ic_name, region, payload):
39
+ """Invoke SageMaker endpoint via boto3 runtime.
40
+
41
+ Uses InvokeEndpoint with InferenceComponentName header for IC routing.
42
+ Payload should be an OpenAI-compatible chat completion request.
43
+
44
+ Returns: parsed JSON response dict
45
+ """
46
+ import boto3
47
+
48
+ client = boto3.client('sagemaker-runtime', region_name=region)
49
+
50
+ kwargs = {
51
+ 'EndpointName': endpoint_name,
52
+ 'ContentType': 'application/json',
53
+ 'Body': json.dumps(payload),
54
+ }
55
+ if ic_name:
56
+ kwargs['InferenceComponentName'] = ic_name
57
+
58
+ try:
59
+ response = client.invoke_endpoint(**kwargs)
60
+ body = response['Body'].read().decode('utf-8')
61
+ return json.loads(body)
62
+ except Exception as e:
63
+ return {"error": str(e)}
64
+
65
+
66
+ def _score_text(endpoint_name, ic_name, region, prompt, completion):
67
+ """Score a completion by getting its logprobs via the endpoint.
68
+
69
+ Sends prompt + completion and requests logprobs for the completion tokens.
70
+ Returns sum of token logprobs, or None if logprobs unavailable.
71
+ """
72
+ messages = [
73
+ {"role": "user", "content": prompt},
74
+ {"role": "assistant", "content": completion},
75
+ ]
76
+
77
+ payload = {
78
+ "messages": messages,
79
+ "max_tokens": 1,
80
+ "temperature": 0.0,
81
+ "logprobs": True,
82
+ "top_logprobs": 1,
83
+ }
84
+
85
+ response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
86
+
87
+ if "error" in response:
88
+ return None
89
+
90
+ # Extract logprobs from response
91
+ try:
92
+ choices = response.get("choices", [])
93
+ if not choices:
94
+ return None
95
+
96
+ # For scoring, we need the logprobs of the completion tokens
97
+ # The response format varies — try OpenAI-compatible format
98
+ logprobs_data = choices[0].get("logprobs")
99
+ if logprobs_data and "content" in logprobs_data:
100
+ token_logprobs = [t.get("logprob", 0.0) for t in logprobs_data["content"]]
101
+ return sum(token_logprobs) if token_logprobs else None
102
+
103
+ return None
104
+ except (KeyError, TypeError, IndexError):
105
+ return None
106
+
107
+
108
+ def _generate_response(endpoint_name, ic_name, region, prompt, max_tokens=256):
109
+ """Generate a response from the endpoint for generation-based metrics.
110
+
111
+ Returns: generated text string, or None on failure.
112
+ """
113
+ payload = {
114
+ "messages": [{"role": "user", "content": prompt}],
115
+ "max_tokens": max_tokens,
116
+ "temperature": 0.0,
117
+ }
118
+
119
+ response = _invoke_endpoint(endpoint_name, ic_name, region, payload)
120
+
121
+ if "error" in response:
122
+ return None
123
+
124
+ try:
125
+ choices = response.get("choices", [])
126
+ if choices:
127
+ return choices[0].get("message", {}).get("content", "")
128
+ return None
129
+ except (KeyError, TypeError, IndexError):
130
+ return None
131
+
132
+
133
+ # ── Metric computation ────────────────────────────────────────────────────────
134
+
135
+ def _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples):
136
+ """Compute SFT evaluation metrics.
137
+
138
+ Metrics: perplexity (via logprobs), avg_response_length, format_compliance, exact_match
139
+ """
140
+ metrics = {}
141
+ logprob_scores = []
142
+ response_lengths = []
143
+ exact_matches = 0
144
+ total = 0
145
+
146
+ for i, record in enumerate(dataset):
147
+ if samples and i >= samples:
148
+ break
149
+
150
+ prompt = record.get("prompt", "")
151
+ reference = record.get("reference", "")
152
+
153
+ if not prompt:
154
+ continue
155
+
156
+ total += 1
157
+
158
+ # Score via logprobs (for perplexity)
159
+ if reference:
160
+ score = _score_text(endpoint_name, ic_name, region, prompt, reference)
161
+ if score is not None:
162
+ # Approximate per-token logprob
163
+ # score is sum of logprobs; we need per-token average
164
+ # Estimate token count from character length (rough: 4 chars/token)
165
+ est_tokens = max(1, len(reference) // 4)
166
+ logprob_scores.append(score / est_tokens)
167
+
168
+ # Generate response (for length and exact match)
169
+ generated = _generate_response(endpoint_name, ic_name, region, prompt)
170
+ if generated is not None:
171
+ response_lengths.append(len(generated.split()))
172
+ if reference and generated.strip() == reference.strip():
173
+ exact_matches += 1
174
+
175
+ # Compute aggregate metrics
176
+ if logprob_scores:
177
+ avg_logprob = sum(logprob_scores) / len(logprob_scores)
178
+ metrics["perplexity"] = round(math.exp(-avg_logprob), 4)
179
+
180
+ if response_lengths:
181
+ metrics["avg_response_length"] = round(sum(response_lengths) / len(response_lengths), 1)
182
+
183
+ if total > 0:
184
+ metrics["exact_match_accuracy"] = round(exact_matches / total, 4)
185
+
186
+ metrics["samples_scored"] = total
187
+
188
+ return metrics
189
+
190
+
191
+ def _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples):
192
+ """Compute DPO evaluation metrics.
193
+
194
+ Metrics: reward_accuracy, avg_chosen_logprob, avg_rejected_logprob, reward_margin
195
+ """
196
+ metrics = {}
197
+ chosen_scores = []
198
+ rejected_scores = []
199
+ reward_correct = 0
200
+ total = 0
201
+
202
+ for i, record in enumerate(dataset):
203
+ if samples and i >= samples:
204
+ break
205
+
206
+ prompt = record.get("prompt", "")
207
+ chosen = record.get("chosen", "")
208
+ rejected = record.get("rejected", "")
209
+
210
+ if not prompt or not chosen or not rejected:
211
+ continue
212
+
213
+ total += 1
214
+
215
+ # Score chosen
216
+ chosen_score = _score_text(endpoint_name, ic_name, region, prompt, chosen)
217
+ # Score rejected
218
+ rejected_score = _score_text(endpoint_name, ic_name, region, prompt, rejected)
219
+
220
+ if chosen_score is not None and rejected_score is not None:
221
+ chosen_scores.append(chosen_score)
222
+ rejected_scores.append(rejected_score)
223
+ if chosen_score > rejected_score:
224
+ reward_correct += 1
225
+
226
+ # Compute aggregate metrics
227
+ scored = len(chosen_scores)
228
+ if scored > 0:
229
+ metrics["reward_accuracy"] = round(reward_correct / scored, 4)
230
+ metrics["avg_chosen_logprob"] = round(sum(chosen_scores) / scored, 4)
231
+ metrics["avg_rejected_logprob"] = round(sum(rejected_scores) / scored, 4)
232
+ metrics["reward_margin"] = round(
233
+ (sum(chosen_scores) - sum(rejected_scores)) / scored, 4
234
+ )
235
+
236
+ metrics["pairs_scored"] = scored
237
+ metrics["samples_evaluated"] = total
238
+
239
+ return metrics
240
+
241
+
242
+ # ── Dataset loading ───────────────────────────────────────────────────────────
243
+
244
+ def _load_eval_dataset(eval_dataset_path):
245
+ """Load evaluation dataset from local JSONL file or S3.
246
+
247
+ For this MVP, expects a local JSONL file path.
248
+ S3 and HF resolution is handled by the bash wrapper.
249
+
250
+ Returns: list of dicts
251
+ """
252
+ records = []
253
+
254
+ if not eval_dataset_path:
255
+ _error_exit("No evaluation dataset specified. Use --eval-dataset <path>")
256
+
257
+ # Handle S3 paths by downloading
258
+ if eval_dataset_path.startswith("s3://"):
259
+ import boto3
260
+ import tempfile
261
+ s3 = boto3.client('s3')
262
+ bucket = eval_dataset_path.split('/')[2]
263
+ key = '/'.join(eval_dataset_path.split('/')[3:])
264
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jsonl')
265
+ s3.download_file(bucket, key, tmp.name)
266
+ eval_dataset_path = tmp.name
267
+
268
+ # Load JSONL
269
+ try:
270
+ with open(eval_dataset_path, 'r') as f:
271
+ for line in f:
272
+ line = line.strip()
273
+ if line:
274
+ records.append(json.loads(line))
275
+ except (IOError, json.JSONDecodeError) as e:
276
+ _error_exit(f"Failed to load eval dataset: {e}")
277
+
278
+ if not records:
279
+ _error_exit("Evaluation dataset is empty")
280
+
281
+ return records
282
+
283
+
284
+ # ── cmd_evaluate ──────────────────────────────────────────────────────────────
285
+
286
+ def cmd_evaluate(args):
287
+ """Run evaluation against deployed endpoint.
288
+
289
+ Returns JSON with metrics and metadata.
290
+ """
291
+ endpoint_name = args.endpoint_name
292
+ ic_name = args.ic_name
293
+ region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
294
+ technique = args.technique or ''
295
+ samples = int(args.samples) if args.samples else None
296
+
297
+ # Load eval dataset
298
+ dataset = _load_eval_dataset(args.eval_dataset)
299
+
300
+ # Determine technique and compute metrics
301
+ if technique.lower() == 'dpo':
302
+ metrics = _compute_dpo_metrics(endpoint_name, ic_name, region, dataset, samples)
303
+ else:
304
+ # Default to SFT metrics (works for any technique)
305
+ metrics = _compute_sft_metrics(endpoint_name, ic_name, region, dataset, samples)
306
+
307
+ # Build result
308
+ result = {
309
+ "adapter_name": args.ic_name,
310
+ "technique": technique or "sft",
311
+ "model": os.environ.get("MODEL_NAME", ""),
312
+ "eval_dataset": args.eval_dataset or "",
313
+ "samples_evaluated": metrics.get("samples_evaluated", metrics.get("samples_scored", 0)),
314
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
315
+ "metrics": metrics,
316
+ }
317
+
318
+ _output(result)
319
+
320
+
321
+ # ── cmd_eval_write ────────────────────────────────────────────────────────────
322
+
323
+ def cmd_eval_write(args):
324
+ """Write evaluation results to S3 as Parquet for Athena.
325
+
326
+ Reads a results JSON file and converts to Parquet format.
327
+ """
328
+ results_file = args.results_file
329
+ bucket = args.bucket
330
+ region = args.region or os.environ.get('AWS_DEFAULT_REGION', 'us-east-1')
331
+
332
+ # Read results
333
+ try:
334
+ with open(results_file, 'r') as f:
335
+ data = json.load(f)
336
+ except (IOError, json.JSONDecodeError) as e:
337
+ _error_exit(f"Failed to read results file: {e}")
338
+
339
+ adapter_name = data.get("adapter_name", "unknown")
340
+ technique = data.get("technique", "unknown")
341
+ model = data.get("model", "unknown")
342
+ timestamp = data.get("timestamp", time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))
343
+
344
+ # Build Parquet record
345
+ record = {
346
+ "project_name": os.environ.get("PROJECT_NAME", ""),
347
+ "model_name": model,
348
+ "adapter_name": adapter_name,
349
+ "technique": technique,
350
+ "eval_dataset": data.get("eval_dataset", ""),
351
+ "samples_evaluated": data.get("samples_evaluated", 0),
352
+ "metrics": json.dumps(data.get("metrics", {})),
353
+ "timestamp": timestamp,
354
+ "region": region,
355
+ }
356
+
357
+ # Write as JSON lines (Athena can read JSON as well as Parquet)
358
+ # For MVP, write as JSON lines to S3. Parquet requires pyarrow dep.
359
+ s3_key = f"evaluations/model={model}/adapter={adapter_name}/{timestamp.replace(':', '-')}.json"
360
+ s3_uri = f"s3://{bucket}/{s3_key}"
361
+
362
+ try:
363
+ import boto3
364
+ s3 = boto3.client('s3', region_name=region)
365
+ s3.put_object(
366
+ Bucket=bucket,
367
+ Key=s3_key,
368
+ Body=json.dumps(record),
369
+ ContentType='application/json',
370
+ )
371
+ _output({"written": True, "s3_uri": s3_uri})
372
+ except Exception as e:
373
+ _error_exit(f"Failed to write to S3: {e}")
374
+
375
+
376
+ # ── Main ──────────────────────────────────────────────────────────────────────
377
+
378
+ def main():
379
+ parser = argparse.ArgumentParser(description='Model Quality Evaluation Helper')
380
+ subparsers = parser.add_subparsers(dest='command', required=True)
381
+
382
+ # evaluate
383
+ eval_parser = subparsers.add_parser('evaluate', help='Run evaluation')
384
+ eval_parser.add_argument('--endpoint-name', required=True)
385
+ eval_parser.add_argument('--ic-name', required=True)
386
+ eval_parser.add_argument('--region')
387
+ eval_parser.add_argument('--technique', default='')
388
+ eval_parser.add_argument('--eval-dataset', default='')
389
+ eval_parser.add_argument('--samples', default='')
390
+ eval_parser.add_argument('--metrics', default='')
391
+
392
+ # eval-write
393
+ write_parser = subparsers.add_parser('eval-write', help='Write results to S3')
394
+ write_parser.add_argument('--results-file', required=True)
395
+ write_parser.add_argument('--bucket', required=True)
396
+ write_parser.add_argument('--region')
397
+
398
+ args = parser.parse_args()
399
+
400
+ if args.command == 'evaluate':
401
+ cmd_evaluate(args)
402
+ elif args.command == 'eval-write':
403
+ cmd_eval_write(args)
404
+ else:
405
+ _error_exit(f"Unknown command: {args.command}")
406
+
407
+
408
+ if __name__ == '__main__':
409
+ main()