mlx-dspark 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_dspark-0.0.1/.gitignore +17 -0
- mlx_dspark-0.0.1/LICENSE +21 -0
- mlx_dspark-0.0.1/NOTICE +22 -0
- mlx_dspark-0.0.1/PKG-INFO +115 -0
- mlx_dspark-0.0.1/README.md +87 -0
- mlx_dspark-0.0.1/benchmark.py +52 -0
- mlx_dspark-0.0.1/pyproject.toml +42 -0
- mlx_dspark-0.0.1/src/mlx_dspark/__init__.py +42 -0
- mlx_dspark-0.0.1/src/mlx_dspark/__main__.py +84 -0
- mlx_dspark-0.0.1/src/mlx_dspark/config.py +145 -0
- mlx_dspark-0.0.1/src/mlx_dspark/generate.py +289 -0
- mlx_dspark-0.0.1/src/mlx_dspark/load.py +117 -0
- mlx_dspark-0.0.1/src/mlx_dspark/model.py +277 -0
- mlx_dspark-0.0.1/src/mlx_dspark/target.py +64 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
.venv/
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.pyc
|
|
4
|
+
*.egg-info/
|
|
5
|
+
build/
|
|
6
|
+
dist/
|
|
7
|
+
.ruff_cache/
|
|
8
|
+
.pytest_cache/
|
|
9
|
+
# local model/drafter weights
|
|
10
|
+
checkpoints/
|
|
11
|
+
*.safetensors
|
|
12
|
+
scratch/
|
|
13
|
+
# local Claude skill copy (not part of the package)
|
|
14
|
+
.claude/
|
|
15
|
+
# private dev/agent tracking notes (kept local, not published)
|
|
16
|
+
CLAUDE.md
|
|
17
|
+
NOTES.md
|
mlx_dspark-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 erahim3
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
mlx_dspark-0.0.1/NOTICE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
mlx-dspark
|
|
2
|
+
==========
|
|
3
|
+
|
|
4
|
+
This project is an independent MLX/Apple-Silicon port of the *inference path* of
|
|
5
|
+
DSpark, a speculative-decoding drafter open-sourced by DeepSeek as part of the
|
|
6
|
+
DeepSpec codebase:
|
|
7
|
+
|
|
8
|
+
- DeepSpec: https://github.com/deepseek-ai/DeepSpec (MIT License)
|
|
9
|
+
|
|
10
|
+
It loads the published DSpark drafter checkpoints (also released by DeepSeek under
|
|
11
|
+
their respective licenses):
|
|
12
|
+
|
|
13
|
+
- deepseek-ai/dspark_gemma4_12b_block7
|
|
14
|
+
- deepseek-ai/dspark_qwen3_4b_block7
|
|
15
|
+
|
|
16
|
+
The drafter architecture, training, and checkpoints are the work of DeepSeek and
|
|
17
|
+
the DeepSpec authors. This repository reimplements only the forward/verification
|
|
18
|
+
path for MLX and contains no DeepSpec source code. Target models (Gemma-4, Qwen3)
|
|
19
|
+
are downloaded at runtime from their respective publishers and are subject to
|
|
20
|
+
their own licenses (e.g. the Gemma Terms of Use and the Qwen license).
|
|
21
|
+
|
|
22
|
+
No model weights are bundled with this package.
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlx-dspark
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: DSpark speculative decoding (semi-autoregressive drafting) for Apple Silicon via MLX
|
|
5
|
+
Project-URL: Homepage, https://github.com/ARahim3/mlx-dspark
|
|
6
|
+
Project-URL: Repository, https://github.com/ARahim3/mlx-dspark
|
|
7
|
+
Project-URL: Issues, https://github.com/ARahim3/mlx-dspark/issues
|
|
8
|
+
Author-email: erahim3 <erahim3@gmail.com>
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
License-File: NOTICE
|
|
12
|
+
Keywords: apple-silicon,deepspec,dspark,llm-inference,mlx,speculative-decoding
|
|
13
|
+
Classifier: Environment :: GPU
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: MacOS
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.10
|
|
19
|
+
Requires-Dist: huggingface-hub
|
|
20
|
+
Requires-Dist: mlx-lm>=0.31.3
|
|
21
|
+
Requires-Dist: mlx-vlm>=0.6.3
|
|
22
|
+
Requires-Dist: mlx>=0.31.2
|
|
23
|
+
Requires-Dist: numpy
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
26
|
+
Requires-Dist: ruff; extra == 'dev'
|
|
27
|
+
Description-Content-Type: text/markdown
|
|
28
|
+
|
|
29
|
+
# mlx-dspark
|
|
30
|
+
|
|
31
|
+
**DSpark speculative decoding for Apple Silicon**, built on [MLX](https://github.com/ml-explore/mlx).
|
|
32
|
+
|
|
33
|
+
DSpark is DeepSeek's semi-autoregressive, EAGLE-family speculative-decoding drafter,
|
|
34
|
+
open-sourced in the [DeepSpec](https://github.com/deepseek-ai/DeepSpec) codebase and used to
|
|
35
|
+
accelerate DeepSeek-V4. This project ports the **inference path** to MLX so the published
|
|
36
|
+
drafter checkpoints run natively on a Mac.
|
|
37
|
+
|
|
38
|
+
**Supported families** (auto-detected from the drafter config):
|
|
39
|
+
|
|
40
|
+
| family | target | drafter | RAM |
|
|
41
|
+
|---|---|---|---|
|
|
42
|
+
| `gemma4` | `gemma-4-12B-it-8bit` | `deepseek-ai/dspark_gemma4_12b_block7` | ~32 GB+ |
|
|
43
|
+
| `qwen3` | `Qwen3-4B-8bit` | `deepseek-ai/dspark_qwen3_4b_block7` | ~16 GB |
|
|
44
|
+
|
|
45
|
+
## How it works
|
|
46
|
+
|
|
47
|
+
- A **parallel backbone** (5 Gemma-4 layers) consumes the target model's hidden states
|
|
48
|
+
(tapped at layers `[5,17,29,41,46]`, EAGLE3-style) and proposes a 7-token block at once.
|
|
49
|
+
- A **rank-256 Markov head** adds a previous-token correction, sampled sequentially — the only
|
|
50
|
+
sequential cost, which kills "suffix decay" cheaply.
|
|
51
|
+
- A **confidence head** scores each draft position (optional adaptive block length).
|
|
52
|
+
- The target **verifies** every token, so output is **greedy-correct by construction**
|
|
53
|
+
(identical to plain greedy decoding, up to floating-point tie-breaking on near-ties).
|
|
54
|
+
|
|
55
|
+
The drafter is loaded 1:1 from the HF checkpoint and **quantized to 4-bit** by default
|
|
56
|
+
(~1.8 GB) so it's cheap to run every round.
|
|
57
|
+
|
|
58
|
+
## Install
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
uv venv --python 3.12
|
|
62
|
+
source .venv/bin/activate
|
|
63
|
+
uv pip install -e .
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Use
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# CLI — pick a family (downloads drafter + instruct target on first run)
|
|
70
|
+
python -m mlx_dspark --family qwen3 --prompt "Explain how rainbows form."
|
|
71
|
+
python -m mlx_dspark --family gemma4 --prompt "Explain how rainbows form." --max-new-tokens 256
|
|
72
|
+
|
|
73
|
+
# side-by-side demo: baseline (plain target) vs dspark (record each, stack)
|
|
74
|
+
python -m mlx_dspark --family qwen3 --mode baseline --prompt "..." --max-new-tokens 400
|
|
75
|
+
python -m mlx_dspark --family qwen3 --mode dspark --prompt "..." --max-new-tokens 400
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
from mlx_dspark import load_pair, speculative_generate
|
|
80
|
+
|
|
81
|
+
target, tok, drafter, cfg = load_pair("qwen3") # or "gemma4"
|
|
82
|
+
res = speculative_generate(target, tok, drafter, "Explain how rainbows form.")
|
|
83
|
+
print(res.text, res.mean_accept_len, res.tokens_per_sec)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Results (M4 Pro, 48 GB; 8-bit instruct target, 4-bit drafter; warm — `python benchmark.py`)
|
|
87
|
+
|
|
88
|
+
| family | drafter `d_0` | accept len | greedy (baseline) | dspark (this project) | speedup |
|
|
89
|
+
|---|---|---|---|---|---|
|
|
90
|
+
| **Gemma-4 12B** | ~82% | 2.5–3.6 | ~17 tok/s | ~28 tok/s | **~1.5–1.6×** |
|
|
91
|
+
| **Qwen3-4B** | ~85% | 2.1–2.8 | ~49 tok/s | ~66 tok/s | **~1.3–1.4×** |
|
|
92
|
+
|
|
93
|
+
"greedy" = the plain target model decoding one token per forward (no drafter); "dspark" =
|
|
94
|
+
speculative decoding with the DSpark drafter. Both produce **identical** output — DSpark is just
|
|
95
|
+
faster (it diverges from sequential greedy only at logit-margin≈0 ties). Smaller/faster targets (Qwen3-4B) have a lower
|
|
96
|
+
per-token verify cost, so the optimal `--max-draft` is smaller (~2).
|
|
97
|
+
|
|
98
|
+
### Target choice matters
|
|
99
|
+
|
|
100
|
+
The drafter is trained against a specific **instruct** model in **bf16** — use the matching
|
|
101
|
+
instruct target (`gemma-4-12B-it` / `Qwen3-4B`), not the base model (base gave `d_0` ~47% vs
|
|
102
|
+
82%). Higher precision raises acceptance; 8-bit is the sweet spot. `-bf16` maximizes acceptance;
|
|
103
|
+
4-bit verifies faster.
|
|
104
|
+
|
|
105
|
+
### Tuning
|
|
106
|
+
|
|
107
|
+
On Apple Silicon the target verify cost grows with the number of tokens verified, so the
|
|
108
|
+
optimum is to verify ~= the acceptance length: `--max-draft 4` (default). `--max-draft <block>`
|
|
109
|
+
(full 7) is faithful but *slower* on M-series. `--confidence-threshold 0.6` instead truncates
|
|
110
|
+
the block adaptively via the drafter's confidence head.
|
|
111
|
+
|
|
112
|
+
## License
|
|
113
|
+
|
|
114
|
+
MIT — see [`LICENSE`](LICENSE). This is an independent MLX port of the inference path of
|
|
115
|
+
DeepSeek's DSpark drafter; see [`NOTICE`](NOTICE) for attribution. No model weights are bundled.
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# mlx-dspark
|
|
2
|
+
|
|
3
|
+
**DSpark speculative decoding for Apple Silicon**, built on [MLX](https://github.com/ml-explore/mlx).
|
|
4
|
+
|
|
5
|
+
DSpark is DeepSeek's semi-autoregressive, EAGLE-family speculative-decoding drafter,
|
|
6
|
+
open-sourced in the [DeepSpec](https://github.com/deepseek-ai/DeepSpec) codebase and used to
|
|
7
|
+
accelerate DeepSeek-V4. This project ports the **inference path** to MLX so the published
|
|
8
|
+
drafter checkpoints run natively on a Mac.
|
|
9
|
+
|
|
10
|
+
**Supported families** (auto-detected from the drafter config):
|
|
11
|
+
|
|
12
|
+
| family | target | drafter | RAM |
|
|
13
|
+
|---|---|---|---|
|
|
14
|
+
| `gemma4` | `gemma-4-12B-it-8bit` | `deepseek-ai/dspark_gemma4_12b_block7` | ~32 GB+ |
|
|
15
|
+
| `qwen3` | `Qwen3-4B-8bit` | `deepseek-ai/dspark_qwen3_4b_block7` | ~16 GB |
|
|
16
|
+
|
|
17
|
+
## How it works
|
|
18
|
+
|
|
19
|
+
- A **parallel backbone** (5 Gemma-4 layers) consumes the target model's hidden states
|
|
20
|
+
(tapped at layers `[5,17,29,41,46]`, EAGLE3-style) and proposes a 7-token block at once.
|
|
21
|
+
- A **rank-256 Markov head** adds a previous-token correction, sampled sequentially — the only
|
|
22
|
+
sequential cost, which kills "suffix decay" cheaply.
|
|
23
|
+
- A **confidence head** scores each draft position (optional adaptive block length).
|
|
24
|
+
- The target **verifies** every token, so output is **greedy-correct by construction**
|
|
25
|
+
(identical to plain greedy decoding, up to floating-point tie-breaking on near-ties).
|
|
26
|
+
|
|
27
|
+
The drafter is loaded 1:1 from the HF checkpoint and **quantized to 4-bit** by default
|
|
28
|
+
(~1.8 GB) so it's cheap to run every round.
|
|
29
|
+
|
|
30
|
+
## Install
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
uv venv --python 3.12
|
|
34
|
+
source .venv/bin/activate
|
|
35
|
+
uv pip install -e .
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
## Use
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
# CLI — pick a family (downloads drafter + instruct target on first run)
|
|
42
|
+
python -m mlx_dspark --family qwen3 --prompt "Explain how rainbows form."
|
|
43
|
+
python -m mlx_dspark --family gemma4 --prompt "Explain how rainbows form." --max-new-tokens 256
|
|
44
|
+
|
|
45
|
+
# side-by-side demo: baseline (plain target) vs dspark (record each, stack)
|
|
46
|
+
python -m mlx_dspark --family qwen3 --mode baseline --prompt "..." --max-new-tokens 400
|
|
47
|
+
python -m mlx_dspark --family qwen3 --mode dspark --prompt "..." --max-new-tokens 400
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from mlx_dspark import load_pair, speculative_generate
|
|
52
|
+
|
|
53
|
+
target, tok, drafter, cfg = load_pair("qwen3") # or "gemma4"
|
|
54
|
+
res = speculative_generate(target, tok, drafter, "Explain how rainbows form.")
|
|
55
|
+
print(res.text, res.mean_accept_len, res.tokens_per_sec)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Results (M4 Pro, 48 GB; 8-bit instruct target, 4-bit drafter; warm — `python benchmark.py`)
|
|
59
|
+
|
|
60
|
+
| family | drafter `d_0` | accept len | greedy (baseline) | dspark (this project) | speedup |
|
|
61
|
+
|---|---|---|---|---|---|
|
|
62
|
+
| **Gemma-4 12B** | ~82% | 2.5–3.6 | ~17 tok/s | ~28 tok/s | **~1.5–1.6×** |
|
|
63
|
+
| **Qwen3-4B** | ~85% | 2.1–2.8 | ~49 tok/s | ~66 tok/s | **~1.3–1.4×** |
|
|
64
|
+
|
|
65
|
+
"greedy" = the plain target model decoding one token per forward (no drafter); "dspark" =
|
|
66
|
+
speculative decoding with the DSpark drafter. Both produce **identical** output — DSpark is just
|
|
67
|
+
faster (it diverges from sequential greedy only at logit-margin≈0 ties). Smaller/faster targets (Qwen3-4B) have a lower
|
|
68
|
+
per-token verify cost, so the optimal `--max-draft` is smaller (~2).
|
|
69
|
+
|
|
70
|
+
### Target choice matters
|
|
71
|
+
|
|
72
|
+
The drafter is trained against a specific **instruct** model in **bf16** — use the matching
|
|
73
|
+
instruct target (`gemma-4-12B-it` / `Qwen3-4B`), not the base model (base gave `d_0` ~47% vs
|
|
74
|
+
82%). Higher precision raises acceptance; 8-bit is the sweet spot. `-bf16` maximizes acceptance;
|
|
75
|
+
4-bit verifies faster.
|
|
76
|
+
|
|
77
|
+
### Tuning
|
|
78
|
+
|
|
79
|
+
On Apple Silicon the target verify cost grows with the number of tokens verified, so the
|
|
80
|
+
optimum is to verify ~= the acceptance length: `--max-draft 4` (default). `--max-draft <block>`
|
|
81
|
+
(full 7) is faithful but *slower* on M-series. `--confidence-threshold 0.6` instead truncates
|
|
82
|
+
the block adaptively via the drafter's confidence head.
|
|
83
|
+
|
|
84
|
+
## License
|
|
85
|
+
|
|
86
|
+
MIT — see [`LICENSE`](LICENSE). This is an independent MLX port of the inference path of
|
|
87
|
+
DeepSeek's DSpark drafter; see [`NOTICE`](NOTICE) for attribution. No model weights are bundled.
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Warm benchmark: greedy vs DSpark speculative (both warm), multi-trial.
|
|
2
|
+
|
|
3
|
+
python benchmark.py # gemma4
|
|
4
|
+
python benchmark.py qwen3
|
|
5
|
+
"""
|
|
6
|
+
import sys, time, mlx.core as mx
|
|
7
|
+
from mlx_dspark.load import load_pair
|
|
8
|
+
from mlx_dspark.generate import (
|
|
9
|
+
speculative_generate, _make_target_cache, eos_token_ids, encode_prompt,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
family = sys.argv[1] if len(sys.argv) > 1 else "gemma4"
|
|
13
|
+
target, tok, drafter, cfg = load_pair(family)
|
|
14
|
+
eos = eos_token_ids(tok)
|
|
15
|
+
N = 100
|
|
16
|
+
PROMPTS = [
|
|
17
|
+
"Explain how rainbows form.",
|
|
18
|
+
"Write a Python function to check if a string is a palindrome.",
|
|
19
|
+
"Give three tips for staying focused while working.",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
def greedy(ids, n):
|
|
23
|
+
cache = _make_target_cache(target)
|
|
24
|
+
lg = target.plain(mx.array([ids]), cache); mx.eval(lg)
|
|
25
|
+
nx = int(mx.argmax(lg[0, -1]).item()); out = [nx]; t = time.time()
|
|
26
|
+
while len(out) < n and nx not in eos:
|
|
27
|
+
lg = target.plain(mx.array([[nx]]), cache); mx.eval(lg)
|
|
28
|
+
nx = int(mx.argmax(lg[0, -1]).item()); out.append(nx)
|
|
29
|
+
return len(out), time.time() - t
|
|
30
|
+
|
|
31
|
+
print(f"[{family}] warming up (ramping clocks)...")
|
|
32
|
+
for _ in range(2):
|
|
33
|
+
greedy(encode_prompt(tok, "Tell me about the sea.", True), 120)
|
|
34
|
+
speculative_generate(target, tok, drafter, "Tell me about the sea.",
|
|
35
|
+
max_new_tokens=120, max_draft_tokens=3)
|
|
36
|
+
|
|
37
|
+
print(f"\n{'prompt':<6} {'greedy':>9} | {'cap':>3} {'spec':>9} {'accept':>7} {'speedup':>8}")
|
|
38
|
+
agg = {}
|
|
39
|
+
for i, p in enumerate(PROMPTS):
|
|
40
|
+
ids = encode_prompt(tok, p, True)
|
|
41
|
+
gt, gs = greedy(ids, N); gtps = gt / gs
|
|
42
|
+
for cap in (2, 3, 4):
|
|
43
|
+
res = speculative_generate(target, tok, drafter, p, max_new_tokens=N,
|
|
44
|
+
max_draft_tokens=cap)
|
|
45
|
+
sp = gs / res.seconds
|
|
46
|
+
agg.setdefault(cap, []).append(sp)
|
|
47
|
+
tag = f"P{i}" if cap == 2 else ""
|
|
48
|
+
print(f"{tag:<6} {gtps:>7.1f}/s | {cap:>3} {res.tokens_per_sec:>7.1f}/s "
|
|
49
|
+
f"{res.mean_accept_len:>7.2f} {sp:>7.2f}x")
|
|
50
|
+
print("\nmean speedup by cap:")
|
|
51
|
+
for cap, v in agg.items():
|
|
52
|
+
print(f" cap={cap}: {sum(v)/len(v):.2f}x")
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mlx-dspark"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
description = "DSpark speculative decoding (semi-autoregressive drafting) for Apple Silicon via MLX"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10"
|
|
7
|
+
license = { text = "MIT" }
|
|
8
|
+
authors = [{ name = "erahim3", email = "erahim3@gmail.com" }]
|
|
9
|
+
keywords = ["mlx", "apple-silicon", "speculative-decoding", "dspark", "deepspec", "llm-inference"]
|
|
10
|
+
classifiers = [
|
|
11
|
+
"License :: OSI Approved :: MIT License",
|
|
12
|
+
"Programming Language :: Python :: 3",
|
|
13
|
+
"Operating System :: MacOS",
|
|
14
|
+
"Environment :: GPU",
|
|
15
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
16
|
+
]
|
|
17
|
+
dependencies = [
|
|
18
|
+
"mlx>=0.31.2",
|
|
19
|
+
"mlx-lm>=0.31.3",
|
|
20
|
+
# mlx-vlm gives us Gemma-4 + the existing EAGLE3/DFlash spec-decode reference impl
|
|
21
|
+
"mlx-vlm>=0.6.3",
|
|
22
|
+
"numpy",
|
|
23
|
+
"huggingface-hub",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[project.optional-dependencies]
|
|
27
|
+
dev = ["pytest", "ruff"]
|
|
28
|
+
|
|
29
|
+
[project.urls]
|
|
30
|
+
Homepage = "https://github.com/ARahim3/mlx-dspark"
|
|
31
|
+
Repository = "https://github.com/ARahim3/mlx-dspark"
|
|
32
|
+
Issues = "https://github.com/ARahim3/mlx-dspark/issues"
|
|
33
|
+
|
|
34
|
+
[build-system]
|
|
35
|
+
requires = ["hatchling"]
|
|
36
|
+
build-backend = "hatchling.build"
|
|
37
|
+
|
|
38
|
+
[tool.hatch.build.targets.wheel]
|
|
39
|
+
packages = ["src/mlx_dspark"]
|
|
40
|
+
|
|
41
|
+
[tool.ruff]
|
|
42
|
+
line-length = 100
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""mlx-dspark: DSpark speculative decoding for Apple Silicon (MLX).
|
|
2
|
+
|
|
3
|
+
DSpark (from DeepSeek's DeepSpec codebase) is a semi-autoregressive,
|
|
4
|
+
EAGLE-family speculative-decoding drafter:
|
|
5
|
+
|
|
6
|
+
- a *parallel backbone* proposes base logits for all K draft positions at once,
|
|
7
|
+
- a tiny *sequential head* (low-rank, previous-token-conditioned) corrects
|
|
8
|
+
suffix decay,
|
|
9
|
+
- a *confidence head* scores how likely each drafted token survives
|
|
10
|
+
verification (used here for adaptive draft length instead of the
|
|
11
|
+
server-side load-aware scheduler, which is irrelevant single-user).
|
|
12
|
+
|
|
13
|
+
This package targets single-user local inference on Apple Silicon.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
__version__ = "0.0.1"
|
|
17
|
+
|
|
18
|
+
from .config import DSparkConfig
|
|
19
|
+
from .load import (
|
|
20
|
+
DEFAULT_DRAFTER,
|
|
21
|
+
DEFAULT_TARGET,
|
|
22
|
+
PRESETS,
|
|
23
|
+
load_drafter,
|
|
24
|
+
load_pair,
|
|
25
|
+
load_target,
|
|
26
|
+
)
|
|
27
|
+
from .generate import GenResult, greedy_generate, speculative_generate
|
|
28
|
+
from .target import Target
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"DSparkConfig",
|
|
32
|
+
"Target",
|
|
33
|
+
"load_drafter",
|
|
34
|
+
"load_target",
|
|
35
|
+
"load_pair",
|
|
36
|
+
"speculative_generate",
|
|
37
|
+
"greedy_generate",
|
|
38
|
+
"GenResult",
|
|
39
|
+
"PRESETS",
|
|
40
|
+
"DEFAULT_TARGET",
|
|
41
|
+
"DEFAULT_DRAFTER",
|
|
42
|
+
]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""CLI: run DSpark speculative decoding on Apple Silicon (streams tokens live).
|
|
2
|
+
|
|
3
|
+
Side-by-side demo (record each, then stack the two screen captures):
|
|
4
|
+
|
|
5
|
+
# left panel — plain target, no drafter
|
|
6
|
+
python -m mlx_dspark --mode baseline --prompt "Explain how rainbows form." --max-new-tokens 220
|
|
7
|
+
|
|
8
|
+
# right panel — DSpark speculative decoding (same prompt, same output, faster)
|
|
9
|
+
python -m mlx_dspark --mode dspark --prompt "Explain how rainbows form." --max-new-tokens 220
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
from .generate import greedy_generate, speculative_generate
|
|
19
|
+
from .load import PRESETS, load_drafter, load_target
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _emit(s: str) -> None:
|
|
23
|
+
sys.stdout.write(s)
|
|
24
|
+
sys.stdout.flush()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main() -> None:
|
|
28
|
+
ap = argparse.ArgumentParser(prog="mlx_dspark")
|
|
29
|
+
ap.add_argument("--mode", choices=["dspark", "baseline"], default="dspark",
|
|
30
|
+
help="dspark = speculative decoding; baseline = plain greedy target")
|
|
31
|
+
ap.add_argument("--family", choices=["gemma4", "qwen3"], default="gemma4",
|
|
32
|
+
help="model preset (target + drafter); overridden by --target/--drafter")
|
|
33
|
+
ap.add_argument("--prompt", default="Explain how rainbows form, in a few sentences.")
|
|
34
|
+
ap.add_argument("--target", default=None)
|
|
35
|
+
ap.add_argument("--drafter", default=None)
|
|
36
|
+
ap.add_argument("--max-new-tokens", type=int, default=220)
|
|
37
|
+
ap.add_argument("--max-draft", type=int, default=4)
|
|
38
|
+
ap.add_argument("--confidence-threshold", type=float, default=0.0)
|
|
39
|
+
ap.add_argument("--drafter-bits", type=int, default=4)
|
|
40
|
+
ap.add_argument("--no-chat-template", action="store_true")
|
|
41
|
+
ap.add_argument("--no-stream", action="store_true")
|
|
42
|
+
args = ap.parse_args()
|
|
43
|
+
target_repo = args.target or PRESETS[args.family]["target"]
|
|
44
|
+
drafter_repo = args.drafter or PRESETS[args.family]["drafter"]
|
|
45
|
+
|
|
46
|
+
label = "DSpark speculative" if args.mode == "dspark" else "Baseline (plain greedy)"
|
|
47
|
+
print(f"loading {args.mode}: target={target_repo}"
|
|
48
|
+
+ (f", drafter={drafter_repo}" if args.mode == "dspark" else ""))
|
|
49
|
+
target, tok = load_target(target_repo)
|
|
50
|
+
drafter = None
|
|
51
|
+
if args.mode == "dspark":
|
|
52
|
+
drafter, _ = load_drafter(drafter_repo, quantize=args.drafter_bits > 0,
|
|
53
|
+
bits=max(args.drafter_bits, 2))
|
|
54
|
+
|
|
55
|
+
on_text = None if args.no_stream else _emit
|
|
56
|
+
print("\n" + "=" * 64)
|
|
57
|
+
print(f" ▶ {label} · {target_repo.split('/')[-1]}")
|
|
58
|
+
print("=" * 64)
|
|
59
|
+
|
|
60
|
+
if args.mode == "dspark":
|
|
61
|
+
res = speculative_generate(
|
|
62
|
+
target, tok, drafter, args.prompt,
|
|
63
|
+
max_new_tokens=args.max_new_tokens, max_draft_tokens=args.max_draft,
|
|
64
|
+
confidence_threshold=args.confidence_threshold,
|
|
65
|
+
apply_chat_template=not args.no_chat_template, on_text=on_text,
|
|
66
|
+
)
|
|
67
|
+
extra = f" · accept {res.mean_accept_len:.2f}/round · {res.target_forwards} target fwds"
|
|
68
|
+
else:
|
|
69
|
+
res = greedy_generate(
|
|
70
|
+
target, tok, args.prompt, max_new_tokens=args.max_new_tokens,
|
|
71
|
+
apply_chat_template=not args.no_chat_template, on_text=on_text,
|
|
72
|
+
)
|
|
73
|
+
extra = ""
|
|
74
|
+
if args.no_stream:
|
|
75
|
+
print(res.text)
|
|
76
|
+
|
|
77
|
+
print("\n" + "-" * 64)
|
|
78
|
+
print(f" {res.num_tokens} tokens · {res.seconds:.2f}s · "
|
|
79
|
+
f"\033[1m{res.tokens_per_sec:.1f} tok/s\033[0m{extra}")
|
|
80
|
+
print("-" * 64)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
main()
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""DSpark drafter config — loaded from the HF checkpoint's config.json.
|
|
2
|
+
|
|
3
|
+
Supports two drafter families with a shared inference path:
|
|
4
|
+
- gemma4 (gemma4_text): k_eq_v attention, v_norm, partial/proportional rope,
|
|
5
|
+
sandwich norms + layer_scalar, gelu-tanh MLP, logit softcap.
|
|
6
|
+
- qwen3 (qwen3): standard GQA (separate v_proj, no v_norm), default rope,
|
|
7
|
+
Llama-style 2-norm layer, silu MLP, no softcap.
|
|
8
|
+
Only the fields the MLX inference path needs are pulled out.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class DSparkConfig:
|
|
20
|
+
family: str = "gemma4" # "gemma4" | "qwen3"
|
|
21
|
+
|
|
22
|
+
# core dims
|
|
23
|
+
hidden_size: int = 3840
|
|
24
|
+
vocab_size: int = 262144
|
|
25
|
+
num_hidden_layers: int = 5
|
|
26
|
+
intermediate_size: int = 15360
|
|
27
|
+
rms_norm_eps: float = 1e-6
|
|
28
|
+
|
|
29
|
+
# attention
|
|
30
|
+
num_attention_heads: int = 16
|
|
31
|
+
num_key_value_heads: int = 8
|
|
32
|
+
num_global_key_value_heads: int = 1
|
|
33
|
+
head_dim: int = 256
|
|
34
|
+
global_head_dim: int = 512
|
|
35
|
+
attention_k_eq_v: bool = True
|
|
36
|
+
attention_bias: bool = False
|
|
37
|
+
|
|
38
|
+
# rope
|
|
39
|
+
rope_theta: float = 1_000_000.0
|
|
40
|
+
partial_rotary_factor: float = 0.25
|
|
41
|
+
rope_type: str = "proportional"
|
|
42
|
+
|
|
43
|
+
# dspark specifics
|
|
44
|
+
block_size: int = 7
|
|
45
|
+
mask_token_id: int = 4
|
|
46
|
+
target_layer_ids: list[int] = field(default_factory=lambda: [5, 17, 29, 41, 46])
|
|
47
|
+
num_target_layers: int = 48
|
|
48
|
+
|
|
49
|
+
# markov + confidence
|
|
50
|
+
markov_rank: int = 256
|
|
51
|
+
markov_head_type: str = "vanilla"
|
|
52
|
+
enable_confidence_head: bool = True
|
|
53
|
+
confidence_head_with_markov: bool = True
|
|
54
|
+
|
|
55
|
+
# logits
|
|
56
|
+
final_logit_softcapping: float | None = 30.0
|
|
57
|
+
pad_token_id: int = 0
|
|
58
|
+
|
|
59
|
+
# ---- family-derived knobs (set in from_json) ----
|
|
60
|
+
mlp_activation: str = "gelu_tanh" # "gelu_tanh" | "silu"
|
|
61
|
+
norm_style: str = "gemma" # "gemma" (sandwich+scalar) | "qwen" (llama 2-norm)
|
|
62
|
+
use_v_norm: bool = True # gemma: RMSNormNoScale v_norm; qwen: none
|
|
63
|
+
attention_scaling: float | None = None # None -> 1/sqrt(attn_head_dim)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def attn_head_dim(self) -> int:
|
|
67
|
+
"""Head dim used by the drafter's own attention."""
|
|
68
|
+
return self.global_head_dim if self.family == "gemma4" else self.head_dim
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def n_kv_heads(self) -> int:
|
|
72
|
+
if self.family == "gemma4" and self.attention_k_eq_v:
|
|
73
|
+
return self.num_global_key_value_heads
|
|
74
|
+
return self.num_key_value_heads
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def scaling(self) -> float:
|
|
78
|
+
if self.attention_scaling is not None:
|
|
79
|
+
return self.attention_scaling
|
|
80
|
+
return self.attn_head_dim ** -0.5 if self.family == "qwen3" else 1.0
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def rope_parameters(self) -> dict:
|
|
84
|
+
return {"rope_type": self.rope_type, "partial_rotary_factor": self.partial_rotary_factor}
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_json(cls, path: str | Path) -> "DSparkConfig":
|
|
88
|
+
with open(path) as f:
|
|
89
|
+
c = json.load(f)
|
|
90
|
+
mt = c.get("model_type", "")
|
|
91
|
+
family = "qwen3" if "qwen3" in mt else "gemma4"
|
|
92
|
+
|
|
93
|
+
if family == "qwen3":
|
|
94
|
+
rp = c.get("rope_parameters") or {}
|
|
95
|
+
return cls(
|
|
96
|
+
family="qwen3",
|
|
97
|
+
hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
|
|
98
|
+
num_hidden_layers=c["num_hidden_layers"],
|
|
99
|
+
intermediate_size=c["intermediate_size"],
|
|
100
|
+
rms_norm_eps=c.get("rms_norm_eps", 1e-6),
|
|
101
|
+
num_attention_heads=c["num_attention_heads"],
|
|
102
|
+
num_key_value_heads=c.get("num_key_value_heads", 8),
|
|
103
|
+
head_dim=c.get("head_dim", c["hidden_size"] // c["num_attention_heads"]),
|
|
104
|
+
attention_k_eq_v=False, attention_bias=c.get("attention_bias", False),
|
|
105
|
+
rope_theta=rp.get("rope_theta", c.get("rope_theta", 1_000_000.0)),
|
|
106
|
+
rope_type="default",
|
|
107
|
+
block_size=c["block_size"], mask_token_id=c["mask_token_id"],
|
|
108
|
+
target_layer_ids=list(c["target_layer_ids"]),
|
|
109
|
+
num_target_layers=c.get("num_target_layers", 36),
|
|
110
|
+
markov_rank=c.get("markov_rank", 256),
|
|
111
|
+
markov_head_type=c.get("markov_head_type", "vanilla"),
|
|
112
|
+
enable_confidence_head=c.get("enable_confidence_head", True),
|
|
113
|
+
confidence_head_with_markov=c.get("confidence_head_with_markov", True),
|
|
114
|
+
final_logit_softcapping=c.get("final_logit_softcapping", None),
|
|
115
|
+
pad_token_id=c.get("pad_token_id") or 0,
|
|
116
|
+
mlp_activation="silu", norm_style="qwen", use_v_norm=False,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
rope = (c.get("rope_parameters") or {}).get("full_attention", {}) or {}
|
|
120
|
+
return cls(
|
|
121
|
+
family="gemma4",
|
|
122
|
+
hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
|
|
123
|
+
num_hidden_layers=c["num_hidden_layers"],
|
|
124
|
+
intermediate_size=c["intermediate_size"],
|
|
125
|
+
rms_norm_eps=c.get("rms_norm_eps", 1e-6),
|
|
126
|
+
num_attention_heads=c["num_attention_heads"],
|
|
127
|
+
num_key_value_heads=c.get("num_key_value_heads", 8),
|
|
128
|
+
num_global_key_value_heads=c.get("num_global_key_value_heads", 1),
|
|
129
|
+
head_dim=c.get("head_dim", 256), global_head_dim=c.get("global_head_dim", 512),
|
|
130
|
+
attention_k_eq_v=c.get("attention_k_eq_v", True),
|
|
131
|
+
attention_bias=c.get("attention_bias", False),
|
|
132
|
+
rope_theta=rope.get("rope_theta", 1_000_000.0),
|
|
133
|
+
partial_rotary_factor=rope.get("partial_rotary_factor", 0.25),
|
|
134
|
+
rope_type=rope.get("rope_type", "proportional"),
|
|
135
|
+
block_size=c["block_size"], mask_token_id=c["mask_token_id"],
|
|
136
|
+
target_layer_ids=list(c["target_layer_ids"]),
|
|
137
|
+
num_target_layers=c.get("num_target_layers", 48),
|
|
138
|
+
markov_rank=c.get("markov_rank", 256),
|
|
139
|
+
markov_head_type=c.get("markov_head_type", "vanilla"),
|
|
140
|
+
enable_confidence_head=c.get("enable_confidence_head", True),
|
|
141
|
+
confidence_head_with_markov=c.get("confidence_head_with_markov", True),
|
|
142
|
+
final_logit_softcapping=c.get("final_logit_softcapping", 30.0),
|
|
143
|
+
pad_token_id=c.get("pad_token_id", 0),
|
|
144
|
+
mlp_activation="gelu_tanh", norm_style="gemma", use_v_norm=True,
|
|
145
|
+
)
|