@aws/ml-container-creator 0.10.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. package/LICENSE-THIRD-PARTY +9304 -0
  2. package/bin/cli.js +2 -0
  3. package/config/bootstrap-e2e-stack.json +341 -0
  4. package/config/bootstrap-stack.json +40 -3
  5. package/config/parameter-schema-v2.json +33 -22
  6. package/config/tune-catalog.json +1781 -0
  7. package/infra/ci-harness/buildspec.yml +1 -0
  8. package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
  9. package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
  10. package/infra/ci-harness/lib/ci-harness-stack.ts +851 -7
  11. package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
  12. package/package.json +53 -67
  13. package/servers/base-image-picker/index.js +121 -121
  14. package/servers/e2e-status/index.js +297 -0
  15. package/servers/e2e-status/manifest.json +14 -0
  16. package/servers/e2e-status/package.json +15 -0
  17. package/servers/endpoint-picker/LICENSE +202 -0
  18. package/servers/endpoint-picker/index.js +536 -0
  19. package/servers/endpoint-picker/manifest.json +14 -0
  20. package/servers/endpoint-picker/package.json +18 -0
  21. package/servers/hyperpod-cluster-picker/index.js +125 -125
  22. package/servers/instance-sizer/index.js +166 -153
  23. package/servers/instance-sizer/lib/instance-ranker.js +120 -76
  24. package/servers/instance-sizer/lib/model-resolver.js +61 -61
  25. package/servers/instance-sizer/lib/quota-resolver.js +113 -113
  26. package/servers/instance-sizer/lib/vram-estimator.js +31 -31
  27. package/servers/lib/bedrock-client.js +38 -38
  28. package/servers/lib/catalogs/instances.json +27 -0
  29. package/servers/lib/catalogs/model-servers.json +201 -3
  30. package/servers/lib/custom-validators.js +13 -13
  31. package/servers/lib/dynamic-resolver.js +4 -4
  32. package/servers/marketplace-picker/index.js +342 -0
  33. package/servers/marketplace-picker/manifest.json +14 -0
  34. package/servers/marketplace-picker/package.json +18 -0
  35. package/servers/model-picker/index.js +382 -382
  36. package/servers/region-picker/index.js +56 -56
  37. package/servers/workload-picker/LICENSE +202 -0
  38. package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
  39. package/servers/workload-picker/index.js +171 -0
  40. package/servers/workload-picker/manifest.json +16 -0
  41. package/servers/workload-picker/package.json +16 -0
  42. package/src/app.js +12 -3
  43. package/src/lib/bootstrap-command-handler.js +609 -15
  44. package/src/lib/bootstrap-config.js +36 -0
  45. package/src/lib/bootstrap-profile-manager.js +48 -41
  46. package/src/lib/ci-register-helpers.js +74 -0
  47. package/src/lib/config-loader.js +3 -0
  48. package/src/lib/config-manager.js +7 -0
  49. package/src/lib/config-validator.js +1 -1
  50. package/src/lib/cuda-resolver.js +17 -8
  51. package/src/lib/generated/cli-options.js +319 -314
  52. package/src/lib/generated/parameter-matrix.js +672 -661
  53. package/src/lib/generated/validation-rules.js +76 -72
  54. package/src/lib/path-prover-brain.js +664 -0
  55. package/src/lib/prompts/infrastructure-prompts.js +2 -2
  56. package/src/lib/prompts/model-prompts.js +6 -0
  57. package/src/lib/prompts/project-prompts.js +12 -0
  58. package/src/lib/secrets-prompt-runner.js +4 -0
  59. package/src/lib/template-manager.js +1 -1
  60. package/src/lib/template-variable-resolver.js +87 -1
  61. package/src/lib/tune-catalog-validator.js +37 -4
  62. package/templates/Dockerfile +9 -0
  63. package/templates/code/adapter_sidecar.py +444 -0
  64. package/templates/code/serve +6 -0
  65. package/templates/code/serve.d/vllm.ejs +1 -1
  66. package/templates/do/.benchmark_writer.py +1476 -0
  67. package/templates/do/.tune_helper.py +982 -57
  68. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  69. package/templates/do/adapter +154 -0
  70. package/templates/do/benchmark +639 -85
  71. package/templates/do/build +5 -0
  72. package/templates/do/clean.d/async-inference.ejs +5 -0
  73. package/templates/do/clean.d/batch-transform.ejs +5 -0
  74. package/templates/do/clean.d/hyperpod-eks.ejs +5 -0
  75. package/templates/do/clean.d/managed-inference.ejs +5 -0
  76. package/templates/do/config +115 -45
  77. package/templates/do/deploy.d/async-inference.ejs +30 -3
  78. package/templates/do/deploy.d/batch-transform.ejs +29 -3
  79. package/templates/do/deploy.d/hyperpod-eks.ejs +4 -0
  80. package/templates/do/deploy.d/managed-inference.ejs +216 -14
  81. package/templates/do/lib/endpoint-config.sh +1 -1
  82. package/templates/do/lib/profile.sh +44 -0
  83. package/templates/do/optimize +106 -37
  84. package/templates/do/push +5 -0
  85. package/templates/do/register +94 -0
  86. package/templates/do/stage +567 -0
  87. package/templates/do/submit +7 -0
  88. package/templates/do/test +14 -0
  89. package/templates/do/tune +382 -59
  90. package/templates/do/validate +44 -4
@@ -0,0 +1,1476 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Benchmark Writer — Converts do/benchmark output to enriched Parquet for Athena.
6
+
7
+ Subcommands:
8
+ write - Validate, enrich, and write benchmark results to S3 as Parquet
9
+
10
+ All output is JSON on stdout for bash consumption.
11
+ Errors are structured JSON objects — never raw tracebacks.
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import re
18
+ import sys
19
+ from datetime import datetime, timezone
20
+
21
+
22
+ # ── Constants ─────────────────────────────────────────────────────────────────
23
+
24
+ REQUIRED_FIELDS = [
25
+ 'project_name',
26
+ 'model_name',
27
+ 'instance_type',
28
+ 'deployment_config',
29
+ 'region',
30
+ 'metrics',
31
+ ]
32
+
33
+ # Pattern for valid SageMaker instance types: ml.<family>.<size>
34
+ _INSTANCE_TYPE_RE = re.compile(r'^ml\.[a-z0-9]+\.[a-z0-9]+$')
35
+
36
+ # Known model family patterns — maps regex to family label
37
+ # Known model family patterns — maps regex to family label.
38
+ # Patterns are searched against the model identifier (after org/ prefix stripping).
39
+ # Order matters: more specific patterns (e.g., deepseek-r1) must precede generic ones.
40
+ # Version dots are collapsed for family grouping (e.g., Llama-3.1 → llama3).
41
+ _MODEL_FAMILY_PATTERNS = [
42
+ # DeepSeek — must come before qwen/llama because model names may contain those
43
+ # (e.g., "DeepSeek-R1-Distill-Qwen-7B" contains "Qwen")
44
+ (re.compile(r'deepseek[-_.]?r1', re.IGNORECASE), 'deepseek-r1'),
45
+ (re.compile(r'deepseek[-_.]?v3', re.IGNORECASE), 'deepseek-v3'),
46
+ (re.compile(r'deepseek[-_.]?v2', re.IGNORECASE), 'deepseek-v2'),
47
+ (re.compile(r'deepseek[-_.]?coder', re.IGNORECASE), 'deepseek-coder'),
48
+ (re.compile(r'deepseek[-_.]?math', re.IGNORECASE), 'deepseek-math'),
49
+ (re.compile(r'deepseek', re.IGNORECASE), 'deepseek'),
50
+ # Qwen family — version number without dots for family grouping
51
+ (re.compile(r'qwen3', re.IGNORECASE), 'qwen3'),
52
+ (re.compile(r'qwen2', re.IGNORECASE), 'qwen2'),
53
+ (re.compile(r'qwen', re.IGNORECASE), 'qwen'),
54
+ # Llama family — collapse version dots (3.1, 3.2 → llama3)
55
+ (re.compile(r'codellama|code[-_]?llama', re.IGNORECASE), 'codellama'),
56
+ (re.compile(r'llama[-_.]?3', re.IGNORECASE), 'llama3'),
57
+ (re.compile(r'llama[-_.]?2', re.IGNORECASE), 'llama2'),
58
+ (re.compile(r'llama', re.IGNORECASE), 'llama'),
59
+ # Mistral/Mixtral
60
+ (re.compile(r'mixtral', re.IGNORECASE), 'mixtral'),
61
+ (re.compile(r'mistral', re.IGNORECASE), 'mistral'),
62
+ # Microsoft Phi
63
+ (re.compile(r'phi[-_.]?3', re.IGNORECASE), 'phi3'),
64
+ (re.compile(r'phi[-_.]?2', re.IGNORECASE), 'phi2'),
65
+ # Google Gemma
66
+ (re.compile(r'gemma[-_.]?2', re.IGNORECASE), 'gemma2'),
67
+ (re.compile(r'gemma', re.IGNORECASE), 'gemma'),
68
+ # Others
69
+ (re.compile(r'falcon', re.IGNORECASE), 'falcon'),
70
+ (re.compile(r'starcoder', re.IGNORECASE), 'starcoder'),
71
+ (re.compile(r'gpt[-_.]?oss', re.IGNORECASE), 'gpt-oss'),
72
+ ]
73
+
74
+ # Approximate on-demand $/hr for common SageMaker AI instances
75
+ INSTANCE_PRICING_USD_PER_HOUR = {
76
+ 'g5.xlarge': 1.408,
77
+ 'g5.2xlarge': 1.52,
78
+ 'g5.4xlarge': 2.03,
79
+ 'g5.8xlarge': 3.06,
80
+ 'g5.12xlarge': 7.09,
81
+ 'g5.16xlarge': 5.10,
82
+ 'g5.24xlarge': 10.18,
83
+ 'g5.48xlarge': 20.36,
84
+ 'g6.xlarge': 1.00,
85
+ 'g6.2xlarge': 1.21,
86
+ 'g6.4xlarge': 1.62,
87
+ 'g6.8xlarge': 2.44,
88
+ 'g6.12xlarge': 5.66,
89
+ 'g6.16xlarge': 4.07,
90
+ 'g6.24xlarge': 7.53,
91
+ 'g6.48xlarge': 15.06,
92
+ 'g6e.xlarge': 1.86,
93
+ 'g6e.2xlarge': 2.35,
94
+ 'g6e.4xlarge': 3.34,
95
+ 'g6e.12xlarge': 11.67,
96
+ 'g6e.48xlarge': 38.12,
97
+ 'p4d.24xlarge': 37.69,
98
+ 'p5.48xlarge': 65.85,
99
+ 'trn2.48xlarge': 21.50,
100
+ }
101
+
102
+
103
+ # ── Utility functions ─────────────────────────────────────────────────────────
104
+
105
+
106
+ def _error_exit(message):
107
+ """Print JSON error to stdout and exit with code 1."""
108
+ print(json.dumps({"error": message}))
109
+ sys.exit(1)
110
+
111
+
112
+ def _output(data):
113
+ """Print JSON result to stdout."""
114
+ print(json.dumps(data))
115
+ sys.exit(0)
116
+
117
+
118
+ # ── Derived field computation ─────────────────────────────────────────────────
119
+
120
+
121
+ def derive_model_family(model_name):
122
+ """Derive model family from model_name.
123
+
124
+ Examples:
125
+ "Qwen/Qwen3-4B" → "qwen3"
126
+ "meta-llama/Llama-3.1-8B" → "llama3"
127
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" → "deepseek-r1"
128
+
129
+ The function:
130
+ 1. Strips the org prefix (everything before /)
131
+ 2. Matches patterns from most-specific to least-specific
132
+ 3. Collapses version dots for family grouping (3.1, 3.2 → 3)
133
+
134
+ Returns:
135
+ str — lowercase family identifier, or "other" if no pattern matches,
136
+ or "unknown" if model_name is empty/None.
137
+ """
138
+ if not model_name:
139
+ return 'unknown'
140
+
141
+ # Strip org prefix: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" → "DeepSeek-R1-Distill-Qwen-7B"
142
+ name = model_name.split('/')[-1] if '/' in model_name else model_name
143
+
144
+ for pattern, family in _MODEL_FAMILY_PATTERNS:
145
+ if pattern.search(name):
146
+ return family
147
+ return 'other'
148
+
149
+
150
+ # Alias for test compatibility
151
+ compute_model_family = derive_model_family
152
+
153
+
154
+ def derive_instance_family(instance_type):
155
+ """Derive instance family from instance_type.
156
+
157
+ Examples:
158
+ "ml.g5.xlarge" → "g5"
159
+ "ml.g6e.2xlarge" → "g6e"
160
+ "ml.p5.48xlarge" → "p5"
161
+ "ml.trn2.xlarge" → "trn2"
162
+
163
+ Returns:
164
+ str — instance family identifier, or "unknown" if pattern doesn't match.
165
+ """
166
+ if not instance_type:
167
+ return 'unknown'
168
+ match = re.match(r'^ml\.([a-z0-9]+)\.[a-z0-9]+$', instance_type)
169
+ if match:
170
+ return match.group(1)
171
+ return 'unknown'
172
+
173
+
174
+ # Alias for test compatibility
175
+ compute_instance_family = derive_instance_family
176
+
177
+
178
+ def compute_cost_per_1m_tokens(instance_type, tokens_per_second):
179
+ """Estimate cost per 1M output tokens (USD).
180
+
181
+ Uses approximate on-demand SageMaker AI instance pricing.
182
+ If tokens_per_second is 0 or unknown, returns None.
183
+
184
+ Args:
185
+ instance_type: SageMaker AI instance type string.
186
+ tokens_per_second: Output tokens/second throughput.
187
+
188
+ Returns:
189
+ float or None — estimated USD cost per 1M output tokens.
190
+ """
191
+ if not instance_type or not tokens_per_second:
192
+ return None
193
+ if tokens_per_second <= 0:
194
+ return None
195
+
196
+ # Extract instance spec (remove ml. prefix)
197
+ instance_spec = instance_type.replace('ml.', '', 1) if instance_type.startswith('ml.') else instance_type
198
+ cost_per_hour = INSTANCE_PRICING_USD_PER_HOUR.get(instance_spec)
199
+ if cost_per_hour is None:
200
+ return None
201
+
202
+ # cost_per_1m_tokens = (cost_per_hour / tokens_per_second / 3600) * 1_000_000
203
+ cost_per_token = cost_per_hour / (tokens_per_second * 3600)
204
+ return round(cost_per_token * 1_000_000, 4)
205
+
206
+
207
+ def compute_partition_keys(timestamp):
208
+ """Compute year and month partition keys from a timestamp.
209
+
210
+ Args:
211
+ timestamp: One of:
212
+ - datetime object
213
+ - ISO 8601 string ("2026-06-09T14:30:22Z" or "2026-06-09T14:30:22+00:00")
214
+ - Compact string ("20260609T143022Z")
215
+ - None (uses current UTC time)
216
+
217
+ Returns:
218
+ tuple (year: str, month: str) — zero-padded strings.
219
+ """
220
+ if timestamp is None:
221
+ dt = datetime.now(timezone.utc)
222
+ elif isinstance(timestamp, datetime):
223
+ dt = timestamp
224
+ elif isinstance(timestamp, str):
225
+ # Try ISO 8601 variants
226
+ ts = timestamp.strip()
227
+ try:
228
+ # Standard ISO: 2026-06-09T14:30:22Z or 2026-06-09T14:30:22+00:00
229
+ if 'T' in ts and '-' in ts[:10]:
230
+ ts_clean = ts.replace('Z', '+00:00')
231
+ dt = datetime.fromisoformat(ts_clean)
232
+ elif 'T' in ts:
233
+ # Compact: 20260609T143022Z
234
+ ts_clean = ts.rstrip('Z')
235
+ dt = datetime.strptime(ts_clean, '%Y%m%dT%H%M%S')
236
+ dt = dt.replace(tzinfo=timezone.utc)
237
+ else:
238
+ dt = datetime.now(timezone.utc)
239
+ except (ValueError, TypeError):
240
+ dt = datetime.now(timezone.utc)
241
+ else:
242
+ dt = datetime.now(timezone.utc)
243
+
244
+ return (dt.strftime('%Y'), dt.strftime('%m'))
245
+
246
+
247
+ def compute_s3_path(bucket, project_name, model_name, instance_type, deployment_target, timestamp):
248
+ """Construct the full S3 URI for a benchmark run Parquet file.
249
+
250
+ Uses model/instance/target partitioning scheme.
251
+
252
+ Args:
253
+ bucket: S3 bucket name.
254
+ project_name: MCC project name.
255
+ model_name: HuggingFace model ID.
256
+ instance_type: SageMaker instance type.
257
+ deployment_target: Deployment target (realtime-inference, etc.).
258
+ timestamp: datetime object for the run timestamp.
259
+
260
+ Returns:
261
+ str — full S3 URI.
262
+ """
263
+ # Sanitize model name for S3 path (/ → _)
264
+ model_partition = model_name.replace('/', '_') if model_name else 'unknown'
265
+ instance_partition = instance_type or 'unknown'
266
+ target_partition = deployment_target or 'realtime-inference'
267
+ ts_str = timestamp.strftime('%Y%m%dT%H%M%SZ')
268
+ filename = f'run-{project_name}-{ts_str}.parquet'
269
+
270
+ return f's3://{bucket}/results/model={model_partition}/instance={instance_partition}/target={target_partition}/{filename}'
271
+
272
+
273
+ def compute_partition_info(model_name, instance_type, deployment_target):
274
+ """Compute partition metadata dict for model/instance/target scheme.
275
+
276
+ Args:
277
+ model_name: HuggingFace model ID (e.g., 'Qwen/Qwen3-0.6B').
278
+ instance_type: SageMaker instance type (e.g., 'ml.g5.xlarge').
279
+ deployment_target: Deployment target (e.g., 'realtime-inference').
280
+
281
+ Returns:
282
+ dict with keys: model, instance, target.
283
+ """
284
+ return {
285
+ "model": model_name.replace('/', '_') if model_name else 'unknown',
286
+ "instance": instance_type or 'unknown',
287
+ "target": deployment_target or 'realtime-inference',
288
+ }
289
+
290
+
291
+ def build_s3_path(bucket, project_name, model_name, instance_type, deployment_target, timestamp=None, region=''):
292
+ """Construct the S3 path and partition info for a benchmark run.
293
+
294
+ Args:
295
+ bucket: S3 bucket name.
296
+ region: AWS region string.
297
+ project_name: MCC project name.
298
+ timestamp: datetime object or None (defaults to now UTC).
299
+
300
+ Returns:
301
+ dict with keys: s3_uri, partition_model, partition_instance, partition_target, filename.
302
+ """
303
+ if timestamp is None:
304
+ timestamp = datetime.now(timezone.utc)
305
+
306
+ year = timestamp.strftime('%Y')
307
+ month = timestamp.strftime('%m')
308
+ ts_str = timestamp.strftime('%Y%m%dT%H%M%SZ')
309
+ model_partition = model_name.replace('/', '_') if model_name else 'unknown'
310
+ instance_partition = instance_type or 'unknown'
311
+ target_partition = deployment_target or 'realtime-inference'
312
+ filename = f'run-{project_name}-{ts_str}.parquet'
313
+
314
+ s3_uri = f's3://{bucket}/results/model={model_partition}/instance={instance_partition}/target={target_partition}/{filename}'
315
+
316
+ return {
317
+ 's3_uri': s3_uri,
318
+ 'partition_model': model_partition,
319
+ 'partition_instance': instance_partition,
320
+ 'partition_target': target_partition,
321
+ 'filename': filename,
322
+ }
323
+
324
+
325
+ def _extract_base_image_version(base_image):
326
+ """Extract version tag from a base image string.
327
+
328
+ Examples:
329
+ "vllm/vllm-openai:v0.8.5" → "v0.8.5"
330
+ "nvcr.io/nvidia/tritonserver:24.01-py3" → "24.01-py3"
331
+ "" → ""
332
+
333
+ Returns:
334
+ str — extracted tag or empty string.
335
+ """
336
+ if not base_image:
337
+ return ''
338
+ if ':' in base_image:
339
+ return base_image.split(':')[-1]
340
+ return ''
341
+
342
+
343
+ def enrich_records(config, results, run_timestamp=None):
344
+ """Build enriched records from config context and benchmark results.
345
+
346
+ Each metrics entry becomes one enriched record with all Athena columns populated.
347
+
348
+ Args:
349
+ config: dict with config context fields (project_name, model_name, etc.)
350
+ results: dict with benchmark results (job_name, metrics array)
351
+ run_timestamp: Optional datetime for run_timestamp. Defaults to now UTC.
352
+
353
+ Returns:
354
+ list of enriched record dicts (one per concurrency level).
355
+ """
356
+ if run_timestamp is None:
357
+ run_timestamp = datetime.now(timezone.utc)
358
+
359
+ model_name = config.get('model_name', '')
360
+ instance_type = config.get('instance_type', '')
361
+ project_name = config.get('project_name', '')
362
+ deployment_config = config.get('deployment_config', '')
363
+ region = config.get('region', '')
364
+
365
+ # Derived fields
366
+ model_family = derive_model_family(model_name)
367
+
368
+ # Optional context fields
369
+ deployment_target = config.get('deployment_target', 'realtime-inference')
370
+ tensor_parallel_degree = config.get('tensor_parallel_degree', 1)
371
+ quantization = config.get('quantization', 'none')
372
+ enable_lora = config.get('enable_lora', False)
373
+ base_image = config.get('base_image', '')
374
+ base_image_version = config.get('base_image_version', '') or _extract_base_image_version(base_image)
375
+ mcc_version = config.get('mcc_version', '')
376
+ run_type = config.get('run_type', 'ci')
377
+ ci_run_id = config.get('ci_run_id', '')
378
+ account_id = config.get('account_id', '')
379
+
380
+
381
+ # Get metrics from results
382
+ metrics = results.get('metrics', []) if isinstance(results, dict) else []
383
+
384
+ # Helper: unwrap aiperf metric dicts to scalar values
385
+ # Derived metrics: {'unit': 'requests/sec', 'avg': 9.57} → 9.57
386
+ # Record metrics: {'unit': 'ms', 'avg': 181.9, 'p50': 183.2, ...} → passed to .get('p50') etc.
387
+ def scalar(val, stat='avg'):
388
+ if isinstance(val, dict):
389
+ return val.get(stat, 0.0)
390
+ return val if val is not None else 0.0
391
+
392
+ records = []
393
+ for metric in metrics:
394
+ concurrency = scalar(metric.get('concurrency', 0))
395
+ throughput_rps = scalar(metric.get('request_throughput', 0.0))
396
+ tokens_per_second = scalar(metric.get('output_token_throughput', 0.0))
397
+ error_count = metric.get('error_count', 0)
398
+ total_requests = scalar(metric.get('total_requests', 0))
399
+ duration_seconds = scalar(metric.get('duration_seconds', 0), stat='avg')
400
+ input_tokens_mean = metric.get('input_tokens_mean', 0)
401
+ output_tokens_mean = metric.get('output_tokens_mean', 0)
402
+
403
+ # Latency percentiles
404
+ ttft = metric.get('time_to_first_token', {})
405
+ itl = metric.get('inter_token_latency', {})
406
+
407
+ # Error rate
408
+ error_rate = (error_count / total_requests) if total_requests > 0 else 0.0
409
+
410
+ # Status based on error rate
411
+ if error_rate >= 1.0:
412
+ status = 'failed'
413
+ else:
414
+ status = 'completed'
415
+
416
+ # Cost computation
417
+ cost = compute_cost_per_1m_tokens(instance_type, tokens_per_second)
418
+
419
+ # Build serving_config JSON blob from all available config params
420
+ serving_config_dict = {
421
+ k: v for k, v in {
422
+ 'quantization': quantization,
423
+ 'tensor_parallel_degree': tensor_parallel_degree,
424
+ 'enable_lora': enable_lora,
425
+ 'base_image': base_image,
426
+ 'kv_cache_dtype': config.get('kv_cache_dtype', 'auto'),
427
+ 'max_model_len': config.get('max_model_len', ''),
428
+ 'vllm_version': config.get('vllm_version', ''),
429
+ 'gpu_memory_utilization': config.get('gpu_memory_utilization', ''),
430
+ 'ic_gpu_count': config.get('ic_gpu_count', ''),
431
+ 'ic_copy_count': config.get('ic_copy_count', ''),
432
+ 'adapter_name': config.get('adapter_name', ''),
433
+ }.items() if v not in ('', None)
434
+ }
435
+
436
+ # Extract richer latency metrics
437
+ e2e_latency = metric.get('e2e_latency', {})
438
+ prefill = metric.get('prefill_throughput', {})
439
+ output_tps = metric.get('output_token_throughput_detail', {})
440
+
441
+ record = {
442
+ 'project_name': project_name,
443
+ 'model_name': model_name,
444
+ 'model_family': model_family,
445
+ 'instance_type': instance_type,
446
+ 'deployment_config': deployment_config,
447
+ 'deployment_target': deployment_target,
448
+ 'quantization': quantization,
449
+ 'tensor_parallel_degree': tensor_parallel_degree,
450
+ 'serving_config': json.dumps(serving_config_dict),
451
+ 'workload': config.get('workload', 'manual'),
452
+ 'concurrency': concurrency,
453
+ 'input_tokens_mean': input_tokens_mean,
454
+ 'output_tokens_mean': output_tokens_mean,
455
+ 'streaming': config.get('streaming', True),
456
+ 'duration_seconds': duration_seconds,
457
+ 'request_throughput_rps': throughput_rps,
458
+ 'total_token_throughput_tps': scalar(metric.get('total_token_throughput', 0.0)),
459
+ 'output_token_throughput_tps': scalar(metric.get('output_token_throughput', 0.0)),
460
+ 'request_count': scalar(metric.get('request_count', metric.get('total_requests', 0))),
461
+ 'ttft_avg_ms': ttft.get('avg', 0.0),
462
+ 'ttft_p50_ms': ttft.get('p50', 0.0),
463
+ 'ttft_p90_ms': ttft.get('p90', 0.0),
464
+ 'ttft_p99_ms': ttft.get('p99', 0.0),
465
+ 'itl_avg_ms': itl.get('avg', 0.0),
466
+ 'itl_p50_ms': itl.get('p50', 0.0),
467
+ 'itl_p90_ms': itl.get('p90', 0.0),
468
+ 'itl_p99_ms': itl.get('p99', 0.0),
469
+ 'e2e_latency_avg_ms': e2e_latency.get('avg', 0.0),
470
+ 'e2e_latency_p50_ms': e2e_latency.get('p50', 0.0),
471
+ 'e2e_latency_p90_ms': e2e_latency.get('p90', 0.0),
472
+ 'e2e_latency_p99_ms': e2e_latency.get('p99', 0.0),
473
+ 'prefill_tps_avg': prefill.get('avg', 0.0),
474
+ 'prefill_tps_p50': prefill.get('p50', 0.0),
475
+ 'output_token_tps_avg': output_tps.get('avg', 0.0),
476
+ 'output_token_tps_p50': output_tps.get('p50', 0.0),
477
+ 'output_token_tps_p90': output_tps.get('p90', 0.0),
478
+ 'ttst_p50_ms': metric.get('time_to_second_token', {}).get('p50', 0.0),
479
+ 'ttst_p90_ms': metric.get('time_to_second_token', {}).get('p90', 0.0),
480
+ 'output_sequence_length_avg': metric.get('output_sequence_length_avg', 0.0),
481
+ 'output_sequence_length_avg': scalar(metric.get('output_sequence_length', metric.get('output_sequence_length_avg', 0.0))),
482
+ 'input_sequence_length_avg': scalar(metric.get('input_sequence_length', metric.get('input_sequence_length_avg', 0.0))),
483
+ 'error_rate': error_rate,
484
+ 'benchmark_duration_sec': metric.get('benchmark_duration_sec', duration_seconds),
485
+ 'run_type': run_type,
486
+ 'benchmark_job_name': results.get('job_name', '') if isinstance(results, dict) else '',
487
+ 'mcc_version': mcc_version,
488
+ 'run_timestamp': run_timestamp.isoformat(),
489
+ 'region': region,
490
+ }
491
+ records.append(record)
492
+
493
+ return records
494
+
495
+
496
+ def validate_input(config, results):
497
+ """Validate config context and results for completeness.
498
+
499
+ Two-argument interface: takes separate config and results dicts,
500
+ merges them, and delegates to validate_benchmark_input.
501
+
502
+ Args:
503
+ config: dict with config context fields.
504
+ results: dict with benchmark results (must have 'metrics' key).
505
+
506
+ Returns:
507
+ list of {"field": str, "reason": str} dicts for each validation failure.
508
+ Empty list means validation passed.
509
+ """
510
+ merged = {}
511
+ if isinstance(config, dict):
512
+ merged.update(config)
513
+ if isinstance(results, dict):
514
+ metrics = results.get('metrics')
515
+ if metrics is not None:
516
+ merged['metrics'] = metrics
517
+ return validate_benchmark_input(merged)
518
+
519
+
520
+ # ── Validation ────────────────────────────────────────────────────────────────
521
+
522
+
523
+ def validate_benchmark_input(data):
524
+ """Validate that all required fields are present and valid.
525
+
526
+ Args:
527
+ data: dict containing the merged benchmark input (config context + results).
528
+ If data is not a dict, returns a single root-level error.
529
+
530
+ Returns:
531
+ list of {"field": str, "reason": str} dicts for each validation failure.
532
+ Empty list means validation passed.
533
+ """
534
+ # Guard against non-dict input
535
+ if not isinstance(data, dict):
536
+ return [{"field": "_root", "reason": "input must be a JSON object"}]
537
+
538
+ errors = []
539
+
540
+ for field in REQUIRED_FIELDS:
541
+ value = data.get(field)
542
+
543
+ if field == 'metrics':
544
+ # metrics must be a non-empty list of objects
545
+ if value is None:
546
+ errors.append({
547
+ "field": field,
548
+ "reason": "required field is missing"
549
+ })
550
+ elif not isinstance(value, list) or len(value) == 0:
551
+ errors.append({
552
+ "field": field,
553
+ "reason": "must be a non-empty array"
554
+ })
555
+ else:
556
+ # Validate each metrics entry
557
+ for i, entry in enumerate(value):
558
+ if not isinstance(entry, dict):
559
+ errors.append({
560
+ "field": f"metrics[{i}]",
561
+ "reason": "each metrics entry must be an object"
562
+ })
563
+ continue
564
+ # Each metrics entry must have concurrency as an integer
565
+ conc = entry.get('concurrency')
566
+ if conc is None:
567
+ errors.append({
568
+ "field": f"metrics[{i}].concurrency",
569
+ "reason": "required field is missing"
570
+ })
571
+ elif not isinstance(conc, int) or isinstance(conc, bool):
572
+ errors.append({
573
+ "field": f"metrics[{i}].concurrency",
574
+ "reason": "must be an integer"
575
+ })
576
+ elif field == 'instance_type':
577
+ # instance_type must be a non-empty string matching ml.* pattern
578
+ if value is None:
579
+ errors.append({
580
+ "field": field,
581
+ "reason": "required field is missing"
582
+ })
583
+ elif not isinstance(value, str):
584
+ errors.append({
585
+ "field": field,
586
+ "reason": "must be a non-empty string"
587
+ })
588
+ elif value.strip() == '':
589
+ errors.append({
590
+ "field": field,
591
+ "reason": "must be a non-empty string"
592
+ })
593
+ elif not _INSTANCE_TYPE_RE.match(value):
594
+ errors.append({
595
+ "field": field,
596
+ "reason": "must match ml.* pattern (e.g., ml.g5.xlarge)"
597
+ })
598
+ else:
599
+ # String fields must be present and non-empty
600
+ if value is None:
601
+ errors.append({
602
+ "field": field,
603
+ "reason": "required field is missing"
604
+ })
605
+ elif not isinstance(value, str):
606
+ errors.append({
607
+ "field": field,
608
+ "reason": "must be a non-empty string"
609
+ })
610
+ elif value.strip() == '':
611
+ errors.append({
612
+ "field": field,
613
+ "reason": "must be a non-empty string"
614
+ })
615
+
616
+ return errors
617
+
618
+
619
+ def emit_validation_error(errors):
620
+ """Output structured validation error JSON and exit with code 1.
621
+
622
+ Args:
623
+ errors: list of {"field": str, "reason": str} dicts.
624
+
625
+ Output format:
626
+ {"error": true, "validation_errors": [...]}
627
+
628
+ Exits with code 1 — does NOT write to S3.
629
+ """
630
+ output = {
631
+ "error": True,
632
+ "validation_errors": errors
633
+ }
634
+ print(json.dumps(output))
635
+ sys.exit(1)
636
+
637
+
638
+ # ── Partition Registration ────────────────────────────────────────────────────
639
+
640
+
641
+ def register_partition(bucket, model, instance, target,
642
+ glue_database='mlcc_ci', glue_table='benchmark_results',
643
+ glue_client=None, region='us-east-1'):
644
+ """Register a partition in the Glue catalog via BatchCreatePartition.
645
+
646
+ After writing Parquet to S3, this function ensures the partition is
647
+ registered in the Glue Data Catalog so the data is immediately
648
+ queryable via Athena. If the partition already exists, the error is
649
+ swallowed silently (idempotent behavior).
650
+
651
+ Uses model/instance/target partitioning scheme matching the S3 data layout.
652
+
653
+ Args:
654
+ bucket: S3 bucket name.
655
+ model: Model partition value (model name with / replaced by _, e.g., 'Qwen_Qwen3-0.6B').
656
+ instance: Instance partition value (e.g., 'ml.g5.xlarge').
657
+ target: Deployment target partition value (e.g., 'realtime-inference').
658
+ glue_database: Glue database name (default: mlcc_ci).
659
+ glue_table: Glue table name (default: benchmark_results).
660
+ glue_client: Optional pre-configured boto3 Glue client (for testing).
661
+ If None, a new client is created for the given region.
662
+ region: AWS region for the Glue client (default: us-east-1).
663
+
664
+ Returns:
665
+ dict with keys:
666
+ - registered (bool): True if partition was newly created
667
+ - already_exists (bool): True if partition already existed
668
+ - partition_values (list): [model, instance, target]
669
+ - location (str): S3 location for the partition
670
+ - error (str|None): Error message if registration failed for
671
+ a reason other than already-exists
672
+
673
+ Note:
674
+ Per the design doc error handling table, partition registration
675
+ failure is non-fatal — results are still readable via MSCK REPAIR TABLE.
676
+ The caller should log a warning on error, not crash.
677
+ """
678
+ import boto3
679
+
680
+ if glue_client is None:
681
+ glue_client = boto3.client('glue', region_name=region)
682
+
683
+ partition_values = [model, instance, target]
684
+ location = f's3://{bucket}/results/model={model}/instance={instance}/target={target}/'
685
+
686
+ # Get table StorageDescriptor to inherit columns/serde
687
+ try:
688
+ table_response = glue_client.get_table(
689
+ DatabaseName=glue_database,
690
+ Name=glue_table,
691
+ )
692
+ except Exception as e:
693
+ error_msg = str(e)
694
+ if 'EntityNotFoundException' in error_msg:
695
+ return {
696
+ 'registered': False,
697
+ 'already_exists': False,
698
+ 'partition_values': partition_values,
699
+ 'location': location,
700
+ 'error': f"Table {glue_database}.{glue_table} not found in Glue catalog",
701
+ }
702
+ return {
703
+ 'registered': False,
704
+ 'already_exists': False,
705
+ 'partition_values': partition_values,
706
+ 'location': location,
707
+ 'error': f"Failed to get table metadata: {error_msg}",
708
+ }
709
+
710
+ table_sd = table_response['Table']['StorageDescriptor']
711
+
712
+ # Build partition StorageDescriptor inheriting from table
713
+ partition_sd = {
714
+ 'Columns': table_sd['Columns'],
715
+ 'Location': location,
716
+ 'InputFormat': table_sd.get('InputFormat', 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'),
717
+ 'OutputFormat': table_sd.get('OutputFormat', 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'),
718
+ 'SerdeInfo': table_sd.get('SerdeInfo', {
719
+ 'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe',
720
+ 'Parameters': {'serialization.format': '1'},
721
+ }),
722
+ 'Compressed': table_sd.get('Compressed', True),
723
+ }
724
+
725
+ partition_input = {
726
+ 'Values': partition_values,
727
+ 'StorageDescriptor': partition_sd,
728
+ 'Parameters': {
729
+ 'classification': 'parquet',
730
+ 'parquet.compression': 'SNAPPY',
731
+ },
732
+ }
733
+
734
+ try:
735
+ response = glue_client.batch_create_partition(
736
+ DatabaseName=glue_database,
737
+ TableName=glue_table,
738
+ PartitionInputList=[partition_input],
739
+ )
740
+ except Exception as e:
741
+ # Handle AlreadyExistsException thrown as an API exception
742
+ if 'AlreadyExistsException' in str(e):
743
+ return {
744
+ 'registered': False,
745
+ 'already_exists': True,
746
+ 'partition_values': partition_values,
747
+ 'location': location,
748
+ 'error': None,
749
+ }
750
+ return {
751
+ 'registered': False,
752
+ 'already_exists': False,
753
+ 'partition_values': partition_values,
754
+ 'location': location,
755
+ 'error': f"Failed to register partition: {e}",
756
+ }
757
+
758
+ # Check for errors in the batch response
759
+ batch_errors = response.get('Errors', [])
760
+ if batch_errors:
761
+ error_detail = batch_errors[0].get('ErrorDetail', {})
762
+ error_code = error_detail.get('ErrorCode', '')
763
+
764
+ if error_code == 'AlreadyExistsException':
765
+ return {
766
+ 'registered': False,
767
+ 'already_exists': True,
768
+ 'partition_values': partition_values,
769
+ 'location': location,
770
+ 'error': None,
771
+ }
772
+ else:
773
+ error_msg = error_detail.get('ErrorMessage', 'unknown error')
774
+ return {
775
+ 'registered': False,
776
+ 'already_exists': False,
777
+ 'partition_values': partition_values,
778
+ 'location': location,
779
+ 'error': f"Partition registration failed: {error_code} — {error_msg}",
780
+ }
781
+
782
+ return {
783
+ 'registered': True,
784
+ 'already_exists': False,
785
+ 'partition_values': partition_values,
786
+ 'location': location,
787
+ 'error': None,
788
+ }
789
+
790
+
791
+ # ── Parquet Serialization ─────────────────────────────────────────────────────
792
+
793
+
794
+ def get_parquet_schema():
795
+ """Return the pyarrow schema matching the Athena DDL for benchmark_results.
796
+
797
+ All columns defined in the Athena DDL are included. Partition columns
798
+ (model, instance, target) are NOT included here — they are encoded in the
799
+ S3 path and handled by Glue/Athena partitioning.
800
+ """
801
+ import pyarrow as pa
802
+
803
+ return pa.schema([
804
+ # Identity
805
+ pa.field("project_name", pa.string()),
806
+
807
+ # Model + Serving Config (queryable columns)
808
+ pa.field("model_name", pa.string()),
809
+ pa.field("model_family", pa.string()),
810
+ pa.field("instance_type", pa.string()),
811
+ pa.field("deployment_config", pa.string()),
812
+ pa.field("deployment_target", pa.string()),
813
+ pa.field("quantization", pa.string()),
814
+ pa.field("tensor_parallel_degree", pa.int32()),
815
+
816
+ # Full serving config (extensible JSON blob)
817
+ pa.field("serving_config", pa.string()),
818
+
819
+ # Workload
820
+ pa.field("workload", pa.string()),
821
+ pa.field("concurrency", pa.int32()),
822
+ pa.field("input_tokens_mean", pa.int32()),
823
+ pa.field("output_tokens_mean", pa.int32()),
824
+ pa.field("streaming", pa.bool_()),
825
+ pa.field("duration_seconds", pa.int32()),
826
+
827
+ # Rich Metrics
828
+ pa.field("request_throughput_rps", pa.float64()),
829
+ pa.field("total_token_throughput_tps", pa.float64()),
830
+ pa.field("output_token_throughput_tps", pa.float64()),
831
+ pa.field("request_count", pa.float64()),
832
+ pa.field("ttft_avg_ms", pa.float64()),
833
+ pa.field("ttft_p50_ms", pa.float64()),
834
+ pa.field("ttft_p90_ms", pa.float64()),
835
+ pa.field("ttft_p99_ms", pa.float64()),
836
+ pa.field("itl_avg_ms", pa.float64()),
837
+ pa.field("itl_p50_ms", pa.float64()),
838
+ pa.field("itl_p90_ms", pa.float64()),
839
+ pa.field("itl_p99_ms", pa.float64()),
840
+ pa.field("e2e_latency_avg_ms", pa.float64()),
841
+ pa.field("e2e_latency_p50_ms", pa.float64()),
842
+ pa.field("e2e_latency_p90_ms", pa.float64()),
843
+ pa.field("e2e_latency_p99_ms", pa.float64()),
844
+ pa.field("prefill_tps_avg", pa.float64()),
845
+ pa.field("prefill_tps_p50", pa.float64()),
846
+ pa.field("output_token_tps_avg", pa.float64()),
847
+ pa.field("output_token_tps_p50", pa.float64()),
848
+ pa.field("output_token_tps_p90", pa.float64()),
849
+ pa.field("ttst_p50_ms", pa.float64()),
850
+ pa.field("ttst_p90_ms", pa.float64()),
851
+ pa.field("output_sequence_length_avg", pa.float64()),
852
+ pa.field("input_sequence_length_avg", pa.float64()),
853
+ pa.field("error_rate", pa.float64()),
854
+ pa.field("benchmark_duration_sec", pa.float64()),
855
+
856
+ # Run Metadata
857
+ pa.field("run_type", pa.string()),
858
+ pa.field("benchmark_job_name", pa.string()),
859
+ pa.field("mcc_version", pa.string()),
860
+ pa.field("run_timestamp", pa.string()),
861
+ pa.field("region", pa.string()),
862
+ ])
863
+
864
+
865
+ def _records_to_parquet_table(records):
866
+ """Convert a list of enriched record dicts to a pyarrow Table.
867
+
868
+ Args:
869
+ records: List of dicts from enrich_records(). Each dict has string keys
870
+ matching the Athena DDL column names.
871
+
872
+ Returns:
873
+ pyarrow.Table with the correct schema and Snappy-compatible types.
874
+ """
875
+ import pyarrow as pa
876
+ from datetime import datetime as dt
877
+
878
+ schema = get_parquet_schema()
879
+
880
+ # Build column arrays from the record dicts
881
+ arrays = []
882
+ for field in schema:
883
+ col_name = field.name
884
+ values = []
885
+ for record in records:
886
+ val = record.get(col_name)
887
+
888
+ # Handle run_timestamp: ensure it's a string (schema is pa.string())
889
+ if col_name == 'run_timestamp' and isinstance(val, dt):
890
+ val = val.isoformat()
891
+ elif col_name == 'run_timestamp' and val is None:
892
+ val = None
893
+
894
+ values.append(val)
895
+
896
+ arrays.append(pa.array(values, type=field.type))
897
+
898
+ return pa.table(arrays, schema=schema)
899
+
900
+
901
+ def _upload_to_s3(local_path, bucket, s3_uri, region):
902
+ """Upload a local file to S3.
903
+
904
+ Args:
905
+ local_path: Path to the local Parquet file.
906
+ bucket: S3 bucket name.
907
+ s3_uri: Full S3 URI (s3://bucket/key).
908
+ region: AWS region for the S3 client.
909
+ """
910
+ import boto3
911
+
912
+ # Extract key from s3_uri
913
+ # s3://bucket/key → key
914
+ s3_key = s3_uri.replace(f's3://{bucket}/', '', 1)
915
+
916
+ s3_client = boto3.client('s3', region_name=region)
917
+ s3_client.upload_file(local_path, bucket, s3_key)
918
+
919
+
920
+
921
+
922
+ def _parse_jsonl_to_metrics(jsonl_path, concurrency=None):
923
+ """Parse profile_export.jsonl and aggregate into metrics format.
924
+
925
+ The JSONL file contains one JSON object per request with:
926
+ - metadata: {session_num, request_start_ns, request_end_ns, ...}
927
+ - metrics: {request_latency: {value, unit}, time_to_first_token: {value, unit}, ...}
928
+
929
+ Returns a dict compatible with the existing validation/enrichment pipeline:
930
+ {
931
+ "metrics": [{concurrency, request_throughput, time_to_first_token: {avg,p50,p90,p99}, ...}]
932
+ }
933
+ """
934
+ import math
935
+
936
+ def _percentile(sorted_vals, pct):
937
+ if not sorted_vals:
938
+ return 0.0
939
+ idx = (pct / 100.0) * (len(sorted_vals) - 1)
940
+ lower = int(math.floor(idx))
941
+ upper = int(math.ceil(idx))
942
+ if lower == upper:
943
+ return sorted_vals[lower]
944
+ frac = idx - lower
945
+ return sorted_vals[lower] * (1 - frac) + sorted_vals[upper] * frac
946
+
947
+ def _get_val(metrics_dict, key):
948
+ """Extract scalar value from a metric dict like {value: X, unit: "ms"}."""
949
+ m = metrics_dict.get(key)
950
+ if isinstance(m, dict):
951
+ return m.get('value')
952
+ return m
953
+
954
+ records = []
955
+ try:
956
+ with open(jsonl_path, 'r') as f:
957
+ for line in f:
958
+ line = line.strip()
959
+ if line:
960
+ try:
961
+ records.append(json.loads(line))
962
+ except json.JSONDecodeError:
963
+ continue
964
+ except (FileNotFoundError, IOError) as e:
965
+ return {"error": str(e)}
966
+
967
+ if not records:
968
+ return {"metrics": []}
969
+
970
+ # Collect per-request metrics
971
+ latencies = []
972
+ ttfts = []
973
+ itls = []
974
+ ttsts = []
975
+ output_tokens = []
976
+ input_tokens = []
977
+ prefill_tps = []
978
+ output_tps = []
979
+ start_times = []
980
+ end_times = []
981
+
982
+ for rec in records:
983
+ meta = rec.get('metadata', {})
984
+ metrics = rec.get('metrics', {})
985
+
986
+ lat = _get_val(metrics, 'request_latency')
987
+ if lat is not None:
988
+ latencies.append(lat)
989
+
990
+ ttft = _get_val(metrics, 'time_to_first_token')
991
+ if ttft is None:
992
+ ttft = _get_val(metrics, 'time_to_first_output_token')
993
+ if ttft is not None:
994
+ ttfts.append(ttft)
995
+
996
+ itl = _get_val(metrics, 'inter_token_latency')
997
+ if itl is not None:
998
+ itls.append(itl)
999
+
1000
+ ttst = _get_val(metrics, 'time_to_second_token')
1001
+ if ttst is not None:
1002
+ ttsts.append(ttst)
1003
+
1004
+ otc = _get_val(metrics, 'output_token_count')
1005
+ if otc is not None:
1006
+ output_tokens.append(otc)
1007
+
1008
+ isl = _get_val(metrics, 'input_sequence_length')
1009
+ if isl is not None:
1010
+ input_tokens.append(isl)
1011
+
1012
+ ptps = _get_val(metrics, 'prefill_throughput_per_user')
1013
+ if ptps is not None:
1014
+ prefill_tps.append(ptps)
1015
+
1016
+ otps = _get_val(metrics, 'output_token_throughput_per_user')
1017
+ if otps is not None:
1018
+ output_tps.append(otps)
1019
+
1020
+ rs = meta.get('request_start_ns')
1021
+ re_ = meta.get('request_end_ns')
1022
+ if rs is not None:
1023
+ start_times.append(rs)
1024
+ if re_ is not None:
1025
+ end_times.append(re_)
1026
+
1027
+ # Sort for percentiles
1028
+ latencies.sort()
1029
+ ttfts.sort()
1030
+ itls.sort()
1031
+ ttsts.sort()
1032
+ prefill_tps.sort()
1033
+ output_tps.sort()
1034
+
1035
+ # Compute system throughput
1036
+ if start_times and end_times:
1037
+ duration_ns = max(end_times) - min(start_times)
1038
+ duration_s = duration_ns / 1e9 if duration_ns > 0 else 1.0
1039
+ else:
1040
+ duration_s = 1.0
1041
+ duration_s = max(duration_s, 0.001)
1042
+
1043
+ n = len(records)
1044
+ req_throughput = n / duration_s
1045
+ total_out_tokens = sum(output_tokens) if output_tokens else 0
1046
+ token_throughput = total_out_tokens / duration_s
1047
+
1048
+ # Determine concurrency (from arg or infer from max concurrent)
1049
+ conc = concurrency if concurrency is not None else n
1050
+
1051
+ # Build metrics entry matching the schema expected by enrich_records
1052
+ entry = {
1053
+ 'concurrency': conc,
1054
+ 'request_throughput': req_throughput,
1055
+ 'output_token_throughput': token_throughput,
1056
+ 'total_token_throughput': (total_out_tokens + sum(input_tokens)) / duration_s if input_tokens else token_throughput,
1057
+ 'total_requests': n,
1058
+ 'request_count': n,
1059
+ 'duration_seconds': duration_s,
1060
+ 'time_to_first_token': {
1061
+ 'avg': sum(ttfts) / len(ttfts) if ttfts else 0.0,
1062
+ 'p50': _percentile(ttfts, 50),
1063
+ 'p90': _percentile(ttfts, 90),
1064
+ 'p99': _percentile(ttfts, 99),
1065
+ },
1066
+ 'inter_token_latency': {
1067
+ 'avg': sum(itls) / len(itls) if itls else 0.0,
1068
+ 'p50': _percentile(itls, 50),
1069
+ 'p90': _percentile(itls, 90),
1070
+ 'p99': _percentile(itls, 99),
1071
+ },
1072
+ 'e2e_latency': {
1073
+ 'avg': sum(latencies) / len(latencies) if latencies else 0.0,
1074
+ 'p50': _percentile(latencies, 50),
1075
+ 'p90': _percentile(latencies, 90),
1076
+ 'p99': _percentile(latencies, 99),
1077
+ },
1078
+ 'request_latency': {
1079
+ 'avg': sum(latencies) / len(latencies) if latencies else 0.0,
1080
+ 'p50': _percentile(latencies, 50),
1081
+ 'p90': _percentile(latencies, 90),
1082
+ 'p99': _percentile(latencies, 99),
1083
+ },
1084
+ 'time_to_second_token': {
1085
+ 'avg': sum(ttsts) / len(ttsts) if ttsts else 0.0,
1086
+ 'p50': _percentile(ttsts, 50),
1087
+ 'p90': _percentile(ttsts, 90),
1088
+ },
1089
+ 'prefill_throughput': {
1090
+ 'avg': sum(prefill_tps) / len(prefill_tps) if prefill_tps else 0.0,
1091
+ 'p50': _percentile(prefill_tps, 50),
1092
+ },
1093
+ 'output_token_throughput_detail': {
1094
+ 'avg': sum(output_tps) / len(output_tps) if output_tps else 0.0,
1095
+ 'p50': _percentile(output_tps, 50),
1096
+ 'p90': _percentile(output_tps, 90),
1097
+ },
1098
+ 'output_sequence_length': sum(output_tokens) / len(output_tokens) if output_tokens else 0.0,
1099
+ 'input_sequence_length': sum(input_tokens) / len(input_tokens) if input_tokens else 0.0,
1100
+ 'input_tokens_mean': int(sum(input_tokens) / len(input_tokens)) if input_tokens else 0,
1101
+ 'output_tokens_mean': int(sum(output_tokens) / len(output_tokens)) if output_tokens else 0,
1102
+ }
1103
+
1104
+ return {"metrics": [entry]}
1105
+
1106
+
1107
+ # ── Command: write ────────────────────────────────────────────────────────────
1108
+
1109
+
1110
+ def cmd_write(args):
1111
+ """Validate, enrich, and write benchmark results to S3 as Parquet.
1112
+
1113
+ Validation occurs before any S3 interaction. If validation fails,
1114
+ a structured error is emitted and no write occurs.
1115
+ """
1116
+ # Load benchmark results (JSON or JSONL)
1117
+ results_path = args.results_file or args.input
1118
+ if not results_path:
1119
+ _error_exit("--results-file (or --input) is required")
1120
+
1121
+ if results_path.endswith('.jsonl'):
1122
+ # Parse JSONL (per-request data) and aggregate into metrics format
1123
+ benchmark_data = _parse_jsonl_to_metrics(results_path, concurrency=getattr(args, 'concurrency', None))
1124
+ if 'error' in benchmark_data:
1125
+ _error_exit(f"Failed to parse JSONL: {benchmark_data['error']}")
1126
+ else:
1127
+ try:
1128
+ with open(results_path, 'r') as f:
1129
+ benchmark_data = json.load(f)
1130
+ except FileNotFoundError:
1131
+ _error_exit(f"Results file not found: {results_path}")
1132
+ except json.JSONDecodeError as e:
1133
+ _error_exit(f"Invalid JSON in results file: {e}")
1134
+ except Exception as e:
1135
+ _error_exit(f"Failed to read results file: {e}")
1136
+
1137
+ # Build the combined input data for validation
1138
+ # Merge CLI-provided fields with the benchmark results
1139
+ input_data = {}
1140
+
1141
+ # Fields from config file (if provided)
1142
+ if args.config_file:
1143
+ try:
1144
+ config_context = _load_config_file(args.config_file)
1145
+ input_data.update(config_context)
1146
+ except Exception as e:
1147
+ _error_exit(f"Failed to read config file: {e}")
1148
+
1149
+ # Fields from the benchmark results file
1150
+ if isinstance(benchmark_data, dict):
1151
+ metrics = benchmark_data.get('metrics')
1152
+ if metrics is not None:
1153
+ input_data['metrics'] = metrics
1154
+ else:
1155
+ # Single-level benchmark: raw results at top level without a 'metrics' wrapper.
1156
+ # Wrap into the expected array format for validation and enrichment.
1157
+ # Detect by presence of known metric fields (request_throughput, output_token_throughput, etc.)
1158
+ metric_indicators = ['request_throughput', 'output_token_throughput', 'time_to_first_token',
1159
+ 'inter_token_latency', 'request_latency', 'concurrency']
1160
+ if any(k in benchmark_data for k in metric_indicators):
1161
+ # Use BENCHMARK_CONCURRENCY from config if concurrency not in the results
1162
+ if 'concurrency' not in benchmark_data:
1163
+ benchmark_data['concurrency'] = int(input_data.get('benchmark_concurrency', 10))
1164
+ input_data['metrics'] = [benchmark_data]
1165
+ # Also pull any config fields from the results file
1166
+ for field in ['model_name', 'instance_type', 'deployment_config', 'project_name', 'region']:
1167
+ if field in benchmark_data and field not in input_data:
1168
+ input_data[field] = benchmark_data[field]
1169
+ elif isinstance(benchmark_data, list):
1170
+ # If the results file is just a raw metrics array
1171
+ input_data['metrics'] = benchmark_data
1172
+
1173
+ # CLI args override config file and results file values
1174
+ if args.project_name:
1175
+ input_data['project_name'] = args.project_name
1176
+ if args.workload:
1177
+ input_data['workload'] = args.workload
1178
+ if args.region:
1179
+ input_data['region'] = args.region
1180
+
1181
+ # ── Validate before any S3 interaction ────────────────────────────────
1182
+ errors = validate_benchmark_input(input_data)
1183
+ if errors:
1184
+ emit_validation_error(errors)
1185
+ return # Never reached, but explicit
1186
+
1187
+ # ── Dry-run mode: output enriched records as JSON, skip S3 ──────────────
1188
+ if args.dry_run:
1189
+ timestamp = datetime.now(timezone.utc)
1190
+
1191
+ # Split input_data back into config and results for enrich_records
1192
+ config_context = {k: v for k, v in input_data.items() if k != 'metrics'}
1193
+ results_obj = {'metrics': input_data['metrics']}
1194
+ if isinstance(benchmark_data, dict) and 'job_name' in benchmark_data:
1195
+ results_obj['job_name'] = benchmark_data['job_name']
1196
+
1197
+ enriched_records = enrich_records(config_context, results_obj, timestamp)
1198
+
1199
+ # Compute intended S3 path (use bucket if provided, else placeholder)
1200
+ bucket = args.bucket or f'mlcc-benchmark-results-<accountId>-{input_data["region"]}'
1201
+ s3_path = compute_s3_path(bucket, input_data.get('project_name', ''), input_data.get('model_name', ''), input_data.get('instance_type', ''), input_data.get('deployment_target', 'realtime-inference'), timestamp)
1202
+ partition = compute_partition_info(input_data.get('model_name', ''), input_data.get('instance_type', ''), input_data.get('deployment_target', 'realtime-inference'))
1203
+
1204
+ _output({
1205
+ "dry_run": True,
1206
+ "s3_path": s3_path,
1207
+ "partition": partition,
1208
+ "record_count": len(enriched_records),
1209
+ "records": enriched_records,
1210
+ })
1211
+ return # Never reached after _output
1212
+
1213
+ # ── Write to S3 (requires bucket) ─────────────────────────────────────
1214
+ if not args.bucket:
1215
+ _error_exit("--bucket is required when not using --dry-run")
1216
+
1217
+ region = input_data.get('region', os.environ.get('AWS_REGION', ''))
1218
+ timestamp = datetime.now(timezone.utc)
1219
+
1220
+ # Split input_data back into config and results for enrich_records
1221
+ config_context = {k: v for k, v in input_data.items() if k != 'metrics'}
1222
+ results_obj = {'metrics': input_data['metrics']}
1223
+ if isinstance(benchmark_data, dict) and 'job_name' in benchmark_data:
1224
+ results_obj['job_name'] = benchmark_data['job_name']
1225
+
1226
+ enriched_records = enrich_records(config_context, results_obj, timestamp)
1227
+
1228
+ if not enriched_records:
1229
+ _error_exit("No records produced from benchmark metrics")
1230
+
1231
+ # Compute S3 path
1232
+ s3_info = build_s3_path(args.bucket, input_data.get('project_name', ''), input_data.get('model_name', ''), input_data.get('instance_type', ''), input_data.get('deployment_target', 'realtime-inference'), timestamp, region=region)
1233
+
1234
+ # Write Parquet to a temp file then upload to S3
1235
+ try:
1236
+ import tempfile
1237
+ import pyarrow as pa
1238
+ import pyarrow.parquet as pq
1239
+ except ImportError as e:
1240
+ _error_exit(f"Missing dependency: {e}. Install: pip install pyarrow")
1241
+
1242
+ # Build pyarrow table from enriched records
1243
+ table = _records_to_parquet_table(enriched_records)
1244
+
1245
+ # Write to temp file with Snappy compression
1246
+ tmp_path = None
1247
+ try:
1248
+ with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as tmp:
1249
+ tmp_path = tmp.name
1250
+
1251
+ pq.write_table(table, tmp_path, compression='snappy')
1252
+
1253
+ # Upload to S3
1254
+ _upload_to_s3(tmp_path, args.bucket, s3_info['s3_uri'], region)
1255
+
1256
+ except Exception as e:
1257
+ _error_exit(f"Failed to write Parquet to S3: {e}")
1258
+ finally:
1259
+ # Clean up temp file
1260
+ if tmp_path and os.path.exists(tmp_path):
1261
+ os.unlink(tmp_path)
1262
+
1263
+ # Register partition in Glue catalog to make data immediately queryable.
1264
+ # This is best-effort — failure is non-fatal per design doc error handling.
1265
+ # Data remains readable via MSCK REPAIR TABLE as a fallback.
1266
+ partition_result = None
1267
+ try:
1268
+ partition_result = register_partition(
1269
+ bucket=args.bucket,
1270
+ model=s3_info['partition_model'],
1271
+ instance=s3_info['partition_instance'],
1272
+ target=s3_info['partition_target'],
1273
+ region=region,
1274
+ )
1275
+ except SystemExit:
1276
+ # register_partition calls _error_exit on some failures; catch to avoid
1277
+ # terminating the process — the Parquet write already succeeded.
1278
+ partition_result = {"registered": False, "error": "partition registration failed (non-fatal)"}
1279
+ except Exception as e:
1280
+ partition_result = {"registered": False, "error": str(e)}
1281
+
1282
+ if partition_result and partition_result.get('error'):
1283
+ print(
1284
+ f"\u26a0\ufe0f Partition registration warning: {partition_result['error']}",
1285
+ file=sys.stderr,
1286
+ )
1287
+
1288
+ _output({
1289
+ "success": True,
1290
+ "s3_uri": s3_info['s3_uri'],
1291
+ "partition": {
1292
+ "model": s3_info['partition_model'],
1293
+ "instance": s3_info['partition_instance'],
1294
+ "target": s3_info['partition_target'],
1295
+ },
1296
+ "rows_written": len(enriched_records),
1297
+ "project_name": input_data.get('project_name', ''),
1298
+ "run_timestamp": timestamp.isoformat(),
1299
+ "partition_registration": partition_result,
1300
+ })
1301
+
1302
+
1303
+ def _load_config_file(config_path):
1304
+ """Load configuration context from a do/config shell file or JSON file.
1305
+
1306
+ Supports two formats:
1307
+ - JSON file: parsed directly
1308
+ - Shell config file: extracts export VAR="value" assignments
1309
+
1310
+ Returns:
1311
+ dict with recognized config fields.
1312
+ """
1313
+ context = {}
1314
+
1315
+ try:
1316
+ # Try JSON first
1317
+ with open(config_path, 'r') as f:
1318
+ content = f.read().strip()
1319
+
1320
+ if content.startswith('{'):
1321
+ data = json.loads(content)
1322
+ # Map known JSON fields to our expected names
1323
+ field_map = {
1324
+ 'project_name': 'project_name',
1325
+ 'projectName': 'project_name',
1326
+ 'model_name': 'model_name',
1327
+ 'modelName': 'model_name',
1328
+ 'MODEL_NAME': 'model_name',
1329
+ 'instance_type': 'instance_type',
1330
+ 'instanceType': 'instance_type',
1331
+ 'INSTANCE_TYPE': 'instance_type',
1332
+ 'deployment_config': 'deployment_config',
1333
+ 'deploymentConfig': 'deployment_config',
1334
+ 'DEPLOYMENT_CONFIG': 'deployment_config',
1335
+ 'region': 'region',
1336
+ 'REGION': 'region',
1337
+ 'deployment_target': 'deployment_target',
1338
+ 'deploymentTarget': 'deployment_target',
1339
+ 'tensor_parallel_degree': 'tensor_parallel_degree',
1340
+ 'tensorParallelDegree': 'tensor_parallel_degree',
1341
+ 'quantization': 'quantization',
1342
+ 'enable_lora': 'enable_lora',
1343
+ 'enableLora': 'enable_lora',
1344
+ 'base_image': 'base_image',
1345
+ 'baseImage': 'base_image',
1346
+ 'base_image_version': 'base_image_version',
1347
+ 'baseImageVersion': 'base_image_version',
1348
+ 'mcc_version': 'mcc_version',
1349
+ 'mccVersion': 'mcc_version',
1350
+ 'account_id': 'account_id',
1351
+ 'accountId': 'account_id',
1352
+ }
1353
+ for source_key, target_key in field_map.items():
1354
+ if source_key in data and target_key not in context:
1355
+ val = data[source_key]
1356
+ # Keep non-string types for certain fields
1357
+ if target_key in ('tensor_parallel_degree',):
1358
+ context[target_key] = int(val) if val is not None else val
1359
+ elif target_key in ('enable_lora',):
1360
+ context[target_key] = bool(val)
1361
+ else:
1362
+ context[target_key] = str(val) if val is not None else val
1363
+ return context
1364
+
1365
+ # Parse shell-style config (export VAR="value" or VAR="value")
1366
+ for line in content.split('\n'):
1367
+ line = line.strip()
1368
+ if line.startswith('#') or not line:
1369
+ continue
1370
+ # Remove 'export ' prefix
1371
+ if line.startswith('export '):
1372
+ line = line[7:]
1373
+ # Parse VAR=value or VAR="value"
1374
+ if '=' in line:
1375
+ key, _, value = line.partition('=')
1376
+ key = key.strip()
1377
+ value = value.strip().strip('"').strip("'")
1378
+ # Handle shell default syntax: ${VAR:-default} → extract default
1379
+ if value.startswith('${') and ':-' in value:
1380
+ value = value.split(':-', 1)[1].rstrip('}')
1381
+ # Skip unresolved shell variables (e.g., ${INSTANCE_TYPE})
1382
+ if value.startswith('${') or value.startswith('$('):
1383
+ continue
1384
+ # Map shell var names to our field names
1385
+ shell_map = {
1386
+ 'PROJECT_NAME': 'project_name',
1387
+ 'MODEL_NAME': 'model_name',
1388
+ 'INSTANCE_TYPE': 'instance_type',
1389
+ 'DEPLOYMENT_CONFIG': 'deployment_config',
1390
+ 'DEPLOYMENT_TARGET': 'deployment_target',
1391
+ 'AWS_REGION': 'region',
1392
+ 'REGION': 'region',
1393
+ 'ACCOUNT_ID': 'account_id',
1394
+ 'MCC_VERSION': 'mcc_version',
1395
+ 'BASE_IMAGE': 'base_image',
1396
+ 'BASE_IMAGE_VERSION': 'base_image_version',
1397
+ 'BENCHMARK_CONCURRENCY': 'benchmark_concurrency',
1398
+ }
1399
+ if key in shell_map:
1400
+ context[shell_map[key]] = value
1401
+
1402
+ except Exception:
1403
+ pass
1404
+
1405
+ return context
1406
+
1407
+
1408
+ # ── CLI entry point ───────────────────────────────────────────────────────────
1409
+
1410
+
1411
+ def main():
1412
+ """Parse CLI args and dispatch to subcommand."""
1413
+ parser = argparse.ArgumentParser(
1414
+ description='Benchmark Writer — Convert benchmark results to Athena-compatible Parquet'
1415
+ )
1416
+ subparsers = parser.add_subparsers(dest='command', help='Available commands')
1417
+
1418
+ # write subcommand
1419
+ write_parser = subparsers.add_parser('write', help='Write benchmark results to S3')
1420
+ write_parser.add_argument(
1421
+ '--input',
1422
+ help='Path to benchmark results JSON file (alias for --results-file)'
1423
+ )
1424
+ write_parser.add_argument(
1425
+ '--results-file', dest='results_file',
1426
+ help='Path to benchmark results JSON file'
1427
+ )
1428
+ write_parser.add_argument(
1429
+ '--config-file', dest='config_file',
1430
+ help='Path to config file (do/config or JSON) for context fields'
1431
+ )
1432
+ write_parser.add_argument(
1433
+ '--project-name', dest='project_name',
1434
+ help='MCC project name (human-readable identifier)'
1435
+ )
1436
+ write_parser.add_argument(
1437
+ '--workload', default='manual',
1438
+ help='Named workload profile (from workload-picker MCP, default: manual)'
1439
+ )
1440
+ write_parser.add_argument(
1441
+ '--concurrency', type=int, default=None,
1442
+ help='Concurrency level (passed to JSONL aggregation if results are per-request)'
1443
+ )
1444
+ write_parser.add_argument(
1445
+ '--bucket',
1446
+ help='S3 bucket name for results (required unless --dry-run)'
1447
+ )
1448
+ write_parser.add_argument(
1449
+ '--region',
1450
+ help='AWS region'
1451
+ )
1452
+ write_parser.add_argument(
1453
+ '--dry-run', dest='dry_run', action='store_true',
1454
+ help='Output enriched records as JSON without writing to S3'
1455
+ )
1456
+
1457
+ args = parser.parse_args()
1458
+
1459
+ if not args.command:
1460
+ parser.print_help()
1461
+ sys.exit(1)
1462
+
1463
+ if args.command == 'write':
1464
+ cmd_write(args)
1465
+
1466
+
1467
+ if __name__ == '__main__':
1468
+ try:
1469
+ main()
1470
+ except SystemExit:
1471
+ raise
1472
+ except Exception as e:
1473
+ # Catch all unexpected exceptions and emit structured error
1474
+ # This ensures we NEVER produce a raw traceback
1475
+ print(json.dumps({"error": f"unexpected error: {e}"}))
1476
+ sys.exit(1)