dflash-mlx 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
benchmark/__init__.py ADDED
File without changes
benchmark/benchmark.py ADDED
@@ -0,0 +1,600 @@
1
+ # Copyright 2026 bstnxbt
2
+ # MIT License — see LICENSE file
3
+ # Based on DFlash (arXiv:2602.06036)
4
+
5
+
6
+ import argparse
7
+ import gc
8
+ import json
9
+ import platform
10
+ import re
11
+ import statistics
12
+ import subprocess
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ import mlx.core as mx
19
+ from mlx_lm import stream_generate as mlx_stream_generate
20
+ from mlx_lm.utils import load as load_pristine_target
21
+
22
+ from dflash_mlx.runtime import (
23
+ generate_dflash_once,
24
+ load_draft_bundle,
25
+ load_target_bundle,
26
+ resolve_model_ref,
27
+ )
28
+
29
+ DEFAULT_SCHEDULES: tuple[int, ...] = (8, 16, 32)
30
+ DEFAULT_REPEAT = 3
31
+
32
+
33
+ def _git_hash_short() -> str:
34
+ return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
35
+
36
+
37
+ def _hardware_info() -> dict[str, str]:
38
+ return {
39
+ "chip": subprocess.check_output(
40
+ ["sysctl", "-n", "machdep.cpu.brand_string"], text=True
41
+ ).strip(),
42
+ "memory_gb": str(
43
+ int(subprocess.check_output(["sysctl", "-n", "hw.memsize"], text=True).strip())
44
+ // (1024**3)
45
+ ),
46
+ "mlx_version": mx.__version__,
47
+ "python": platform.python_version(),
48
+ }
49
+
50
+
51
+ def _get_thermal_pressure() -> str:
52
+ try:
53
+ out = subprocess.check_output(["pmset", "-g", "therm"], text=True, timeout=2)
54
+ for line in out.splitlines():
55
+ if "CPU_Scheduler_Limit" not in line:
56
+ continue
57
+ val = int(line.strip().split("=")[-1].strip())
58
+ if val == 100:
59
+ return "nominal"
60
+ if val >= 80:
61
+ return "fair"
62
+ if val >= 50:
63
+ return "serious"
64
+ return "critical"
65
+ except Exception:
66
+ pass
67
+ return "unknown"
68
+
69
+
70
+ def _warn_if_throttled(thermal_pressure: str) -> None:
71
+ if thermal_pressure == "nominal":
72
+ return
73
+ print(
74
+ f"WARNING: thermal pressure is '{thermal_pressure}' — results may be throttled. "
75
+ "Increase --cooldown or wait for chip to cool.",
76
+ file=sys.stderr,
77
+ )
78
+
79
+
80
+ def _slugify_prompt_id(prompt: str) -> str:
81
+ slug = re.sub(r"[^a-z0-9]+", "_", prompt.lower()).strip("_")
82
+ slug = re.sub(r"_+", "_", slug)
83
+ return slug[:48] or "prompt"
84
+
85
+
86
+ def _slugify_model_ref(model_ref: str | None) -> str:
87
+ resolved = resolve_model_ref(model_ref, kind="target")
88
+ label = Path(str(resolved)).name or str(resolved)
89
+ label = re.sub(r"[^a-z0-9]+", "-", label.lower())
90
+ label = re.sub(r"-+", "-", label).strip("-")
91
+ return label or "model"
92
+
93
+
94
+ def _default_results_path(*, target_model_ref: str | None, max_new_tokens: int) -> Path:
95
+ return Path("benchmark/results") / f"{_slugify_model_ref(target_model_ref)}-{int(max_new_tokens)}.json"
96
+
97
+
98
+ def _strip_generation_payload(result: dict[str, Any], *, drop_phase_timings: bool = False) -> dict[str, Any]:
99
+ cleaned = dict(result)
100
+ cleaned.pop("generated_token_ids", None)
101
+ if drop_phase_timings:
102
+ phase_timings = dict(cleaned.pop("phase_timings_us", {}) or {})
103
+ if "prefill" in phase_timings and "prefill_us" not in cleaned:
104
+ cleaned["prefill_us"] = float(phase_timings["prefill"])
105
+ return cleaned
106
+
107
+
108
+ def _format_run_entry(run: dict[str, Any]) -> dict[str, Any]:
109
+ baseline = dict(run["baseline"])
110
+ dflash = dict(run["dflash"])
111
+ return {
112
+ "run": int(run["run_index"]),
113
+ "thermal_pressure": str(run.get("thermal_pressure", "unknown")),
114
+ "baseline": {
115
+ "ttft_ms": float(run["baseline_ttft_ms"]),
116
+ "generation_tps": float(run["baseline_generation_tps"]),
117
+ "peak_memory_gb": baseline.get("peak_memory_gb"),
118
+ },
119
+ "dflash": {
120
+ "ttft_ms": float(run["dflash_ttft_ms"]),
121
+ "generation_tps": float(run["dflash_generation_tps"]),
122
+ "tokens_per_cycle": float(dflash.get("tokens_per_cycle", 0.0)),
123
+ "cycles": int(dflash.get("cycles_completed", 0)),
124
+ "acceptance_ratio": float(dflash.get("acceptance_ratio", 0.0)),
125
+ "acceptance_first_20_avg": float(dflash.get("acceptance_first_20_avg", 0.0)),
126
+ "acceptance_last_20_avg": float(dflash.get("acceptance_last_20_avg", 0.0)),
127
+ "peak_memory_gb": dflash.get("peak_memory_gb"),
128
+ },
129
+ "speedup": float(run["generation_speedup_vs_baseline"]) if run["generation_speedup_vs_baseline"] is not None else None,
130
+ }
131
+
132
+
133
+ def _build_config(
134
+ *,
135
+ prompt: str,
136
+ prompt_tokens: int,
137
+ max_new_tokens: int,
138
+ block_tokens: int,
139
+ repeat: int,
140
+ cooldown: int,
141
+ target_model: str,
142
+ draft_model: str,
143
+ ) -> dict[str, Any]:
144
+ return {
145
+ "target_model": target_model,
146
+ "draft_model": draft_model,
147
+ "max_new_tokens": int(max_new_tokens),
148
+ "block_tokens": int(block_tokens),
149
+ "cooldown": int(cooldown),
150
+ "prompt": prompt,
151
+ "prompt_tokens": int(prompt_tokens),
152
+ "prompt_id": _slugify_prompt_id(prompt),
153
+ "repeat": int(repeat),
154
+ "git_hash": _git_hash_short(),
155
+ }
156
+
157
+
158
+ def _build_single_case_report(
159
+ *,
160
+ prompt: str,
161
+ max_new_tokens: int,
162
+ block_tokens: int,
163
+ repeat: int,
164
+ cooldown: int,
165
+ runs: list[dict[str, Any]],
166
+ target_model: str,
167
+ draft_model: str,
168
+ ) -> dict[str, Any]:
169
+ run_entries = [_format_run_entry(run) for run in runs]
170
+ baseline_tps_values = [float(run["baseline_generation_tps"]) for run in runs]
171
+ dflash_tps_values = [float(run["dflash_generation_tps"]) for run in runs]
172
+ speedup_values = [float(run["generation_speedup_vs_baseline"]) for run in runs if run["generation_speedup_vs_baseline"] is not None]
173
+ acceptance_ratio_values = [float(run["dflash"]["acceptance_ratio"]) for run in runs]
174
+ prompt_tokens = int(runs[0]["baseline"]["prompt_token_count"]) if runs else 0
175
+ return {
176
+ "hardware": _hardware_info(),
177
+ "config": _build_config(
178
+ prompt=prompt,
179
+ prompt_tokens=prompt_tokens,
180
+ max_new_tokens=max_new_tokens,
181
+ block_tokens=block_tokens,
182
+ repeat=repeat,
183
+ cooldown=cooldown,
184
+ target_model=target_model,
185
+ draft_model=draft_model,
186
+ ),
187
+ "runs": run_entries,
188
+ "summary": {
189
+ "baseline_tps_median": statistics.median(baseline_tps_values) if baseline_tps_values else None,
190
+ "dflash_tps_median": statistics.median(dflash_tps_values) if dflash_tps_values else None,
191
+ "dflash_tps_min": min(dflash_tps_values) if dflash_tps_values else None,
192
+ "dflash_tps_max": max(dflash_tps_values) if dflash_tps_values else None,
193
+ "speedup_median": statistics.median(speedup_values) if speedup_values else None,
194
+ "acceptance_ratio_median": statistics.median(acceptance_ratio_values) if acceptance_ratio_values else None,
195
+ },
196
+ }
197
+
198
+
199
+ def get_stop_token_ids(tokenizer: Any) -> list[int]:
200
+ eos_token_ids = list(getattr(tokenizer, "eos_token_ids", None) or [])
201
+ eos_token_id = getattr(tokenizer, "eos_token_id", None)
202
+ if eos_token_id is not None and eos_token_id not in eos_token_ids:
203
+ eos_token_ids.append(int(eos_token_id))
204
+ return eos_token_ids
205
+
206
+
207
+ def _speedup(baseline_elapsed: float, dflash_elapsed: float) -> float | None:
208
+ return baseline_elapsed / dflash_elapsed if dflash_elapsed > 0.0 else None
209
+
210
+
211
+ def _generation_speedup(baseline_tps: float, dflash_tps: float) -> float | None:
212
+ return dflash_tps / baseline_tps if baseline_tps > 0.0 else None
213
+
214
+
215
+ def _ttft_ms_from_baseline(result: dict[str, Any]) -> float:
216
+ return float(result.get("prefill_us", 0.0)) / 1_000.0
217
+
218
+
219
+ def _ttft_ms_from_dflash(result: dict[str, Any]) -> float:
220
+ phase_timings = dict(result.get("phase_timings_us", {}))
221
+ return float(phase_timings.get("prefill", 0.0)) / 1_000.0
222
+
223
+
224
+ def _generation_tps_from_baseline(result: dict[str, Any]) -> float:
225
+ if "generation_tps" in result:
226
+ return float(result["generation_tps"])
227
+ elapsed_us = float(result.get("elapsed_us", 0.0))
228
+ prefill_us = float(result.get("prefill_us", 0.0))
229
+ generation_tokens = int(result.get("generation_tokens", 0))
230
+ generation_us = max(0.0, elapsed_us - prefill_us)
231
+ return (generation_tokens / (generation_us / 1e6)) if generation_us > 0.0 else 0.0
232
+
233
+
234
+ def _generation_tps_from_dflash(result: dict[str, Any]) -> float:
235
+ elapsed_us = float(result.get("elapsed_us", 0.0))
236
+ phase_timings = dict(result.get("phase_timings_us", {}))
237
+ prefill_us = float(phase_timings.get("prefill", 0.0))
238
+ generation_tokens = int(result.get("generation_tokens", 0))
239
+ generation_us = max(0.0, elapsed_us - prefill_us)
240
+ return (generation_tokens / (generation_us / 1e6)) if generation_us > 0.0 else 0.0
241
+
242
+
243
+ def _load_pristine_target_bundle(model_ref: str | None):
244
+ resolved_ref = resolve_model_ref(model_ref, kind="target")
245
+ model, tokenizer, config = load_pristine_target(resolved_ref, lazy=True, return_config=True)
246
+ return model, tokenizer, {"resolved_model_ref": resolved_ref, "config": config}
247
+
248
+
249
+ def _generate_stock_baseline_once(
250
+ *,
251
+ target_model: Any,
252
+ tokenizer: Any,
253
+ prompt: str,
254
+ max_new_tokens: int,
255
+ no_eos: bool,
256
+ ) -> dict[str, Any]:
257
+ if hasattr(mx, "reset_peak_memory"):
258
+ try:
259
+ mx.reset_peak_memory()
260
+ except Exception:
261
+ pass
262
+
263
+ original_eos_token_ids = getattr(tokenizer, "eos_token_ids", None)
264
+ original_eos_token_id = getattr(tokenizer, "eos_token_id", None)
265
+ if no_eos:
266
+ try:
267
+ tokenizer.eos_token_ids = set()
268
+ except Exception:
269
+ tokenizer.eos_token_ids = []
270
+ try:
271
+ tokenizer.eos_token_id = None
272
+ except Exception:
273
+ pass
274
+
275
+ generated_token_ids: list[int] = []
276
+ final_response = None
277
+ start_ns = time.perf_counter_ns()
278
+ try:
279
+ for response in mlx_stream_generate(
280
+ target_model,
281
+ tokenizer,
282
+ prompt,
283
+ max_tokens=max_new_tokens,
284
+ ):
285
+ final_response = response
286
+ generated_token_ids.append(int(response.token))
287
+ finally:
288
+ if no_eos:
289
+ tokenizer.eos_token_ids = original_eos_token_ids
290
+ tokenizer.eos_token_id = original_eos_token_id
291
+
292
+ elapsed_us = (time.perf_counter_ns() - start_ns) / 1_000.0
293
+ if final_response is None:
294
+ prompt_tokens = len(tokenizer.encode(prompt))
295
+ return {
296
+ "elapsed_us": elapsed_us,
297
+ "prefill_us": 0.0,
298
+ "prompt_token_count": prompt_tokens,
299
+ "generated_token_ids": [],
300
+ "generation_tokens": 0,
301
+ "peak_memory_gb": float(mx.get_peak_memory()) / 1e9 if hasattr(mx, "get_peak_memory") else None,
302
+ }
303
+
304
+ prompt_tokens = int(final_response.prompt_tokens)
305
+ prompt_tps = float(final_response.prompt_tps)
306
+ generation_tokens = int(final_response.generation_tokens)
307
+ generation_tps = float(final_response.generation_tps)
308
+ prefill_us = (prompt_tokens / prompt_tps) * 1e6 if prompt_tps > 0.0 else 0.0
309
+ generation_us = (generation_tokens / generation_tps) * 1e6 if generation_tps > 0.0 else 0.0
310
+ return {
311
+ "elapsed_us": elapsed_us,
312
+ "prefill_us": prefill_us,
313
+ "prompt_token_count": prompt_tokens,
314
+ "generated_token_ids": generated_token_ids,
315
+ "generation_tokens": generation_tokens,
316
+ "generation_tps": generation_tps,
317
+ "peak_memory_gb": float(final_response.peak_memory),
318
+ }
319
+
320
+
321
+ def _release_loaded_models() -> None:
322
+ gc.collect()
323
+ if hasattr(mx, "clear_cache"):
324
+ try:
325
+ mx.clear_cache()
326
+ return
327
+ except Exception:
328
+ pass
329
+ if hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"):
330
+ try:
331
+ mx.metal.clear_cache()
332
+ except Exception:
333
+ pass
334
+
335
+
336
+ def _run_once_sequential(
337
+ *,
338
+ prompt: str,
339
+ max_new_tokens: int,
340
+ block_tokens: int,
341
+ verify_chunk_tokens: int | None,
342
+ use_chat_template: bool,
343
+ target_model_ref: str | None,
344
+ draft_model_ref: str | None,
345
+ quantize_draft: bool,
346
+ no_eos: bool,
347
+ split_sdpa: bool,
348
+ ) -> dict[str, Any]:
349
+ pristine_target_model, pristine_tokenizer, pristine_meta = _load_pristine_target_bundle(
350
+ target_model_ref
351
+ )
352
+ try:
353
+ baseline = _generate_stock_baseline_once(
354
+ target_model=pristine_target_model,
355
+ tokenizer=pristine_tokenizer,
356
+ prompt=prompt,
357
+ max_new_tokens=max_new_tokens,
358
+ no_eos=no_eos,
359
+ )
360
+ finally:
361
+ del pristine_target_model
362
+ del pristine_tokenizer
363
+ _release_loaded_models()
364
+
365
+ target_model, tokenizer, target_meta = load_target_bundle(
366
+ target_model_ref,
367
+ lazy=True,
368
+ split_full_attention_sdpa=split_sdpa,
369
+ )
370
+ draft_model, draft_meta = load_draft_bundle(
371
+ draft_model_ref,
372
+ lazy=True,
373
+ quantize_draft=quantize_draft,
374
+ )
375
+ dflash_eos_token_ids = get_stop_token_ids(tokenizer)
376
+ dflash_stop_token_ids = [] if no_eos else dflash_eos_token_ids
377
+ dflash_suppress_token_ids = dflash_eos_token_ids if no_eos else None
378
+ try:
379
+ dflash = generate_dflash_once(
380
+ target_model=target_model,
381
+ tokenizer=tokenizer,
382
+ draft_model=draft_model,
383
+ prompt=prompt,
384
+ max_new_tokens=max_new_tokens,
385
+ use_chat_template=use_chat_template,
386
+ block_tokens=block_tokens,
387
+ verify_chunk_tokens=verify_chunk_tokens,
388
+ stop_token_ids=dflash_stop_token_ids,
389
+ suppress_token_ids=dflash_suppress_token_ids,
390
+ )
391
+ finally:
392
+ del target_model
393
+ del tokenizer
394
+ del draft_model
395
+ _release_loaded_models()
396
+
397
+ baseline_elapsed = float(baseline["elapsed_us"])
398
+ dflash_elapsed = float(dflash["elapsed_us"])
399
+ baseline_generation_tps = _generation_tps_from_baseline(baseline)
400
+ dflash_generation_tps = _generation_tps_from_dflash(dflash)
401
+ return {
402
+ "baseline": _strip_generation_payload(baseline),
403
+ "dflash": _strip_generation_payload(dflash, drop_phase_timings=True),
404
+ "speedup_vs_baseline": _speedup(baseline_elapsed, dflash_elapsed),
405
+ "baseline_ttft_ms": _ttft_ms_from_baseline(baseline),
406
+ "dflash_ttft_ms": _ttft_ms_from_dflash(dflash),
407
+ "baseline_generation_tps": baseline_generation_tps,
408
+ "dflash_generation_tps": dflash_generation_tps,
409
+ "generation_speedup_vs_baseline": _generation_speedup(
410
+ baseline_generation_tps,
411
+ dflash_generation_tps,
412
+ ),
413
+ "token_match": baseline["generated_token_ids"] == dflash["generated_token_ids"],
414
+ "target_meta": target_meta,
415
+ "draft_meta": draft_meta,
416
+ "pristine_target_meta": pristine_meta,
417
+ }
418
+
419
+
420
+ def benchmark_once(
421
+ *,
422
+ prompt: str,
423
+ max_new_tokens: int,
424
+ block_tokens: int,
425
+ verify_chunk_tokens: int | None,
426
+ use_chat_template: bool,
427
+ target_model_ref: str | None,
428
+ draft_model_ref: str | None,
429
+ quantize_draft: bool = False,
430
+ no_eos: bool = False,
431
+ split_sdpa: bool = True,
432
+ cooldown: int = 10,
433
+ ) -> dict[str, Any]:
434
+ thermal_pressure = _get_thermal_pressure()
435
+ _warn_if_throttled(thermal_pressure)
436
+ result = _run_once_sequential(
437
+ prompt=prompt,
438
+ max_new_tokens=max_new_tokens,
439
+ block_tokens=block_tokens,
440
+ verify_chunk_tokens=verify_chunk_tokens,
441
+ use_chat_template=use_chat_template,
442
+ target_model_ref=target_model_ref,
443
+ draft_model_ref=draft_model_ref,
444
+ quantize_draft=quantize_draft,
445
+ no_eos=no_eos,
446
+ split_sdpa=split_sdpa,
447
+ )
448
+ target_meta = result.pop("target_meta")
449
+ draft_meta = result.pop("draft_meta")
450
+ result.pop("pristine_target_meta", None)
451
+ result["run_index"] = 1
452
+ result["thermal_pressure"] = thermal_pressure
453
+ return _build_single_case_report(
454
+ prompt=prompt,
455
+ max_new_tokens=max_new_tokens,
456
+ block_tokens=block_tokens,
457
+ repeat=1,
458
+ cooldown=cooldown,
459
+ runs=[result],
460
+ target_model=target_meta["resolved_model_ref"],
461
+ draft_model=draft_meta["resolved_model_ref"],
462
+ )
463
+
464
+
465
+ def benchmark_matrix(
466
+ *,
467
+ prompts: tuple[str, ...] = (),
468
+ schedules: tuple[int, ...] = DEFAULT_SCHEDULES,
469
+ repeat: int = DEFAULT_REPEAT,
470
+ block_tokens: int = 16,
471
+ verify_chunk_tokens: int | None = None,
472
+ use_chat_template: bool = False,
473
+ target_model_ref: str | None = None,
474
+ draft_model_ref: str | None = None,
475
+ quantize_draft: bool = False,
476
+ no_eos: bool = False,
477
+ split_sdpa: bool = True,
478
+ cooldown: int = 10,
479
+ ) -> dict[str, Any]:
480
+ target_meta: dict[str, Any] | None = None
481
+ draft_meta: dict[str, Any] | None = None
482
+ if len(prompts) != 1 or len(schedules) != 1:
483
+ raise ValueError("benchmark_matrix currently expects exactly one prompt and one schedule.")
484
+ prompt = prompts[0]
485
+ max_new_tokens = schedules[0]
486
+ runs: list[dict[str, Any]] = []
487
+
488
+ for run_index in range(1, repeat + 1):
489
+ thermal_pressure = _get_thermal_pressure()
490
+ _warn_if_throttled(thermal_pressure)
491
+ run = _run_once_sequential(
492
+ prompt=prompt,
493
+ max_new_tokens=max_new_tokens,
494
+ block_tokens=block_tokens,
495
+ verify_chunk_tokens=verify_chunk_tokens,
496
+ use_chat_template=use_chat_template,
497
+ target_model_ref=target_model_ref,
498
+ draft_model_ref=draft_model_ref,
499
+ quantize_draft=quantize_draft,
500
+ no_eos=no_eos,
501
+ split_sdpa=split_sdpa,
502
+ )
503
+ if target_meta is None:
504
+ target_meta = run.pop("target_meta")
505
+ else:
506
+ run.pop("target_meta", None)
507
+ if draft_meta is None:
508
+ draft_meta = run.pop("draft_meta")
509
+ else:
510
+ run.pop("draft_meta", None)
511
+ run.pop("pristine_target_meta", None)
512
+ run["run_index"] = run_index
513
+ run["thermal_pressure"] = thermal_pressure
514
+ runs.append(run)
515
+ if cooldown > 0 and run_index < repeat:
516
+ time.sleep(cooldown)
517
+
518
+ return _build_single_case_report(
519
+ prompt=prompt,
520
+ max_new_tokens=max_new_tokens,
521
+ block_tokens=block_tokens,
522
+ repeat=repeat,
523
+ cooldown=cooldown,
524
+ runs=runs,
525
+ target_model=target_meta["resolved_model_ref"] if target_meta is not None else "",
526
+ draft_model=draft_meta["resolved_model_ref"] if draft_meta is not None else "",
527
+ )
528
+
529
+
530
+ def main() -> None:
531
+ parser = argparse.ArgumentParser(description="Benchmark baseline MLX vs DFlash MLX runtime.")
532
+ parser.add_argument("--prompt", required=True, help="Prompt to benchmark.")
533
+ parser.add_argument("--max-tokens", type=int, default=64)
534
+ parser.add_argument("--block-tokens", type=int, default=16)
535
+ parser.add_argument("--verify-chunk-tokens", type=int, default=None)
536
+ parser.add_argument("--matrix", action="store_true")
537
+ parser.add_argument(
538
+ "--repeat",
539
+ type=int,
540
+ default=None,
541
+ help="Number of measured runs. Uses matrix mode automatically when > 1.",
542
+ )
543
+ parser.add_argument(
544
+ "--cooldown",
545
+ type=int,
546
+ default=10,
547
+ help="Seconds between runs for thermal stabilization.",
548
+ )
549
+ parser.add_argument("--no-chat-template", action="store_true")
550
+ parser.add_argument("--model", default=None)
551
+ parser.add_argument("--draft", default=None)
552
+ parser.add_argument("--quantize-draft", action="store_true")
553
+ parser.add_argument("--no-eos", action="store_true")
554
+ parser.add_argument(
555
+ "--split-sdpa",
556
+ action=argparse.BooleanOptionalAction,
557
+ default=True,
558
+ help="Enable split_full_attention_sdpa when loading the target model (default: enabled).",
559
+ )
560
+ args = parser.parse_args()
561
+ repeat = args.repeat if args.repeat is not None else (DEFAULT_REPEAT if args.matrix else 1)
562
+ if repeat < 1:
563
+ raise ValueError("--repeat must be >= 1")
564
+
565
+ common_kwargs = {
566
+ "block_tokens": args.block_tokens,
567
+ "verify_chunk_tokens": args.verify_chunk_tokens,
568
+ "use_chat_template": not args.no_chat_template,
569
+ "target_model_ref": args.model,
570
+ "draft_model_ref": args.draft,
571
+ "quantize_draft": args.quantize_draft,
572
+ "no_eos": args.no_eos,
573
+ "split_sdpa": args.split_sdpa,
574
+ "cooldown": args.cooldown,
575
+ }
576
+ if args.matrix or repeat > 1:
577
+ result = benchmark_matrix(
578
+ prompts=(args.prompt,),
579
+ schedules=(args.max_tokens,),
580
+ repeat=repeat,
581
+ **common_kwargs,
582
+ )
583
+ else:
584
+ result = benchmark_once(
585
+ prompt=args.prompt,
586
+ max_new_tokens=args.max_tokens,
587
+ **common_kwargs,
588
+ )
589
+
590
+ output_path = _default_results_path(
591
+ target_model_ref=args.model,
592
+ max_new_tokens=args.max_tokens,
593
+ )
594
+ output_path.parent.mkdir(parents=True, exist_ok=True)
595
+ output_path.write_text(json.dumps(result, indent=2) + "\n")
596
+ print(json.dumps(result, indent=2))
597
+
598
+
599
+ if __name__ == "__main__":
600
+ main()
dflash_mlx/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ # Copyright 2026 bstnxbt
2
+ # MIT License — see LICENSE file
3
+ # Based on DFlash (arXiv:2602.06036)
4
+
5
+ __version__ = "0.1.0"