mlx-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_stack/__init__.py +5 -0
- mlx_stack/_version.py +24 -0
- mlx_stack/cli/__init__.py +5 -0
- mlx_stack/cli/bench.py +221 -0
- mlx_stack/cli/config.py +166 -0
- mlx_stack/cli/down.py +109 -0
- mlx_stack/cli/init.py +180 -0
- mlx_stack/cli/install.py +165 -0
- mlx_stack/cli/logs.py +234 -0
- mlx_stack/cli/main.py +187 -0
- mlx_stack/cli/models.py +304 -0
- mlx_stack/cli/profile.py +65 -0
- mlx_stack/cli/pull.py +134 -0
- mlx_stack/cli/recommend.py +397 -0
- mlx_stack/cli/status.py +111 -0
- mlx_stack/cli/up.py +163 -0
- mlx_stack/cli/watch.py +252 -0
- mlx_stack/core/__init__.py +1 -0
- mlx_stack/core/benchmark.py +1182 -0
- mlx_stack/core/catalog.py +560 -0
- mlx_stack/core/config.py +471 -0
- mlx_stack/core/deps.py +323 -0
- mlx_stack/core/hardware.py +304 -0
- mlx_stack/core/launchd.py +531 -0
- mlx_stack/core/litellm_gen.py +188 -0
- mlx_stack/core/log_rotation.py +231 -0
- mlx_stack/core/log_viewer.py +386 -0
- mlx_stack/core/models.py +639 -0
- mlx_stack/core/paths.py +79 -0
- mlx_stack/core/process.py +887 -0
- mlx_stack/core/pull.py +815 -0
- mlx_stack/core/scoring.py +611 -0
- mlx_stack/core/stack_down.py +317 -0
- mlx_stack/core/stack_init.py +524 -0
- mlx_stack/core/stack_status.py +229 -0
- mlx_stack/core/stack_up.py +856 -0
- mlx_stack/core/watchdog.py +744 -0
- mlx_stack/data/__init__.py +1 -0
- mlx_stack/data/catalog/__init__.py +1 -0
- mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
- mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
- mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
- mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
- mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
- mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
- mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
- mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
- mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
- mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
- mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
- mlx_stack/py.typed +1 -0
- mlx_stack/utils/__init__.py +1 -0
- mlx_stack-0.1.0.dist-info/METADATA +397 -0
- mlx_stack-0.1.0.dist-info/RECORD +61 -0
- mlx_stack-0.1.0.dist-info/WHEEL +4 -0
- mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
- mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1182 @@
|
|
|
1
|
+
"""Benchmark execution engine for mlx-stack.
|
|
2
|
+
|
|
3
|
+
Runs 3 iterations of 1024-token prompt + 100-token generation against
|
|
4
|
+
a vllm-mlx instance, reports mean ± std dev for prompt_tps and gen_tps.
|
|
5
|
+
Compares against catalog thresholds: PASS (<15%), WARN (15-30%),
|
|
6
|
+
FAIL (>30%). Handles tool-calling benchmarks for capable models.
|
|
7
|
+
|
|
8
|
+
Supports benchmarking running tier instances and local models via
|
|
9
|
+
temporary vllm-mlx instances with full cleanup (including Ctrl+C).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import math
|
|
16
|
+
import os
|
|
17
|
+
import shutil
|
|
18
|
+
import signal
|
|
19
|
+
import time
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import httpx
|
|
25
|
+
|
|
26
|
+
from mlx_stack.core.catalog import (
|
|
27
|
+
CatalogEntry,
|
|
28
|
+
get_entry_by_id,
|
|
29
|
+
load_catalog,
|
|
30
|
+
)
|
|
31
|
+
from mlx_stack.core.config import ConfigCorruptError, get_value
|
|
32
|
+
from mlx_stack.core.deps import ensure_dependency
|
|
33
|
+
from mlx_stack.core.hardware import HardwareProfile, detect_hardware, load_profile
|
|
34
|
+
from mlx_stack.core.paths import (
|
|
35
|
+
ensure_data_home,
|
|
36
|
+
get_benchmarks_dir,
|
|
37
|
+
get_data_home,
|
|
38
|
+
get_stacks_dir,
|
|
39
|
+
)
|
|
40
|
+
from mlx_stack.core.process import (
|
|
41
|
+
HealthCheckError,
|
|
42
|
+
ProcessError,
|
|
43
|
+
check_port_conflict,
|
|
44
|
+
is_process_alive,
|
|
45
|
+
read_pid_file,
|
|
46
|
+
remove_pid_file,
|
|
47
|
+
start_service,
|
|
48
|
+
stop_service,
|
|
49
|
+
wait_for_healthy,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# --------------------------------------------------------------------------- #
|
|
53
|
+
# Constants
|
|
54
|
+
# --------------------------------------------------------------------------- #
|
|
55
|
+
|
|
56
|
+
# Benchmark parameters
|
|
57
|
+
NUM_ITERATIONS = 3
|
|
58
|
+
PROMPT_TOKEN_COUNT = 1024
|
|
59
|
+
MAX_GENERATION_TOKENS = 100
|
|
60
|
+
|
|
61
|
+
# Health check timeout for temporary instances (seconds)
|
|
62
|
+
TEMP_INSTANCE_HEALTH_TIMEOUT = 120.0
|
|
63
|
+
|
|
64
|
+
# Port range for temporary instances (avoid 4000, 8000-8002)
|
|
65
|
+
TEMP_PORT_START = 8100
|
|
66
|
+
TEMP_PORT_END = 8200
|
|
67
|
+
|
|
68
|
+
# Threshold percentages for comparison
|
|
69
|
+
PASS_THRESHOLD = 0.15 # within 15%
|
|
70
|
+
WARN_THRESHOLD = 0.30 # within 30%
|
|
71
|
+
|
|
72
|
+
# Classification labels
|
|
73
|
+
CLASSIFICATION_PASS = "PASS"
|
|
74
|
+
CLASSIFICATION_WARN = "WARN"
|
|
75
|
+
CLASSIFICATION_FAIL = "FAIL"
|
|
76
|
+
|
|
77
|
+
# Temporary service name prefix
|
|
78
|
+
TEMP_SERVICE_PREFIX = "bench-temp"
|
|
79
|
+
|
|
80
|
+
# Tool-calling benchmark tool definition
|
|
81
|
+
TOOL_DEFINITION = {
|
|
82
|
+
"type": "function",
|
|
83
|
+
"function": {
|
|
84
|
+
"name": "get_weather",
|
|
85
|
+
"description": "Get the current weather for a given location.",
|
|
86
|
+
"parameters": {
|
|
87
|
+
"type": "object",
|
|
88
|
+
"properties": {
|
|
89
|
+
"location": {
|
|
90
|
+
"type": "string",
|
|
91
|
+
"description": "The city and state, e.g. San Francisco, CA",
|
|
92
|
+
},
|
|
93
|
+
"unit": {
|
|
94
|
+
"type": "string",
|
|
95
|
+
"enum": ["celsius", "fahrenheit"],
|
|
96
|
+
"description": "The temperature unit to use.",
|
|
97
|
+
},
|
|
98
|
+
},
|
|
99
|
+
"required": ["location"],
|
|
100
|
+
},
|
|
101
|
+
},
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# --------------------------------------------------------------------------- #
|
|
106
|
+
# Exceptions
|
|
107
|
+
# --------------------------------------------------------------------------- #
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class BenchmarkError(Exception):
|
|
111
|
+
"""Raised when a benchmark operation fails."""
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BenchmarkTargetError(BenchmarkError):
|
|
115
|
+
"""Raised when the benchmark target is not found."""
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class BenchmarkRunError(BenchmarkError):
|
|
119
|
+
"""Raised when a benchmark iteration fails."""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# --------------------------------------------------------------------------- #
|
|
123
|
+
# Data classes
|
|
124
|
+
# --------------------------------------------------------------------------- #
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass(frozen=True)
|
|
128
|
+
class IterationResult:
|
|
129
|
+
"""Result of a single benchmark iteration."""
|
|
130
|
+
|
|
131
|
+
prompt_tps: float
|
|
132
|
+
gen_tps: float
|
|
133
|
+
prompt_tokens: int
|
|
134
|
+
completion_tokens: int
|
|
135
|
+
total_time: float
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass(frozen=True)
|
|
139
|
+
class MetricClassification:
|
|
140
|
+
"""Classification for a single metric against catalog threshold."""
|
|
141
|
+
|
|
142
|
+
metric: str
|
|
143
|
+
measured: float
|
|
144
|
+
catalog: float
|
|
145
|
+
delta_pct: float
|
|
146
|
+
classification: str # PASS, WARN, FAIL
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass(frozen=True)
|
|
150
|
+
class ToolCallResult:
|
|
151
|
+
"""Result of a tool-calling benchmark."""
|
|
152
|
+
|
|
153
|
+
success: bool
|
|
154
|
+
round_trip_time: float
|
|
155
|
+
error: str | None = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclass
|
|
159
|
+
class BenchmarkResult_:
|
|
160
|
+
"""Complete benchmark result for a model."""
|
|
161
|
+
|
|
162
|
+
model_id: str
|
|
163
|
+
quant: str
|
|
164
|
+
iterations: list[IterationResult] = field(default_factory=list)
|
|
165
|
+
prompt_tps_mean: float = 0.0
|
|
166
|
+
prompt_tps_std: float = 0.0
|
|
167
|
+
gen_tps_mean: float = 0.0
|
|
168
|
+
gen_tps_std: float = 0.0
|
|
169
|
+
classifications: list[MetricClassification] = field(default_factory=list)
|
|
170
|
+
tool_call_result: ToolCallResult | None = None
|
|
171
|
+
used_temporary_instance: bool = False
|
|
172
|
+
catalog_data_available: bool = False
|
|
173
|
+
|
|
174
|
+
def to_save_dict(self) -> dict[str, Any]:
|
|
175
|
+
"""Convert to dict for saving to benchmark JSON file."""
|
|
176
|
+
return {
|
|
177
|
+
"model_id": self.model_id,
|
|
178
|
+
"quant": self.quant,
|
|
179
|
+
"prompt_tps": self.prompt_tps_mean,
|
|
180
|
+
"gen_tps": self.gen_tps_mean,
|
|
181
|
+
"memory_gb": 0.0, # placeholder, updated if available
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# --------------------------------------------------------------------------- #
|
|
186
|
+
# Prompt generation
|
|
187
|
+
# --------------------------------------------------------------------------- #
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _generate_prompt(token_count: int) -> str:
|
|
191
|
+
"""Generate a prompt with approximately the specified token count.
|
|
192
|
+
|
|
193
|
+
Uses repeated text to approximate the desired token count. Each word
|
|
194
|
+
is roughly 1-2 tokens, so we generate enough words to cover the count.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
token_count: Target number of tokens.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
A text string approximately token_count tokens long.
|
|
201
|
+
"""
|
|
202
|
+
# Each word is roughly 1.3 tokens on average for English text.
|
|
203
|
+
# We use a repeating pattern to keep it simple and predictable.
|
|
204
|
+
base_phrase = (
|
|
205
|
+
"The quick brown fox jumps over the lazy dog near the river bank "
|
|
206
|
+
"where flowers bloom in the warm sunshine of a summer afternoon "
|
|
207
|
+
)
|
|
208
|
+
words_needed = int(token_count / 1.3) + 10
|
|
209
|
+
words = base_phrase.split()
|
|
210
|
+
repeated = []
|
|
211
|
+
for i in range(words_needed):
|
|
212
|
+
repeated.append(words[i % len(words)])
|
|
213
|
+
return " ".join(repeated)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# --------------------------------------------------------------------------- #
|
|
217
|
+
# Running tier resolution
|
|
218
|
+
# --------------------------------------------------------------------------- #
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _load_stack_tiers() -> list[dict[str, Any]]:
|
|
222
|
+
"""Load tier definitions from the active stack.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of tier dicts, or empty list if no stack is configured.
|
|
226
|
+
"""
|
|
227
|
+
stack_path = get_stacks_dir() / "default.yaml"
|
|
228
|
+
if not stack_path.exists():
|
|
229
|
+
return []
|
|
230
|
+
try:
|
|
231
|
+
import yaml
|
|
232
|
+
|
|
233
|
+
content = stack_path.read_text(encoding="utf-8")
|
|
234
|
+
stack = yaml.safe_load(content)
|
|
235
|
+
if isinstance(stack, dict):
|
|
236
|
+
tiers = stack.get("tiers", [])
|
|
237
|
+
if isinstance(tiers, list):
|
|
238
|
+
return tiers
|
|
239
|
+
except Exception:
|
|
240
|
+
pass
|
|
241
|
+
return []
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _find_running_tier(target: str) -> dict[str, Any] | None:
|
|
245
|
+
"""Find a running tier by name.
|
|
246
|
+
|
|
247
|
+
Checks if the tier's PID file exists and the process is alive.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
target: The tier name to look for.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
The tier dict if running, or None.
|
|
254
|
+
"""
|
|
255
|
+
tiers = _load_stack_tiers()
|
|
256
|
+
for tier in tiers:
|
|
257
|
+
tier_name = tier.get("name", "")
|
|
258
|
+
if tier_name == target:
|
|
259
|
+
# Check if running
|
|
260
|
+
try:
|
|
261
|
+
pid = read_pid_file(tier_name)
|
|
262
|
+
if pid is not None and is_process_alive(pid):
|
|
263
|
+
return tier
|
|
264
|
+
except ProcessError:
|
|
265
|
+
pass
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _get_running_tier_names() -> list[str]:
|
|
270
|
+
"""Get names of all running tiers from the active stack."""
|
|
271
|
+
tiers = _load_stack_tiers()
|
|
272
|
+
running = []
|
|
273
|
+
for tier in tiers:
|
|
274
|
+
name = tier.get("name", "")
|
|
275
|
+
try:
|
|
276
|
+
pid = read_pid_file(name)
|
|
277
|
+
if pid is not None and is_process_alive(pid):
|
|
278
|
+
running.append(name)
|
|
279
|
+
except ProcessError:
|
|
280
|
+
pass
|
|
281
|
+
return running
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _get_all_tier_names() -> list[str]:
|
|
285
|
+
"""Get names of all tiers in the active stack."""
|
|
286
|
+
tiers = _load_stack_tiers()
|
|
287
|
+
return [t.get("name", "") for t in tiers if t.get("name")]
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _get_used_ports() -> set[int]:
|
|
291
|
+
"""Get all ports used by running stack services and LiteLLM."""
|
|
292
|
+
ports: set[int] = set()
|
|
293
|
+
tiers = _load_stack_tiers()
|
|
294
|
+
for tier in tiers:
|
|
295
|
+
port = tier.get("port")
|
|
296
|
+
if port is not None:
|
|
297
|
+
ports.add(int(port))
|
|
298
|
+
# Add LiteLLM port
|
|
299
|
+
try:
|
|
300
|
+
litellm_port = int(get_value("litellm-port"))
|
|
301
|
+
ports.add(litellm_port)
|
|
302
|
+
except (ConfigCorruptError, ValueError, TypeError):
|
|
303
|
+
ports.add(4000) # default
|
|
304
|
+
return ports
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _find_temp_port(used_ports: set[int]) -> int:
|
|
308
|
+
"""Find an available port for a temporary vllm-mlx instance.
|
|
309
|
+
|
|
310
|
+
Avoids ports used by running stack and LiteLLM.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
used_ports: Set of ports currently in use.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
An available port number.
|
|
317
|
+
|
|
318
|
+
Raises:
|
|
319
|
+
BenchmarkError: If no free port can be found.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
for port in range(TEMP_PORT_START, TEMP_PORT_END):
|
|
323
|
+
if port in used_ports:
|
|
324
|
+
continue
|
|
325
|
+
# Verify the port is actually free
|
|
326
|
+
conflict = check_port_conflict(port)
|
|
327
|
+
if conflict is None:
|
|
328
|
+
return port
|
|
329
|
+
|
|
330
|
+
msg = (
|
|
331
|
+
f"Could not find an available port in range "
|
|
332
|
+
f"{TEMP_PORT_START}-{TEMP_PORT_END} for temporary benchmark instance."
|
|
333
|
+
)
|
|
334
|
+
raise BenchmarkError(msg)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# --------------------------------------------------------------------------- #
|
|
338
|
+
# Benchmark iterations
|
|
339
|
+
# --------------------------------------------------------------------------- #
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _run_single_iteration(
|
|
343
|
+
port: int,
|
|
344
|
+
model_name: str,
|
|
345
|
+
prompt: str,
|
|
346
|
+
max_tokens: int = MAX_GENERATION_TOKENS,
|
|
347
|
+
host: str = "127.0.0.1",
|
|
348
|
+
) -> IterationResult:
|
|
349
|
+
"""Run a single benchmark iteration against a vllm-mlx instance.
|
|
350
|
+
|
|
351
|
+
Uses streaming to separately measure prompt processing time (time to
|
|
352
|
+
first token) and generation time (first token to last token). This
|
|
353
|
+
gives distinct prompt_tps and gen_tps measurements.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
port: Port of the vllm-mlx instance.
|
|
357
|
+
model_name: The model identifier for the API request.
|
|
358
|
+
prompt: The prompt text.
|
|
359
|
+
max_tokens: Maximum tokens to generate.
|
|
360
|
+
host: The host to connect to.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
An IterationResult with timing data.
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
BenchmarkRunError: If the API request fails.
|
|
367
|
+
"""
|
|
368
|
+
url = f"http://{host}:{port}/v1/chat/completions"
|
|
369
|
+
payload = {
|
|
370
|
+
"model": model_name,
|
|
371
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
372
|
+
"max_tokens": max_tokens,
|
|
373
|
+
"temperature": 0.0,
|
|
374
|
+
"stream": True,
|
|
375
|
+
"stream_options": {"include_usage": True},
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
start_time = time.monotonic()
|
|
380
|
+
first_token_time: float | None = None
|
|
381
|
+
completion_tokens = 0
|
|
382
|
+
prompt_tokens = 0
|
|
383
|
+
chunk_count = 0
|
|
384
|
+
|
|
385
|
+
with httpx.stream("POST", url, json=payload, timeout=300.0) as response:
|
|
386
|
+
if response.status_code != 200:
|
|
387
|
+
body = response.read().decode("utf-8", errors="replace")[:200]
|
|
388
|
+
msg = (
|
|
389
|
+
f"API request failed with status {response.status_code}: "
|
|
390
|
+
f"{body}"
|
|
391
|
+
)
|
|
392
|
+
raise BenchmarkRunError(msg)
|
|
393
|
+
|
|
394
|
+
for line in response.iter_lines():
|
|
395
|
+
if not line.startswith("data: "):
|
|
396
|
+
continue
|
|
397
|
+
data_str = line[6:].strip()
|
|
398
|
+
if data_str == "[DONE]":
|
|
399
|
+
break
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
chunk = json.loads(data_str)
|
|
403
|
+
except json.JSONDecodeError:
|
|
404
|
+
continue
|
|
405
|
+
|
|
406
|
+
# Check for usage in the final chunk
|
|
407
|
+
usage = chunk.get("usage")
|
|
408
|
+
if usage:
|
|
409
|
+
prompt_tokens = usage.get("prompt_tokens", prompt_tokens)
|
|
410
|
+
completion_tokens = usage.get("completion_tokens", completion_tokens)
|
|
411
|
+
|
|
412
|
+
# Track first content token for TTFT measurement.
|
|
413
|
+
# Thinking models (e.g. Qwen3) emit reasoning_content for
|
|
414
|
+
# <think> tokens instead of content — check both fields so
|
|
415
|
+
# first_token_time is set and TPS is calculated correctly.
|
|
416
|
+
choices = chunk.get("choices", [])
|
|
417
|
+
if choices:
|
|
418
|
+
delta = choices[0].get("delta", {})
|
|
419
|
+
content = delta.get("content") or delta.get("reasoning_content")
|
|
420
|
+
if content and first_token_time is None:
|
|
421
|
+
first_token_time = time.monotonic()
|
|
422
|
+
if content:
|
|
423
|
+
chunk_count += 1
|
|
424
|
+
|
|
425
|
+
end_time = time.monotonic()
|
|
426
|
+
total_time = end_time - start_time
|
|
427
|
+
|
|
428
|
+
# Calculate distinct prompt_tps and gen_tps from timing phases:
|
|
429
|
+
# - prompt_time = time from request start to first generated token
|
|
430
|
+
# - gen_time = time from first generated token to last token
|
|
431
|
+
if first_token_time is not None and prompt_tokens > 0:
|
|
432
|
+
prompt_time = first_token_time - start_time
|
|
433
|
+
gen_time = end_time - first_token_time
|
|
434
|
+
|
|
435
|
+
prompt_tps = prompt_tokens / prompt_time if prompt_time > 0 else 0.0
|
|
436
|
+
gen_tps = completion_tokens / gen_time if gen_time > 0 else 0.0
|
|
437
|
+
else:
|
|
438
|
+
prompt_tps = 0.0
|
|
439
|
+
gen_tps = 0.0
|
|
440
|
+
|
|
441
|
+
return IterationResult(
|
|
442
|
+
prompt_tps=prompt_tps,
|
|
443
|
+
gen_tps=gen_tps,
|
|
444
|
+
prompt_tokens=prompt_tokens,
|
|
445
|
+
completion_tokens=completion_tokens,
|
|
446
|
+
total_time=total_time,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
except BenchmarkRunError:
|
|
450
|
+
raise
|
|
451
|
+
except httpx.TimeoutException:
|
|
452
|
+
msg = "Benchmark request timed out after 300s"
|
|
453
|
+
raise BenchmarkRunError(msg) from None
|
|
454
|
+
except httpx.HTTPError as exc:
|
|
455
|
+
msg = f"HTTP error during benchmark: {exc}"
|
|
456
|
+
raise BenchmarkRunError(msg) from None
|
|
457
|
+
except (json.JSONDecodeError, KeyError) as exc:
|
|
458
|
+
msg = f"Could not parse benchmark response: {exc}"
|
|
459
|
+
raise BenchmarkRunError(msg) from None
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def _run_iterations(
|
|
463
|
+
port: int,
|
|
464
|
+
model_name: str,
|
|
465
|
+
num_iterations: int = NUM_ITERATIONS,
|
|
466
|
+
) -> list[IterationResult]:
|
|
467
|
+
"""Run multiple benchmark iterations.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
port: Port of the vllm-mlx instance.
|
|
471
|
+
model_name: Model identifier for API requests.
|
|
472
|
+
num_iterations: Number of iterations to run.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
List of IterationResult from each iteration.
|
|
476
|
+
|
|
477
|
+
Raises:
|
|
478
|
+
BenchmarkRunError: If any iteration fails.
|
|
479
|
+
"""
|
|
480
|
+
prompt = _generate_prompt(PROMPT_TOKEN_COUNT)
|
|
481
|
+
results: list[IterationResult] = []
|
|
482
|
+
|
|
483
|
+
for _i in range(num_iterations):
|
|
484
|
+
result = _run_single_iteration(
|
|
485
|
+
port=port,
|
|
486
|
+
model_name=model_name,
|
|
487
|
+
prompt=prompt,
|
|
488
|
+
max_tokens=MAX_GENERATION_TOKENS,
|
|
489
|
+
)
|
|
490
|
+
results.append(result)
|
|
491
|
+
|
|
492
|
+
return results
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# --------------------------------------------------------------------------- #
|
|
496
|
+
# Statistics
|
|
497
|
+
# --------------------------------------------------------------------------- #
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _compute_stats(values: list[float]) -> tuple[float, float]:
|
|
501
|
+
"""Compute mean and standard deviation.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
values: List of numeric values.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
Tuple of (mean, std_dev).
|
|
508
|
+
"""
|
|
509
|
+
if not values:
|
|
510
|
+
return 0.0, 0.0
|
|
511
|
+
|
|
512
|
+
n = len(values)
|
|
513
|
+
mean = sum(values) / n
|
|
514
|
+
|
|
515
|
+
if n < 2:
|
|
516
|
+
return mean, 0.0
|
|
517
|
+
|
|
518
|
+
variance = sum((v - mean) ** 2 for v in values) / (n - 1)
|
|
519
|
+
std_dev = math.sqrt(variance)
|
|
520
|
+
return mean, std_dev
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# --------------------------------------------------------------------------- #
|
|
524
|
+
# Catalog comparison
|
|
525
|
+
# --------------------------------------------------------------------------- #
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _classify_metric(
|
|
529
|
+
metric_name: str,
|
|
530
|
+
measured: float,
|
|
531
|
+
catalog_value: float,
|
|
532
|
+
) -> MetricClassification:
|
|
533
|
+
"""Classify a metric against the catalog threshold.
|
|
534
|
+
|
|
535
|
+
PASS: within 15% of catalog value
|
|
536
|
+
WARN: 15-30% below catalog value
|
|
537
|
+
FAIL: more than 30% below catalog value
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
metric_name: Name of the metric (e.g. "prompt_tps").
|
|
541
|
+
measured: Measured value.
|
|
542
|
+
catalog_value: Catalog reference value.
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
A MetricClassification instance.
|
|
546
|
+
"""
|
|
547
|
+
if catalog_value <= 0:
|
|
548
|
+
return MetricClassification(
|
|
549
|
+
metric=metric_name,
|
|
550
|
+
measured=measured,
|
|
551
|
+
catalog=catalog_value,
|
|
552
|
+
delta_pct=0.0,
|
|
553
|
+
classification=CLASSIFICATION_PASS,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
delta_pct = (catalog_value - measured) / catalog_value
|
|
557
|
+
|
|
558
|
+
if delta_pct <= PASS_THRESHOLD:
|
|
559
|
+
classification = CLASSIFICATION_PASS
|
|
560
|
+
elif delta_pct <= WARN_THRESHOLD:
|
|
561
|
+
classification = CLASSIFICATION_WARN
|
|
562
|
+
else:
|
|
563
|
+
classification = CLASSIFICATION_FAIL
|
|
564
|
+
|
|
565
|
+
return MetricClassification(
|
|
566
|
+
metric=metric_name,
|
|
567
|
+
measured=measured,
|
|
568
|
+
catalog=catalog_value,
|
|
569
|
+
delta_pct=delta_pct * 100, # Store as percentage
|
|
570
|
+
classification=classification,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def _compare_against_catalog(
|
|
575
|
+
prompt_tps_mean: float,
|
|
576
|
+
gen_tps_mean: float,
|
|
577
|
+
entry: CatalogEntry,
|
|
578
|
+
profile: HardwareProfile,
|
|
579
|
+
) -> list[MetricClassification]:
|
|
580
|
+
"""Compare benchmark results against catalog thresholds.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
prompt_tps_mean: Mean prompt tokens/sec.
|
|
584
|
+
gen_tps_mean: Mean generation tokens/sec.
|
|
585
|
+
entry: The catalog entry for the model.
|
|
586
|
+
profile: The hardware profile for benchmark lookup.
|
|
587
|
+
|
|
588
|
+
Returns:
|
|
589
|
+
List of MetricClassification, empty if no catalog data.
|
|
590
|
+
"""
|
|
591
|
+
profile_id = profile.profile_id
|
|
592
|
+
if profile_id not in entry.benchmarks:
|
|
593
|
+
return []
|
|
594
|
+
|
|
595
|
+
bench = entry.benchmarks[profile_id]
|
|
596
|
+
classifications: list[MetricClassification] = []
|
|
597
|
+
|
|
598
|
+
classifications.append(
|
|
599
|
+
_classify_metric("prompt_tps", prompt_tps_mean, bench.prompt_tps)
|
|
600
|
+
)
|
|
601
|
+
classifications.append(
|
|
602
|
+
_classify_metric("gen_tps", gen_tps_mean, bench.gen_tps)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
return classifications
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
# --------------------------------------------------------------------------- #
|
|
609
|
+
# Tool-calling benchmark
|
|
610
|
+
# --------------------------------------------------------------------------- #
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _run_tool_call_benchmark(
|
|
614
|
+
port: int,
|
|
615
|
+
model_name: str,
|
|
616
|
+
host: str = "127.0.0.1",
|
|
617
|
+
) -> ToolCallResult:
|
|
618
|
+
"""Run a tool-calling benchmark.
|
|
619
|
+
|
|
620
|
+
Sends a function-calling request with a tool definition and verifies
|
|
621
|
+
the response contains a valid tool call structure.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
port: Port of the vllm-mlx instance.
|
|
625
|
+
model_name: Model identifier.
|
|
626
|
+
host: Host to connect to.
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
A ToolCallResult with timing and validity information.
|
|
630
|
+
"""
|
|
631
|
+
url = f"http://{host}:{port}/v1/chat/completions"
|
|
632
|
+
payload = {
|
|
633
|
+
"model": model_name,
|
|
634
|
+
"messages": [
|
|
635
|
+
{
|
|
636
|
+
"role": "user",
|
|
637
|
+
"content": "What is the current weather in San Francisco, CA?",
|
|
638
|
+
}
|
|
639
|
+
],
|
|
640
|
+
"tools": [TOOL_DEFINITION],
|
|
641
|
+
"tool_choice": "auto",
|
|
642
|
+
"max_tokens": 1024,
|
|
643
|
+
"temperature": 0.0,
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
try:
|
|
647
|
+
start_time = time.monotonic()
|
|
648
|
+
response = httpx.post(url, json=payload, timeout=60.0)
|
|
649
|
+
round_trip_time = time.monotonic() - start_time
|
|
650
|
+
|
|
651
|
+
if response.status_code != 200:
|
|
652
|
+
return ToolCallResult(
|
|
653
|
+
success=False,
|
|
654
|
+
round_trip_time=round_trip_time,
|
|
655
|
+
error=f"API returned status {response.status_code}",
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
data = response.json()
|
|
659
|
+
choices = data.get("choices", [])
|
|
660
|
+
if not choices:
|
|
661
|
+
return ToolCallResult(
|
|
662
|
+
success=False,
|
|
663
|
+
round_trip_time=round_trip_time,
|
|
664
|
+
error="No choices in response",
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
message = choices[0].get("message", {})
|
|
668
|
+
tool_calls = message.get("tool_calls", [])
|
|
669
|
+
if not tool_calls:
|
|
670
|
+
return ToolCallResult(
|
|
671
|
+
success=False,
|
|
672
|
+
round_trip_time=round_trip_time,
|
|
673
|
+
error="No tool calls in response",
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
# Validate tool call structure
|
|
677
|
+
tool_call = tool_calls[0]
|
|
678
|
+
fn = tool_call.get("function", {})
|
|
679
|
+
fn_name = fn.get("name", "")
|
|
680
|
+
fn_args_str = fn.get("arguments", "")
|
|
681
|
+
|
|
682
|
+
if fn_name != "get_weather":
|
|
683
|
+
return ToolCallResult(
|
|
684
|
+
success=False,
|
|
685
|
+
round_trip_time=round_trip_time,
|
|
686
|
+
error=f"Wrong function name: {fn_name}",
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
# Try to parse arguments as JSON
|
|
690
|
+
try:
|
|
691
|
+
fn_args = json.loads(fn_args_str) if isinstance(fn_args_str, str) else fn_args_str
|
|
692
|
+
if "location" not in fn_args:
|
|
693
|
+
return ToolCallResult(
|
|
694
|
+
success=False,
|
|
695
|
+
round_trip_time=round_trip_time,
|
|
696
|
+
error="Missing 'location' in tool call arguments",
|
|
697
|
+
)
|
|
698
|
+
except (json.JSONDecodeError, TypeError):
|
|
699
|
+
return ToolCallResult(
|
|
700
|
+
success=False,
|
|
701
|
+
round_trip_time=round_trip_time,
|
|
702
|
+
error="Could not parse tool call arguments as JSON",
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
return ToolCallResult(
|
|
706
|
+
success=True,
|
|
707
|
+
round_trip_time=round_trip_time,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
except httpx.TimeoutException:
|
|
711
|
+
return ToolCallResult(
|
|
712
|
+
success=False,
|
|
713
|
+
round_trip_time=60.0,
|
|
714
|
+
error="Tool call request timed out",
|
|
715
|
+
)
|
|
716
|
+
except (httpx.HTTPError, Exception) as exc:
|
|
717
|
+
return ToolCallResult(
|
|
718
|
+
success=False,
|
|
719
|
+
round_trip_time=0.0,
|
|
720
|
+
error=f"Tool call request failed: {exc}",
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
# --------------------------------------------------------------------------- #
|
|
725
|
+
# Temporary instance management
|
|
726
|
+
# --------------------------------------------------------------------------- #
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def _resolve_model_source(model_id: str, quant: str) -> str:
|
|
730
|
+
"""Resolve the model source path for vllm-mlx.
|
|
731
|
+
|
|
732
|
+
Checks local models directory first, then falls back to the catalog
|
|
733
|
+
HuggingFace repo.
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
model_id: Catalog model ID.
|
|
737
|
+
quant: Quantization level.
|
|
738
|
+
|
|
739
|
+
Returns:
|
|
740
|
+
The model source path or HF repo identifier.
|
|
741
|
+
|
|
742
|
+
Raises:
|
|
743
|
+
BenchmarkError: If the model cannot be resolved.
|
|
744
|
+
"""
|
|
745
|
+
catalog = load_catalog()
|
|
746
|
+
entry = get_entry_by_id(catalog, model_id)
|
|
747
|
+
if entry is None:
|
|
748
|
+
msg = f"Model '{model_id}' not found in catalog."
|
|
749
|
+
raise BenchmarkTargetError(msg)
|
|
750
|
+
|
|
751
|
+
if quant not in entry.sources:
|
|
752
|
+
available_quants = ", ".join(sorted(entry.sources.keys()))
|
|
753
|
+
msg = (
|
|
754
|
+
f"Quantization '{quant}' not available for model '{model_id}'. "
|
|
755
|
+
f"Available: {available_quants}"
|
|
756
|
+
)
|
|
757
|
+
raise BenchmarkError(msg)
|
|
758
|
+
|
|
759
|
+
source = entry.sources[quant]
|
|
760
|
+
|
|
761
|
+
# Check local models directory
|
|
762
|
+
try:
|
|
763
|
+
model_dir = str(get_value("model-dir"))
|
|
764
|
+
models_path = Path(model_dir).expanduser()
|
|
765
|
+
except (ConfigCorruptError, Exception):
|
|
766
|
+
models_path = get_data_home() / "models"
|
|
767
|
+
|
|
768
|
+
# Check by repo directory name
|
|
769
|
+
repo_dir_name = source.hf_repo.rsplit("/", 1)[-1] if "/" in source.hf_repo else source.hf_repo
|
|
770
|
+
local_path = models_path / repo_dir_name
|
|
771
|
+
if local_path.exists():
|
|
772
|
+
return str(local_path)
|
|
773
|
+
|
|
774
|
+
# Check by model ID
|
|
775
|
+
model_path = models_path / model_id
|
|
776
|
+
if model_path.exists():
|
|
777
|
+
return str(model_path)
|
|
778
|
+
|
|
779
|
+
# Use HuggingFace repo directly
|
|
780
|
+
return source.hf_repo
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def _start_temp_instance(
|
|
784
|
+
model_source: str,
|
|
785
|
+
port: int,
|
|
786
|
+
entry: CatalogEntry,
|
|
787
|
+
quant: str,
|
|
788
|
+
) -> str:
|
|
789
|
+
"""Start a temporary vllm-mlx instance for benchmarking.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
model_source: Model source path or HF repo.
|
|
793
|
+
port: Port to bind the instance to.
|
|
794
|
+
entry: Catalog entry for the model.
|
|
795
|
+
quant: Quantization level.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
The service name used for PID management.
|
|
799
|
+
|
|
800
|
+
Raises:
|
|
801
|
+
BenchmarkError: If the instance cannot be started.
|
|
802
|
+
"""
|
|
803
|
+
service_name = f"{TEMP_SERVICE_PREFIX}-{entry.id}"
|
|
804
|
+
|
|
805
|
+
# Ensure vllm-mlx is installed
|
|
806
|
+
ensure_dependency("vllm-mlx")
|
|
807
|
+
|
|
808
|
+
vllm_binary = shutil.which("vllm-mlx")
|
|
809
|
+
if vllm_binary is None:
|
|
810
|
+
msg = (
|
|
811
|
+
"vllm-mlx binary not found on PATH after installation. "
|
|
812
|
+
"Try: uv tool install vllm-mlx"
|
|
813
|
+
)
|
|
814
|
+
raise BenchmarkError(msg)
|
|
815
|
+
|
|
816
|
+
cmd = [
|
|
817
|
+
vllm_binary,
|
|
818
|
+
"serve", model_source,
|
|
819
|
+
"--port", str(port),
|
|
820
|
+
"--host", "127.0.0.1",
|
|
821
|
+
]
|
|
822
|
+
|
|
823
|
+
# Add tool-calling flags if the model supports it
|
|
824
|
+
if entry.capabilities.tool_calling:
|
|
825
|
+
cmd.append("--enable-auto-tool-choice")
|
|
826
|
+
if entry.capabilities.tool_call_parser:
|
|
827
|
+
cmd.extend(["--tool-call-parser", entry.capabilities.tool_call_parser])
|
|
828
|
+
|
|
829
|
+
# Add reasoning parser for thinking models (e.g., Qwen3).
|
|
830
|
+
# Without this flag, thinking models emit <think> tags that break
|
|
831
|
+
# the tool-call parser (e.g., hermes).
|
|
832
|
+
if entry.capabilities.thinking and entry.capabilities.reasoning_parser:
|
|
833
|
+
cmd.extend(["--reasoning-parser", entry.capabilities.reasoning_parser])
|
|
834
|
+
|
|
835
|
+
try:
|
|
836
|
+
start_service(
|
|
837
|
+
service_name=service_name,
|
|
838
|
+
cmd=cmd,
|
|
839
|
+
port=port,
|
|
840
|
+
)
|
|
841
|
+
except ProcessError as exc:
|
|
842
|
+
msg = f"Could not start temporary vllm-mlx instance: {exc}"
|
|
843
|
+
raise BenchmarkError(msg) from None
|
|
844
|
+
|
|
845
|
+
# Wait for health
|
|
846
|
+
try:
|
|
847
|
+
wait_for_healthy(
|
|
848
|
+
port=port,
|
|
849
|
+
path="/v1/models",
|
|
850
|
+
total_timeout=TEMP_INSTANCE_HEALTH_TIMEOUT,
|
|
851
|
+
)
|
|
852
|
+
except HealthCheckError:
|
|
853
|
+
# Clean up the failed instance
|
|
854
|
+
_cleanup_temp_instance(service_name)
|
|
855
|
+
msg = (
|
|
856
|
+
f"Temporary vllm-mlx instance did not become healthy "
|
|
857
|
+
f"within {TEMP_INSTANCE_HEALTH_TIMEOUT}s."
|
|
858
|
+
)
|
|
859
|
+
raise BenchmarkError(msg) from None
|
|
860
|
+
|
|
861
|
+
return service_name
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def _cleanup_temp_instance(service_name: str) -> None:
|
|
865
|
+
"""Clean up a temporary vllm-mlx instance.
|
|
866
|
+
|
|
867
|
+
Sends SIGTERM then SIGKILL to ensure full cleanup.
|
|
868
|
+
|
|
869
|
+
Args:
|
|
870
|
+
service_name: The service name from PID management.
|
|
871
|
+
"""
|
|
872
|
+
try:
|
|
873
|
+
stop_service(service_name, grace_period=5.0)
|
|
874
|
+
except Exception:
|
|
875
|
+
pass
|
|
876
|
+
|
|
877
|
+
# Double-check: try reading PID and kill directly if still alive
|
|
878
|
+
try:
|
|
879
|
+
pid = read_pid_file(service_name)
|
|
880
|
+
if pid is not None and is_process_alive(pid):
|
|
881
|
+
os.kill(pid, signal.SIGKILL)
|
|
882
|
+
time.sleep(0.5)
|
|
883
|
+
except (ProcessError, OSError):
|
|
884
|
+
pass
|
|
885
|
+
|
|
886
|
+
# Remove PID file
|
|
887
|
+
remove_pid_file(service_name)
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
# --------------------------------------------------------------------------- #
|
|
891
|
+
# Save benchmark results
|
|
892
|
+
# --------------------------------------------------------------------------- #
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
def save_benchmark_results(
|
|
896
|
+
result: BenchmarkResult_,
|
|
897
|
+
profile: HardwareProfile,
|
|
898
|
+
) -> Path:
|
|
899
|
+
"""Save benchmark results to ~/.mlx-stack/benchmarks/<profile_id>.json.
|
|
900
|
+
|
|
901
|
+
Loads existing data (if any) and merges/updates with the new result.
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
result: The benchmark result to save.
|
|
905
|
+
profile: The hardware profile for path determination.
|
|
906
|
+
|
|
907
|
+
Returns:
|
|
908
|
+
Path to the saved file.
|
|
909
|
+
"""
|
|
910
|
+
ensure_data_home()
|
|
911
|
+
benchmarks_dir = get_benchmarks_dir()
|
|
912
|
+
benchmarks_dir.mkdir(parents=True, exist_ok=True)
|
|
913
|
+
|
|
914
|
+
benchmark_file = benchmarks_dir / f"{profile.profile_id}.json"
|
|
915
|
+
|
|
916
|
+
# Load existing data
|
|
917
|
+
existing: dict[str, Any] = {}
|
|
918
|
+
if benchmark_file.exists():
|
|
919
|
+
try:
|
|
920
|
+
existing = json.loads(benchmark_file.read_text(encoding="utf-8"))
|
|
921
|
+
if not isinstance(existing, dict):
|
|
922
|
+
existing = {}
|
|
923
|
+
except (json.JSONDecodeError, OSError):
|
|
924
|
+
existing = {}
|
|
925
|
+
|
|
926
|
+
# Merge new result
|
|
927
|
+
existing[result.model_id] = result.to_save_dict()
|
|
928
|
+
|
|
929
|
+
# Write back
|
|
930
|
+
benchmark_file.write_text(
|
|
931
|
+
json.dumps(existing, indent=2) + "\n",
|
|
932
|
+
encoding="utf-8",
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
return benchmark_file
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
# --------------------------------------------------------------------------- #
|
|
939
|
+
# Resolve benchmark target
|
|
940
|
+
# --------------------------------------------------------------------------- #
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
@dataclass
|
|
944
|
+
class BenchmarkTarget:
|
|
945
|
+
"""Resolved target for benchmarking."""
|
|
946
|
+
|
|
947
|
+
model_id: str
|
|
948
|
+
quant: str
|
|
949
|
+
port: int
|
|
950
|
+
model_name: str # The model name used in API calls
|
|
951
|
+
entry: CatalogEntry
|
|
952
|
+
is_running_tier: bool
|
|
953
|
+
temp_service_name: str | None = None # Set if using a temp instance
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def resolve_target(target: str) -> BenchmarkTarget:
|
|
957
|
+
"""Resolve a benchmark target to a specific model and port.
|
|
958
|
+
|
|
959
|
+
Tries in order:
|
|
960
|
+
1. Running tier by name
|
|
961
|
+
2. Catalog model by ID (starts temp instance)
|
|
962
|
+
|
|
963
|
+
Args:
|
|
964
|
+
target: Tier name or model ID.
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
A BenchmarkTarget with all needed info.
|
|
968
|
+
|
|
969
|
+
Raises:
|
|
970
|
+
BenchmarkTargetError: If the target cannot be resolved.
|
|
971
|
+
"""
|
|
972
|
+
# 1. Try as a running tier
|
|
973
|
+
tier = _find_running_tier(target)
|
|
974
|
+
if tier is not None:
|
|
975
|
+
port = tier["port"]
|
|
976
|
+
model_id = tier.get("model", target)
|
|
977
|
+
quant = tier.get("quant", "int4")
|
|
978
|
+
source = tier.get("source", model_id)
|
|
979
|
+
|
|
980
|
+
catalog = load_catalog()
|
|
981
|
+
entry = get_entry_by_id(catalog, model_id)
|
|
982
|
+
if entry is None:
|
|
983
|
+
# Still benchmark it even without catalog data
|
|
984
|
+
from mlx_stack.core.catalog import (
|
|
985
|
+
Capabilities,
|
|
986
|
+
CatalogEntry,
|
|
987
|
+
QualityScores,
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
entry = CatalogEntry(
|
|
991
|
+
id=model_id,
|
|
992
|
+
name=model_id,
|
|
993
|
+
family="unknown",
|
|
994
|
+
params_b=0.0,
|
|
995
|
+
architecture="unknown",
|
|
996
|
+
min_mlx_lm_version="0.0.0",
|
|
997
|
+
sources={},
|
|
998
|
+
capabilities=Capabilities(
|
|
999
|
+
tool_calling=False,
|
|
1000
|
+
tool_call_parser=None,
|
|
1001
|
+
thinking=False,
|
|
1002
|
+
reasoning_parser=None,
|
|
1003
|
+
vision=False,
|
|
1004
|
+
),
|
|
1005
|
+
quality=QualityScores(overall=0, coding=0, reasoning=0, instruction_following=0),
|
|
1006
|
+
benchmarks={},
|
|
1007
|
+
tags=[],
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
return BenchmarkTarget(
|
|
1011
|
+
model_id=model_id,
|
|
1012
|
+
quant=quant,
|
|
1013
|
+
port=port,
|
|
1014
|
+
model_name=source,
|
|
1015
|
+
entry=entry,
|
|
1016
|
+
is_running_tier=True,
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
# 2. Try as a catalog model
|
|
1020
|
+
catalog = load_catalog()
|
|
1021
|
+
entry = get_entry_by_id(catalog, target)
|
|
1022
|
+
if entry is not None:
|
|
1023
|
+
# Determine quant
|
|
1024
|
+
try:
|
|
1025
|
+
quant = str(get_value("default-quant"))
|
|
1026
|
+
except (ConfigCorruptError, Exception):
|
|
1027
|
+
quant = "int4"
|
|
1028
|
+
|
|
1029
|
+
if quant not in entry.sources:
|
|
1030
|
+
available = sorted(entry.sources.keys())
|
|
1031
|
+
quant = available[0] if available else "int4"
|
|
1032
|
+
|
|
1033
|
+
# Resolve model source
|
|
1034
|
+
model_source = _resolve_model_source(target, quant)
|
|
1035
|
+
|
|
1036
|
+
# Find a free port
|
|
1037
|
+
used_ports = _get_used_ports()
|
|
1038
|
+
port = _find_temp_port(used_ports)
|
|
1039
|
+
|
|
1040
|
+
# Start temp instance
|
|
1041
|
+
service_name = _start_temp_instance(model_source, port, entry, quant)
|
|
1042
|
+
|
|
1043
|
+
return BenchmarkTarget(
|
|
1044
|
+
model_id=entry.id,
|
|
1045
|
+
quant=quant,
|
|
1046
|
+
port=port,
|
|
1047
|
+
model_name=model_source,
|
|
1048
|
+
entry=entry,
|
|
1049
|
+
is_running_tier=False,
|
|
1050
|
+
temp_service_name=service_name,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
# Neither tier nor model
|
|
1054
|
+
tier_names = _get_all_tier_names()
|
|
1055
|
+
running_tiers = _get_running_tier_names()
|
|
1056
|
+
|
|
1057
|
+
parts: list[str] = [f"'{target}' is not a running tier or known model."]
|
|
1058
|
+
|
|
1059
|
+
if tier_names:
|
|
1060
|
+
parts.append(f"\nValid tier names: {', '.join(tier_names)}")
|
|
1061
|
+
if running_tiers:
|
|
1062
|
+
parts.append(f"Running tiers: {', '.join(running_tiers)}")
|
|
1063
|
+
|
|
1064
|
+
parts.append("\nUse 'mlx-stack models --catalog' to see available models.")
|
|
1065
|
+
|
|
1066
|
+
raise BenchmarkTargetError("\n".join(parts))
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
# --------------------------------------------------------------------------- #
|
|
1070
|
+
# Main benchmark function
|
|
1071
|
+
# --------------------------------------------------------------------------- #
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
def run_benchmark(
|
|
1075
|
+
target: str,
|
|
1076
|
+
save: bool = False,
|
|
1077
|
+
) -> BenchmarkResult_:
|
|
1078
|
+
"""Run a complete benchmark against a tier or model.
|
|
1079
|
+
|
|
1080
|
+
This is the main entry point for the benchmark engine.
|
|
1081
|
+
|
|
1082
|
+
Args:
|
|
1083
|
+
target: Tier name or model ID.
|
|
1084
|
+
save: Whether to persist results to disk.
|
|
1085
|
+
|
|
1086
|
+
Returns:
|
|
1087
|
+
A BenchmarkResult_ with all collected data.
|
|
1088
|
+
|
|
1089
|
+
Raises:
|
|
1090
|
+
BenchmarkError: If the benchmark fails.
|
|
1091
|
+
BenchmarkTargetError: If the target is not found.
|
|
1092
|
+
"""
|
|
1093
|
+
# Auto-install vllm-mlx (needed for API calls even to running tiers)
|
|
1094
|
+
ensure_dependency("vllm-mlx")
|
|
1095
|
+
|
|
1096
|
+
resolved = resolve_target(target)
|
|
1097
|
+
temp_service = resolved.temp_service_name
|
|
1098
|
+
|
|
1099
|
+
try:
|
|
1100
|
+
result = _execute_benchmark(resolved, save)
|
|
1101
|
+
return result
|
|
1102
|
+
except Exception:
|
|
1103
|
+
# Ensure cleanup on any failure
|
|
1104
|
+
if temp_service:
|
|
1105
|
+
_cleanup_temp_instance(temp_service)
|
|
1106
|
+
raise
|
|
1107
|
+
finally:
|
|
1108
|
+
# Always clean up temp instance
|
|
1109
|
+
if temp_service:
|
|
1110
|
+
_cleanup_temp_instance(temp_service)
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
def _execute_benchmark(
|
|
1114
|
+
target: BenchmarkTarget,
|
|
1115
|
+
save: bool,
|
|
1116
|
+
) -> BenchmarkResult_:
|
|
1117
|
+
"""Execute the benchmark iterations and comparison.
|
|
1118
|
+
|
|
1119
|
+
Args:
|
|
1120
|
+
target: The resolved benchmark target.
|
|
1121
|
+
save: Whether to persist results.
|
|
1122
|
+
|
|
1123
|
+
Returns:
|
|
1124
|
+
A BenchmarkResult_.
|
|
1125
|
+
"""
|
|
1126
|
+
# Run iterations
|
|
1127
|
+
iterations = _run_iterations(
|
|
1128
|
+
port=target.port,
|
|
1129
|
+
model_name=target.model_name,
|
|
1130
|
+
num_iterations=NUM_ITERATIONS,
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
# Compute stats
|
|
1134
|
+
prompt_tps_values = [it.prompt_tps for it in iterations]
|
|
1135
|
+
gen_tps_values = [it.gen_tps for it in iterations]
|
|
1136
|
+
|
|
1137
|
+
prompt_tps_mean, prompt_tps_std = _compute_stats(prompt_tps_values)
|
|
1138
|
+
gen_tps_mean, gen_tps_std = _compute_stats(gen_tps_values)
|
|
1139
|
+
|
|
1140
|
+
# Compare against catalog
|
|
1141
|
+
profile = load_profile()
|
|
1142
|
+
if profile is None:
|
|
1143
|
+
try:
|
|
1144
|
+
profile = detect_hardware()
|
|
1145
|
+
except Exception:
|
|
1146
|
+
profile = None
|
|
1147
|
+
|
|
1148
|
+
classifications: list[MetricClassification] = []
|
|
1149
|
+
catalog_data_available = False
|
|
1150
|
+
if profile is not None:
|
|
1151
|
+
classifications = _compare_against_catalog(
|
|
1152
|
+
prompt_tps_mean, gen_tps_mean, target.entry, profile
|
|
1153
|
+
)
|
|
1154
|
+
catalog_data_available = len(classifications) > 0
|
|
1155
|
+
|
|
1156
|
+
# Tool-calling benchmark
|
|
1157
|
+
tool_call_result: ToolCallResult | None = None
|
|
1158
|
+
if target.entry.capabilities.tool_calling:
|
|
1159
|
+
tool_call_result = _run_tool_call_benchmark(
|
|
1160
|
+
port=target.port,
|
|
1161
|
+
model_name=target.model_name,
|
|
1162
|
+
)
|
|
1163
|
+
|
|
1164
|
+
bench_result = BenchmarkResult_(
|
|
1165
|
+
model_id=target.model_id,
|
|
1166
|
+
quant=target.quant,
|
|
1167
|
+
iterations=iterations,
|
|
1168
|
+
prompt_tps_mean=prompt_tps_mean,
|
|
1169
|
+
prompt_tps_std=prompt_tps_std,
|
|
1170
|
+
gen_tps_mean=gen_tps_mean,
|
|
1171
|
+
gen_tps_std=gen_tps_std,
|
|
1172
|
+
classifications=classifications,
|
|
1173
|
+
tool_call_result=tool_call_result,
|
|
1174
|
+
used_temporary_instance=not target.is_running_tier,
|
|
1175
|
+
catalog_data_available=catalog_data_available,
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
# Save results if requested
|
|
1179
|
+
if save and profile is not None:
|
|
1180
|
+
save_benchmark_results(bench_result, profile)
|
|
1181
|
+
|
|
1182
|
+
return bench_result
|