dflash-mlx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/__init__.py +0 -0
- benchmark/benchmark.py +600 -0
- dflash_mlx/__init__.py +5 -0
- dflash_mlx/generate.py +152 -0
- dflash_mlx/kernels.py +761 -0
- dflash_mlx/model.py +324 -0
- dflash_mlx/recurrent_rollback_cache.py +165 -0
- dflash_mlx/runtime.py +1789 -0
- dflash_mlx/serve.py +320 -0
- dflash_mlx-0.1.0.dist-info/METADATA +149 -0
- dflash_mlx-0.1.0.dist-info/RECORD +15 -0
- dflash_mlx-0.1.0.dist-info/WHEEL +5 -0
- dflash_mlx-0.1.0.dist-info/entry_points.txt +4 -0
- dflash_mlx-0.1.0.dist-info/licenses/LICENSE +21 -0
- dflash_mlx-0.1.0.dist-info/top_level.txt +2 -0
benchmark/__init__.py
ADDED
|
File without changes
|
benchmark/benchmark.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
# Copyright 2026 bstnxbt
|
|
2
|
+
# MIT License — see LICENSE file
|
|
3
|
+
# Based on DFlash (arXiv:2602.06036)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import gc
|
|
8
|
+
import json
|
|
9
|
+
import platform
|
|
10
|
+
import re
|
|
11
|
+
import statistics
|
|
12
|
+
import subprocess
|
|
13
|
+
import sys
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import mlx.core as mx
|
|
19
|
+
from mlx_lm import stream_generate as mlx_stream_generate
|
|
20
|
+
from mlx_lm.utils import load as load_pristine_target
|
|
21
|
+
|
|
22
|
+
from dflash_mlx.runtime import (
|
|
23
|
+
generate_dflash_once,
|
|
24
|
+
load_draft_bundle,
|
|
25
|
+
load_target_bundle,
|
|
26
|
+
resolve_model_ref,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
DEFAULT_SCHEDULES: tuple[int, ...] = (8, 16, 32)
|
|
30
|
+
DEFAULT_REPEAT = 3
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _git_hash_short() -> str:
|
|
34
|
+
return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _hardware_info() -> dict[str, str]:
|
|
38
|
+
return {
|
|
39
|
+
"chip": subprocess.check_output(
|
|
40
|
+
["sysctl", "-n", "machdep.cpu.brand_string"], text=True
|
|
41
|
+
).strip(),
|
|
42
|
+
"memory_gb": str(
|
|
43
|
+
int(subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip())
|
|
44
|
+
// (1024**3)
|
|
45
|
+
),
|
|
46
|
+
"mlx_version": mx.__version__,
|
|
47
|
+
"python": platform.python_version(),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_thermal_pressure() -> str:
|
|
52
|
+
try:
|
|
53
|
+
out = subprocess.check_output(["pmset", "-g", "therm"], text=True, timeout=2)
|
|
54
|
+
for line in out.splitlines():
|
|
55
|
+
if "CPU_Scheduler_Limit" not in line:
|
|
56
|
+
continue
|
|
57
|
+
val = int(line.strip().split("=")[-1].strip())
|
|
58
|
+
if val == 100:
|
|
59
|
+
return "nominal"
|
|
60
|
+
if val >= 80:
|
|
61
|
+
return "fair"
|
|
62
|
+
if val >= 50:
|
|
63
|
+
return "serious"
|
|
64
|
+
return "critical"
|
|
65
|
+
except Exception:
|
|
66
|
+
pass
|
|
67
|
+
return "unknown"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _warn_if_throttled(thermal_pressure: str) -> None:
|
|
71
|
+
if thermal_pressure == "nominal":
|
|
72
|
+
return
|
|
73
|
+
print(
|
|
74
|
+
f"WARNING: thermal pressure is '{thermal_pressure}' — results may be throttled. "
|
|
75
|
+
"Increase --cooldown or wait for chip to cool.",
|
|
76
|
+
file=sys.stderr,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _slugify_prompt_id(prompt: str) -> str:
|
|
81
|
+
slug = re.sub(r"[^a-z0-9]+", "_", prompt.lower()).strip("_")
|
|
82
|
+
slug = re.sub(r"_+", "_", slug)
|
|
83
|
+
return slug[:48] or "prompt"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _slugify_model_ref(model_ref: str | None) -> str:
|
|
87
|
+
resolved = resolve_model_ref(model_ref, kind="target")
|
|
88
|
+
label = Path(str(resolved)).name or str(resolved)
|
|
89
|
+
label = re.sub(r"[^a-z0-9]+", "-", label.lower())
|
|
90
|
+
label = re.sub(r"-+", "-", label).strip("-")
|
|
91
|
+
return label or "model"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _default_results_path(*, target_model_ref: str | None, max_new_tokens: int) -> Path:
|
|
95
|
+
return Path("benchmark/results") / f"{_slugify_model_ref(target_model_ref)}-{int(max_new_tokens)}.json"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _strip_generation_payload(result: dict[str, Any], *, drop_phase_timings: bool = False) -> dict[str, Any]:
|
|
99
|
+
cleaned = dict(result)
|
|
100
|
+
cleaned.pop("generated_token_ids", None)
|
|
101
|
+
if drop_phase_timings:
|
|
102
|
+
phase_timings = dict(cleaned.pop("phase_timings_us", {}) or {})
|
|
103
|
+
if "prefill" in phase_timings and "prefill_us" not in cleaned:
|
|
104
|
+
cleaned["prefill_us"] = float(phase_timings["prefill"])
|
|
105
|
+
return cleaned
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _format_run_entry(run: dict[str, Any]) -> dict[str, Any]:
|
|
109
|
+
baseline = dict(run["baseline"])
|
|
110
|
+
dflash = dict(run["dflash"])
|
|
111
|
+
return {
|
|
112
|
+
"run": int(run["run_index"]),
|
|
113
|
+
"thermal_pressure": str(run.get("thermal_pressure", "unknown")),
|
|
114
|
+
"baseline": {
|
|
115
|
+
"ttft_ms": float(run["baseline_ttft_ms"]),
|
|
116
|
+
"generation_tps": float(run["baseline_generation_tps"]),
|
|
117
|
+
"peak_memory_gb": baseline.get("peak_memory_gb"),
|
|
118
|
+
},
|
|
119
|
+
"dflash": {
|
|
120
|
+
"ttft_ms": float(run["dflash_ttft_ms"]),
|
|
121
|
+
"generation_tps": float(run["dflash_generation_tps"]),
|
|
122
|
+
"tokens_per_cycle": float(dflash.get("tokens_per_cycle", 0.0)),
|
|
123
|
+
"cycles": int(dflash.get("cycles_completed", 0)),
|
|
124
|
+
"acceptance_ratio": float(dflash.get("acceptance_ratio", 0.0)),
|
|
125
|
+
"acceptance_first_20_avg": float(dflash.get("acceptance_first_20_avg", 0.0)),
|
|
126
|
+
"acceptance_last_20_avg": float(dflash.get("acceptance_last_20_avg", 0.0)),
|
|
127
|
+
"peak_memory_gb": dflash.get("peak_memory_gb"),
|
|
128
|
+
},
|
|
129
|
+
"speedup": float(run["generation_speedup_vs_baseline"]) if run["generation_speedup_vs_baseline"] is not None else None,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _build_config(
|
|
134
|
+
*,
|
|
135
|
+
prompt: str,
|
|
136
|
+
prompt_tokens: int,
|
|
137
|
+
max_new_tokens: int,
|
|
138
|
+
block_tokens: int,
|
|
139
|
+
repeat: int,
|
|
140
|
+
cooldown: int,
|
|
141
|
+
target_model: str,
|
|
142
|
+
draft_model: str,
|
|
143
|
+
) -> dict[str, Any]:
|
|
144
|
+
return {
|
|
145
|
+
"target_model": target_model,
|
|
146
|
+
"draft_model": draft_model,
|
|
147
|
+
"max_new_tokens": int(max_new_tokens),
|
|
148
|
+
"block_tokens": int(block_tokens),
|
|
149
|
+
"cooldown": int(cooldown),
|
|
150
|
+
"prompt": prompt,
|
|
151
|
+
"prompt_tokens": int(prompt_tokens),
|
|
152
|
+
"prompt_id": _slugify_prompt_id(prompt),
|
|
153
|
+
"repeat": int(repeat),
|
|
154
|
+
"git_hash": _git_hash_short(),
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _build_single_case_report(
|
|
159
|
+
*,
|
|
160
|
+
prompt: str,
|
|
161
|
+
max_new_tokens: int,
|
|
162
|
+
block_tokens: int,
|
|
163
|
+
repeat: int,
|
|
164
|
+
cooldown: int,
|
|
165
|
+
runs: list[dict[str, Any]],
|
|
166
|
+
target_model: str,
|
|
167
|
+
draft_model: str,
|
|
168
|
+
) -> dict[str, Any]:
|
|
169
|
+
run_entries = [_format_run_entry(run) for run in runs]
|
|
170
|
+
baseline_tps_values = [float(run["baseline_generation_tps"]) for run in runs]
|
|
171
|
+
dflash_tps_values = [float(run["dflash_generation_tps"]) for run in runs]
|
|
172
|
+
speedup_values = [float(run["generation_speedup_vs_baseline"]) for run in runs if run["generation_speedup_vs_baseline"] is not None]
|
|
173
|
+
acceptance_ratio_values = [float(run["dflash"]["acceptance_ratio"]) for run in runs]
|
|
174
|
+
prompt_tokens = int(runs[0]["baseline"]["prompt_token_count"]) if runs else 0
|
|
175
|
+
return {
|
|
176
|
+
"hardware": _hardware_info(),
|
|
177
|
+
"config": _build_config(
|
|
178
|
+
prompt=prompt,
|
|
179
|
+
prompt_tokens=prompt_tokens,
|
|
180
|
+
max_new_tokens=max_new_tokens,
|
|
181
|
+
block_tokens=block_tokens,
|
|
182
|
+
repeat=repeat,
|
|
183
|
+
cooldown=cooldown,
|
|
184
|
+
target_model=target_model,
|
|
185
|
+
draft_model=draft_model,
|
|
186
|
+
),
|
|
187
|
+
"runs": run_entries,
|
|
188
|
+
"summary": {
|
|
189
|
+
"baseline_tps_median": statistics.median(baseline_tps_values) if baseline_tps_values else None,
|
|
190
|
+
"dflash_tps_median": statistics.median(dflash_tps_values) if dflash_tps_values else None,
|
|
191
|
+
"dflash_tps_min": min(dflash_tps_values) if dflash_tps_values else None,
|
|
192
|
+
"dflash_tps_max": max(dflash_tps_values) if dflash_tps_values else None,
|
|
193
|
+
"speedup_median": statistics.median(speedup_values) if speedup_values else None,
|
|
194
|
+
"acceptance_ratio_median": statistics.median(acceptance_ratio_values) if acceptance_ratio_values else None,
|
|
195
|
+
},
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def get_stop_token_ids(tokenizer: Any) -> list[int]:
|
|
200
|
+
eos_token_ids = list(getattr(tokenizer, "eos_token_ids", None) or [])
|
|
201
|
+
eos_token_id = getattr(tokenizer, "eos_token_id", None)
|
|
202
|
+
if eos_token_id is not None and eos_token_id not in eos_token_ids:
|
|
203
|
+
eos_token_ids.append(int(eos_token_id))
|
|
204
|
+
return eos_token_ids
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _speedup(baseline_elapsed: float, dflash_elapsed: float) -> float | None:
|
|
208
|
+
return baseline_elapsed / dflash_elapsed if dflash_elapsed > 0.0 else None
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _generation_speedup(baseline_tps: float, dflash_tps: float) -> float | None:
|
|
212
|
+
return dflash_tps / baseline_tps if baseline_tps > 0.0 else None
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _ttft_ms_from_baseline(result: dict[str, Any]) -> float:
|
|
216
|
+
return float(result.get("prefill_us", 0.0)) / 1_000.0
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _ttft_ms_from_dflash(result: dict[str, Any]) -> float:
|
|
220
|
+
phase_timings = dict(result.get("phase_timings_us", {}))
|
|
221
|
+
return float(phase_timings.get("prefill", 0.0)) / 1_000.0
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _generation_tps_from_baseline(result: dict[str, Any]) -> float:
|
|
225
|
+
if "generation_tps" in result:
|
|
226
|
+
return float(result["generation_tps"])
|
|
227
|
+
elapsed_us = float(result.get("elapsed_us", 0.0))
|
|
228
|
+
prefill_us = float(result.get("prefill_us", 0.0))
|
|
229
|
+
generation_tokens = int(result.get("generation_tokens", 0))
|
|
230
|
+
generation_us = max(0.0, elapsed_us - prefill_us)
|
|
231
|
+
return (generation_tokens / (generation_us / 1e6)) if generation_us > 0.0 else 0.0
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _generation_tps_from_dflash(result: dict[str, Any]) -> float:
|
|
235
|
+
elapsed_us = float(result.get("elapsed_us", 0.0))
|
|
236
|
+
phase_timings = dict(result.get("phase_timings_us", {}))
|
|
237
|
+
prefill_us = float(phase_timings.get("prefill", 0.0))
|
|
238
|
+
generation_tokens = int(result.get("generation_tokens", 0))
|
|
239
|
+
generation_us = max(0.0, elapsed_us - prefill_us)
|
|
240
|
+
return (generation_tokens / (generation_us / 1e6)) if generation_us > 0.0 else 0.0
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _load_pristine_target_bundle(model_ref: str | None):
|
|
244
|
+
resolved_ref = resolve_model_ref(model_ref, kind="target")
|
|
245
|
+
model, tokenizer, config = load_pristine_target(resolved_ref, lazy=True, return_config=True)
|
|
246
|
+
return model, tokenizer, {"resolved_model_ref": resolved_ref, "config": config}
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _generate_stock_baseline_once(
|
|
250
|
+
*,
|
|
251
|
+
target_model: Any,
|
|
252
|
+
tokenizer: Any,
|
|
253
|
+
prompt: str,
|
|
254
|
+
max_new_tokens: int,
|
|
255
|
+
no_eos: bool,
|
|
256
|
+
) -> dict[str, Any]:
|
|
257
|
+
if hasattr(mx, "reset_peak_memory"):
|
|
258
|
+
try:
|
|
259
|
+
mx.reset_peak_memory()
|
|
260
|
+
except Exception:
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
original_eos_token_ids = getattr(tokenizer, "eos_token_ids", None)
|
|
264
|
+
original_eos_token_id = getattr(tokenizer, "eos_token_id", None)
|
|
265
|
+
if no_eos:
|
|
266
|
+
try:
|
|
267
|
+
tokenizer.eos_token_ids = set()
|
|
268
|
+
except Exception:
|
|
269
|
+
tokenizer.eos_token_ids = []
|
|
270
|
+
try:
|
|
271
|
+
tokenizer.eos_token_id = None
|
|
272
|
+
except Exception:
|
|
273
|
+
pass
|
|
274
|
+
|
|
275
|
+
generated_token_ids: list[int] = []
|
|
276
|
+
final_response = None
|
|
277
|
+
start_ns = time.perf_counter_ns()
|
|
278
|
+
try:
|
|
279
|
+
for response in mlx_stream_generate(
|
|
280
|
+
target_model,
|
|
281
|
+
tokenizer,
|
|
282
|
+
prompt,
|
|
283
|
+
max_tokens=max_new_tokens,
|
|
284
|
+
):
|
|
285
|
+
final_response = response
|
|
286
|
+
generated_token_ids.append(int(response.token))
|
|
287
|
+
finally:
|
|
288
|
+
if no_eos:
|
|
289
|
+
tokenizer.eos_token_ids = original_eos_token_ids
|
|
290
|
+
tokenizer.eos_token_id = original_eos_token_id
|
|
291
|
+
|
|
292
|
+
elapsed_us = (time.perf_counter_ns() - start_ns) / 1_000.0
|
|
293
|
+
if final_response is None:
|
|
294
|
+
prompt_tokens = len(tokenizer.encode(prompt))
|
|
295
|
+
return {
|
|
296
|
+
"elapsed_us": elapsed_us,
|
|
297
|
+
"prefill_us": 0.0,
|
|
298
|
+
"prompt_token_count": prompt_tokens,
|
|
299
|
+
"generated_token_ids": [],
|
|
300
|
+
"generation_tokens": 0,
|
|
301
|
+
"peak_memory_gb": float(mx.get_peak_memory()) / 1e9 if hasattr(mx, "get_peak_memory") else None,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
prompt_tokens = int(final_response.prompt_tokens)
|
|
305
|
+
prompt_tps = float(final_response.prompt_tps)
|
|
306
|
+
generation_tokens = int(final_response.generation_tokens)
|
|
307
|
+
generation_tps = float(final_response.generation_tps)
|
|
308
|
+
prefill_us = (prompt_tokens / prompt_tps) * 1e6 if prompt_tps > 0.0 else 0.0
|
|
309
|
+
generation_us = (generation_tokens / generation_tps) * 1e6 if generation_tps > 0.0 else 0.0
|
|
310
|
+
return {
|
|
311
|
+
"elapsed_us": elapsed_us,
|
|
312
|
+
"prefill_us": prefill_us,
|
|
313
|
+
"prompt_token_count": prompt_tokens,
|
|
314
|
+
"generated_token_ids": generated_token_ids,
|
|
315
|
+
"generation_tokens": generation_tokens,
|
|
316
|
+
"generation_tps": generation_tps,
|
|
317
|
+
"peak_memory_gb": float(final_response.peak_memory),
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _release_loaded_models() -> None:
|
|
322
|
+
gc.collect()
|
|
323
|
+
if hasattr(mx, "clear_cache"):
|
|
324
|
+
try:
|
|
325
|
+
mx.clear_cache()
|
|
326
|
+
return
|
|
327
|
+
except Exception:
|
|
328
|
+
pass
|
|
329
|
+
if hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
|
|
330
|
+
try:
|
|
331
|
+
mx.metal.clear_cache()
|
|
332
|
+
except Exception:
|
|
333
|
+
pass
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _run_once_sequential(
|
|
337
|
+
*,
|
|
338
|
+
prompt: str,
|
|
339
|
+
max_new_tokens: int,
|
|
340
|
+
block_tokens: int,
|
|
341
|
+
verify_chunk_tokens: int | None,
|
|
342
|
+
use_chat_template: bool,
|
|
343
|
+
target_model_ref: str | None,
|
|
344
|
+
draft_model_ref: str | None,
|
|
345
|
+
quantize_draft: bool,
|
|
346
|
+
no_eos: bool,
|
|
347
|
+
split_sdpa: bool,
|
|
348
|
+
) -> dict[str, Any]:
|
|
349
|
+
pristine_target_model, pristine_tokenizer, pristine_meta = _load_pristine_target_bundle(
|
|
350
|
+
target_model_ref
|
|
351
|
+
)
|
|
352
|
+
try:
|
|
353
|
+
baseline = _generate_stock_baseline_once(
|
|
354
|
+
target_model=pristine_target_model,
|
|
355
|
+
tokenizer=pristine_tokenizer,
|
|
356
|
+
prompt=prompt,
|
|
357
|
+
max_new_tokens=max_new_tokens,
|
|
358
|
+
no_eos=no_eos,
|
|
359
|
+
)
|
|
360
|
+
finally:
|
|
361
|
+
del pristine_target_model
|
|
362
|
+
del pristine_tokenizer
|
|
363
|
+
_release_loaded_models()
|
|
364
|
+
|
|
365
|
+
target_model, tokenizer, target_meta = load_target_bundle(
|
|
366
|
+
target_model_ref,
|
|
367
|
+
lazy=True,
|
|
368
|
+
split_full_attention_sdpa=split_sdpa,
|
|
369
|
+
)
|
|
370
|
+
draft_model, draft_meta = load_draft_bundle(
|
|
371
|
+
draft_model_ref,
|
|
372
|
+
lazy=True,
|
|
373
|
+
quantize_draft=quantize_draft,
|
|
374
|
+
)
|
|
375
|
+
dflash_eos_token_ids = get_stop_token_ids(tokenizer)
|
|
376
|
+
dflash_stop_token_ids = [] if no_eos else dflash_eos_token_ids
|
|
377
|
+
dflash_suppress_token_ids = dflash_eos_token_ids if no_eos else None
|
|
378
|
+
try:
|
|
379
|
+
dflash = generate_dflash_once(
|
|
380
|
+
target_model=target_model,
|
|
381
|
+
tokenizer=tokenizer,
|
|
382
|
+
draft_model=draft_model,
|
|
383
|
+
prompt=prompt,
|
|
384
|
+
max_new_tokens=max_new_tokens,
|
|
385
|
+
use_chat_template=use_chat_template,
|
|
386
|
+
block_tokens=block_tokens,
|
|
387
|
+
verify_chunk_tokens=verify_chunk_tokens,
|
|
388
|
+
stop_token_ids=dflash_stop_token_ids,
|
|
389
|
+
suppress_token_ids=dflash_suppress_token_ids,
|
|
390
|
+
)
|
|
391
|
+
finally:
|
|
392
|
+
del target_model
|
|
393
|
+
del tokenizer
|
|
394
|
+
del draft_model
|
|
395
|
+
_release_loaded_models()
|
|
396
|
+
|
|
397
|
+
baseline_elapsed = float(baseline["elapsed_us"])
|
|
398
|
+
dflash_elapsed = float(dflash["elapsed_us"])
|
|
399
|
+
baseline_generation_tps = _generation_tps_from_baseline(baseline)
|
|
400
|
+
dflash_generation_tps = _generation_tps_from_dflash(dflash)
|
|
401
|
+
return {
|
|
402
|
+
"baseline": _strip_generation_payload(baseline),
|
|
403
|
+
"dflash": _strip_generation_payload(dflash, drop_phase_timings=True),
|
|
404
|
+
"speedup_vs_baseline": _speedup(baseline_elapsed, dflash_elapsed),
|
|
405
|
+
"baseline_ttft_ms": _ttft_ms_from_baseline(baseline),
|
|
406
|
+
"dflash_ttft_ms": _ttft_ms_from_dflash(dflash),
|
|
407
|
+
"baseline_generation_tps": baseline_generation_tps,
|
|
408
|
+
"dflash_generation_tps": dflash_generation_tps,
|
|
409
|
+
"generation_speedup_vs_baseline": _generation_speedup(
|
|
410
|
+
baseline_generation_tps,
|
|
411
|
+
dflash_generation_tps,
|
|
412
|
+
),
|
|
413
|
+
"token_match": baseline["generated_token_ids"] == dflash["generated_token_ids"],
|
|
414
|
+
"target_meta": target_meta,
|
|
415
|
+
"draft_meta": draft_meta,
|
|
416
|
+
"pristine_target_meta": pristine_meta,
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def benchmark_once(
|
|
421
|
+
*,
|
|
422
|
+
prompt: str,
|
|
423
|
+
max_new_tokens: int,
|
|
424
|
+
block_tokens: int,
|
|
425
|
+
verify_chunk_tokens: int | None,
|
|
426
|
+
use_chat_template: bool,
|
|
427
|
+
target_model_ref: str | None,
|
|
428
|
+
draft_model_ref: str | None,
|
|
429
|
+
quantize_draft: bool = False,
|
|
430
|
+
no_eos: bool = False,
|
|
431
|
+
split_sdpa: bool = True,
|
|
432
|
+
cooldown: int = 10,
|
|
433
|
+
) -> dict[str, Any]:
|
|
434
|
+
thermal_pressure = _get_thermal_pressure()
|
|
435
|
+
_warn_if_throttled(thermal_pressure)
|
|
436
|
+
result = _run_once_sequential(
|
|
437
|
+
prompt=prompt,
|
|
438
|
+
max_new_tokens=max_new_tokens,
|
|
439
|
+
block_tokens=block_tokens,
|
|
440
|
+
verify_chunk_tokens=verify_chunk_tokens,
|
|
441
|
+
use_chat_template=use_chat_template,
|
|
442
|
+
target_model_ref=target_model_ref,
|
|
443
|
+
draft_model_ref=draft_model_ref,
|
|
444
|
+
quantize_draft=quantize_draft,
|
|
445
|
+
no_eos=no_eos,
|
|
446
|
+
split_sdpa=split_sdpa,
|
|
447
|
+
)
|
|
448
|
+
target_meta = result.pop("target_meta")
|
|
449
|
+
draft_meta = result.pop("draft_meta")
|
|
450
|
+
result.pop("pristine_target_meta", None)
|
|
451
|
+
result["run_index"] = 1
|
|
452
|
+
result["thermal_pressure"] = thermal_pressure
|
|
453
|
+
return _build_single_case_report(
|
|
454
|
+
prompt=prompt,
|
|
455
|
+
max_new_tokens=max_new_tokens,
|
|
456
|
+
block_tokens=block_tokens,
|
|
457
|
+
repeat=1,
|
|
458
|
+
cooldown=cooldown,
|
|
459
|
+
runs=[result],
|
|
460
|
+
target_model=target_meta["resolved_model_ref"],
|
|
461
|
+
draft_model=draft_meta["resolved_model_ref"],
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def benchmark_matrix(
|
|
466
|
+
*,
|
|
467
|
+
prompts: tuple[str, ...] = (),
|
|
468
|
+
schedules: tuple[int, ...] = DEFAULT_SCHEDULES,
|
|
469
|
+
repeat: int = DEFAULT_REPEAT,
|
|
470
|
+
block_tokens: int = 16,
|
|
471
|
+
verify_chunk_tokens: int | None = None,
|
|
472
|
+
use_chat_template: bool = False,
|
|
473
|
+
target_model_ref: str | None = None,
|
|
474
|
+
draft_model_ref: str | None = None,
|
|
475
|
+
quantize_draft: bool = False,
|
|
476
|
+
no_eos: bool = False,
|
|
477
|
+
split_sdpa: bool = True,
|
|
478
|
+
cooldown: int = 10,
|
|
479
|
+
) -> dict[str, Any]:
|
|
480
|
+
target_meta: dict[str, Any] | None = None
|
|
481
|
+
draft_meta: dict[str, Any] | None = None
|
|
482
|
+
if len(prompts) != 1 or len(schedules) != 1:
|
|
483
|
+
raise ValueError("benchmark_matrix currently expects exactly one prompt and one schedule.")
|
|
484
|
+
prompt = prompts[0]
|
|
485
|
+
max_new_tokens = schedules[0]
|
|
486
|
+
runs: list[dict[str, Any]] = []
|
|
487
|
+
|
|
488
|
+
for run_index in range(1, repeat + 1):
|
|
489
|
+
thermal_pressure = _get_thermal_pressure()
|
|
490
|
+
_warn_if_throttled(thermal_pressure)
|
|
491
|
+
run = _run_once_sequential(
|
|
492
|
+
prompt=prompt,
|
|
493
|
+
max_new_tokens=max_new_tokens,
|
|
494
|
+
block_tokens=block_tokens,
|
|
495
|
+
verify_chunk_tokens=verify_chunk_tokens,
|
|
496
|
+
use_chat_template=use_chat_template,
|
|
497
|
+
target_model_ref=target_model_ref,
|
|
498
|
+
draft_model_ref=draft_model_ref,
|
|
499
|
+
quantize_draft=quantize_draft,
|
|
500
|
+
no_eos=no_eos,
|
|
501
|
+
split_sdpa=split_sdpa,
|
|
502
|
+
)
|
|
503
|
+
if target_meta is None:
|
|
504
|
+
target_meta = run.pop("target_meta")
|
|
505
|
+
else:
|
|
506
|
+
run.pop("target_meta", None)
|
|
507
|
+
if draft_meta is None:
|
|
508
|
+
draft_meta = run.pop("draft_meta")
|
|
509
|
+
else:
|
|
510
|
+
run.pop("draft_meta", None)
|
|
511
|
+
run.pop("pristine_target_meta", None)
|
|
512
|
+
run["run_index"] = run_index
|
|
513
|
+
run["thermal_pressure"] = thermal_pressure
|
|
514
|
+
runs.append(run)
|
|
515
|
+
if cooldown > 0 and run_index < repeat:
|
|
516
|
+
time.sleep(cooldown)
|
|
517
|
+
|
|
518
|
+
return _build_single_case_report(
|
|
519
|
+
prompt=prompt,
|
|
520
|
+
max_new_tokens=max_new_tokens,
|
|
521
|
+
block_tokens=block_tokens,
|
|
522
|
+
repeat=repeat,
|
|
523
|
+
cooldown=cooldown,
|
|
524
|
+
runs=runs,
|
|
525
|
+
target_model=target_meta["resolved_model_ref"] if target_meta is not None else "",
|
|
526
|
+
draft_model=draft_meta["resolved_model_ref"] if draft_meta is not None else "",
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def main() -> None:
|
|
531
|
+
parser = argparse.ArgumentParser(description="Benchmark baseline MLX vs DFlash MLX runtime.")
|
|
532
|
+
parser.add_argument("--prompt", required=True, help="Prompt to benchmark.")
|
|
533
|
+
parser.add_argument("--max-tokens", type=int, default=64)
|
|
534
|
+
parser.add_argument("--block-tokens", type=int, default=16)
|
|
535
|
+
parser.add_argument("--verify-chunk-tokens", type=int, default=None)
|
|
536
|
+
parser.add_argument("--matrix", action="store_true")
|
|
537
|
+
parser.add_argument(
|
|
538
|
+
"--repeat",
|
|
539
|
+
type=int,
|
|
540
|
+
default=None,
|
|
541
|
+
help="Number of measured runs. Uses matrix mode automatically when > 1.",
|
|
542
|
+
)
|
|
543
|
+
parser.add_argument(
|
|
544
|
+
"--cooldown",
|
|
545
|
+
type=int,
|
|
546
|
+
default=10,
|
|
547
|
+
help="Seconds between runs for thermal stabilization.",
|
|
548
|
+
)
|
|
549
|
+
parser.add_argument("--no-chat-template", action="store_true")
|
|
550
|
+
parser.add_argument("--model", default=None)
|
|
551
|
+
parser.add_argument("--draft", default=None)
|
|
552
|
+
parser.add_argument("--quantize-draft", action="store_true")
|
|
553
|
+
parser.add_argument("--no-eos", action="store_true")
|
|
554
|
+
parser.add_argument(
|
|
555
|
+
"--split-sdpa",
|
|
556
|
+
action=argparse.BooleanOptionalAction,
|
|
557
|
+
default=True,
|
|
558
|
+
help="Enable split_full_attention_sdpa when loading the target model (default: enabled).",
|
|
559
|
+
)
|
|
560
|
+
args = parser.parse_args()
|
|
561
|
+
repeat = args.repeat if args.repeat is not None else (DEFAULT_REPEAT if args.matrix else 1)
|
|
562
|
+
if repeat < 1:
|
|
563
|
+
raise ValueError("--repeat must be >= 1")
|
|
564
|
+
|
|
565
|
+
common_kwargs = {
|
|
566
|
+
"block_tokens": args.block_tokens,
|
|
567
|
+
"verify_chunk_tokens": args.verify_chunk_tokens,
|
|
568
|
+
"use_chat_template": not args.no_chat_template,
|
|
569
|
+
"target_model_ref": args.model,
|
|
570
|
+
"draft_model_ref": args.draft,
|
|
571
|
+
"quantize_draft": args.quantize_draft,
|
|
572
|
+
"no_eos": args.no_eos,
|
|
573
|
+
"split_sdpa": args.split_sdpa,
|
|
574
|
+
"cooldown": args.cooldown,
|
|
575
|
+
}
|
|
576
|
+
if args.matrix or repeat > 1:
|
|
577
|
+
result = benchmark_matrix(
|
|
578
|
+
prompts=(args.prompt,),
|
|
579
|
+
schedules=(args.max_tokens,),
|
|
580
|
+
repeat=repeat,
|
|
581
|
+
**common_kwargs,
|
|
582
|
+
)
|
|
583
|
+
else:
|
|
584
|
+
result = benchmark_once(
|
|
585
|
+
prompt=args.prompt,
|
|
586
|
+
max_new_tokens=args.max_tokens,
|
|
587
|
+
**common_kwargs,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
output_path = _default_results_path(
|
|
591
|
+
target_model_ref=args.model,
|
|
592
|
+
max_new_tokens=args.max_tokens,
|
|
593
|
+
)
|
|
594
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
595
|
+
output_path.write_text(json.dumps(result, indent=2) + "\n")
|
|
596
|
+
print(json.dumps(result, indent=2))
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
if __name__ == "__main__":
|
|
600
|
+
main()
|