mlx-dspark 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.egg-info/
5
+ build/
6
+ dist/
7
+ .ruff_cache/
8
+ .pytest_cache/
9
+ # local model/drafter weights
10
+ checkpoints/
11
+ *.safetensors
12
+ scratch/
13
+ # local Claude skill copy (not part of the package)
14
+ .claude/
15
+ # private dev/agent tracking notes (kept local, not published)
16
+ CLAUDE.md
17
+ NOTES.md
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 erahim3
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,22 @@
1
+ mlx-dspark
2
+ ==========
3
+
4
+ This project is an independent MLX/Apple-Silicon port of the *inference path* of
5
+ DSpark, a speculative-decoding drafter open-sourced by DeepSeek as part of the
6
+ DeepSpec codebase:
7
+
8
+ - DeepSpec: https://github.com/deepseek-ai/DeepSpec (MIT License)
9
+
10
+ It loads the published DSpark drafter checkpoints (also released by DeepSeek under
11
+ their respective licenses):
12
+
13
+ - deepseek-ai/dspark_gemma4_12b_block7
14
+ - deepseek-ai/dspark_qwen3_4b_block7
15
+
16
+ The drafter architecture, training, and checkpoints are the work of DeepSeek and
17
+ the DeepSpec authors. This repository reimplements only the forward/verification
18
+ path for MLX and contains no DeepSpec source code. Target models (Gemma-4, Qwen3)
19
+ are downloaded at runtime from their respective publishers and are subject to
20
+ their own licenses (e.g. the Gemma Terms of Use and the Qwen license).
21
+
22
+ No model weights are bundled with this package.
@@ -0,0 +1,115 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-dspark
3
+ Version: 0.0.1
4
+ Summary: DSpark speculative decoding (semi-autoregressive drafting) for Apple Silicon via MLX
5
+ Project-URL: Homepage, https://github.com/ARahim3/mlx-dspark
6
+ Project-URL: Repository, https://github.com/ARahim3/mlx-dspark
7
+ Project-URL: Issues, https://github.com/ARahim3/mlx-dspark/issues
8
+ Author-email: erahim3 <erahim3@gmail.com>
9
+ License: MIT
10
+ License-File: LICENSE
11
+ License-File: NOTICE
12
+ Keywords: apple-silicon,deepspec,dspark,llm-inference,mlx,speculative-decoding
13
+ Classifier: Environment :: GPU
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: MacOS
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.10
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: mlx-lm>=0.31.3
21
+ Requires-Dist: mlx-vlm>=0.6.3
22
+ Requires-Dist: mlx>=0.31.2
23
+ Requires-Dist: numpy
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest; extra == 'dev'
26
+ Requires-Dist: ruff; extra == 'dev'
27
+ Description-Content-Type: text/markdown
28
+
29
+ # mlx-dspark
30
+
31
+ **DSpark speculative decoding for Apple Silicon**, built on [MLX](https://github.com/ml-explore/mlx).
32
+
33
+ DSpark is DeepSeek's semi-autoregressive, EAGLE-family speculative-decoding drafter,
34
+ open-sourced in the [DeepSpec](https://github.com/deepseek-ai/DeepSpec) codebase and used to
35
+ accelerate DeepSeek-V4. This project ports the **inference path** to MLX so the published
36
+ drafter checkpoints run natively on a Mac.
37
+
38
+ **Supported families** (auto-detected from the drafter config):
39
+
40
+ | family | target | drafter | RAM |
41
+ |---|---|---|---|
42
+ | `gemma4` | `gemma-4-12B-it-8bit` | `deepseek-ai/dspark_gemma4_12b_block7` | ~32 GB+ |
43
+ | `qwen3` | `Qwen3-4B-8bit` | `deepseek-ai/dspark_qwen3_4b_block7` | ~16 GB |
44
+
45
+ ## How it works
46
+
47
+ - A **parallel backbone** (5 Gemma-4 layers) consumes the target model's hidden states
48
+ (tapped at layers `[5,17,29,41,46]`, EAGLE3-style) and proposes a 7-token block at once.
49
+ - A **rank-256 Markov head** adds a previous-token correction, sampled sequentially — the only
50
+ sequential cost, which kills "suffix decay" cheaply.
51
+ - A **confidence head** scores each draft position (optional adaptive block length).
52
+ - The target **verifies** every token, so output is **greedy-correct by construction**
53
+ (identical to plain greedy decoding, up to floating-point tie-breaking on near-ties).
54
+
55
+ The drafter is loaded 1:1 from the HF checkpoint and **quantized to 4-bit** by default
56
+ (~1.8 GB) so it's cheap to run every round.
57
+
58
+ ## Install
59
+
60
+ ```bash
61
+ uv venv --python 3.12
62
+ source .venv/bin/activate
63
+ uv pip install -e .
64
+ ```
65
+
66
+ ## Use
67
+
68
+ ```bash
69
+ # CLI — pick a family (downloads drafter + instruct target on first run)
70
+ python -m mlx_dspark --family qwen3 --prompt "Explain how rainbows form."
71
+ python -m mlx_dspark --family gemma4 --prompt "Explain how rainbows form." --max-new-tokens 256
72
+
73
+ # side-by-side demo: baseline (plain target) vs dspark (record each, stack)
74
+ python -m mlx_dspark --family qwen3 --mode baseline --prompt "..." --max-new-tokens 400
75
+ python -m mlx_dspark --family qwen3 --mode dspark --prompt "..." --max-new-tokens 400
76
+ ```
77
+
78
+ ```python
79
+ from mlx_dspark import load_pair, speculative_generate
80
+
81
+ target, tok, drafter, cfg = load_pair("qwen3") # or "gemma4"
82
+ res = speculative_generate(target, tok, drafter, "Explain how rainbows form.")
83
+ print(res.text, res.mean_accept_len, res.tokens_per_sec)
84
+ ```
85
+
86
+ ## Results (M4 Pro, 48 GB; 8-bit instruct target, 4-bit drafter; warm — `python benchmark.py`)
87
+
88
+ | family | drafter `d_0` | accept len | greedy (baseline) | dspark (this project) | speedup |
89
+ |---|---|---|---|---|---|
90
+ | **Gemma-4 12B** | ~82% | 2.5–3.6 | ~17 tok/s | ~28 tok/s | **~1.5–1.6×** |
91
+ | **Qwen3-4B** | ~85% | 2.1–2.8 | ~49 tok/s | ~66 tok/s | **~1.3–1.4×** |
92
+
93
+ "greedy" = the plain target model decoding one token per forward (no drafter); "dspark" =
94
+ speculative decoding with the DSpark drafter. Both produce **identical** output — DSpark is just
95
+ faster (it diverges from sequential greedy only at logit-margin≈0 ties). Smaller/faster targets (Qwen3-4B) have a lower
96
+ per-token verify cost, so the optimal `--max-draft` is smaller (~2).
97
+
98
+ ### Target choice matters
99
+
100
+ The drafter is trained against a specific **instruct** model in **bf16** — use the matching
101
+ instruct target (`gemma-4-12B-it` / `Qwen3-4B`), not the base model (base gave `d_0` ~47% vs
102
+ 82%). Higher precision raises acceptance; 8-bit is the sweet spot. `-bf16` maximizes acceptance;
103
+ 4-bit verifies faster.
104
+
105
+ ### Tuning
106
+
107
+ On Apple Silicon the target verify cost grows with the number of tokens verified, so the
108
+ optimum is to verify ~= the acceptance length: `--max-draft 4` (default). `--max-draft <block>`
109
+ (full 7) is faithful but *slower* on M-series. `--confidence-threshold 0.6` instead truncates
110
+ the block adaptively via the drafter's confidence head.
111
+
112
+ ## License
113
+
114
+ MIT — see [`LICENSE`](LICENSE). This is an independent MLX port of the inference path of
115
+ DeepSeek's DSpark drafter; see [`NOTICE`](NOTICE) for attribution. No model weights are bundled.
@@ -0,0 +1,87 @@
1
+ # mlx-dspark
2
+
3
+ **DSpark speculative decoding for Apple Silicon**, built on [MLX](https://github.com/ml-explore/mlx).
4
+
5
+ DSpark is DeepSeek's semi-autoregressive, EAGLE-family speculative-decoding drafter,
6
+ open-sourced in the [DeepSpec](https://github.com/deepseek-ai/DeepSpec) codebase and used to
7
+ accelerate DeepSeek-V4. This project ports the **inference path** to MLX so the published
8
+ drafter checkpoints run natively on a Mac.
9
+
10
+ **Supported families** (auto-detected from the drafter config):
11
+
12
+ | family | target | drafter | RAM |
13
+ |---|---|---|---|
14
+ | `gemma4` | `gemma-4-12B-it-8bit` | `deepseek-ai/dspark_gemma4_12b_block7` | ~32 GB+ |
15
+ | `qwen3` | `Qwen3-4B-8bit` | `deepseek-ai/dspark_qwen3_4b_block7` | ~16 GB |
16
+
17
+ ## How it works
18
+
19
+ - A **parallel backbone** (5 Gemma-4 layers) consumes the target model's hidden states
20
+ (tapped at layers `[5,17,29,41,46]`, EAGLE3-style) and proposes a 7-token block at once.
21
+ - A **rank-256 Markov head** adds a previous-token correction, sampled sequentially — the only
22
+ sequential cost, which kills "suffix decay" cheaply.
23
+ - A **confidence head** scores each draft position (optional adaptive block length).
24
+ - The target **verifies** every token, so output is **greedy-correct by construction**
25
+ (identical to plain greedy decoding, up to floating-point tie-breaking on near-ties).
26
+
27
+ The drafter is loaded 1:1 from the HF checkpoint and **quantized to 4-bit** by default
28
+ (~1.8 GB) so it's cheap to run every round.
29
+
30
+ ## Install
31
+
32
+ ```bash
33
+ uv venv --python 3.12
34
+ source .venv/bin/activate
35
+ uv pip install -e .
36
+ ```
37
+
38
+ ## Use
39
+
40
+ ```bash
41
+ # CLI — pick a family (downloads drafter + instruct target on first run)
42
+ python -m mlx_dspark --family qwen3 --prompt "Explain how rainbows form."
43
+ python -m mlx_dspark --family gemma4 --prompt "Explain how rainbows form." --max-new-tokens 256
44
+
45
+ # side-by-side demo: baseline (plain target) vs dspark (record each, stack)
46
+ python -m mlx_dspark --family qwen3 --mode baseline --prompt "..." --max-new-tokens 400
47
+ python -m mlx_dspark --family qwen3 --mode dspark --prompt "..." --max-new-tokens 400
48
+ ```
49
+
50
+ ```python
51
+ from mlx_dspark import load_pair, speculative_generate
52
+
53
+ target, tok, drafter, cfg = load_pair("qwen3") # or "gemma4"
54
+ res = speculative_generate(target, tok, drafter, "Explain how rainbows form.")
55
+ print(res.text, res.mean_accept_len, res.tokens_per_sec)
56
+ ```
57
+
58
+ ## Results (M4 Pro, 48 GB; 8-bit instruct target, 4-bit drafter; warm — `python benchmark.py`)
59
+
60
+ | family | drafter `d_0` | accept len | greedy (baseline) | dspark (this project) | speedup |
61
+ |---|---|---|---|---|---|
62
+ | **Gemma-4 12B** | ~82% | 2.5–3.6 | ~17 tok/s | ~28 tok/s | **~1.5–1.6×** |
63
+ | **Qwen3-4B** | ~85% | 2.1–2.8 | ~49 tok/s | ~66 tok/s | **~1.3–1.4×** |
64
+
65
+ "greedy" = the plain target model decoding one token per forward (no drafter); "dspark" =
66
+ speculative decoding with the DSpark drafter. Both produce **identical** output — DSpark is just
67
+ faster (it diverges from sequential greedy only at logit-margin≈0 ties). Smaller/faster targets (Qwen3-4B) have a lower
68
+ per-token verify cost, so the optimal `--max-draft` is smaller (~2).
69
+
70
+ ### Target choice matters
71
+
72
+ The drafter is trained against a specific **instruct** model in **bf16** — use the matching
73
+ instruct target (`gemma-4-12B-it` / `Qwen3-4B`), not the base model (base gave `d_0` ~47% vs
74
+ 82%). Higher precision raises acceptance; 8-bit is the sweet spot. `-bf16` maximizes acceptance;
75
+ 4-bit verifies faster.
76
+
77
+ ### Tuning
78
+
79
+ On Apple Silicon the target verify cost grows with the number of tokens verified, so the
80
+ optimum is to verify ~= the acceptance length: `--max-draft 4` (default). `--max-draft <block>`
81
+ (full 7) is faithful but *slower* on M-series. `--confidence-threshold 0.6` instead truncates
82
+ the block adaptively via the drafter's confidence head.
83
+
84
+ ## License
85
+
86
+ MIT — see [`LICENSE`](LICENSE). This is an independent MLX port of the inference path of
87
+ DeepSeek's DSpark drafter; see [`NOTICE`](NOTICE) for attribution. No model weights are bundled.
@@ -0,0 +1,52 @@
1
+ """Warm benchmark: greedy vs DSpark speculative (both warm), multi-trial.
2
+
3
+ python benchmark.py # gemma4
4
+ python benchmark.py qwen3
5
+ """
6
+ import sys, time, mlx.core as mx
7
+ from mlx_dspark.load import load_pair
8
+ from mlx_dspark.generate import (
9
+ speculative_generate, _make_target_cache, eos_token_ids, encode_prompt,
10
+ )
11
+
12
+ family = sys.argv[1] if len(sys.argv) > 1 else "gemma4"
13
+ target, tok, drafter, cfg = load_pair(family)
14
+ eos = eos_token_ids(tok)
15
+ N = 100
16
+ PROMPTS = [
17
+ "Explain how rainbows form.",
18
+ "Write a Python function to check if a string is a palindrome.",
19
+ "Give three tips for staying focused while working.",
20
+ ]
21
+
22
+ def greedy(ids, n):
23
+ cache = _make_target_cache(target)
24
+ lg = target.plain(mx.array([ids]), cache); mx.eval(lg)
25
+ nx = int(mx.argmax(lg[0, -1]).item()); out = [nx]; t = time.time()
26
+ while len(out) < n and nx not in eos:
27
+ lg = target.plain(mx.array([[nx]]), cache); mx.eval(lg)
28
+ nx = int(mx.argmax(lg[0, -1]).item()); out.append(nx)
29
+ return len(out), time.time() - t
30
+
31
+ print(f"[{family}] warming up (ramping clocks)...")
32
+ for _ in range(2):
33
+ greedy(encode_prompt(tok, "Tell me about the sea.", True), 120)
34
+ speculative_generate(target, tok, drafter, "Tell me about the sea.",
35
+ max_new_tokens=120, max_draft_tokens=3)
36
+
37
+ print(f"\n{'prompt':<6} {'greedy':>9} | {'cap':>3} {'spec':>9} {'accept':>7} {'speedup':>8}")
38
+ agg = {}
39
+ for i, p in enumerate(PROMPTS):
40
+ ids = encode_prompt(tok, p, True)
41
+ gt, gs = greedy(ids, N); gtps = gt / gs
42
+ for cap in (2, 3, 4):
43
+ res = speculative_generate(target, tok, drafter, p, max_new_tokens=N,
44
+ max_draft_tokens=cap)
45
+ sp = gs / res.seconds
46
+ agg.setdefault(cap, []).append(sp)
47
+ tag = f"P{i}" if cap == 2 else ""
48
+ print(f"{tag:<6} {gtps:>7.1f}/s | {cap:>3} {res.tokens_per_sec:>7.1f}/s "
49
+ f"{res.mean_accept_len:>7.2f} {sp:>7.2f}x")
50
+ print("\nmean speedup by cap:")
51
+ for cap, v in agg.items():
52
+ print(f" cap={cap}: {sum(v)/len(v):.2f}x")
@@ -0,0 +1,42 @@
1
+ [project]
2
+ name = "mlx-dspark"
3
+ version = "0.0.1"
4
+ description = "DSpark speculative decoding (semi-autoregressive drafting) for Apple Silicon via MLX"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ license = { text = "MIT" }
8
+ authors = [{ name = "erahim3", email = "erahim3@gmail.com" }]
9
+ keywords = ["mlx", "apple-silicon", "speculative-decoding", "dspark", "deepspec", "llm-inference"]
10
+ classifiers = [
11
+ "License :: OSI Approved :: MIT License",
12
+ "Programming Language :: Python :: 3",
13
+ "Operating System :: MacOS",
14
+ "Environment :: GPU",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ dependencies = [
18
+ "mlx>=0.31.2",
19
+ "mlx-lm>=0.31.3",
20
+ # mlx-vlm gives us Gemma-4 + the existing EAGLE3/DFlash spec-decode reference impl
21
+ "mlx-vlm>=0.6.3",
22
+ "numpy",
23
+ "huggingface-hub",
24
+ ]
25
+
26
+ [project.optional-dependencies]
27
+ dev = ["pytest", "ruff"]
28
+
29
+ [project.urls]
30
+ Homepage = "https://github.com/ARahim3/mlx-dspark"
31
+ Repository = "https://github.com/ARahim3/mlx-dspark"
32
+ Issues = "https://github.com/ARahim3/mlx-dspark/issues"
33
+
34
+ [build-system]
35
+ requires = ["hatchling"]
36
+ build-backend = "hatchling.build"
37
+
38
+ [tool.hatch.build.targets.wheel]
39
+ packages = ["src/mlx_dspark"]
40
+
41
+ [tool.ruff]
42
+ line-length = 100
@@ -0,0 +1,42 @@
1
+ """mlx-dspark: DSpark speculative decoding for Apple Silicon (MLX).
2
+
3
+ DSpark (from DeepSeek's DeepSpec codebase) is a semi-autoregressive,
4
+ EAGLE-family speculative-decoding drafter:
5
+
6
+ - a *parallel backbone* proposes base logits for all K draft positions at once,
7
+ - a tiny *sequential head* (low-rank, previous-token-conditioned) corrects
8
+ suffix decay,
9
+ - a *confidence head* scores how likely each drafted token survives
10
+ verification (used here for adaptive draft length instead of the
11
+ server-side load-aware scheduler, which is irrelevant single-user).
12
+
13
+ This package targets single-user local inference on Apple Silicon.
14
+ """
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ from .config import DSparkConfig
19
+ from .load import (
20
+ DEFAULT_DRAFTER,
21
+ DEFAULT_TARGET,
22
+ PRESETS,
23
+ load_drafter,
24
+ load_pair,
25
+ load_target,
26
+ )
27
+ from .generate import GenResult, greedy_generate, speculative_generate
28
+ from .target import Target
29
+
30
+ __all__ = [
31
+ "DSparkConfig",
32
+ "Target",
33
+ "load_drafter",
34
+ "load_target",
35
+ "load_pair",
36
+ "speculative_generate",
37
+ "greedy_generate",
38
+ "GenResult",
39
+ "PRESETS",
40
+ "DEFAULT_TARGET",
41
+ "DEFAULT_DRAFTER",
42
+ ]
@@ -0,0 +1,84 @@
1
+ """CLI: run DSpark speculative decoding on Apple Silicon (streams tokens live).
2
+
3
+ Side-by-side demo (record each, then stack the two screen captures):
4
+
5
+ # left panel — plain target, no drafter
6
+ python -m mlx_dspark --mode baseline --prompt "Explain how rainbows form." --max-new-tokens 220
7
+
8
+ # right panel — DSpark speculative decoding (same prompt, same output, faster)
9
+ python -m mlx_dspark --mode dspark --prompt "Explain how rainbows form." --max-new-tokens 220
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import sys
16
+ import time
17
+
18
+ from .generate import greedy_generate, speculative_generate
19
+ from .load import PRESETS, load_drafter, load_target
20
+
21
+
22
+ def _emit(s: str) -> None:
23
+ sys.stdout.write(s)
24
+ sys.stdout.flush()
25
+
26
+
27
+ def main() -> None:
28
+ ap = argparse.ArgumentParser(prog="mlx_dspark")
29
+ ap.add_argument("--mode", choices=["dspark", "baseline"], default="dspark",
30
+ help="dspark = speculative decoding; baseline = plain greedy target")
31
+ ap.add_argument("--family", choices=["gemma4", "qwen3"], default="gemma4",
32
+ help="model preset (target + drafter); overridden by --target/--drafter")
33
+ ap.add_argument("--prompt", default="Explain how rainbows form, in a few sentences.")
34
+ ap.add_argument("--target", default=None)
35
+ ap.add_argument("--drafter", default=None)
36
+ ap.add_argument("--max-new-tokens", type=int, default=220)
37
+ ap.add_argument("--max-draft", type=int, default=4)
38
+ ap.add_argument("--confidence-threshold", type=float, default=0.0)
39
+ ap.add_argument("--drafter-bits", type=int, default=4)
40
+ ap.add_argument("--no-chat-template", action="store_true")
41
+ ap.add_argument("--no-stream", action="store_true")
42
+ args = ap.parse_args()
43
+ target_repo = args.target or PRESETS[args.family]["target"]
44
+ drafter_repo = args.drafter or PRESETS[args.family]["drafter"]
45
+
46
+ label = "DSpark speculative" if args.mode == "dspark" else "Baseline (plain greedy)"
47
+ print(f"loading {args.mode}: target={target_repo}"
48
+ + (f", drafter={drafter_repo}" if args.mode == "dspark" else ""))
49
+ target, tok = load_target(target_repo)
50
+ drafter = None
51
+ if args.mode == "dspark":
52
+ drafter, _ = load_drafter(drafter_repo, quantize=args.drafter_bits > 0,
53
+ bits=max(args.drafter_bits, 2))
54
+
55
+ on_text = None if args.no_stream else _emit
56
+ print("\n" + "=" * 64)
57
+ print(f" ▶ {label} · {target_repo.split('/')[-1]}")
58
+ print("=" * 64)
59
+
60
+ if args.mode == "dspark":
61
+ res = speculative_generate(
62
+ target, tok, drafter, args.prompt,
63
+ max_new_tokens=args.max_new_tokens, max_draft_tokens=args.max_draft,
64
+ confidence_threshold=args.confidence_threshold,
65
+ apply_chat_template=not args.no_chat_template, on_text=on_text,
66
+ )
67
+ extra = f" · accept {res.mean_accept_len:.2f}/round · {res.target_forwards} target fwds"
68
+ else:
69
+ res = greedy_generate(
70
+ target, tok, args.prompt, max_new_tokens=args.max_new_tokens,
71
+ apply_chat_template=not args.no_chat_template, on_text=on_text,
72
+ )
73
+ extra = ""
74
+ if args.no_stream:
75
+ print(res.text)
76
+
77
+ print("\n" + "-" * 64)
78
+ print(f" {res.num_tokens} tokens · {res.seconds:.2f}s · "
79
+ f"\033[1m{res.tokens_per_sec:.1f} tok/s\033[0m{extra}")
80
+ print("-" * 64)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()
@@ -0,0 +1,145 @@
1
+ """DSpark drafter config — loaded from the HF checkpoint's config.json.
2
+
3
+ Supports two drafter families with a shared inference path:
4
+ - gemma4 (gemma4_text): k_eq_v attention, v_norm, partial/proportional rope,
5
+ sandwich norms + layer_scalar, gelu-tanh MLP, logit softcap.
6
+ - qwen3 (qwen3): standard GQA (separate v_proj, no v_norm), default rope,
7
+ Llama-style 2-norm layer, silu MLP, no softcap.
8
+ Only the fields the MLX inference path needs are pulled out.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+
17
+
18
+ @dataclass
19
+ class DSparkConfig:
20
+ family: str = "gemma4" # "gemma4" | "qwen3"
21
+
22
+ # core dims
23
+ hidden_size: int = 3840
24
+ vocab_size: int = 262144
25
+ num_hidden_layers: int = 5
26
+ intermediate_size: int = 15360
27
+ rms_norm_eps: float = 1e-6
28
+
29
+ # attention
30
+ num_attention_heads: int = 16
31
+ num_key_value_heads: int = 8
32
+ num_global_key_value_heads: int = 1
33
+ head_dim: int = 256
34
+ global_head_dim: int = 512
35
+ attention_k_eq_v: bool = True
36
+ attention_bias: bool = False
37
+
38
+ # rope
39
+ rope_theta: float = 1_000_000.0
40
+ partial_rotary_factor: float = 0.25
41
+ rope_type: str = "proportional"
42
+
43
+ # dspark specifics
44
+ block_size: int = 7
45
+ mask_token_id: int = 4
46
+ target_layer_ids: list[int] = field(default_factory=lambda: [5, 17, 29, 41, 46])
47
+ num_target_layers: int = 48
48
+
49
+ # markov + confidence
50
+ markov_rank: int = 256
51
+ markov_head_type: str = "vanilla"
52
+ enable_confidence_head: bool = True
53
+ confidence_head_with_markov: bool = True
54
+
55
+ # logits
56
+ final_logit_softcapping: float | None = 30.0
57
+ pad_token_id: int = 0
58
+
59
+ # ---- family-derived knobs (set in from_json) ----
60
+ mlp_activation: str = "gelu_tanh" # "gelu_tanh" | "silu"
61
+ norm_style: str = "gemma" # "gemma" (sandwich+scalar) | "qwen" (llama 2-norm)
62
+ use_v_norm: bool = True # gemma: RMSNormNoScale v_norm; qwen: none
63
+ attention_scaling: float | None = None # None -> 1/sqrt(attn_head_dim)
64
+
65
+ @property
66
+ def attn_head_dim(self) -> int:
67
+ """Head dim used by the drafter's own attention."""
68
+ return self.global_head_dim if self.family == "gemma4" else self.head_dim
69
+
70
+ @property
71
+ def n_kv_heads(self) -> int:
72
+ if self.family == "gemma4" and self.attention_k_eq_v:
73
+ return self.num_global_key_value_heads
74
+ return self.num_key_value_heads
75
+
76
+ @property
77
+ def scaling(self) -> float:
78
+ if self.attention_scaling is not None:
79
+ return self.attention_scaling
80
+ return self.attn_head_dim ** -0.5 if self.family == "qwen3" else 1.0
81
+
82
+ @property
83
+ def rope_parameters(self) -> dict:
84
+ return {"rope_type": self.rope_type, "partial_rotary_factor": self.partial_rotary_factor}
85
+
86
+ @classmethod
87
+ def from_json(cls, path: str | Path) -> "DSparkConfig":
88
+ with open(path) as f:
89
+ c = json.load(f)
90
+ mt = c.get("model_type", "")
91
+ family = "qwen3" if "qwen3" in mt else "gemma4"
92
+
93
+ if family == "qwen3":
94
+ rp = c.get("rope_parameters") or {}
95
+ return cls(
96
+ family="qwen3",
97
+ hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
98
+ num_hidden_layers=c["num_hidden_layers"],
99
+ intermediate_size=c["intermediate_size"],
100
+ rms_norm_eps=c.get("rms_norm_eps", 1e-6),
101
+ num_attention_heads=c["num_attention_heads"],
102
+ num_key_value_heads=c.get("num_key_value_heads", 8),
103
+ head_dim=c.get("head_dim", c["hidden_size"] // c["num_attention_heads"]),
104
+ attention_k_eq_v=False, attention_bias=c.get("attention_bias", False),
105
+ rope_theta=rp.get("rope_theta", c.get("rope_theta", 1_000_000.0)),
106
+ rope_type="default",
107
+ block_size=c["block_size"], mask_token_id=c["mask_token_id"],
108
+ target_layer_ids=list(c["target_layer_ids"]),
109
+ num_target_layers=c.get("num_target_layers", 36),
110
+ markov_rank=c.get("markov_rank", 256),
111
+ markov_head_type=c.get("markov_head_type", "vanilla"),
112
+ enable_confidence_head=c.get("enable_confidence_head", True),
113
+ confidence_head_with_markov=c.get("confidence_head_with_markov", True),
114
+ final_logit_softcapping=c.get("final_logit_softcapping", None),
115
+ pad_token_id=c.get("pad_token_id") or 0,
116
+ mlp_activation="silu", norm_style="qwen", use_v_norm=False,
117
+ )
118
+
119
+ rope = (c.get("rope_parameters") or {}).get("full_attention", {}) or {}
120
+ return cls(
121
+ family="gemma4",
122
+ hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
123
+ num_hidden_layers=c["num_hidden_layers"],
124
+ intermediate_size=c["intermediate_size"],
125
+ rms_norm_eps=c.get("rms_norm_eps", 1e-6),
126
+ num_attention_heads=c["num_attention_heads"],
127
+ num_key_value_heads=c.get("num_key_value_heads", 8),
128
+ num_global_key_value_heads=c.get("num_global_key_value_heads", 1),
129
+ head_dim=c.get("head_dim", 256), global_head_dim=c.get("global_head_dim", 512),
130
+ attention_k_eq_v=c.get("attention_k_eq_v", True),
131
+ attention_bias=c.get("attention_bias", False),
132
+ rope_theta=rope.get("rope_theta", 1_000_000.0),
133
+ partial_rotary_factor=rope.get("partial_rotary_factor", 0.25),
134
+ rope_type=rope.get("rope_type", "proportional"),
135
+ block_size=c["block_size"], mask_token_id=c["mask_token_id"],
136
+ target_layer_ids=list(c["target_layer_ids"]),
137
+ num_target_layers=c.get("num_target_layers", 48),
138
+ markov_rank=c.get("markov_rank", 256),
139
+ markov_head_type=c.get("markov_head_type", "vanilla"),
140
+ enable_confidence_head=c.get("enable_confidence_head", True),
141
+ confidence_head_with_markov=c.get("confidence_head_with_markov", True),
142
+ final_logit_softcapping=c.get("final_logit_softcapping", 30.0),
143
+ pad_token_id=c.get("pad_token_id", 0),
144
+ mlp_activation="gelu_tanh", norm_style="gemma", use_v_norm=True,
145
+ )