mlx-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. mlx_stack/__init__.py +5 -0
  2. mlx_stack/_version.py +24 -0
  3. mlx_stack/cli/__init__.py +5 -0
  4. mlx_stack/cli/bench.py +221 -0
  5. mlx_stack/cli/config.py +166 -0
  6. mlx_stack/cli/down.py +109 -0
  7. mlx_stack/cli/init.py +180 -0
  8. mlx_stack/cli/install.py +165 -0
  9. mlx_stack/cli/logs.py +234 -0
  10. mlx_stack/cli/main.py +187 -0
  11. mlx_stack/cli/models.py +304 -0
  12. mlx_stack/cli/profile.py +65 -0
  13. mlx_stack/cli/pull.py +134 -0
  14. mlx_stack/cli/recommend.py +397 -0
  15. mlx_stack/cli/status.py +111 -0
  16. mlx_stack/cli/up.py +163 -0
  17. mlx_stack/cli/watch.py +252 -0
  18. mlx_stack/core/__init__.py +1 -0
  19. mlx_stack/core/benchmark.py +1182 -0
  20. mlx_stack/core/catalog.py +560 -0
  21. mlx_stack/core/config.py +471 -0
  22. mlx_stack/core/deps.py +323 -0
  23. mlx_stack/core/hardware.py +304 -0
  24. mlx_stack/core/launchd.py +531 -0
  25. mlx_stack/core/litellm_gen.py +188 -0
  26. mlx_stack/core/log_rotation.py +231 -0
  27. mlx_stack/core/log_viewer.py +386 -0
  28. mlx_stack/core/models.py +639 -0
  29. mlx_stack/core/paths.py +79 -0
  30. mlx_stack/core/process.py +887 -0
  31. mlx_stack/core/pull.py +815 -0
  32. mlx_stack/core/scoring.py +611 -0
  33. mlx_stack/core/stack_down.py +317 -0
  34. mlx_stack/core/stack_init.py +524 -0
  35. mlx_stack/core/stack_status.py +229 -0
  36. mlx_stack/core/stack_up.py +856 -0
  37. mlx_stack/core/watchdog.py +744 -0
  38. mlx_stack/data/__init__.py +1 -0
  39. mlx_stack/data/catalog/__init__.py +1 -0
  40. mlx_stack/data/catalog/deepseek-r1-32b.yaml +46 -0
  41. mlx_stack/data/catalog/deepseek-r1-8b.yaml +45 -0
  42. mlx_stack/data/catalog/gemma3-12b.yaml +45 -0
  43. mlx_stack/data/catalog/gemma3-27b.yaml +45 -0
  44. mlx_stack/data/catalog/gemma3-4b.yaml +45 -0
  45. mlx_stack/data/catalog/llama3.3-8b.yaml +44 -0
  46. mlx_stack/data/catalog/nemotron-49b.yaml +41 -0
  47. mlx_stack/data/catalog/nemotron-8b.yaml +44 -0
  48. mlx_stack/data/catalog/qwen3-8b.yaml +45 -0
  49. mlx_stack/data/catalog/qwen3.5-0.8b.yaml +45 -0
  50. mlx_stack/data/catalog/qwen3.5-14b.yaml +46 -0
  51. mlx_stack/data/catalog/qwen3.5-32b.yaml +45 -0
  52. mlx_stack/data/catalog/qwen3.5-3b.yaml +44 -0
  53. mlx_stack/data/catalog/qwen3.5-72b.yaml +42 -0
  54. mlx_stack/data/catalog/qwen3.5-8b.yaml +45 -0
  55. mlx_stack/py.typed +1 -0
  56. mlx_stack/utils/__init__.py +1 -0
  57. mlx_stack-0.1.0.dist-info/METADATA +397 -0
  58. mlx_stack-0.1.0.dist-info/RECORD +61 -0
  59. mlx_stack-0.1.0.dist-info/WHEEL +4 -0
  60. mlx_stack-0.1.0.dist-info/entry_points.txt +2 -0
  61. mlx_stack-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1182 @@
1
+ """Benchmark execution engine for mlx-stack.
2
+
3
+ Runs 3 iterations of 1024-token prompt + 100-token generation against
4
+ a vllm-mlx instance, reports mean ± std dev for prompt_tps and gen_tps.
5
+ Compares against catalog thresholds: PASS (<15%), WARN (15-30%),
6
+ FAIL (>30%). Handles tool-calling benchmarks for capable models.
7
+
8
+ Supports benchmarking running tier instances and local models via
9
+ temporary vllm-mlx instances with full cleanup (including Ctrl+C).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import math
16
+ import os
17
+ import shutil
18
+ import signal
19
+ import time
20
+ from dataclasses import dataclass, field
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+ import httpx
25
+
26
+ from mlx_stack.core.catalog import (
27
+ CatalogEntry,
28
+ get_entry_by_id,
29
+ load_catalog,
30
+ )
31
+ from mlx_stack.core.config import ConfigCorruptError, get_value
32
+ from mlx_stack.core.deps import ensure_dependency
33
+ from mlx_stack.core.hardware import HardwareProfile, detect_hardware, load_profile
34
+ from mlx_stack.core.paths import (
35
+ ensure_data_home,
36
+ get_benchmarks_dir,
37
+ get_data_home,
38
+ get_stacks_dir,
39
+ )
40
+ from mlx_stack.core.process import (
41
+ HealthCheckError,
42
+ ProcessError,
43
+ check_port_conflict,
44
+ is_process_alive,
45
+ read_pid_file,
46
+ remove_pid_file,
47
+ start_service,
48
+ stop_service,
49
+ wait_for_healthy,
50
+ )
51
+
52
+ # --------------------------------------------------------------------------- #
53
+ # Constants
54
+ # --------------------------------------------------------------------------- #
55
+
56
+ # Benchmark parameters
57
+ NUM_ITERATIONS = 3
58
+ PROMPT_TOKEN_COUNT = 1024
59
+ MAX_GENERATION_TOKENS = 100
60
+
61
+ # Health check timeout for temporary instances (seconds)
62
+ TEMP_INSTANCE_HEALTH_TIMEOUT = 120.0
63
+
64
+ # Port range for temporary instances (avoid 4000, 8000-8002)
65
+ TEMP_PORT_START = 8100
66
+ TEMP_PORT_END = 8200
67
+
68
+ # Threshold percentages for comparison
69
+ PASS_THRESHOLD = 0.15 # within 15%
70
+ WARN_THRESHOLD = 0.30 # within 30%
71
+
72
+ # Classification labels
73
+ CLASSIFICATION_PASS = "PASS"
74
+ CLASSIFICATION_WARN = "WARN"
75
+ CLASSIFICATION_FAIL = "FAIL"
76
+
77
+ # Temporary service name prefix
78
+ TEMP_SERVICE_PREFIX = "bench-temp"
79
+
80
+ # Tool-calling benchmark tool definition
81
+ TOOL_DEFINITION = {
82
+ "type": "function",
83
+ "function": {
84
+ "name": "get_weather",
85
+ "description": "Get the current weather for a given location.",
86
+ "parameters": {
87
+ "type": "object",
88
+ "properties": {
89
+ "location": {
90
+ "type": "string",
91
+ "description": "The city and state, e.g. San Francisco, CA",
92
+ },
93
+ "unit": {
94
+ "type": "string",
95
+ "enum": ["celsius", "fahrenheit"],
96
+ "description": "The temperature unit to use.",
97
+ },
98
+ },
99
+ "required": ["location"],
100
+ },
101
+ },
102
+ }
103
+
104
+
105
+ # --------------------------------------------------------------------------- #
106
+ # Exceptions
107
+ # --------------------------------------------------------------------------- #
108
+
109
+
110
+ class BenchmarkError(Exception):
111
+ """Raised when a benchmark operation fails."""
112
+
113
+
114
+ class BenchmarkTargetError(BenchmarkError):
115
+ """Raised when the benchmark target is not found."""
116
+
117
+
118
+ class BenchmarkRunError(BenchmarkError):
119
+ """Raised when a benchmark iteration fails."""
120
+
121
+
122
+ # --------------------------------------------------------------------------- #
123
+ # Data classes
124
+ # --------------------------------------------------------------------------- #
125
+
126
+
127
+ @dataclass(frozen=True)
128
+ class IterationResult:
129
+ """Result of a single benchmark iteration."""
130
+
131
+ prompt_tps: float
132
+ gen_tps: float
133
+ prompt_tokens: int
134
+ completion_tokens: int
135
+ total_time: float
136
+
137
+
138
+ @dataclass(frozen=True)
139
+ class MetricClassification:
140
+ """Classification for a single metric against catalog threshold."""
141
+
142
+ metric: str
143
+ measured: float
144
+ catalog: float
145
+ delta_pct: float
146
+ classification: str # PASS, WARN, FAIL
147
+
148
+
149
+ @dataclass(frozen=True)
150
+ class ToolCallResult:
151
+ """Result of a tool-calling benchmark."""
152
+
153
+ success: bool
154
+ round_trip_time: float
155
+ error: str | None = None
156
+
157
+
158
+ @dataclass
159
+ class BenchmarkResult_:
160
+ """Complete benchmark result for a model."""
161
+
162
+ model_id: str
163
+ quant: str
164
+ iterations: list[IterationResult] = field(default_factory=list)
165
+ prompt_tps_mean: float = 0.0
166
+ prompt_tps_std: float = 0.0
167
+ gen_tps_mean: float = 0.0
168
+ gen_tps_std: float = 0.0
169
+ classifications: list[MetricClassification] = field(default_factory=list)
170
+ tool_call_result: ToolCallResult | None = None
171
+ used_temporary_instance: bool = False
172
+ catalog_data_available: bool = False
173
+
174
+ def to_save_dict(self) -> dict[str, Any]:
175
+ """Convert to dict for saving to benchmark JSON file."""
176
+ return {
177
+ "model_id": self.model_id,
178
+ "quant": self.quant,
179
+ "prompt_tps": self.prompt_tps_mean,
180
+ "gen_tps": self.gen_tps_mean,
181
+ "memory_gb": 0.0, # placeholder, updated if available
182
+ }
183
+
184
+
185
+ # --------------------------------------------------------------------------- #
186
+ # Prompt generation
187
+ # --------------------------------------------------------------------------- #
188
+
189
+
190
+ def _generate_prompt(token_count: int) -> str:
191
+ """Generate a prompt with approximately the specified token count.
192
+
193
+ Uses repeated text to approximate the desired token count. Each word
194
+ is roughly 1-2 tokens, so we generate enough words to cover the count.
195
+
196
+ Args:
197
+ token_count: Target number of tokens.
198
+
199
+ Returns:
200
+ A text string approximately token_count tokens long.
201
+ """
202
+ # Each word is roughly 1.3 tokens on average for English text.
203
+ # We use a repeating pattern to keep it simple and predictable.
204
+ base_phrase = (
205
+ "The quick brown fox jumps over the lazy dog near the river bank "
206
+ "where flowers bloom in the warm sunshine of a summer afternoon "
207
+ )
208
+ words_needed = int(token_count / 1.3) + 10
209
+ words = base_phrase.split()
210
+ repeated = []
211
+ for i in range(words_needed):
212
+ repeated.append(words[i % len(words)])
213
+ return " ".join(repeated)
214
+
215
+
216
+ # --------------------------------------------------------------------------- #
217
+ # Running tier resolution
218
+ # --------------------------------------------------------------------------- #
219
+
220
+
221
+ def _load_stack_tiers() -> list[dict[str, Any]]:
222
+ """Load tier definitions from the active stack.
223
+
224
+ Returns:
225
+ List of tier dicts, or empty list if no stack is configured.
226
+ """
227
+ stack_path = get_stacks_dir() / "default.yaml"
228
+ if not stack_path.exists():
229
+ return []
230
+ try:
231
+ import yaml
232
+
233
+ content = stack_path.read_text(encoding="utf-8")
234
+ stack = yaml.safe_load(content)
235
+ if isinstance(stack, dict):
236
+ tiers = stack.get("tiers", [])
237
+ if isinstance(tiers, list):
238
+ return tiers
239
+ except Exception:
240
+ pass
241
+ return []
242
+
243
+
244
+ def _find_running_tier(target: str) -> dict[str, Any] | None:
245
+ """Find a running tier by name.
246
+
247
+ Checks if the tier's PID file exists and the process is alive.
248
+
249
+ Args:
250
+ target: The tier name to look for.
251
+
252
+ Returns:
253
+ The tier dict if running, or None.
254
+ """
255
+ tiers = _load_stack_tiers()
256
+ for tier in tiers:
257
+ tier_name = tier.get("name", "")
258
+ if tier_name == target:
259
+ # Check if running
260
+ try:
261
+ pid = read_pid_file(tier_name)
262
+ if pid is not None and is_process_alive(pid):
263
+ return tier
264
+ except ProcessError:
265
+ pass
266
+ return None
267
+
268
+
269
+ def _get_running_tier_names() -> list[str]:
270
+ """Get names of all running tiers from the active stack."""
271
+ tiers = _load_stack_tiers()
272
+ running = []
273
+ for tier in tiers:
274
+ name = tier.get("name", "")
275
+ try:
276
+ pid = read_pid_file(name)
277
+ if pid is not None and is_process_alive(pid):
278
+ running.append(name)
279
+ except ProcessError:
280
+ pass
281
+ return running
282
+
283
+
284
+ def _get_all_tier_names() -> list[str]:
285
+ """Get names of all tiers in the active stack."""
286
+ tiers = _load_stack_tiers()
287
+ return [t.get("name", "") for t in tiers if t.get("name")]
288
+
289
+
290
+ def _get_used_ports() -> set[int]:
291
+ """Get all ports used by running stack services and LiteLLM."""
292
+ ports: set[int] = set()
293
+ tiers = _load_stack_tiers()
294
+ for tier in tiers:
295
+ port = tier.get("port")
296
+ if port is not None:
297
+ ports.add(int(port))
298
+ # Add LiteLLM port
299
+ try:
300
+ litellm_port = int(get_value("litellm-port"))
301
+ ports.add(litellm_port)
302
+ except (ConfigCorruptError, ValueError, TypeError):
303
+ ports.add(4000) # default
304
+ return ports
305
+
306
+
307
+ def _find_temp_port(used_ports: set[int]) -> int:
308
+ """Find an available port for a temporary vllm-mlx instance.
309
+
310
+ Avoids ports used by running stack and LiteLLM.
311
+
312
+ Args:
313
+ used_ports: Set of ports currently in use.
314
+
315
+ Returns:
316
+ An available port number.
317
+
318
+ Raises:
319
+ BenchmarkError: If no free port can be found.
320
+ """
321
+
322
+ for port in range(TEMP_PORT_START, TEMP_PORT_END):
323
+ if port in used_ports:
324
+ continue
325
+ # Verify the port is actually free
326
+ conflict = check_port_conflict(port)
327
+ if conflict is None:
328
+ return port
329
+
330
+ msg = (
331
+ f"Could not find an available port in range "
332
+ f"{TEMP_PORT_START}-{TEMP_PORT_END} for temporary benchmark instance."
333
+ )
334
+ raise BenchmarkError(msg)
335
+
336
+
337
+ # --------------------------------------------------------------------------- #
338
+ # Benchmark iterations
339
+ # --------------------------------------------------------------------------- #
340
+
341
+
342
+ def _run_single_iteration(
343
+ port: int,
344
+ model_name: str,
345
+ prompt: str,
346
+ max_tokens: int = MAX_GENERATION_TOKENS,
347
+ host: str = "127.0.0.1",
348
+ ) -> IterationResult:
349
+ """Run a single benchmark iteration against a vllm-mlx instance.
350
+
351
+ Uses streaming to separately measure prompt processing time (time to
352
+ first token) and generation time (first token to last token). This
353
+ gives distinct prompt_tps and gen_tps measurements.
354
+
355
+ Args:
356
+ port: Port of the vllm-mlx instance.
357
+ model_name: The model identifier for the API request.
358
+ prompt: The prompt text.
359
+ max_tokens: Maximum tokens to generate.
360
+ host: The host to connect to.
361
+
362
+ Returns:
363
+ An IterationResult with timing data.
364
+
365
+ Raises:
366
+ BenchmarkRunError: If the API request fails.
367
+ """
368
+ url = f"http://{host}:{port}/v1/chat/completions"
369
+ payload = {
370
+ "model": model_name,
371
+ "messages": [{"role": "user", "content": prompt}],
372
+ "max_tokens": max_tokens,
373
+ "temperature": 0.0,
374
+ "stream": True,
375
+ "stream_options": {"include_usage": True},
376
+ }
377
+
378
+ try:
379
+ start_time = time.monotonic()
380
+ first_token_time: float | None = None
381
+ completion_tokens = 0
382
+ prompt_tokens = 0
383
+ chunk_count = 0
384
+
385
+ with httpx.stream("POST", url, json=payload, timeout=300.0) as response:
386
+ if response.status_code != 200:
387
+ body = response.read().decode("utf-8", errors="replace")[:200]
388
+ msg = (
389
+ f"API request failed with status {response.status_code}: "
390
+ f"{body}"
391
+ )
392
+ raise BenchmarkRunError(msg)
393
+
394
+ for line in response.iter_lines():
395
+ if not line.startswith("data: "):
396
+ continue
397
+ data_str = line[6:].strip()
398
+ if data_str == "[DONE]":
399
+ break
400
+
401
+ try:
402
+ chunk = json.loads(data_str)
403
+ except json.JSONDecodeError:
404
+ continue
405
+
406
+ # Check for usage in the final chunk
407
+ usage = chunk.get("usage")
408
+ if usage:
409
+ prompt_tokens = usage.get("prompt_tokens", prompt_tokens)
410
+ completion_tokens = usage.get("completion_tokens", completion_tokens)
411
+
412
+ # Track first content token for TTFT measurement.
413
+ # Thinking models (e.g. Qwen3) emit reasoning_content for
414
+ # <think> tokens instead of content — check both fields so
415
+ # first_token_time is set and TPS is calculated correctly.
416
+ choices = chunk.get("choices", [])
417
+ if choices:
418
+ delta = choices[0].get("delta", {})
419
+ content = delta.get("content") or delta.get("reasoning_content")
420
+ if content and first_token_time is None:
421
+ first_token_time = time.monotonic()
422
+ if content:
423
+ chunk_count += 1
424
+
425
+ end_time = time.monotonic()
426
+ total_time = end_time - start_time
427
+
428
+ # Calculate distinct prompt_tps and gen_tps from timing phases:
429
+ # - prompt_time = time from request start to first generated token
430
+ # - gen_time = time from first generated token to last token
431
+ if first_token_time is not None and prompt_tokens > 0:
432
+ prompt_time = first_token_time - start_time
433
+ gen_time = end_time - first_token_time
434
+
435
+ prompt_tps = prompt_tokens / prompt_time if prompt_time > 0 else 0.0
436
+ gen_tps = completion_tokens / gen_time if gen_time > 0 else 0.0
437
+ else:
438
+ prompt_tps = 0.0
439
+ gen_tps = 0.0
440
+
441
+ return IterationResult(
442
+ prompt_tps=prompt_tps,
443
+ gen_tps=gen_tps,
444
+ prompt_tokens=prompt_tokens,
445
+ completion_tokens=completion_tokens,
446
+ total_time=total_time,
447
+ )
448
+
449
+ except BenchmarkRunError:
450
+ raise
451
+ except httpx.TimeoutException:
452
+ msg = "Benchmark request timed out after 300s"
453
+ raise BenchmarkRunError(msg) from None
454
+ except httpx.HTTPError as exc:
455
+ msg = f"HTTP error during benchmark: {exc}"
456
+ raise BenchmarkRunError(msg) from None
457
+ except (json.JSONDecodeError, KeyError) as exc:
458
+ msg = f"Could not parse benchmark response: {exc}"
459
+ raise BenchmarkRunError(msg) from None
460
+
461
+
462
+ def _run_iterations(
463
+ port: int,
464
+ model_name: str,
465
+ num_iterations: int = NUM_ITERATIONS,
466
+ ) -> list[IterationResult]:
467
+ """Run multiple benchmark iterations.
468
+
469
+ Args:
470
+ port: Port of the vllm-mlx instance.
471
+ model_name: Model identifier for API requests.
472
+ num_iterations: Number of iterations to run.
473
+
474
+ Returns:
475
+ List of IterationResult from each iteration.
476
+
477
+ Raises:
478
+ BenchmarkRunError: If any iteration fails.
479
+ """
480
+ prompt = _generate_prompt(PROMPT_TOKEN_COUNT)
481
+ results: list[IterationResult] = []
482
+
483
+ for _i in range(num_iterations):
484
+ result = _run_single_iteration(
485
+ port=port,
486
+ model_name=model_name,
487
+ prompt=prompt,
488
+ max_tokens=MAX_GENERATION_TOKENS,
489
+ )
490
+ results.append(result)
491
+
492
+ return results
493
+
494
+
495
+ # --------------------------------------------------------------------------- #
496
+ # Statistics
497
+ # --------------------------------------------------------------------------- #
498
+
499
+
500
+ def _compute_stats(values: list[float]) -> tuple[float, float]:
501
+ """Compute mean and standard deviation.
502
+
503
+ Args:
504
+ values: List of numeric values.
505
+
506
+ Returns:
507
+ Tuple of (mean, std_dev).
508
+ """
509
+ if not values:
510
+ return 0.0, 0.0
511
+
512
+ n = len(values)
513
+ mean = sum(values) / n
514
+
515
+ if n < 2:
516
+ return mean, 0.0
517
+
518
+ variance = sum((v - mean) ** 2 for v in values) / (n - 1)
519
+ std_dev = math.sqrt(variance)
520
+ return mean, std_dev
521
+
522
+
523
+ # --------------------------------------------------------------------------- #
524
+ # Catalog comparison
525
+ # --------------------------------------------------------------------------- #
526
+
527
+
528
+ def _classify_metric(
529
+ metric_name: str,
530
+ measured: float,
531
+ catalog_value: float,
532
+ ) -> MetricClassification:
533
+ """Classify a metric against the catalog threshold.
534
+
535
+ PASS: within 15% of catalog value
536
+ WARN: 15-30% below catalog value
537
+ FAIL: more than 30% below catalog value
538
+
539
+ Args:
540
+ metric_name: Name of the metric (e.g. "prompt_tps").
541
+ measured: Measured value.
542
+ catalog_value: Catalog reference value.
543
+
544
+ Returns:
545
+ A MetricClassification instance.
546
+ """
547
+ if catalog_value <= 0:
548
+ return MetricClassification(
549
+ metric=metric_name,
550
+ measured=measured,
551
+ catalog=catalog_value,
552
+ delta_pct=0.0,
553
+ classification=CLASSIFICATION_PASS,
554
+ )
555
+
556
+ delta_pct = (catalog_value - measured) / catalog_value
557
+
558
+ if delta_pct <= PASS_THRESHOLD:
559
+ classification = CLASSIFICATION_PASS
560
+ elif delta_pct <= WARN_THRESHOLD:
561
+ classification = CLASSIFICATION_WARN
562
+ else:
563
+ classification = CLASSIFICATION_FAIL
564
+
565
+ return MetricClassification(
566
+ metric=metric_name,
567
+ measured=measured,
568
+ catalog=catalog_value,
569
+ delta_pct=delta_pct * 100, # Store as percentage
570
+ classification=classification,
571
+ )
572
+
573
+
574
+ def _compare_against_catalog(
575
+ prompt_tps_mean: float,
576
+ gen_tps_mean: float,
577
+ entry: CatalogEntry,
578
+ profile: HardwareProfile,
579
+ ) -> list[MetricClassification]:
580
+ """Compare benchmark results against catalog thresholds.
581
+
582
+ Args:
583
+ prompt_tps_mean: Mean prompt tokens/sec.
584
+ gen_tps_mean: Mean generation tokens/sec.
585
+ entry: The catalog entry for the model.
586
+ profile: The hardware profile for benchmark lookup.
587
+
588
+ Returns:
589
+ List of MetricClassification, empty if no catalog data.
590
+ """
591
+ profile_id = profile.profile_id
592
+ if profile_id not in entry.benchmarks:
593
+ return []
594
+
595
+ bench = entry.benchmarks[profile_id]
596
+ classifications: list[MetricClassification] = []
597
+
598
+ classifications.append(
599
+ _classify_metric("prompt_tps", prompt_tps_mean, bench.prompt_tps)
600
+ )
601
+ classifications.append(
602
+ _classify_metric("gen_tps", gen_tps_mean, bench.gen_tps)
603
+ )
604
+
605
+ return classifications
606
+
607
+
608
+ # --------------------------------------------------------------------------- #
609
+ # Tool-calling benchmark
610
+ # --------------------------------------------------------------------------- #
611
+
612
+
613
+ def _run_tool_call_benchmark(
614
+ port: int,
615
+ model_name: str,
616
+ host: str = "127.0.0.1",
617
+ ) -> ToolCallResult:
618
+ """Run a tool-calling benchmark.
619
+
620
+ Sends a function-calling request with a tool definition and verifies
621
+ the response contains a valid tool call structure.
622
+
623
+ Args:
624
+ port: Port of the vllm-mlx instance.
625
+ model_name: Model identifier.
626
+ host: Host to connect to.
627
+
628
+ Returns:
629
+ A ToolCallResult with timing and validity information.
630
+ """
631
+ url = f"http://{host}:{port}/v1/chat/completions"
632
+ payload = {
633
+ "model": model_name,
634
+ "messages": [
635
+ {
636
+ "role": "user",
637
+ "content": "What is the current weather in San Francisco, CA?",
638
+ }
639
+ ],
640
+ "tools": [TOOL_DEFINITION],
641
+ "tool_choice": "auto",
642
+ "max_tokens": 1024,
643
+ "temperature": 0.0,
644
+ }
645
+
646
+ try:
647
+ start_time = time.monotonic()
648
+ response = httpx.post(url, json=payload, timeout=60.0)
649
+ round_trip_time = time.monotonic() - start_time
650
+
651
+ if response.status_code != 200:
652
+ return ToolCallResult(
653
+ success=False,
654
+ round_trip_time=round_trip_time,
655
+ error=f"API returned status {response.status_code}",
656
+ )
657
+
658
+ data = response.json()
659
+ choices = data.get("choices", [])
660
+ if not choices:
661
+ return ToolCallResult(
662
+ success=False,
663
+ round_trip_time=round_trip_time,
664
+ error="No choices in response",
665
+ )
666
+
667
+ message = choices[0].get("message", {})
668
+ tool_calls = message.get("tool_calls", [])
669
+ if not tool_calls:
670
+ return ToolCallResult(
671
+ success=False,
672
+ round_trip_time=round_trip_time,
673
+ error="No tool calls in response",
674
+ )
675
+
676
+ # Validate tool call structure
677
+ tool_call = tool_calls[0]
678
+ fn = tool_call.get("function", {})
679
+ fn_name = fn.get("name", "")
680
+ fn_args_str = fn.get("arguments", "")
681
+
682
+ if fn_name != "get_weather":
683
+ return ToolCallResult(
684
+ success=False,
685
+ round_trip_time=round_trip_time,
686
+ error=f"Wrong function name: {fn_name}",
687
+ )
688
+
689
+ # Try to parse arguments as JSON
690
+ try:
691
+ fn_args = json.loads(fn_args_str) if isinstance(fn_args_str, str) else fn_args_str
692
+ if "location" not in fn_args:
693
+ return ToolCallResult(
694
+ success=False,
695
+ round_trip_time=round_trip_time,
696
+ error="Missing 'location' in tool call arguments",
697
+ )
698
+ except (json.JSONDecodeError, TypeError):
699
+ return ToolCallResult(
700
+ success=False,
701
+ round_trip_time=round_trip_time,
702
+ error="Could not parse tool call arguments as JSON",
703
+ )
704
+
705
+ return ToolCallResult(
706
+ success=True,
707
+ round_trip_time=round_trip_time,
708
+ )
709
+
710
+ except httpx.TimeoutException:
711
+ return ToolCallResult(
712
+ success=False,
713
+ round_trip_time=60.0,
714
+ error="Tool call request timed out",
715
+ )
716
+ except (httpx.HTTPError, Exception) as exc:
717
+ return ToolCallResult(
718
+ success=False,
719
+ round_trip_time=0.0,
720
+ error=f"Tool call request failed: {exc}",
721
+ )
722
+
723
+
724
+ # --------------------------------------------------------------------------- #
725
+ # Temporary instance management
726
+ # --------------------------------------------------------------------------- #
727
+
728
+
729
+ def _resolve_model_source(model_id: str, quant: str) -> str:
730
+ """Resolve the model source path for vllm-mlx.
731
+
732
+ Checks local models directory first, then falls back to the catalog
733
+ HuggingFace repo.
734
+
735
+ Args:
736
+ model_id: Catalog model ID.
737
+ quant: Quantization level.
738
+
739
+ Returns:
740
+ The model source path or HF repo identifier.
741
+
742
+ Raises:
743
+ BenchmarkError: If the model cannot be resolved.
744
+ """
745
+ catalog = load_catalog()
746
+ entry = get_entry_by_id(catalog, model_id)
747
+ if entry is None:
748
+ msg = f"Model '{model_id}' not found in catalog."
749
+ raise BenchmarkTargetError(msg)
750
+
751
+ if quant not in entry.sources:
752
+ available_quants = ", ".join(sorted(entry.sources.keys()))
753
+ msg = (
754
+ f"Quantization '{quant}' not available for model '{model_id}'. "
755
+ f"Available: {available_quants}"
756
+ )
757
+ raise BenchmarkError(msg)
758
+
759
+ source = entry.sources[quant]
760
+
761
+ # Check local models directory
762
+ try:
763
+ model_dir = str(get_value("model-dir"))
764
+ models_path = Path(model_dir).expanduser()
765
+ except (ConfigCorruptError, Exception):
766
+ models_path = get_data_home() / "models"
767
+
768
+ # Check by repo directory name
769
+ repo_dir_name = source.hf_repo.rsplit("/", 1)[-1] if "/" in source.hf_repo else source.hf_repo
770
+ local_path = models_path / repo_dir_name
771
+ if local_path.exists():
772
+ return str(local_path)
773
+
774
+ # Check by model ID
775
+ model_path = models_path / model_id
776
+ if model_path.exists():
777
+ return str(model_path)
778
+
779
+ # Use HuggingFace repo directly
780
+ return source.hf_repo
781
+
782
+
783
+ def _start_temp_instance(
784
+ model_source: str,
785
+ port: int,
786
+ entry: CatalogEntry,
787
+ quant: str,
788
+ ) -> str:
789
+ """Start a temporary vllm-mlx instance for benchmarking.
790
+
791
+ Args:
792
+ model_source: Model source path or HF repo.
793
+ port: Port to bind the instance to.
794
+ entry: Catalog entry for the model.
795
+ quant: Quantization level.
796
+
797
+ Returns:
798
+ The service name used for PID management.
799
+
800
+ Raises:
801
+ BenchmarkError: If the instance cannot be started.
802
+ """
803
+ service_name = f"{TEMP_SERVICE_PREFIX}-{entry.id}"
804
+
805
+ # Ensure vllm-mlx is installed
806
+ ensure_dependency("vllm-mlx")
807
+
808
+ vllm_binary = shutil.which("vllm-mlx")
809
+ if vllm_binary is None:
810
+ msg = (
811
+ "vllm-mlx binary not found on PATH after installation. "
812
+ "Try: uv tool install vllm-mlx"
813
+ )
814
+ raise BenchmarkError(msg)
815
+
816
+ cmd = [
817
+ vllm_binary,
818
+ "serve", model_source,
819
+ "--port", str(port),
820
+ "--host", "127.0.0.1",
821
+ ]
822
+
823
+ # Add tool-calling flags if the model supports it
824
+ if entry.capabilities.tool_calling:
825
+ cmd.append("--enable-auto-tool-choice")
826
+ if entry.capabilities.tool_call_parser:
827
+ cmd.extend(["--tool-call-parser", entry.capabilities.tool_call_parser])
828
+
829
+ # Add reasoning parser for thinking models (e.g., Qwen3).
830
+ # Without this flag, thinking models emit <think> tags that break
831
+ # the tool-call parser (e.g., hermes).
832
+ if entry.capabilities.thinking and entry.capabilities.reasoning_parser:
833
+ cmd.extend(["--reasoning-parser", entry.capabilities.reasoning_parser])
834
+
835
+ try:
836
+ start_service(
837
+ service_name=service_name,
838
+ cmd=cmd,
839
+ port=port,
840
+ )
841
+ except ProcessError as exc:
842
+ msg = f"Could not start temporary vllm-mlx instance: {exc}"
843
+ raise BenchmarkError(msg) from None
844
+
845
+ # Wait for health
846
+ try:
847
+ wait_for_healthy(
848
+ port=port,
849
+ path="/v1/models",
850
+ total_timeout=TEMP_INSTANCE_HEALTH_TIMEOUT,
851
+ )
852
+ except HealthCheckError:
853
+ # Clean up the failed instance
854
+ _cleanup_temp_instance(service_name)
855
+ msg = (
856
+ f"Temporary vllm-mlx instance did not become healthy "
857
+ f"within {TEMP_INSTANCE_HEALTH_TIMEOUT}s."
858
+ )
859
+ raise BenchmarkError(msg) from None
860
+
861
+ return service_name
862
+
863
+
864
+ def _cleanup_temp_instance(service_name: str) -> None:
865
+ """Clean up a temporary vllm-mlx instance.
866
+
867
+ Sends SIGTERM then SIGKILL to ensure full cleanup.
868
+
869
+ Args:
870
+ service_name: The service name from PID management.
871
+ """
872
+ try:
873
+ stop_service(service_name, grace_period=5.0)
874
+ except Exception:
875
+ pass
876
+
877
+ # Double-check: try reading PID and kill directly if still alive
878
+ try:
879
+ pid = read_pid_file(service_name)
880
+ if pid is not None and is_process_alive(pid):
881
+ os.kill(pid, signal.SIGKILL)
882
+ time.sleep(0.5)
883
+ except (ProcessError, OSError):
884
+ pass
885
+
886
+ # Remove PID file
887
+ remove_pid_file(service_name)
888
+
889
+
890
+ # --------------------------------------------------------------------------- #
891
+ # Save benchmark results
892
+ # --------------------------------------------------------------------------- #
893
+
894
+
895
+ def save_benchmark_results(
896
+ result: BenchmarkResult_,
897
+ profile: HardwareProfile,
898
+ ) -> Path:
899
+ """Save benchmark results to ~/.mlx-stack/benchmarks/<profile_id>.json.
900
+
901
+ Loads existing data (if any) and merges/updates with the new result.
902
+
903
+ Args:
904
+ result: The benchmark result to save.
905
+ profile: The hardware profile for path determination.
906
+
907
+ Returns:
908
+ Path to the saved file.
909
+ """
910
+ ensure_data_home()
911
+ benchmarks_dir = get_benchmarks_dir()
912
+ benchmarks_dir.mkdir(parents=True, exist_ok=True)
913
+
914
+ benchmark_file = benchmarks_dir / f"{profile.profile_id}.json"
915
+
916
+ # Load existing data
917
+ existing: dict[str, Any] = {}
918
+ if benchmark_file.exists():
919
+ try:
920
+ existing = json.loads(benchmark_file.read_text(encoding="utf-8"))
921
+ if not isinstance(existing, dict):
922
+ existing = {}
923
+ except (json.JSONDecodeError, OSError):
924
+ existing = {}
925
+
926
+ # Merge new result
927
+ existing[result.model_id] = result.to_save_dict()
928
+
929
+ # Write back
930
+ benchmark_file.write_text(
931
+ json.dumps(existing, indent=2) + "\n",
932
+ encoding="utf-8",
933
+ )
934
+
935
+ return benchmark_file
936
+
937
+
938
+ # --------------------------------------------------------------------------- #
939
+ # Resolve benchmark target
940
+ # --------------------------------------------------------------------------- #
941
+
942
+
943
+ @dataclass
944
+ class BenchmarkTarget:
945
+ """Resolved target for benchmarking."""
946
+
947
+ model_id: str
948
+ quant: str
949
+ port: int
950
+ model_name: str # The model name used in API calls
951
+ entry: CatalogEntry
952
+ is_running_tier: bool
953
+ temp_service_name: str | None = None # Set if using a temp instance
954
+
955
+
956
+ def resolve_target(target: str) -> BenchmarkTarget:
957
+ """Resolve a benchmark target to a specific model and port.
958
+
959
+ Tries in order:
960
+ 1. Running tier by name
961
+ 2. Catalog model by ID (starts temp instance)
962
+
963
+ Args:
964
+ target: Tier name or model ID.
965
+
966
+ Returns:
967
+ A BenchmarkTarget with all needed info.
968
+
969
+ Raises:
970
+ BenchmarkTargetError: If the target cannot be resolved.
971
+ """
972
+ # 1. Try as a running tier
973
+ tier = _find_running_tier(target)
974
+ if tier is not None:
975
+ port = tier["port"]
976
+ model_id = tier.get("model", target)
977
+ quant = tier.get("quant", "int4")
978
+ source = tier.get("source", model_id)
979
+
980
+ catalog = load_catalog()
981
+ entry = get_entry_by_id(catalog, model_id)
982
+ if entry is None:
983
+ # Still benchmark it even without catalog data
984
+ from mlx_stack.core.catalog import (
985
+ Capabilities,
986
+ CatalogEntry,
987
+ QualityScores,
988
+ )
989
+
990
+ entry = CatalogEntry(
991
+ id=model_id,
992
+ name=model_id,
993
+ family="unknown",
994
+ params_b=0.0,
995
+ architecture="unknown",
996
+ min_mlx_lm_version="0.0.0",
997
+ sources={},
998
+ capabilities=Capabilities(
999
+ tool_calling=False,
1000
+ tool_call_parser=None,
1001
+ thinking=False,
1002
+ reasoning_parser=None,
1003
+ vision=False,
1004
+ ),
1005
+ quality=QualityScores(overall=0, coding=0, reasoning=0, instruction_following=0),
1006
+ benchmarks={},
1007
+ tags=[],
1008
+ )
1009
+
1010
+ return BenchmarkTarget(
1011
+ model_id=model_id,
1012
+ quant=quant,
1013
+ port=port,
1014
+ model_name=source,
1015
+ entry=entry,
1016
+ is_running_tier=True,
1017
+ )
1018
+
1019
+ # 2. Try as a catalog model
1020
+ catalog = load_catalog()
1021
+ entry = get_entry_by_id(catalog, target)
1022
+ if entry is not None:
1023
+ # Determine quant
1024
+ try:
1025
+ quant = str(get_value("default-quant"))
1026
+ except (ConfigCorruptError, Exception):
1027
+ quant = "int4"
1028
+
1029
+ if quant not in entry.sources:
1030
+ available = sorted(entry.sources.keys())
1031
+ quant = available[0] if available else "int4"
1032
+
1033
+ # Resolve model source
1034
+ model_source = _resolve_model_source(target, quant)
1035
+
1036
+ # Find a free port
1037
+ used_ports = _get_used_ports()
1038
+ port = _find_temp_port(used_ports)
1039
+
1040
+ # Start temp instance
1041
+ service_name = _start_temp_instance(model_source, port, entry, quant)
1042
+
1043
+ return BenchmarkTarget(
1044
+ model_id=entry.id,
1045
+ quant=quant,
1046
+ port=port,
1047
+ model_name=model_source,
1048
+ entry=entry,
1049
+ is_running_tier=False,
1050
+ temp_service_name=service_name,
1051
+ )
1052
+
1053
+ # Neither tier nor model
1054
+ tier_names = _get_all_tier_names()
1055
+ running_tiers = _get_running_tier_names()
1056
+
1057
+ parts: list[str] = [f"'{target}' is not a running tier or known model."]
1058
+
1059
+ if tier_names:
1060
+ parts.append(f"\nValid tier names: {', '.join(tier_names)}")
1061
+ if running_tiers:
1062
+ parts.append(f"Running tiers: {', '.join(running_tiers)}")
1063
+
1064
+ parts.append("\nUse 'mlx-stack models --catalog' to see available models.")
1065
+
1066
+ raise BenchmarkTargetError("\n".join(parts))
1067
+
1068
+
1069
+ # --------------------------------------------------------------------------- #
1070
+ # Main benchmark function
1071
+ # --------------------------------------------------------------------------- #
1072
+
1073
+
1074
+ def run_benchmark(
1075
+ target: str,
1076
+ save: bool = False,
1077
+ ) -> BenchmarkResult_:
1078
+ """Run a complete benchmark against a tier or model.
1079
+
1080
+ This is the main entry point for the benchmark engine.
1081
+
1082
+ Args:
1083
+ target: Tier name or model ID.
1084
+ save: Whether to persist results to disk.
1085
+
1086
+ Returns:
1087
+ A BenchmarkResult_ with all collected data.
1088
+
1089
+ Raises:
1090
+ BenchmarkError: If the benchmark fails.
1091
+ BenchmarkTargetError: If the target is not found.
1092
+ """
1093
+ # Auto-install vllm-mlx (needed for API calls even to running tiers)
1094
+ ensure_dependency("vllm-mlx")
1095
+
1096
+ resolved = resolve_target(target)
1097
+ temp_service = resolved.temp_service_name
1098
+
1099
+ try:
1100
+ result = _execute_benchmark(resolved, save)
1101
+ return result
1102
+ except Exception:
1103
+ # Ensure cleanup on any failure
1104
+ if temp_service:
1105
+ _cleanup_temp_instance(temp_service)
1106
+ raise
1107
+ finally:
1108
+ # Always clean up temp instance
1109
+ if temp_service:
1110
+ _cleanup_temp_instance(temp_service)
1111
+
1112
+
1113
+ def _execute_benchmark(
1114
+ target: BenchmarkTarget,
1115
+ save: bool,
1116
+ ) -> BenchmarkResult_:
1117
+ """Execute the benchmark iterations and comparison.
1118
+
1119
+ Args:
1120
+ target: The resolved benchmark target.
1121
+ save: Whether to persist results.
1122
+
1123
+ Returns:
1124
+ A BenchmarkResult_.
1125
+ """
1126
+ # Run iterations
1127
+ iterations = _run_iterations(
1128
+ port=target.port,
1129
+ model_name=target.model_name,
1130
+ num_iterations=NUM_ITERATIONS,
1131
+ )
1132
+
1133
+ # Compute stats
1134
+ prompt_tps_values = [it.prompt_tps for it in iterations]
1135
+ gen_tps_values = [it.gen_tps for it in iterations]
1136
+
1137
+ prompt_tps_mean, prompt_tps_std = _compute_stats(prompt_tps_values)
1138
+ gen_tps_mean, gen_tps_std = _compute_stats(gen_tps_values)
1139
+
1140
+ # Compare against catalog
1141
+ profile = load_profile()
1142
+ if profile is None:
1143
+ try:
1144
+ profile = detect_hardware()
1145
+ except Exception:
1146
+ profile = None
1147
+
1148
+ classifications: list[MetricClassification] = []
1149
+ catalog_data_available = False
1150
+ if profile is not None:
1151
+ classifications = _compare_against_catalog(
1152
+ prompt_tps_mean, gen_tps_mean, target.entry, profile
1153
+ )
1154
+ catalog_data_available = len(classifications) > 0
1155
+
1156
+ # Tool-calling benchmark
1157
+ tool_call_result: ToolCallResult | None = None
1158
+ if target.entry.capabilities.tool_calling:
1159
+ tool_call_result = _run_tool_call_benchmark(
1160
+ port=target.port,
1161
+ model_name=target.model_name,
1162
+ )
1163
+
1164
+ bench_result = BenchmarkResult_(
1165
+ model_id=target.model_id,
1166
+ quant=target.quant,
1167
+ iterations=iterations,
1168
+ prompt_tps_mean=prompt_tps_mean,
1169
+ prompt_tps_std=prompt_tps_std,
1170
+ gen_tps_mean=gen_tps_mean,
1171
+ gen_tps_std=gen_tps_std,
1172
+ classifications=classifications,
1173
+ tool_call_result=tool_call_result,
1174
+ used_temporary_instance=not target.is_running_tier,
1175
+ catalog_data_available=catalog_data_available,
1176
+ )
1177
+
1178
+ # Save results if requested
1179
+ if save and profile is not None:
1180
+ save_benchmark_results(bench_result, profile)
1181
+
1182
+ return bench_result