@aws/ml-container-creator 0.10.0 → 0.10.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +9304 -0
- package/bin/cli.js +2 -0
- package/config/bootstrap-e2e-stack.json +341 -0
- package/config/bootstrap-stack.json +40 -3
- package/config/parameter-schema-v2.json +5 -21
- package/config/tune-catalog.json +1781 -0
- package/infra/ci-harness/buildspec.yml +1 -0
- package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
- package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +837 -7
- package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
- package/package.json +51 -66
- package/servers/base-image-picker/index.js +121 -121
- package/servers/e2e-status/index.js +297 -0
- package/servers/e2e-status/manifest.json +14 -0
- package/servers/e2e-status/package.json +15 -0
- package/servers/endpoint-picker/LICENSE +202 -0
- package/servers/endpoint-picker/index.js +536 -0
- package/servers/endpoint-picker/manifest.json +14 -0
- package/servers/endpoint-picker/package.json +18 -0
- package/servers/hyperpod-cluster-picker/index.js +125 -125
- package/servers/instance-sizer/index.js +138 -138
- package/servers/instance-sizer/lib/instance-ranker.js +76 -76
- package/servers/instance-sizer/lib/model-resolver.js +61 -61
- package/servers/instance-sizer/lib/quota-resolver.js +113 -113
- package/servers/instance-sizer/lib/vram-estimator.js +31 -31
- package/servers/lib/bedrock-client.js +38 -38
- package/servers/lib/catalogs/model-servers.json +201 -3
- package/servers/lib/custom-validators.js +13 -13
- package/servers/lib/dynamic-resolver.js +4 -4
- package/servers/marketplace-picker/index.js +342 -0
- package/servers/marketplace-picker/manifest.json +14 -0
- package/servers/marketplace-picker/package.json +18 -0
- package/servers/model-picker/index.js +382 -382
- package/servers/region-picker/index.js +56 -56
- package/servers/workload-picker/LICENSE +202 -0
- package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
- package/servers/workload-picker/index.js +171 -0
- package/servers/workload-picker/manifest.json +16 -0
- package/servers/workload-picker/package.json +16 -0
- package/src/app.js +4 -2
- package/src/lib/bootstrap-command-handler.js +579 -14
- package/src/lib/bootstrap-config.js +36 -0
- package/src/lib/bootstrap-profile-manager.js +48 -41
- package/src/lib/ci-register-helpers.js +74 -0
- package/src/lib/config-loader.js +3 -0
- package/src/lib/config-manager.js +7 -0
- package/src/lib/cuda-resolver.js +17 -8
- package/src/lib/generated/cli-options.js +315 -315
- package/src/lib/generated/parameter-matrix.js +661 -661
- package/src/lib/generated/validation-rules.js +71 -71
- package/src/lib/path-prover-brain.js +607 -0
- package/src/lib/prompts/project-prompts.js +12 -0
- package/src/lib/template-variable-resolver.js +25 -1
- package/src/lib/tune-catalog-validator.js +37 -4
- package/templates/Dockerfile +9 -0
- package/templates/code/adapter_sidecar.py +444 -0
- package/templates/code/serve +6 -0
- package/templates/code/serve.d/vllm.ejs +1 -1
- package/templates/do/.benchmark_writer.py +1476 -0
- package/templates/do/.tune_helper.py +982 -57
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/adapter +149 -0
- package/templates/do/benchmark +639 -85
- package/templates/do/config +108 -5
- package/templates/do/deploy.d/managed-inference.ejs +192 -11
- package/templates/do/optimize +106 -37
- package/templates/do/register +89 -0
- package/templates/do/test +13 -0
- package/templates/do/tune +378 -59
- package/templates/do/validate +44 -4
|
@@ -0,0 +1,1476 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Benchmark Writer — Converts do/benchmark output to enriched Parquet for Athena.
|
|
6
|
+
|
|
7
|
+
Subcommands:
|
|
8
|
+
write - Validate, enrich, and write benchmark results to S3 as Parquet
|
|
9
|
+
|
|
10
|
+
All output is JSON on stdout for bash consumption.
|
|
11
|
+
Errors are structured JSON objects — never raw tracebacks.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import sys
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
23
|
+
|
|
24
|
+
REQUIRED_FIELDS = [
|
|
25
|
+
'project_name',
|
|
26
|
+
'model_name',
|
|
27
|
+
'instance_type',
|
|
28
|
+
'deployment_config',
|
|
29
|
+
'region',
|
|
30
|
+
'metrics',
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
# Pattern for valid SageMaker instance types: ml.<family>.<size>
|
|
34
|
+
_INSTANCE_TYPE_RE = re.compile(r'^ml\.[a-z0-9]+\.[a-z0-9]+$')
|
|
35
|
+
|
|
36
|
+
# Known model family patterns — maps regex to family label
|
|
37
|
+
# Known model family patterns — maps regex to family label.
|
|
38
|
+
# Patterns are searched against the model identifier (after org/ prefix stripping).
|
|
39
|
+
# Order matters: more specific patterns (e.g., deepseek-r1) must precede generic ones.
|
|
40
|
+
# Version dots are collapsed for family grouping (e.g., Llama-3.1 → llama3).
|
|
41
|
+
_MODEL_FAMILY_PATTERNS = [
|
|
42
|
+
# DeepSeek — must come before qwen/llama because model names may contain those
|
|
43
|
+
# (e.g., "DeepSeek-R1-Distill-Qwen-7B" contains "Qwen")
|
|
44
|
+
(re.compile(r'deepseek[-_.]?r1', re.IGNORECASE), 'deepseek-r1'),
|
|
45
|
+
(re.compile(r'deepseek[-_.]?v3', re.IGNORECASE), 'deepseek-v3'),
|
|
46
|
+
(re.compile(r'deepseek[-_.]?v2', re.IGNORECASE), 'deepseek-v2'),
|
|
47
|
+
(re.compile(r'deepseek[-_.]?coder', re.IGNORECASE), 'deepseek-coder'),
|
|
48
|
+
(re.compile(r'deepseek[-_.]?math', re.IGNORECASE), 'deepseek-math'),
|
|
49
|
+
(re.compile(r'deepseek', re.IGNORECASE), 'deepseek'),
|
|
50
|
+
# Qwen family — version number without dots for family grouping
|
|
51
|
+
(re.compile(r'qwen3', re.IGNORECASE), 'qwen3'),
|
|
52
|
+
(re.compile(r'qwen2', re.IGNORECASE), 'qwen2'),
|
|
53
|
+
(re.compile(r'qwen', re.IGNORECASE), 'qwen'),
|
|
54
|
+
# Llama family — collapse version dots (3.1, 3.2 → llama3)
|
|
55
|
+
(re.compile(r'codellama|code[-_]?llama', re.IGNORECASE), 'codellama'),
|
|
56
|
+
(re.compile(r'llama[-_.]?3', re.IGNORECASE), 'llama3'),
|
|
57
|
+
(re.compile(r'llama[-_.]?2', re.IGNORECASE), 'llama2'),
|
|
58
|
+
(re.compile(r'llama', re.IGNORECASE), 'llama'),
|
|
59
|
+
# Mistral/Mixtral
|
|
60
|
+
(re.compile(r'mixtral', re.IGNORECASE), 'mixtral'),
|
|
61
|
+
(re.compile(r'mistral', re.IGNORECASE), 'mistral'),
|
|
62
|
+
# Microsoft Phi
|
|
63
|
+
(re.compile(r'phi[-_.]?3', re.IGNORECASE), 'phi3'),
|
|
64
|
+
(re.compile(r'phi[-_.]?2', re.IGNORECASE), 'phi2'),
|
|
65
|
+
# Google Gemma
|
|
66
|
+
(re.compile(r'gemma[-_.]?2', re.IGNORECASE), 'gemma2'),
|
|
67
|
+
(re.compile(r'gemma', re.IGNORECASE), 'gemma'),
|
|
68
|
+
# Others
|
|
69
|
+
(re.compile(r'falcon', re.IGNORECASE), 'falcon'),
|
|
70
|
+
(re.compile(r'starcoder', re.IGNORECASE), 'starcoder'),
|
|
71
|
+
(re.compile(r'gpt[-_.]?oss', re.IGNORECASE), 'gpt-oss'),
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
# Approximate on-demand $/hr for common SageMaker AI instances
|
|
75
|
+
INSTANCE_PRICING_USD_PER_HOUR = {
|
|
76
|
+
'g5.xlarge': 1.408,
|
|
77
|
+
'g5.2xlarge': 1.52,
|
|
78
|
+
'g5.4xlarge': 2.03,
|
|
79
|
+
'g5.8xlarge': 3.06,
|
|
80
|
+
'g5.12xlarge': 7.09,
|
|
81
|
+
'g5.16xlarge': 5.10,
|
|
82
|
+
'g5.24xlarge': 10.18,
|
|
83
|
+
'g5.48xlarge': 20.36,
|
|
84
|
+
'g6.xlarge': 1.00,
|
|
85
|
+
'g6.2xlarge': 1.21,
|
|
86
|
+
'g6.4xlarge': 1.62,
|
|
87
|
+
'g6.8xlarge': 2.44,
|
|
88
|
+
'g6.12xlarge': 5.66,
|
|
89
|
+
'g6.16xlarge': 4.07,
|
|
90
|
+
'g6.24xlarge': 7.53,
|
|
91
|
+
'g6.48xlarge': 15.06,
|
|
92
|
+
'g6e.xlarge': 1.86,
|
|
93
|
+
'g6e.2xlarge': 2.35,
|
|
94
|
+
'g6e.4xlarge': 3.34,
|
|
95
|
+
'g6e.12xlarge': 11.67,
|
|
96
|
+
'g6e.48xlarge': 38.12,
|
|
97
|
+
'p4d.24xlarge': 37.69,
|
|
98
|
+
'p5.48xlarge': 65.85,
|
|
99
|
+
'trn2.48xlarge': 21.50,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# ── Utility functions ─────────────────────────────────────────────────────────
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _error_exit(message):
|
|
107
|
+
"""Print JSON error to stdout and exit with code 1."""
|
|
108
|
+
print(json.dumps({"error": message}))
|
|
109
|
+
sys.exit(1)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _output(data):
|
|
113
|
+
"""Print JSON result to stdout."""
|
|
114
|
+
print(json.dumps(data))
|
|
115
|
+
sys.exit(0)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ── Derived field computation ─────────────────────────────────────────────────
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def derive_model_family(model_name):
|
|
122
|
+
"""Derive model family from model_name.
|
|
123
|
+
|
|
124
|
+
Examples:
|
|
125
|
+
"Qwen/Qwen3-4B" → "qwen3"
|
|
126
|
+
"meta-llama/Llama-3.1-8B" → "llama3"
|
|
127
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" → "deepseek-r1"
|
|
128
|
+
|
|
129
|
+
The function:
|
|
130
|
+
1. Strips the org prefix (everything before /)
|
|
131
|
+
2. Matches patterns from most-specific to least-specific
|
|
132
|
+
3. Collapses version dots for family grouping (3.1, 3.2 → 3)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
str — lowercase family identifier, or "other" if no pattern matches,
|
|
136
|
+
or "unknown" if model_name is empty/None.
|
|
137
|
+
"""
|
|
138
|
+
if not model_name:
|
|
139
|
+
return 'unknown'
|
|
140
|
+
|
|
141
|
+
# Strip org prefix: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" → "DeepSeek-R1-Distill-Qwen-7B"
|
|
142
|
+
name = model_name.split('/')[-1] if '/' in model_name else model_name
|
|
143
|
+
|
|
144
|
+
for pattern, family in _MODEL_FAMILY_PATTERNS:
|
|
145
|
+
if pattern.search(name):
|
|
146
|
+
return family
|
|
147
|
+
return 'other'
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# Alias for test compatibility
|
|
151
|
+
compute_model_family = derive_model_family
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def derive_instance_family(instance_type):
|
|
155
|
+
"""Derive instance family from instance_type.
|
|
156
|
+
|
|
157
|
+
Examples:
|
|
158
|
+
"ml.g5.xlarge" → "g5"
|
|
159
|
+
"ml.g6e.2xlarge" → "g6e"
|
|
160
|
+
"ml.p5.48xlarge" → "p5"
|
|
161
|
+
"ml.trn2.xlarge" → "trn2"
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str — instance family identifier, or "unknown" if pattern doesn't match.
|
|
165
|
+
"""
|
|
166
|
+
if not instance_type:
|
|
167
|
+
return 'unknown'
|
|
168
|
+
match = re.match(r'^ml\.([a-z0-9]+)\.[a-z0-9]+$', instance_type)
|
|
169
|
+
if match:
|
|
170
|
+
return match.group(1)
|
|
171
|
+
return 'unknown'
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# Alias for test compatibility
|
|
175
|
+
compute_instance_family = derive_instance_family
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def compute_cost_per_1m_tokens(instance_type, tokens_per_second):
|
|
179
|
+
"""Estimate cost per 1M output tokens (USD).
|
|
180
|
+
|
|
181
|
+
Uses approximate on-demand SageMaker AI instance pricing.
|
|
182
|
+
If tokens_per_second is 0 or unknown, returns None.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
instance_type: SageMaker AI instance type string.
|
|
186
|
+
tokens_per_second: Output tokens/second throughput.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
float or None — estimated USD cost per 1M output tokens.
|
|
190
|
+
"""
|
|
191
|
+
if not instance_type or not tokens_per_second:
|
|
192
|
+
return None
|
|
193
|
+
if tokens_per_second <= 0:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
# Extract instance spec (remove ml. prefix)
|
|
197
|
+
instance_spec = instance_type.replace('ml.', '', 1) if instance_type.startswith('ml.') else instance_type
|
|
198
|
+
cost_per_hour = INSTANCE_PRICING_USD_PER_HOUR.get(instance_spec)
|
|
199
|
+
if cost_per_hour is None:
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
# cost_per_1m_tokens = (cost_per_hour / tokens_per_second / 3600) * 1_000_000
|
|
203
|
+
cost_per_token = cost_per_hour / (tokens_per_second * 3600)
|
|
204
|
+
return round(cost_per_token * 1_000_000, 4)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def compute_partition_keys(timestamp):
|
|
208
|
+
"""Compute year and month partition keys from a timestamp.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
timestamp: One of:
|
|
212
|
+
- datetime object
|
|
213
|
+
- ISO 8601 string ("2026-06-09T14:30:22Z" or "2026-06-09T14:30:22+00:00")
|
|
214
|
+
- Compact string ("20260609T143022Z")
|
|
215
|
+
- None (uses current UTC time)
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
tuple (year: str, month: str) — zero-padded strings.
|
|
219
|
+
"""
|
|
220
|
+
if timestamp is None:
|
|
221
|
+
dt = datetime.now(timezone.utc)
|
|
222
|
+
elif isinstance(timestamp, datetime):
|
|
223
|
+
dt = timestamp
|
|
224
|
+
elif isinstance(timestamp, str):
|
|
225
|
+
# Try ISO 8601 variants
|
|
226
|
+
ts = timestamp.strip()
|
|
227
|
+
try:
|
|
228
|
+
# Standard ISO: 2026-06-09T14:30:22Z or 2026-06-09T14:30:22+00:00
|
|
229
|
+
if 'T' in ts and '-' in ts[:10]:
|
|
230
|
+
ts_clean = ts.replace('Z', '+00:00')
|
|
231
|
+
dt = datetime.fromisoformat(ts_clean)
|
|
232
|
+
elif 'T' in ts:
|
|
233
|
+
# Compact: 20260609T143022Z
|
|
234
|
+
ts_clean = ts.rstrip('Z')
|
|
235
|
+
dt = datetime.strptime(ts_clean, '%Y%m%dT%H%M%S')
|
|
236
|
+
dt = dt.replace(tzinfo=timezone.utc)
|
|
237
|
+
else:
|
|
238
|
+
dt = datetime.now(timezone.utc)
|
|
239
|
+
except (ValueError, TypeError):
|
|
240
|
+
dt = datetime.now(timezone.utc)
|
|
241
|
+
else:
|
|
242
|
+
dt = datetime.now(timezone.utc)
|
|
243
|
+
|
|
244
|
+
return (dt.strftime('%Y'), dt.strftime('%m'))
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def compute_s3_path(bucket, project_name, model_name, instance_type, deployment_target, timestamp):
|
|
248
|
+
"""Construct the full S3 URI for a benchmark run Parquet file.
|
|
249
|
+
|
|
250
|
+
Uses model/instance/target partitioning scheme.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
bucket: S3 bucket name.
|
|
254
|
+
project_name: MCC project name.
|
|
255
|
+
model_name: HuggingFace model ID.
|
|
256
|
+
instance_type: SageMaker instance type.
|
|
257
|
+
deployment_target: Deployment target (realtime-inference, etc.).
|
|
258
|
+
timestamp: datetime object for the run timestamp.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
str — full S3 URI.
|
|
262
|
+
"""
|
|
263
|
+
# Sanitize model name for S3 path (/ → _)
|
|
264
|
+
model_partition = model_name.replace('/', '_') if model_name else 'unknown'
|
|
265
|
+
instance_partition = instance_type or 'unknown'
|
|
266
|
+
target_partition = deployment_target or 'realtime-inference'
|
|
267
|
+
ts_str = timestamp.strftime('%Y%m%dT%H%M%SZ')
|
|
268
|
+
filename = f'run-{project_name}-{ts_str}.parquet'
|
|
269
|
+
|
|
270
|
+
return f's3://{bucket}/results/model={model_partition}/instance={instance_partition}/target={target_partition}/{filename}'
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def compute_partition_info(model_name, instance_type, deployment_target):
|
|
274
|
+
"""Compute partition metadata dict for model/instance/target scheme.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
model_name: HuggingFace model ID (e.g., 'Qwen/Qwen3-0.6B').
|
|
278
|
+
instance_type: SageMaker instance type (e.g., 'ml.g5.xlarge').
|
|
279
|
+
deployment_target: Deployment target (e.g., 'realtime-inference').
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
dict with keys: model, instance, target.
|
|
283
|
+
"""
|
|
284
|
+
return {
|
|
285
|
+
"model": model_name.replace('/', '_') if model_name else 'unknown',
|
|
286
|
+
"instance": instance_type or 'unknown',
|
|
287
|
+
"target": deployment_target or 'realtime-inference',
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def build_s3_path(bucket, project_name, model_name, instance_type, deployment_target, timestamp=None, region=''):
|
|
292
|
+
"""Construct the S3 path and partition info for a benchmark run.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
bucket: S3 bucket name.
|
|
296
|
+
region: AWS region string.
|
|
297
|
+
project_name: MCC project name.
|
|
298
|
+
timestamp: datetime object or None (defaults to now UTC).
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
dict with keys: s3_uri, partition_model, partition_instance, partition_target, filename.
|
|
302
|
+
"""
|
|
303
|
+
if timestamp is None:
|
|
304
|
+
timestamp = datetime.now(timezone.utc)
|
|
305
|
+
|
|
306
|
+
year = timestamp.strftime('%Y')
|
|
307
|
+
month = timestamp.strftime('%m')
|
|
308
|
+
ts_str = timestamp.strftime('%Y%m%dT%H%M%SZ')
|
|
309
|
+
model_partition = model_name.replace('/', '_') if model_name else 'unknown'
|
|
310
|
+
instance_partition = instance_type or 'unknown'
|
|
311
|
+
target_partition = deployment_target or 'realtime-inference'
|
|
312
|
+
filename = f'run-{project_name}-{ts_str}.parquet'
|
|
313
|
+
|
|
314
|
+
s3_uri = f's3://{bucket}/results/model={model_partition}/instance={instance_partition}/target={target_partition}/{filename}'
|
|
315
|
+
|
|
316
|
+
return {
|
|
317
|
+
's3_uri': s3_uri,
|
|
318
|
+
'partition_model': model_partition,
|
|
319
|
+
'partition_instance': instance_partition,
|
|
320
|
+
'partition_target': target_partition,
|
|
321
|
+
'filename': filename,
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _extract_base_image_version(base_image):
|
|
326
|
+
"""Extract version tag from a base image string.
|
|
327
|
+
|
|
328
|
+
Examples:
|
|
329
|
+
"vllm/vllm-openai:v0.8.5" → "v0.8.5"
|
|
330
|
+
"nvcr.io/nvidia/tritonserver:24.01-py3" → "24.01-py3"
|
|
331
|
+
"" → ""
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
str — extracted tag or empty string.
|
|
335
|
+
"""
|
|
336
|
+
if not base_image:
|
|
337
|
+
return ''
|
|
338
|
+
if ':' in base_image:
|
|
339
|
+
return base_image.split(':')[-1]
|
|
340
|
+
return ''
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def enrich_records(config, results, run_timestamp=None):
|
|
344
|
+
"""Build enriched records from config context and benchmark results.
|
|
345
|
+
|
|
346
|
+
Each metrics entry becomes one enriched record with all Athena columns populated.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
config: dict with config context fields (project_name, model_name, etc.)
|
|
350
|
+
results: dict with benchmark results (job_name, metrics array)
|
|
351
|
+
run_timestamp: Optional datetime for run_timestamp. Defaults to now UTC.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
list of enriched record dicts (one per concurrency level).
|
|
355
|
+
"""
|
|
356
|
+
if run_timestamp is None:
|
|
357
|
+
run_timestamp = datetime.now(timezone.utc)
|
|
358
|
+
|
|
359
|
+
model_name = config.get('model_name', '')
|
|
360
|
+
instance_type = config.get('instance_type', '')
|
|
361
|
+
project_name = config.get('project_name', '')
|
|
362
|
+
deployment_config = config.get('deployment_config', '')
|
|
363
|
+
region = config.get('region', '')
|
|
364
|
+
|
|
365
|
+
# Derived fields
|
|
366
|
+
model_family = derive_model_family(model_name)
|
|
367
|
+
|
|
368
|
+
# Optional context fields
|
|
369
|
+
deployment_target = config.get('deployment_target', 'realtime-inference')
|
|
370
|
+
tensor_parallel_degree = config.get('tensor_parallel_degree', 1)
|
|
371
|
+
quantization = config.get('quantization', 'none')
|
|
372
|
+
enable_lora = config.get('enable_lora', False)
|
|
373
|
+
base_image = config.get('base_image', '')
|
|
374
|
+
base_image_version = config.get('base_image_version', '') or _extract_base_image_version(base_image)
|
|
375
|
+
mcc_version = config.get('mcc_version', '')
|
|
376
|
+
run_type = config.get('run_type', 'ci')
|
|
377
|
+
ci_run_id = config.get('ci_run_id', '')
|
|
378
|
+
account_id = config.get('account_id', '')
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
# Get metrics from results
|
|
382
|
+
metrics = results.get('metrics', []) if isinstance(results, dict) else []
|
|
383
|
+
|
|
384
|
+
# Helper: unwrap aiperf metric dicts to scalar values
|
|
385
|
+
# Derived metrics: {'unit': 'requests/sec', 'avg': 9.57} → 9.57
|
|
386
|
+
# Record metrics: {'unit': 'ms', 'avg': 181.9, 'p50': 183.2, ...} → passed to .get('p50') etc.
|
|
387
|
+
def scalar(val, stat='avg'):
|
|
388
|
+
if isinstance(val, dict):
|
|
389
|
+
return val.get(stat, 0.0)
|
|
390
|
+
return val if val is not None else 0.0
|
|
391
|
+
|
|
392
|
+
records = []
|
|
393
|
+
for metric in metrics:
|
|
394
|
+
concurrency = scalar(metric.get('concurrency', 0))
|
|
395
|
+
throughput_rps = scalar(metric.get('request_throughput', 0.0))
|
|
396
|
+
tokens_per_second = scalar(metric.get('output_token_throughput', 0.0))
|
|
397
|
+
error_count = metric.get('error_count', 0)
|
|
398
|
+
total_requests = scalar(metric.get('total_requests', 0))
|
|
399
|
+
duration_seconds = scalar(metric.get('duration_seconds', 0), stat='avg')
|
|
400
|
+
input_tokens_mean = metric.get('input_tokens_mean', 0)
|
|
401
|
+
output_tokens_mean = metric.get('output_tokens_mean', 0)
|
|
402
|
+
|
|
403
|
+
# Latency percentiles
|
|
404
|
+
ttft = metric.get('time_to_first_token', {})
|
|
405
|
+
itl = metric.get('inter_token_latency', {})
|
|
406
|
+
|
|
407
|
+
# Error rate
|
|
408
|
+
error_rate = (error_count / total_requests) if total_requests > 0 else 0.0
|
|
409
|
+
|
|
410
|
+
# Status based on error rate
|
|
411
|
+
if error_rate >= 1.0:
|
|
412
|
+
status = 'failed'
|
|
413
|
+
else:
|
|
414
|
+
status = 'completed'
|
|
415
|
+
|
|
416
|
+
# Cost computation
|
|
417
|
+
cost = compute_cost_per_1m_tokens(instance_type, tokens_per_second)
|
|
418
|
+
|
|
419
|
+
# Build serving_config JSON blob from all available config params
|
|
420
|
+
serving_config_dict = {
|
|
421
|
+
k: v for k, v in {
|
|
422
|
+
'quantization': quantization,
|
|
423
|
+
'tensor_parallel_degree': tensor_parallel_degree,
|
|
424
|
+
'enable_lora': enable_lora,
|
|
425
|
+
'base_image': base_image,
|
|
426
|
+
'kv_cache_dtype': config.get('kv_cache_dtype', 'auto'),
|
|
427
|
+
'max_model_len': config.get('max_model_len', ''),
|
|
428
|
+
'vllm_version': config.get('vllm_version', ''),
|
|
429
|
+
'gpu_memory_utilization': config.get('gpu_memory_utilization', ''),
|
|
430
|
+
'ic_gpu_count': config.get('ic_gpu_count', ''),
|
|
431
|
+
'ic_copy_count': config.get('ic_copy_count', ''),
|
|
432
|
+
'adapter_name': config.get('adapter_name', ''),
|
|
433
|
+
}.items() if v not in ('', None)
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
# Extract richer latency metrics
|
|
437
|
+
e2e_latency = metric.get('e2e_latency', {})
|
|
438
|
+
prefill = metric.get('prefill_throughput', {})
|
|
439
|
+
output_tps = metric.get('output_token_throughput_detail', {})
|
|
440
|
+
|
|
441
|
+
record = {
|
|
442
|
+
'project_name': project_name,
|
|
443
|
+
'model_name': model_name,
|
|
444
|
+
'model_family': model_family,
|
|
445
|
+
'instance_type': instance_type,
|
|
446
|
+
'deployment_config': deployment_config,
|
|
447
|
+
'deployment_target': deployment_target,
|
|
448
|
+
'quantization': quantization,
|
|
449
|
+
'tensor_parallel_degree': tensor_parallel_degree,
|
|
450
|
+
'serving_config': json.dumps(serving_config_dict),
|
|
451
|
+
'workload': config.get('workload', 'manual'),
|
|
452
|
+
'concurrency': concurrency,
|
|
453
|
+
'input_tokens_mean': input_tokens_mean,
|
|
454
|
+
'output_tokens_mean': output_tokens_mean,
|
|
455
|
+
'streaming': config.get('streaming', True),
|
|
456
|
+
'duration_seconds': duration_seconds,
|
|
457
|
+
'request_throughput_rps': throughput_rps,
|
|
458
|
+
'total_token_throughput_tps': scalar(metric.get('total_token_throughput', 0.0)),
|
|
459
|
+
'output_token_throughput_tps': scalar(metric.get('output_token_throughput', 0.0)),
|
|
460
|
+
'request_count': scalar(metric.get('request_count', metric.get('total_requests', 0))),
|
|
461
|
+
'ttft_avg_ms': ttft.get('avg', 0.0),
|
|
462
|
+
'ttft_p50_ms': ttft.get('p50', 0.0),
|
|
463
|
+
'ttft_p90_ms': ttft.get('p90', 0.0),
|
|
464
|
+
'ttft_p99_ms': ttft.get('p99', 0.0),
|
|
465
|
+
'itl_avg_ms': itl.get('avg', 0.0),
|
|
466
|
+
'itl_p50_ms': itl.get('p50', 0.0),
|
|
467
|
+
'itl_p90_ms': itl.get('p90', 0.0),
|
|
468
|
+
'itl_p99_ms': itl.get('p99', 0.0),
|
|
469
|
+
'e2e_latency_avg_ms': e2e_latency.get('avg', 0.0),
|
|
470
|
+
'e2e_latency_p50_ms': e2e_latency.get('p50', 0.0),
|
|
471
|
+
'e2e_latency_p90_ms': e2e_latency.get('p90', 0.0),
|
|
472
|
+
'e2e_latency_p99_ms': e2e_latency.get('p99', 0.0),
|
|
473
|
+
'prefill_tps_avg': prefill.get('avg', 0.0),
|
|
474
|
+
'prefill_tps_p50': prefill.get('p50', 0.0),
|
|
475
|
+
'output_token_tps_avg': output_tps.get('avg', 0.0),
|
|
476
|
+
'output_token_tps_p50': output_tps.get('p50', 0.0),
|
|
477
|
+
'output_token_tps_p90': output_tps.get('p90', 0.0),
|
|
478
|
+
'ttst_p50_ms': metric.get('time_to_second_token', {}).get('p50', 0.0),
|
|
479
|
+
'ttst_p90_ms': metric.get('time_to_second_token', {}).get('p90', 0.0),
|
|
480
|
+
'output_sequence_length_avg': metric.get('output_sequence_length_avg', 0.0),
|
|
481
|
+
'output_sequence_length_avg': scalar(metric.get('output_sequence_length', metric.get('output_sequence_length_avg', 0.0))),
|
|
482
|
+
'input_sequence_length_avg': scalar(metric.get('input_sequence_length', metric.get('input_sequence_length_avg', 0.0))),
|
|
483
|
+
'error_rate': error_rate,
|
|
484
|
+
'benchmark_duration_sec': metric.get('benchmark_duration_sec', duration_seconds),
|
|
485
|
+
'run_type': run_type,
|
|
486
|
+
'benchmark_job_name': results.get('job_name', '') if isinstance(results, dict) else '',
|
|
487
|
+
'mcc_version': mcc_version,
|
|
488
|
+
'run_timestamp': run_timestamp.isoformat(),
|
|
489
|
+
'region': region,
|
|
490
|
+
}
|
|
491
|
+
records.append(record)
|
|
492
|
+
|
|
493
|
+
return records
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def validate_input(config, results):
|
|
497
|
+
"""Validate config context and results for completeness.
|
|
498
|
+
|
|
499
|
+
Two-argument interface: takes separate config and results dicts,
|
|
500
|
+
merges them, and delegates to validate_benchmark_input.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
config: dict with config context fields.
|
|
504
|
+
results: dict with benchmark results (must have 'metrics' key).
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
list of {"field": str, "reason": str} dicts for each validation failure.
|
|
508
|
+
Empty list means validation passed.
|
|
509
|
+
"""
|
|
510
|
+
merged = {}
|
|
511
|
+
if isinstance(config, dict):
|
|
512
|
+
merged.update(config)
|
|
513
|
+
if isinstance(results, dict):
|
|
514
|
+
metrics = results.get('metrics')
|
|
515
|
+
if metrics is not None:
|
|
516
|
+
merged['metrics'] = metrics
|
|
517
|
+
return validate_benchmark_input(merged)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
# ── Validation ────────────────────────────────────────────────────────────────
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def validate_benchmark_input(data):
|
|
524
|
+
"""Validate that all required fields are present and valid.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
data: dict containing the merged benchmark input (config context + results).
|
|
528
|
+
If data is not a dict, returns a single root-level error.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
list of {"field": str, "reason": str} dicts for each validation failure.
|
|
532
|
+
Empty list means validation passed.
|
|
533
|
+
"""
|
|
534
|
+
# Guard against non-dict input
|
|
535
|
+
if not isinstance(data, dict):
|
|
536
|
+
return [{"field": "_root", "reason": "input must be a JSON object"}]
|
|
537
|
+
|
|
538
|
+
errors = []
|
|
539
|
+
|
|
540
|
+
for field in REQUIRED_FIELDS:
|
|
541
|
+
value = data.get(field)
|
|
542
|
+
|
|
543
|
+
if field == 'metrics':
|
|
544
|
+
# metrics must be a non-empty list of objects
|
|
545
|
+
if value is None:
|
|
546
|
+
errors.append({
|
|
547
|
+
"field": field,
|
|
548
|
+
"reason": "required field is missing"
|
|
549
|
+
})
|
|
550
|
+
elif not isinstance(value, list) or len(value) == 0:
|
|
551
|
+
errors.append({
|
|
552
|
+
"field": field,
|
|
553
|
+
"reason": "must be a non-empty array"
|
|
554
|
+
})
|
|
555
|
+
else:
|
|
556
|
+
# Validate each metrics entry
|
|
557
|
+
for i, entry in enumerate(value):
|
|
558
|
+
if not isinstance(entry, dict):
|
|
559
|
+
errors.append({
|
|
560
|
+
"field": f"metrics[{i}]",
|
|
561
|
+
"reason": "each metrics entry must be an object"
|
|
562
|
+
})
|
|
563
|
+
continue
|
|
564
|
+
# Each metrics entry must have concurrency as an integer
|
|
565
|
+
conc = entry.get('concurrency')
|
|
566
|
+
if conc is None:
|
|
567
|
+
errors.append({
|
|
568
|
+
"field": f"metrics[{i}].concurrency",
|
|
569
|
+
"reason": "required field is missing"
|
|
570
|
+
})
|
|
571
|
+
elif not isinstance(conc, int) or isinstance(conc, bool):
|
|
572
|
+
errors.append({
|
|
573
|
+
"field": f"metrics[{i}].concurrency",
|
|
574
|
+
"reason": "must be an integer"
|
|
575
|
+
})
|
|
576
|
+
elif field == 'instance_type':
|
|
577
|
+
# instance_type must be a non-empty string matching ml.* pattern
|
|
578
|
+
if value is None:
|
|
579
|
+
errors.append({
|
|
580
|
+
"field": field,
|
|
581
|
+
"reason": "required field is missing"
|
|
582
|
+
})
|
|
583
|
+
elif not isinstance(value, str):
|
|
584
|
+
errors.append({
|
|
585
|
+
"field": field,
|
|
586
|
+
"reason": "must be a non-empty string"
|
|
587
|
+
})
|
|
588
|
+
elif value.strip() == '':
|
|
589
|
+
errors.append({
|
|
590
|
+
"field": field,
|
|
591
|
+
"reason": "must be a non-empty string"
|
|
592
|
+
})
|
|
593
|
+
elif not _INSTANCE_TYPE_RE.match(value):
|
|
594
|
+
errors.append({
|
|
595
|
+
"field": field,
|
|
596
|
+
"reason": "must match ml.* pattern (e.g., ml.g5.xlarge)"
|
|
597
|
+
})
|
|
598
|
+
else:
|
|
599
|
+
# String fields must be present and non-empty
|
|
600
|
+
if value is None:
|
|
601
|
+
errors.append({
|
|
602
|
+
"field": field,
|
|
603
|
+
"reason": "required field is missing"
|
|
604
|
+
})
|
|
605
|
+
elif not isinstance(value, str):
|
|
606
|
+
errors.append({
|
|
607
|
+
"field": field,
|
|
608
|
+
"reason": "must be a non-empty string"
|
|
609
|
+
})
|
|
610
|
+
elif value.strip() == '':
|
|
611
|
+
errors.append({
|
|
612
|
+
"field": field,
|
|
613
|
+
"reason": "must be a non-empty string"
|
|
614
|
+
})
|
|
615
|
+
|
|
616
|
+
return errors
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def emit_validation_error(errors):
|
|
620
|
+
"""Output structured validation error JSON and exit with code 1.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
errors: list of {"field": str, "reason": str} dicts.
|
|
624
|
+
|
|
625
|
+
Output format:
|
|
626
|
+
{"error": true, "validation_errors": [...]}
|
|
627
|
+
|
|
628
|
+
Exits with code 1 — does NOT write to S3.
|
|
629
|
+
"""
|
|
630
|
+
output = {
|
|
631
|
+
"error": True,
|
|
632
|
+
"validation_errors": errors
|
|
633
|
+
}
|
|
634
|
+
print(json.dumps(output))
|
|
635
|
+
sys.exit(1)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
# ── Partition Registration ────────────────────────────────────────────────────
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def register_partition(bucket, model, instance, target,
|
|
642
|
+
glue_database='mlcc_ci', glue_table='benchmark_results',
|
|
643
|
+
glue_client=None, region='us-east-1'):
|
|
644
|
+
"""Register a partition in the Glue catalog via BatchCreatePartition.
|
|
645
|
+
|
|
646
|
+
After writing Parquet to S3, this function ensures the partition is
|
|
647
|
+
registered in the Glue Data Catalog so the data is immediately
|
|
648
|
+
queryable via Athena. If the partition already exists, the error is
|
|
649
|
+
swallowed silently (idempotent behavior).
|
|
650
|
+
|
|
651
|
+
Uses model/instance/target partitioning scheme matching the S3 data layout.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
bucket: S3 bucket name.
|
|
655
|
+
model: Model partition value (model name with / replaced by _, e.g., 'Qwen_Qwen3-0.6B').
|
|
656
|
+
instance: Instance partition value (e.g., 'ml.g5.xlarge').
|
|
657
|
+
target: Deployment target partition value (e.g., 'realtime-inference').
|
|
658
|
+
glue_database: Glue database name (default: mlcc_ci).
|
|
659
|
+
glue_table: Glue table name (default: benchmark_results).
|
|
660
|
+
glue_client: Optional pre-configured boto3 Glue client (for testing).
|
|
661
|
+
If None, a new client is created for the given region.
|
|
662
|
+
region: AWS region for the Glue client (default: us-east-1).
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
dict with keys:
|
|
666
|
+
- registered (bool): True if partition was newly created
|
|
667
|
+
- already_exists (bool): True if partition already existed
|
|
668
|
+
- partition_values (list): [model, instance, target]
|
|
669
|
+
- location (str): S3 location for the partition
|
|
670
|
+
- error (str|None): Error message if registration failed for
|
|
671
|
+
a reason other than already-exists
|
|
672
|
+
|
|
673
|
+
Note:
|
|
674
|
+
Per the design doc error handling table, partition registration
|
|
675
|
+
failure is non-fatal — results are still readable via MSCK REPAIR TABLE.
|
|
676
|
+
The caller should log a warning on error, not crash.
|
|
677
|
+
"""
|
|
678
|
+
import boto3
|
|
679
|
+
|
|
680
|
+
if glue_client is None:
|
|
681
|
+
glue_client = boto3.client('glue', region_name=region)
|
|
682
|
+
|
|
683
|
+
partition_values = [model, instance, target]
|
|
684
|
+
location = f's3://{bucket}/results/model={model}/instance={instance}/target={target}/'
|
|
685
|
+
|
|
686
|
+
# Get table StorageDescriptor to inherit columns/serde
|
|
687
|
+
try:
|
|
688
|
+
table_response = glue_client.get_table(
|
|
689
|
+
DatabaseName=glue_database,
|
|
690
|
+
Name=glue_table,
|
|
691
|
+
)
|
|
692
|
+
except Exception as e:
|
|
693
|
+
error_msg = str(e)
|
|
694
|
+
if 'EntityNotFoundException' in error_msg:
|
|
695
|
+
return {
|
|
696
|
+
'registered': False,
|
|
697
|
+
'already_exists': False,
|
|
698
|
+
'partition_values': partition_values,
|
|
699
|
+
'location': location,
|
|
700
|
+
'error': f"Table {glue_database}.{glue_table} not found in Glue catalog",
|
|
701
|
+
}
|
|
702
|
+
return {
|
|
703
|
+
'registered': False,
|
|
704
|
+
'already_exists': False,
|
|
705
|
+
'partition_values': partition_values,
|
|
706
|
+
'location': location,
|
|
707
|
+
'error': f"Failed to get table metadata: {error_msg}",
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
table_sd = table_response['Table']['StorageDescriptor']
|
|
711
|
+
|
|
712
|
+
# Build partition StorageDescriptor inheriting from table
|
|
713
|
+
partition_sd = {
|
|
714
|
+
'Columns': table_sd['Columns'],
|
|
715
|
+
'Location': location,
|
|
716
|
+
'InputFormat': table_sd.get('InputFormat', 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'),
|
|
717
|
+
'OutputFormat': table_sd.get('OutputFormat', 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'),
|
|
718
|
+
'SerdeInfo': table_sd.get('SerdeInfo', {
|
|
719
|
+
'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe',
|
|
720
|
+
'Parameters': {'serialization.format': '1'},
|
|
721
|
+
}),
|
|
722
|
+
'Compressed': table_sd.get('Compressed', True),
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
partition_input = {
|
|
726
|
+
'Values': partition_values,
|
|
727
|
+
'StorageDescriptor': partition_sd,
|
|
728
|
+
'Parameters': {
|
|
729
|
+
'classification': 'parquet',
|
|
730
|
+
'parquet.compression': 'SNAPPY',
|
|
731
|
+
},
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
try:
|
|
735
|
+
response = glue_client.batch_create_partition(
|
|
736
|
+
DatabaseName=glue_database,
|
|
737
|
+
TableName=glue_table,
|
|
738
|
+
PartitionInputList=[partition_input],
|
|
739
|
+
)
|
|
740
|
+
except Exception as e:
|
|
741
|
+
# Handle AlreadyExistsException thrown as an API exception
|
|
742
|
+
if 'AlreadyExistsException' in str(e):
|
|
743
|
+
return {
|
|
744
|
+
'registered': False,
|
|
745
|
+
'already_exists': True,
|
|
746
|
+
'partition_values': partition_values,
|
|
747
|
+
'location': location,
|
|
748
|
+
'error': None,
|
|
749
|
+
}
|
|
750
|
+
return {
|
|
751
|
+
'registered': False,
|
|
752
|
+
'already_exists': False,
|
|
753
|
+
'partition_values': partition_values,
|
|
754
|
+
'location': location,
|
|
755
|
+
'error': f"Failed to register partition: {e}",
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
# Check for errors in the batch response
|
|
759
|
+
batch_errors = response.get('Errors', [])
|
|
760
|
+
if batch_errors:
|
|
761
|
+
error_detail = batch_errors[0].get('ErrorDetail', {})
|
|
762
|
+
error_code = error_detail.get('ErrorCode', '')
|
|
763
|
+
|
|
764
|
+
if error_code == 'AlreadyExistsException':
|
|
765
|
+
return {
|
|
766
|
+
'registered': False,
|
|
767
|
+
'already_exists': True,
|
|
768
|
+
'partition_values': partition_values,
|
|
769
|
+
'location': location,
|
|
770
|
+
'error': None,
|
|
771
|
+
}
|
|
772
|
+
else:
|
|
773
|
+
error_msg = error_detail.get('ErrorMessage', 'unknown error')
|
|
774
|
+
return {
|
|
775
|
+
'registered': False,
|
|
776
|
+
'already_exists': False,
|
|
777
|
+
'partition_values': partition_values,
|
|
778
|
+
'location': location,
|
|
779
|
+
'error': f"Partition registration failed: {error_code} — {error_msg}",
|
|
780
|
+
}
|
|
781
|
+
|
|
782
|
+
return {
|
|
783
|
+
'registered': True,
|
|
784
|
+
'already_exists': False,
|
|
785
|
+
'partition_values': partition_values,
|
|
786
|
+
'location': location,
|
|
787
|
+
'error': None,
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
# ── Parquet Serialization ─────────────────────────────────────────────────────
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def get_parquet_schema():
|
|
795
|
+
"""Return the pyarrow schema matching the Athena DDL for benchmark_results.
|
|
796
|
+
|
|
797
|
+
All columns defined in the Athena DDL are included. Partition columns
|
|
798
|
+
(model, instance, target) are NOT included here — they are encoded in the
|
|
799
|
+
S3 path and handled by Glue/Athena partitioning.
|
|
800
|
+
"""
|
|
801
|
+
import pyarrow as pa
|
|
802
|
+
|
|
803
|
+
return pa.schema([
|
|
804
|
+
# Identity
|
|
805
|
+
pa.field("project_name", pa.string()),
|
|
806
|
+
|
|
807
|
+
# Model + Serving Config (queryable columns)
|
|
808
|
+
pa.field("model_name", pa.string()),
|
|
809
|
+
pa.field("model_family", pa.string()),
|
|
810
|
+
pa.field("instance_type", pa.string()),
|
|
811
|
+
pa.field("deployment_config", pa.string()),
|
|
812
|
+
pa.field("deployment_target", pa.string()),
|
|
813
|
+
pa.field("quantization", pa.string()),
|
|
814
|
+
pa.field("tensor_parallel_degree", pa.int32()),
|
|
815
|
+
|
|
816
|
+
# Full serving config (extensible JSON blob)
|
|
817
|
+
pa.field("serving_config", pa.string()),
|
|
818
|
+
|
|
819
|
+
# Workload
|
|
820
|
+
pa.field("workload", pa.string()),
|
|
821
|
+
pa.field("concurrency", pa.int32()),
|
|
822
|
+
pa.field("input_tokens_mean", pa.int32()),
|
|
823
|
+
pa.field("output_tokens_mean", pa.int32()),
|
|
824
|
+
pa.field("streaming", pa.bool_()),
|
|
825
|
+
pa.field("duration_seconds", pa.int32()),
|
|
826
|
+
|
|
827
|
+
# Rich Metrics
|
|
828
|
+
pa.field("request_throughput_rps", pa.float64()),
|
|
829
|
+
pa.field("total_token_throughput_tps", pa.float64()),
|
|
830
|
+
pa.field("output_token_throughput_tps", pa.float64()),
|
|
831
|
+
pa.field("request_count", pa.float64()),
|
|
832
|
+
pa.field("ttft_avg_ms", pa.float64()),
|
|
833
|
+
pa.field("ttft_p50_ms", pa.float64()),
|
|
834
|
+
pa.field("ttft_p90_ms", pa.float64()),
|
|
835
|
+
pa.field("ttft_p99_ms", pa.float64()),
|
|
836
|
+
pa.field("itl_avg_ms", pa.float64()),
|
|
837
|
+
pa.field("itl_p50_ms", pa.float64()),
|
|
838
|
+
pa.field("itl_p90_ms", pa.float64()),
|
|
839
|
+
pa.field("itl_p99_ms", pa.float64()),
|
|
840
|
+
pa.field("e2e_latency_avg_ms", pa.float64()),
|
|
841
|
+
pa.field("e2e_latency_p50_ms", pa.float64()),
|
|
842
|
+
pa.field("e2e_latency_p90_ms", pa.float64()),
|
|
843
|
+
pa.field("e2e_latency_p99_ms", pa.float64()),
|
|
844
|
+
pa.field("prefill_tps_avg", pa.float64()),
|
|
845
|
+
pa.field("prefill_tps_p50", pa.float64()),
|
|
846
|
+
pa.field("output_token_tps_avg", pa.float64()),
|
|
847
|
+
pa.field("output_token_tps_p50", pa.float64()),
|
|
848
|
+
pa.field("output_token_tps_p90", pa.float64()),
|
|
849
|
+
pa.field("ttst_p50_ms", pa.float64()),
|
|
850
|
+
pa.field("ttst_p90_ms", pa.float64()),
|
|
851
|
+
pa.field("output_sequence_length_avg", pa.float64()),
|
|
852
|
+
pa.field("input_sequence_length_avg", pa.float64()),
|
|
853
|
+
pa.field("error_rate", pa.float64()),
|
|
854
|
+
pa.field("benchmark_duration_sec", pa.float64()),
|
|
855
|
+
|
|
856
|
+
# Run Metadata
|
|
857
|
+
pa.field("run_type", pa.string()),
|
|
858
|
+
pa.field("benchmark_job_name", pa.string()),
|
|
859
|
+
pa.field("mcc_version", pa.string()),
|
|
860
|
+
pa.field("run_timestamp", pa.string()),
|
|
861
|
+
pa.field("region", pa.string()),
|
|
862
|
+
])
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def _records_to_parquet_table(records):
|
|
866
|
+
"""Convert a list of enriched record dicts to a pyarrow Table.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
records: List of dicts from enrich_records(). Each dict has string keys
|
|
870
|
+
matching the Athena DDL column names.
|
|
871
|
+
|
|
872
|
+
Returns:
|
|
873
|
+
pyarrow.Table with the correct schema and Snappy-compatible types.
|
|
874
|
+
"""
|
|
875
|
+
import pyarrow as pa
|
|
876
|
+
from datetime import datetime as dt
|
|
877
|
+
|
|
878
|
+
schema = get_parquet_schema()
|
|
879
|
+
|
|
880
|
+
# Build column arrays from the record dicts
|
|
881
|
+
arrays = []
|
|
882
|
+
for field in schema:
|
|
883
|
+
col_name = field.name
|
|
884
|
+
values = []
|
|
885
|
+
for record in records:
|
|
886
|
+
val = record.get(col_name)
|
|
887
|
+
|
|
888
|
+
# Handle run_timestamp: ensure it's a string (schema is pa.string())
|
|
889
|
+
if col_name == 'run_timestamp' and isinstance(val, dt):
|
|
890
|
+
val = val.isoformat()
|
|
891
|
+
elif col_name == 'run_timestamp' and val is None:
|
|
892
|
+
val = None
|
|
893
|
+
|
|
894
|
+
values.append(val)
|
|
895
|
+
|
|
896
|
+
arrays.append(pa.array(values, type=field.type))
|
|
897
|
+
|
|
898
|
+
return pa.table(arrays, schema=schema)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def _upload_to_s3(local_path, bucket, s3_uri, region):
|
|
902
|
+
"""Upload a local file to S3.
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
local_path: Path to the local Parquet file.
|
|
906
|
+
bucket: S3 bucket name.
|
|
907
|
+
s3_uri: Full S3 URI (s3://bucket/key).
|
|
908
|
+
region: AWS region for the S3 client.
|
|
909
|
+
"""
|
|
910
|
+
import boto3
|
|
911
|
+
|
|
912
|
+
# Extract key from s3_uri
|
|
913
|
+
# s3://bucket/key → key
|
|
914
|
+
s3_key = s3_uri.replace(f's3://{bucket}/', '', 1)
|
|
915
|
+
|
|
916
|
+
s3_client = boto3.client('s3', region_name=region)
|
|
917
|
+
s3_client.upload_file(local_path, bucket, s3_key)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def _parse_jsonl_to_metrics(jsonl_path, concurrency=None):
|
|
923
|
+
"""Parse profile_export.jsonl and aggregate into metrics format.
|
|
924
|
+
|
|
925
|
+
The JSONL file contains one JSON object per request with:
|
|
926
|
+
- metadata: {session_num, request_start_ns, request_end_ns, ...}
|
|
927
|
+
- metrics: {request_latency: {value, unit}, time_to_first_token: {value, unit}, ...}
|
|
928
|
+
|
|
929
|
+
Returns a dict compatible with the existing validation/enrichment pipeline:
|
|
930
|
+
{
|
|
931
|
+
"metrics": [{concurrency, request_throughput, time_to_first_token: {avg,p50,p90,p99}, ...}]
|
|
932
|
+
}
|
|
933
|
+
"""
|
|
934
|
+
import math
|
|
935
|
+
|
|
936
|
+
def _percentile(sorted_vals, pct):
|
|
937
|
+
if not sorted_vals:
|
|
938
|
+
return 0.0
|
|
939
|
+
idx = (pct / 100.0) * (len(sorted_vals) - 1)
|
|
940
|
+
lower = int(math.floor(idx))
|
|
941
|
+
upper = int(math.ceil(idx))
|
|
942
|
+
if lower == upper:
|
|
943
|
+
return sorted_vals[lower]
|
|
944
|
+
frac = idx - lower
|
|
945
|
+
return sorted_vals[lower] * (1 - frac) + sorted_vals[upper] * frac
|
|
946
|
+
|
|
947
|
+
def _get_val(metrics_dict, key):
|
|
948
|
+
"""Extract scalar value from a metric dict like {value: X, unit: "ms"}."""
|
|
949
|
+
m = metrics_dict.get(key)
|
|
950
|
+
if isinstance(m, dict):
|
|
951
|
+
return m.get('value')
|
|
952
|
+
return m
|
|
953
|
+
|
|
954
|
+
records = []
|
|
955
|
+
try:
|
|
956
|
+
with open(jsonl_path, 'r') as f:
|
|
957
|
+
for line in f:
|
|
958
|
+
line = line.strip()
|
|
959
|
+
if line:
|
|
960
|
+
try:
|
|
961
|
+
records.append(json.loads(line))
|
|
962
|
+
except json.JSONDecodeError:
|
|
963
|
+
continue
|
|
964
|
+
except (FileNotFoundError, IOError) as e:
|
|
965
|
+
return {"error": str(e)}
|
|
966
|
+
|
|
967
|
+
if not records:
|
|
968
|
+
return {"metrics": []}
|
|
969
|
+
|
|
970
|
+
# Collect per-request metrics
|
|
971
|
+
latencies = []
|
|
972
|
+
ttfts = []
|
|
973
|
+
itls = []
|
|
974
|
+
ttsts = []
|
|
975
|
+
output_tokens = []
|
|
976
|
+
input_tokens = []
|
|
977
|
+
prefill_tps = []
|
|
978
|
+
output_tps = []
|
|
979
|
+
start_times = []
|
|
980
|
+
end_times = []
|
|
981
|
+
|
|
982
|
+
for rec in records:
|
|
983
|
+
meta = rec.get('metadata', {})
|
|
984
|
+
metrics = rec.get('metrics', {})
|
|
985
|
+
|
|
986
|
+
lat = _get_val(metrics, 'request_latency')
|
|
987
|
+
if lat is not None:
|
|
988
|
+
latencies.append(lat)
|
|
989
|
+
|
|
990
|
+
ttft = _get_val(metrics, 'time_to_first_token')
|
|
991
|
+
if ttft is None:
|
|
992
|
+
ttft = _get_val(metrics, 'time_to_first_output_token')
|
|
993
|
+
if ttft is not None:
|
|
994
|
+
ttfts.append(ttft)
|
|
995
|
+
|
|
996
|
+
itl = _get_val(metrics, 'inter_token_latency')
|
|
997
|
+
if itl is not None:
|
|
998
|
+
itls.append(itl)
|
|
999
|
+
|
|
1000
|
+
ttst = _get_val(metrics, 'time_to_second_token')
|
|
1001
|
+
if ttst is not None:
|
|
1002
|
+
ttsts.append(ttst)
|
|
1003
|
+
|
|
1004
|
+
otc = _get_val(metrics, 'output_token_count')
|
|
1005
|
+
if otc is not None:
|
|
1006
|
+
output_tokens.append(otc)
|
|
1007
|
+
|
|
1008
|
+
isl = _get_val(metrics, 'input_sequence_length')
|
|
1009
|
+
if isl is not None:
|
|
1010
|
+
input_tokens.append(isl)
|
|
1011
|
+
|
|
1012
|
+
ptps = _get_val(metrics, 'prefill_throughput_per_user')
|
|
1013
|
+
if ptps is not None:
|
|
1014
|
+
prefill_tps.append(ptps)
|
|
1015
|
+
|
|
1016
|
+
otps = _get_val(metrics, 'output_token_throughput_per_user')
|
|
1017
|
+
if otps is not None:
|
|
1018
|
+
output_tps.append(otps)
|
|
1019
|
+
|
|
1020
|
+
rs = meta.get('request_start_ns')
|
|
1021
|
+
re_ = meta.get('request_end_ns')
|
|
1022
|
+
if rs is not None:
|
|
1023
|
+
start_times.append(rs)
|
|
1024
|
+
if re_ is not None:
|
|
1025
|
+
end_times.append(re_)
|
|
1026
|
+
|
|
1027
|
+
# Sort for percentiles
|
|
1028
|
+
latencies.sort()
|
|
1029
|
+
ttfts.sort()
|
|
1030
|
+
itls.sort()
|
|
1031
|
+
ttsts.sort()
|
|
1032
|
+
prefill_tps.sort()
|
|
1033
|
+
output_tps.sort()
|
|
1034
|
+
|
|
1035
|
+
# Compute system throughput
|
|
1036
|
+
if start_times and end_times:
|
|
1037
|
+
duration_ns = max(end_times) - min(start_times)
|
|
1038
|
+
duration_s = duration_ns / 1e9 if duration_ns > 0 else 1.0
|
|
1039
|
+
else:
|
|
1040
|
+
duration_s = 1.0
|
|
1041
|
+
duration_s = max(duration_s, 0.001)
|
|
1042
|
+
|
|
1043
|
+
n = len(records)
|
|
1044
|
+
req_throughput = n / duration_s
|
|
1045
|
+
total_out_tokens = sum(output_tokens) if output_tokens else 0
|
|
1046
|
+
token_throughput = total_out_tokens / duration_s
|
|
1047
|
+
|
|
1048
|
+
# Determine concurrency (from arg or infer from max concurrent)
|
|
1049
|
+
conc = concurrency if concurrency is not None else n
|
|
1050
|
+
|
|
1051
|
+
# Build metrics entry matching the schema expected by enrich_records
|
|
1052
|
+
entry = {
|
|
1053
|
+
'concurrency': conc,
|
|
1054
|
+
'request_throughput': req_throughput,
|
|
1055
|
+
'output_token_throughput': token_throughput,
|
|
1056
|
+
'total_token_throughput': (total_out_tokens + sum(input_tokens)) / duration_s if input_tokens else token_throughput,
|
|
1057
|
+
'total_requests': n,
|
|
1058
|
+
'request_count': n,
|
|
1059
|
+
'duration_seconds': duration_s,
|
|
1060
|
+
'time_to_first_token': {
|
|
1061
|
+
'avg': sum(ttfts) / len(ttfts) if ttfts else 0.0,
|
|
1062
|
+
'p50': _percentile(ttfts, 50),
|
|
1063
|
+
'p90': _percentile(ttfts, 90),
|
|
1064
|
+
'p99': _percentile(ttfts, 99),
|
|
1065
|
+
},
|
|
1066
|
+
'inter_token_latency': {
|
|
1067
|
+
'avg': sum(itls) / len(itls) if itls else 0.0,
|
|
1068
|
+
'p50': _percentile(itls, 50),
|
|
1069
|
+
'p90': _percentile(itls, 90),
|
|
1070
|
+
'p99': _percentile(itls, 99),
|
|
1071
|
+
},
|
|
1072
|
+
'e2e_latency': {
|
|
1073
|
+
'avg': sum(latencies) / len(latencies) if latencies else 0.0,
|
|
1074
|
+
'p50': _percentile(latencies, 50),
|
|
1075
|
+
'p90': _percentile(latencies, 90),
|
|
1076
|
+
'p99': _percentile(latencies, 99),
|
|
1077
|
+
},
|
|
1078
|
+
'request_latency': {
|
|
1079
|
+
'avg': sum(latencies) / len(latencies) if latencies else 0.0,
|
|
1080
|
+
'p50': _percentile(latencies, 50),
|
|
1081
|
+
'p90': _percentile(latencies, 90),
|
|
1082
|
+
'p99': _percentile(latencies, 99),
|
|
1083
|
+
},
|
|
1084
|
+
'time_to_second_token': {
|
|
1085
|
+
'avg': sum(ttsts) / len(ttsts) if ttsts else 0.0,
|
|
1086
|
+
'p50': _percentile(ttsts, 50),
|
|
1087
|
+
'p90': _percentile(ttsts, 90),
|
|
1088
|
+
},
|
|
1089
|
+
'prefill_throughput': {
|
|
1090
|
+
'avg': sum(prefill_tps) / len(prefill_tps) if prefill_tps else 0.0,
|
|
1091
|
+
'p50': _percentile(prefill_tps, 50),
|
|
1092
|
+
},
|
|
1093
|
+
'output_token_throughput_detail': {
|
|
1094
|
+
'avg': sum(output_tps) / len(output_tps) if output_tps else 0.0,
|
|
1095
|
+
'p50': _percentile(output_tps, 50),
|
|
1096
|
+
'p90': _percentile(output_tps, 90),
|
|
1097
|
+
},
|
|
1098
|
+
'output_sequence_length': sum(output_tokens) / len(output_tokens) if output_tokens else 0.0,
|
|
1099
|
+
'input_sequence_length': sum(input_tokens) / len(input_tokens) if input_tokens else 0.0,
|
|
1100
|
+
'input_tokens_mean': int(sum(input_tokens) / len(input_tokens)) if input_tokens else 0,
|
|
1101
|
+
'output_tokens_mean': int(sum(output_tokens) / len(output_tokens)) if output_tokens else 0,
|
|
1102
|
+
}
|
|
1103
|
+
|
|
1104
|
+
return {"metrics": [entry]}
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
# ── Command: write ────────────────────────────────────────────────────────────
|
|
1108
|
+
|
|
1109
|
+
|
|
1110
|
+
def cmd_write(args):
|
|
1111
|
+
"""Validate, enrich, and write benchmark results to S3 as Parquet.
|
|
1112
|
+
|
|
1113
|
+
Validation occurs before any S3 interaction. If validation fails,
|
|
1114
|
+
a structured error is emitted and no write occurs.
|
|
1115
|
+
"""
|
|
1116
|
+
# Load benchmark results (JSON or JSONL)
|
|
1117
|
+
results_path = args.results_file or args.input
|
|
1118
|
+
if not results_path:
|
|
1119
|
+
_error_exit("--results-file (or --input) is required")
|
|
1120
|
+
|
|
1121
|
+
if results_path.endswith('.jsonl'):
|
|
1122
|
+
# Parse JSONL (per-request data) and aggregate into metrics format
|
|
1123
|
+
benchmark_data = _parse_jsonl_to_metrics(results_path, concurrency=getattr(args, 'concurrency', None))
|
|
1124
|
+
if 'error' in benchmark_data:
|
|
1125
|
+
_error_exit(f"Failed to parse JSONL: {benchmark_data['error']}")
|
|
1126
|
+
else:
|
|
1127
|
+
try:
|
|
1128
|
+
with open(results_path, 'r') as f:
|
|
1129
|
+
benchmark_data = json.load(f)
|
|
1130
|
+
except FileNotFoundError:
|
|
1131
|
+
_error_exit(f"Results file not found: {results_path}")
|
|
1132
|
+
except json.JSONDecodeError as e:
|
|
1133
|
+
_error_exit(f"Invalid JSON in results file: {e}")
|
|
1134
|
+
except Exception as e:
|
|
1135
|
+
_error_exit(f"Failed to read results file: {e}")
|
|
1136
|
+
|
|
1137
|
+
# Build the combined input data for validation
|
|
1138
|
+
# Merge CLI-provided fields with the benchmark results
|
|
1139
|
+
input_data = {}
|
|
1140
|
+
|
|
1141
|
+
# Fields from config file (if provided)
|
|
1142
|
+
if args.config_file:
|
|
1143
|
+
try:
|
|
1144
|
+
config_context = _load_config_file(args.config_file)
|
|
1145
|
+
input_data.update(config_context)
|
|
1146
|
+
except Exception as e:
|
|
1147
|
+
_error_exit(f"Failed to read config file: {e}")
|
|
1148
|
+
|
|
1149
|
+
# Fields from the benchmark results file
|
|
1150
|
+
if isinstance(benchmark_data, dict):
|
|
1151
|
+
metrics = benchmark_data.get('metrics')
|
|
1152
|
+
if metrics is not None:
|
|
1153
|
+
input_data['metrics'] = metrics
|
|
1154
|
+
else:
|
|
1155
|
+
# Single-level benchmark: raw results at top level without a 'metrics' wrapper.
|
|
1156
|
+
# Wrap into the expected array format for validation and enrichment.
|
|
1157
|
+
# Detect by presence of known metric fields (request_throughput, output_token_throughput, etc.)
|
|
1158
|
+
metric_indicators = ['request_throughput', 'output_token_throughput', 'time_to_first_token',
|
|
1159
|
+
'inter_token_latency', 'request_latency', 'concurrency']
|
|
1160
|
+
if any(k in benchmark_data for k in metric_indicators):
|
|
1161
|
+
# Use BENCHMARK_CONCURRENCY from config if concurrency not in the results
|
|
1162
|
+
if 'concurrency' not in benchmark_data:
|
|
1163
|
+
benchmark_data['concurrency'] = int(input_data.get('benchmark_concurrency', 10))
|
|
1164
|
+
input_data['metrics'] = [benchmark_data]
|
|
1165
|
+
# Also pull any config fields from the results file
|
|
1166
|
+
for field in ['model_name', 'instance_type', 'deployment_config', 'project_name', 'region']:
|
|
1167
|
+
if field in benchmark_data and field not in input_data:
|
|
1168
|
+
input_data[field] = benchmark_data[field]
|
|
1169
|
+
elif isinstance(benchmark_data, list):
|
|
1170
|
+
# If the results file is just a raw metrics array
|
|
1171
|
+
input_data['metrics'] = benchmark_data
|
|
1172
|
+
|
|
1173
|
+
# CLI args override config file and results file values
|
|
1174
|
+
if args.project_name:
|
|
1175
|
+
input_data['project_name'] = args.project_name
|
|
1176
|
+
if args.workload:
|
|
1177
|
+
input_data['workload'] = args.workload
|
|
1178
|
+
if args.region:
|
|
1179
|
+
input_data['region'] = args.region
|
|
1180
|
+
|
|
1181
|
+
# ── Validate before any S3 interaction ────────────────────────────────
|
|
1182
|
+
errors = validate_benchmark_input(input_data)
|
|
1183
|
+
if errors:
|
|
1184
|
+
emit_validation_error(errors)
|
|
1185
|
+
return # Never reached, but explicit
|
|
1186
|
+
|
|
1187
|
+
# ── Dry-run mode: output enriched records as JSON, skip S3 ──────────────
|
|
1188
|
+
if args.dry_run:
|
|
1189
|
+
timestamp = datetime.now(timezone.utc)
|
|
1190
|
+
|
|
1191
|
+
# Split input_data back into config and results for enrich_records
|
|
1192
|
+
config_context = {k: v for k, v in input_data.items() if k != 'metrics'}
|
|
1193
|
+
results_obj = {'metrics': input_data['metrics']}
|
|
1194
|
+
if isinstance(benchmark_data, dict) and 'job_name' in benchmark_data:
|
|
1195
|
+
results_obj['job_name'] = benchmark_data['job_name']
|
|
1196
|
+
|
|
1197
|
+
enriched_records = enrich_records(config_context, results_obj, timestamp)
|
|
1198
|
+
|
|
1199
|
+
# Compute intended S3 path (use bucket if provided, else placeholder)
|
|
1200
|
+
bucket = args.bucket or f'mlcc-benchmark-results-<accountId>-{input_data["region"]}'
|
|
1201
|
+
s3_path = compute_s3_path(bucket, input_data.get('project_name', ''), input_data.get('model_name', ''), input_data.get('instance_type', ''), input_data.get('deployment_target', 'realtime-inference'), timestamp)
|
|
1202
|
+
partition = compute_partition_info(input_data.get('model_name', ''), input_data.get('instance_type', ''), input_data.get('deployment_target', 'realtime-inference'))
|
|
1203
|
+
|
|
1204
|
+
_output({
|
|
1205
|
+
"dry_run": True,
|
|
1206
|
+
"s3_path": s3_path,
|
|
1207
|
+
"partition": partition,
|
|
1208
|
+
"record_count": len(enriched_records),
|
|
1209
|
+
"records": enriched_records,
|
|
1210
|
+
})
|
|
1211
|
+
return # Never reached after _output
|
|
1212
|
+
|
|
1213
|
+
# ── Write to S3 (requires bucket) ─────────────────────────────────────
|
|
1214
|
+
if not args.bucket:
|
|
1215
|
+
_error_exit("--bucket is required when not using --dry-run")
|
|
1216
|
+
|
|
1217
|
+
region = input_data.get('region', os.environ.get('AWS_REGION', ''))
|
|
1218
|
+
timestamp = datetime.now(timezone.utc)
|
|
1219
|
+
|
|
1220
|
+
# Split input_data back into config and results for enrich_records
|
|
1221
|
+
config_context = {k: v for k, v in input_data.items() if k != 'metrics'}
|
|
1222
|
+
results_obj = {'metrics': input_data['metrics']}
|
|
1223
|
+
if isinstance(benchmark_data, dict) and 'job_name' in benchmark_data:
|
|
1224
|
+
results_obj['job_name'] = benchmark_data['job_name']
|
|
1225
|
+
|
|
1226
|
+
enriched_records = enrich_records(config_context, results_obj, timestamp)
|
|
1227
|
+
|
|
1228
|
+
if not enriched_records:
|
|
1229
|
+
_error_exit("No records produced from benchmark metrics")
|
|
1230
|
+
|
|
1231
|
+
# Compute S3 path
|
|
1232
|
+
s3_info = build_s3_path(args.bucket, input_data.get('project_name', ''), input_data.get('model_name', ''), input_data.get('instance_type', ''), input_data.get('deployment_target', 'realtime-inference'), timestamp, region=region)
|
|
1233
|
+
|
|
1234
|
+
# Write Parquet to a temp file then upload to S3
|
|
1235
|
+
try:
|
|
1236
|
+
import tempfile
|
|
1237
|
+
import pyarrow as pa
|
|
1238
|
+
import pyarrow.parquet as pq
|
|
1239
|
+
except ImportError as e:
|
|
1240
|
+
_error_exit(f"Missing dependency: {e}. Install: pip install pyarrow")
|
|
1241
|
+
|
|
1242
|
+
# Build pyarrow table from enriched records
|
|
1243
|
+
table = _records_to_parquet_table(enriched_records)
|
|
1244
|
+
|
|
1245
|
+
# Write to temp file with Snappy compression
|
|
1246
|
+
tmp_path = None
|
|
1247
|
+
try:
|
|
1248
|
+
with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as tmp:
|
|
1249
|
+
tmp_path = tmp.name
|
|
1250
|
+
|
|
1251
|
+
pq.write_table(table, tmp_path, compression='snappy')
|
|
1252
|
+
|
|
1253
|
+
# Upload to S3
|
|
1254
|
+
_upload_to_s3(tmp_path, args.bucket, s3_info['s3_uri'], region)
|
|
1255
|
+
|
|
1256
|
+
except Exception as e:
|
|
1257
|
+
_error_exit(f"Failed to write Parquet to S3: {e}")
|
|
1258
|
+
finally:
|
|
1259
|
+
# Clean up temp file
|
|
1260
|
+
if tmp_path and os.path.exists(tmp_path):
|
|
1261
|
+
os.unlink(tmp_path)
|
|
1262
|
+
|
|
1263
|
+
# Register partition in Glue catalog to make data immediately queryable.
|
|
1264
|
+
# This is best-effort — failure is non-fatal per design doc error handling.
|
|
1265
|
+
# Data remains readable via MSCK REPAIR TABLE as a fallback.
|
|
1266
|
+
partition_result = None
|
|
1267
|
+
try:
|
|
1268
|
+
partition_result = register_partition(
|
|
1269
|
+
bucket=args.bucket,
|
|
1270
|
+
model=s3_info['partition_model'],
|
|
1271
|
+
instance=s3_info['partition_instance'],
|
|
1272
|
+
target=s3_info['partition_target'],
|
|
1273
|
+
region=region,
|
|
1274
|
+
)
|
|
1275
|
+
except SystemExit:
|
|
1276
|
+
# register_partition calls _error_exit on some failures; catch to avoid
|
|
1277
|
+
# terminating the process — the Parquet write already succeeded.
|
|
1278
|
+
partition_result = {"registered": False, "error": "partition registration failed (non-fatal)"}
|
|
1279
|
+
except Exception as e:
|
|
1280
|
+
partition_result = {"registered": False, "error": str(e)}
|
|
1281
|
+
|
|
1282
|
+
if partition_result and partition_result.get('error'):
|
|
1283
|
+
print(
|
|
1284
|
+
f"\u26a0\ufe0f Partition registration warning: {partition_result['error']}",
|
|
1285
|
+
file=sys.stderr,
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
_output({
|
|
1289
|
+
"success": True,
|
|
1290
|
+
"s3_uri": s3_info['s3_uri'],
|
|
1291
|
+
"partition": {
|
|
1292
|
+
"model": s3_info['partition_model'],
|
|
1293
|
+
"instance": s3_info['partition_instance'],
|
|
1294
|
+
"target": s3_info['partition_target'],
|
|
1295
|
+
},
|
|
1296
|
+
"rows_written": len(enriched_records),
|
|
1297
|
+
"project_name": input_data.get('project_name', ''),
|
|
1298
|
+
"run_timestamp": timestamp.isoformat(),
|
|
1299
|
+
"partition_registration": partition_result,
|
|
1300
|
+
})
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
def _load_config_file(config_path):
|
|
1304
|
+
"""Load configuration context from a do/config shell file or JSON file.
|
|
1305
|
+
|
|
1306
|
+
Supports two formats:
|
|
1307
|
+
- JSON file: parsed directly
|
|
1308
|
+
- Shell config file: extracts export VAR="value" assignments
|
|
1309
|
+
|
|
1310
|
+
Returns:
|
|
1311
|
+
dict with recognized config fields.
|
|
1312
|
+
"""
|
|
1313
|
+
context = {}
|
|
1314
|
+
|
|
1315
|
+
try:
|
|
1316
|
+
# Try JSON first
|
|
1317
|
+
with open(config_path, 'r') as f:
|
|
1318
|
+
content = f.read().strip()
|
|
1319
|
+
|
|
1320
|
+
if content.startswith('{'):
|
|
1321
|
+
data = json.loads(content)
|
|
1322
|
+
# Map known JSON fields to our expected names
|
|
1323
|
+
field_map = {
|
|
1324
|
+
'project_name': 'project_name',
|
|
1325
|
+
'projectName': 'project_name',
|
|
1326
|
+
'model_name': 'model_name',
|
|
1327
|
+
'modelName': 'model_name',
|
|
1328
|
+
'MODEL_NAME': 'model_name',
|
|
1329
|
+
'instance_type': 'instance_type',
|
|
1330
|
+
'instanceType': 'instance_type',
|
|
1331
|
+
'INSTANCE_TYPE': 'instance_type',
|
|
1332
|
+
'deployment_config': 'deployment_config',
|
|
1333
|
+
'deploymentConfig': 'deployment_config',
|
|
1334
|
+
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1335
|
+
'region': 'region',
|
|
1336
|
+
'REGION': 'region',
|
|
1337
|
+
'deployment_target': 'deployment_target',
|
|
1338
|
+
'deploymentTarget': 'deployment_target',
|
|
1339
|
+
'tensor_parallel_degree': 'tensor_parallel_degree',
|
|
1340
|
+
'tensorParallelDegree': 'tensor_parallel_degree',
|
|
1341
|
+
'quantization': 'quantization',
|
|
1342
|
+
'enable_lora': 'enable_lora',
|
|
1343
|
+
'enableLora': 'enable_lora',
|
|
1344
|
+
'base_image': 'base_image',
|
|
1345
|
+
'baseImage': 'base_image',
|
|
1346
|
+
'base_image_version': 'base_image_version',
|
|
1347
|
+
'baseImageVersion': 'base_image_version',
|
|
1348
|
+
'mcc_version': 'mcc_version',
|
|
1349
|
+
'mccVersion': 'mcc_version',
|
|
1350
|
+
'account_id': 'account_id',
|
|
1351
|
+
'accountId': 'account_id',
|
|
1352
|
+
}
|
|
1353
|
+
for source_key, target_key in field_map.items():
|
|
1354
|
+
if source_key in data and target_key not in context:
|
|
1355
|
+
val = data[source_key]
|
|
1356
|
+
# Keep non-string types for certain fields
|
|
1357
|
+
if target_key in ('tensor_parallel_degree',):
|
|
1358
|
+
context[target_key] = int(val) if val is not None else val
|
|
1359
|
+
elif target_key in ('enable_lora',):
|
|
1360
|
+
context[target_key] = bool(val)
|
|
1361
|
+
else:
|
|
1362
|
+
context[target_key] = str(val) if val is not None else val
|
|
1363
|
+
return context
|
|
1364
|
+
|
|
1365
|
+
# Parse shell-style config (export VAR="value" or VAR="value")
|
|
1366
|
+
for line in content.split('\n'):
|
|
1367
|
+
line = line.strip()
|
|
1368
|
+
if line.startswith('#') or not line:
|
|
1369
|
+
continue
|
|
1370
|
+
# Remove 'export ' prefix
|
|
1371
|
+
if line.startswith('export '):
|
|
1372
|
+
line = line[7:]
|
|
1373
|
+
# Parse VAR=value or VAR="value"
|
|
1374
|
+
if '=' in line:
|
|
1375
|
+
key, _, value = line.partition('=')
|
|
1376
|
+
key = key.strip()
|
|
1377
|
+
value = value.strip().strip('"').strip("'")
|
|
1378
|
+
# Handle shell default syntax: ${VAR:-default} → extract default
|
|
1379
|
+
if value.startswith('${') and ':-' in value:
|
|
1380
|
+
value = value.split(':-', 1)[1].rstrip('}')
|
|
1381
|
+
# Skip unresolved shell variables (e.g., ${INSTANCE_TYPE})
|
|
1382
|
+
if value.startswith('${') or value.startswith('$('):
|
|
1383
|
+
continue
|
|
1384
|
+
# Map shell var names to our field names
|
|
1385
|
+
shell_map = {
|
|
1386
|
+
'PROJECT_NAME': 'project_name',
|
|
1387
|
+
'MODEL_NAME': 'model_name',
|
|
1388
|
+
'INSTANCE_TYPE': 'instance_type',
|
|
1389
|
+
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1390
|
+
'DEPLOYMENT_TARGET': 'deployment_target',
|
|
1391
|
+
'AWS_REGION': 'region',
|
|
1392
|
+
'REGION': 'region',
|
|
1393
|
+
'ACCOUNT_ID': 'account_id',
|
|
1394
|
+
'MCC_VERSION': 'mcc_version',
|
|
1395
|
+
'BASE_IMAGE': 'base_image',
|
|
1396
|
+
'BASE_IMAGE_VERSION': 'base_image_version',
|
|
1397
|
+
'BENCHMARK_CONCURRENCY': 'benchmark_concurrency',
|
|
1398
|
+
}
|
|
1399
|
+
if key in shell_map:
|
|
1400
|
+
context[shell_map[key]] = value
|
|
1401
|
+
|
|
1402
|
+
except Exception:
|
|
1403
|
+
pass
|
|
1404
|
+
|
|
1405
|
+
return context
|
|
1406
|
+
|
|
1407
|
+
|
|
1408
|
+
# ── CLI entry point ───────────────────────────────────────────────────────────
|
|
1409
|
+
|
|
1410
|
+
|
|
1411
|
+
def main():
|
|
1412
|
+
"""Parse CLI args and dispatch to subcommand."""
|
|
1413
|
+
parser = argparse.ArgumentParser(
|
|
1414
|
+
description='Benchmark Writer — Convert benchmark results to Athena-compatible Parquet'
|
|
1415
|
+
)
|
|
1416
|
+
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
|
1417
|
+
|
|
1418
|
+
# write subcommand
|
|
1419
|
+
write_parser = subparsers.add_parser('write', help='Write benchmark results to S3')
|
|
1420
|
+
write_parser.add_argument(
|
|
1421
|
+
'--input',
|
|
1422
|
+
help='Path to benchmark results JSON file (alias for --results-file)'
|
|
1423
|
+
)
|
|
1424
|
+
write_parser.add_argument(
|
|
1425
|
+
'--results-file', dest='results_file',
|
|
1426
|
+
help='Path to benchmark results JSON file'
|
|
1427
|
+
)
|
|
1428
|
+
write_parser.add_argument(
|
|
1429
|
+
'--config-file', dest='config_file',
|
|
1430
|
+
help='Path to config file (do/config or JSON) for context fields'
|
|
1431
|
+
)
|
|
1432
|
+
write_parser.add_argument(
|
|
1433
|
+
'--project-name', dest='project_name',
|
|
1434
|
+
help='MCC project name (human-readable identifier)'
|
|
1435
|
+
)
|
|
1436
|
+
write_parser.add_argument(
|
|
1437
|
+
'--workload', default='manual',
|
|
1438
|
+
help='Named workload profile (from workload-picker MCP, default: manual)'
|
|
1439
|
+
)
|
|
1440
|
+
write_parser.add_argument(
|
|
1441
|
+
'--concurrency', type=int, default=None,
|
|
1442
|
+
help='Concurrency level (passed to JSONL aggregation if results are per-request)'
|
|
1443
|
+
)
|
|
1444
|
+
write_parser.add_argument(
|
|
1445
|
+
'--bucket',
|
|
1446
|
+
help='S3 bucket name for results (required unless --dry-run)'
|
|
1447
|
+
)
|
|
1448
|
+
write_parser.add_argument(
|
|
1449
|
+
'--region',
|
|
1450
|
+
help='AWS region'
|
|
1451
|
+
)
|
|
1452
|
+
write_parser.add_argument(
|
|
1453
|
+
'--dry-run', dest='dry_run', action='store_true',
|
|
1454
|
+
help='Output enriched records as JSON without writing to S3'
|
|
1455
|
+
)
|
|
1456
|
+
|
|
1457
|
+
args = parser.parse_args()
|
|
1458
|
+
|
|
1459
|
+
if not args.command:
|
|
1460
|
+
parser.print_help()
|
|
1461
|
+
sys.exit(1)
|
|
1462
|
+
|
|
1463
|
+
if args.command == 'write':
|
|
1464
|
+
cmd_write(args)
|
|
1465
|
+
|
|
1466
|
+
|
|
1467
|
+
if __name__ == '__main__':
|
|
1468
|
+
try:
|
|
1469
|
+
main()
|
|
1470
|
+
except SystemExit:
|
|
1471
|
+
raise
|
|
1472
|
+
except Exception as e:
|
|
1473
|
+
# Catch all unexpected exceptions and emit structured error
|
|
1474
|
+
# This ensures we NEVER produce a raw traceback
|
|
1475
|
+
print(json.dumps({"error": f"unexpected error: {e}"}))
|
|
1476
|
+
sys.exit(1)
|