mlx-dspark 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlx_dspark/__init__.py ADDED
@@ -0,0 +1,42 @@
1
+ """mlx-dspark: DSpark speculative decoding for Apple Silicon (MLX).
2
+
3
+ DSpark (from DeepSeek's DeepSpec codebase) is a semi-autoregressive,
4
+ EAGLE-family speculative-decoding drafter:
5
+
6
+ - a *parallel backbone* proposes base logits for all K draft positions at once,
7
+ - a tiny *sequential head* (low-rank, previous-token-conditioned) corrects
8
+ suffix decay,
9
+ - a *confidence head* scores how likely each drafted token survives
10
+ verification (used here for adaptive draft length instead of the
11
+ server-side load-aware scheduler, which is irrelevant single-user).
12
+
13
+ This package targets single-user local inference on Apple Silicon.
14
+ """
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ from .config import DSparkConfig
19
+ from .load import (
20
+ DEFAULT_DRAFTER,
21
+ DEFAULT_TARGET,
22
+ PRESETS,
23
+ load_drafter,
24
+ load_pair,
25
+ load_target,
26
+ )
27
+ from .generate import GenResult, greedy_generate, speculative_generate
28
+ from .target import Target
29
+
30
+ __all__ = [
31
+ "DSparkConfig",
32
+ "Target",
33
+ "load_drafter",
34
+ "load_target",
35
+ "load_pair",
36
+ "speculative_generate",
37
+ "greedy_generate",
38
+ "GenResult",
39
+ "PRESETS",
40
+ "DEFAULT_TARGET",
41
+ "DEFAULT_DRAFTER",
42
+ ]
mlx_dspark/__main__.py ADDED
@@ -0,0 +1,84 @@
1
+ """CLI: run DSpark speculative decoding on Apple Silicon (streams tokens live).
2
+
3
+ Side-by-side demo (record each, then stack the two screen captures):
4
+
5
+ # left panel — plain target, no drafter
6
+ python -m mlx_dspark --mode baseline --prompt "Explain how rainbows form." --max-new-tokens 220
7
+
8
+ # right panel — DSpark speculative decoding (same prompt, same output, faster)
9
+ python -m mlx_dspark --mode dspark --prompt "Explain how rainbows form." --max-new-tokens 220
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import sys
16
+ import time
17
+
18
+ from .generate import greedy_generate, speculative_generate
19
+ from .load import PRESETS, load_drafter, load_target
20
+
21
+
22
+ def _emit(s: str) -> None:
23
+ sys.stdout.write(s)
24
+ sys.stdout.flush()
25
+
26
+
27
+ def main() -> None:
28
+ ap = argparse.ArgumentParser(prog="mlx_dspark")
29
+ ap.add_argument("--mode", choices=["dspark", "baseline"], default="dspark",
30
+ help="dspark = speculative decoding; baseline = plain greedy target")
31
+ ap.add_argument("--family", choices=["gemma4", "qwen3"], default="gemma4",
32
+ help="model preset (target + drafter); overridden by --target/--drafter")
33
+ ap.add_argument("--prompt", default="Explain how rainbows form, in a few sentences.")
34
+ ap.add_argument("--target", default=None)
35
+ ap.add_argument("--drafter", default=None)
36
+ ap.add_argument("--max-new-tokens", type=int, default=220)
37
+ ap.add_argument("--max-draft", type=int, default=4)
38
+ ap.add_argument("--confidence-threshold", type=float, default=0.0)
39
+ ap.add_argument("--drafter-bits", type=int, default=4)
40
+ ap.add_argument("--no-chat-template", action="store_true")
41
+ ap.add_argument("--no-stream", action="store_true")
42
+ args = ap.parse_args()
43
+ target_repo = args.target or PRESETS[args.family]["target"]
44
+ drafter_repo = args.drafter or PRESETS[args.family]["drafter"]
45
+
46
+ label = "DSpark speculative" if args.mode == "dspark" else "Baseline (plain greedy)"
47
+ print(f"loading {args.mode}: target={target_repo}"
48
+ + (f", drafter={drafter_repo}" if args.mode == "dspark" else ""))
49
+ target, tok = load_target(target_repo)
50
+ drafter = None
51
+ if args.mode == "dspark":
52
+ drafter, _ = load_drafter(drafter_repo, quantize=args.drafter_bits > 0,
53
+ bits=max(args.drafter_bits, 2))
54
+
55
+ on_text = None if args.no_stream else _emit
56
+ print("\n" + "=" * 64)
57
+ print(f" ▶ {label} · {target_repo.split('/')[-1]}")
58
+ print("=" * 64)
59
+
60
+ if args.mode == "dspark":
61
+ res = speculative_generate(
62
+ target, tok, drafter, args.prompt,
63
+ max_new_tokens=args.max_new_tokens, max_draft_tokens=args.max_draft,
64
+ confidence_threshold=args.confidence_threshold,
65
+ apply_chat_template=not args.no_chat_template, on_text=on_text,
66
+ )
67
+ extra = f" · accept {res.mean_accept_len:.2f}/round · {res.target_forwards} target fwds"
68
+ else:
69
+ res = greedy_generate(
70
+ target, tok, args.prompt, max_new_tokens=args.max_new_tokens,
71
+ apply_chat_template=not args.no_chat_template, on_text=on_text,
72
+ )
73
+ extra = ""
74
+ if args.no_stream:
75
+ print(res.text)
76
+
77
+ print("\n" + "-" * 64)
78
+ print(f" {res.num_tokens} tokens · {res.seconds:.2f}s · "
79
+ f"\033[1m{res.tokens_per_sec:.1f} tok/s\033[0m{extra}")
80
+ print("-" * 64)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()
mlx_dspark/config.py ADDED
@@ -0,0 +1,145 @@
1
+ """DSpark drafter config — loaded from the HF checkpoint's config.json.
2
+
3
+ Supports two drafter families with a shared inference path:
4
+ - gemma4 (gemma4_text): k_eq_v attention, v_norm, partial/proportional rope,
5
+ sandwich norms + layer_scalar, gelu-tanh MLP, logit softcap.
6
+ - qwen3 (qwen3): standard GQA (separate v_proj, no v_norm), default rope,
7
+ Llama-style 2-norm layer, silu MLP, no softcap.
8
+ Only the fields the MLX inference path needs are pulled out.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+
17
+
18
+ @dataclass
19
+ class DSparkConfig:
20
+ family: str = "gemma4" # "gemma4" | "qwen3"
21
+
22
+ # core dims
23
+ hidden_size: int = 3840
24
+ vocab_size: int = 262144
25
+ num_hidden_layers: int = 5
26
+ intermediate_size: int = 15360
27
+ rms_norm_eps: float = 1e-6
28
+
29
+ # attention
30
+ num_attention_heads: int = 16
31
+ num_key_value_heads: int = 8
32
+ num_global_key_value_heads: int = 1
33
+ head_dim: int = 256
34
+ global_head_dim: int = 512
35
+ attention_k_eq_v: bool = True
36
+ attention_bias: bool = False
37
+
38
+ # rope
39
+ rope_theta: float = 1_000_000.0
40
+ partial_rotary_factor: float = 0.25
41
+ rope_type: str = "proportional"
42
+
43
+ # dspark specifics
44
+ block_size: int = 7
45
+ mask_token_id: int = 4
46
+ target_layer_ids: list[int] = field(default_factory=lambda: [5, 17, 29, 41, 46])
47
+ num_target_layers: int = 48
48
+
49
+ # markov + confidence
50
+ markov_rank: int = 256
51
+ markov_head_type: str = "vanilla"
52
+ enable_confidence_head: bool = True
53
+ confidence_head_with_markov: bool = True
54
+
55
+ # logits
56
+ final_logit_softcapping: float | None = 30.0
57
+ pad_token_id: int = 0
58
+
59
+ # ---- family-derived knobs (set in from_json) ----
60
+ mlp_activation: str = "gelu_tanh" # "gelu_tanh" | "silu"
61
+ norm_style: str = "gemma" # "gemma" (sandwich+scalar) | "qwen" (llama 2-norm)
62
+ use_v_norm: bool = True # gemma: RMSNormNoScale v_norm; qwen: none
63
+ attention_scaling: float | None = None # None -> 1/sqrt(attn_head_dim)
64
+
65
+ @property
66
+ def attn_head_dim(self) -> int:
67
+ """Head dim used by the drafter's own attention."""
68
+ return self.global_head_dim if self.family == "gemma4" else self.head_dim
69
+
70
+ @property
71
+ def n_kv_heads(self) -> int:
72
+ if self.family == "gemma4" and self.attention_k_eq_v:
73
+ return self.num_global_key_value_heads
74
+ return self.num_key_value_heads
75
+
76
+ @property
77
+ def scaling(self) -> float:
78
+ if self.attention_scaling is not None:
79
+ return self.attention_scaling
80
+ return self.attn_head_dim ** -0.5 if self.family == "qwen3" else 1.0
81
+
82
+ @property
83
+ def rope_parameters(self) -> dict:
84
+ return {"rope_type": self.rope_type, "partial_rotary_factor": self.partial_rotary_factor}
85
+
86
+ @classmethod
87
+ def from_json(cls, path: str | Path) -> "DSparkConfig":
88
+ with open(path) as f:
89
+ c = json.load(f)
90
+ mt = c.get("model_type", "")
91
+ family = "qwen3" if "qwen3" in mt else "gemma4"
92
+
93
+ if family == "qwen3":
94
+ rp = c.get("rope_parameters") or {}
95
+ return cls(
96
+ family="qwen3",
97
+ hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
98
+ num_hidden_layers=c["num_hidden_layers"],
99
+ intermediate_size=c["intermediate_size"],
100
+ rms_norm_eps=c.get("rms_norm_eps", 1e-6),
101
+ num_attention_heads=c["num_attention_heads"],
102
+ num_key_value_heads=c.get("num_key_value_heads", 8),
103
+ head_dim=c.get("head_dim", c["hidden_size"] // c["num_attention_heads"]),
104
+ attention_k_eq_v=False, attention_bias=c.get("attention_bias", False),
105
+ rope_theta=rp.get("rope_theta", c.get("rope_theta", 1_000_000.0)),
106
+ rope_type="default",
107
+ block_size=c["block_size"], mask_token_id=c["mask_token_id"],
108
+ target_layer_ids=list(c["target_layer_ids"]),
109
+ num_target_layers=c.get("num_target_layers", 36),
110
+ markov_rank=c.get("markov_rank", 256),
111
+ markov_head_type=c.get("markov_head_type", "vanilla"),
112
+ enable_confidence_head=c.get("enable_confidence_head", True),
113
+ confidence_head_with_markov=c.get("confidence_head_with_markov", True),
114
+ final_logit_softcapping=c.get("final_logit_softcapping", None),
115
+ pad_token_id=c.get("pad_token_id") or 0,
116
+ mlp_activation="silu", norm_style="qwen", use_v_norm=False,
117
+ )
118
+
119
+ rope = (c.get("rope_parameters") or {}).get("full_attention", {}) or {}
120
+ return cls(
121
+ family="gemma4",
122
+ hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
123
+ num_hidden_layers=c["num_hidden_layers"],
124
+ intermediate_size=c["intermediate_size"],
125
+ rms_norm_eps=c.get("rms_norm_eps", 1e-6),
126
+ num_attention_heads=c["num_attention_heads"],
127
+ num_key_value_heads=c.get("num_key_value_heads", 8),
128
+ num_global_key_value_heads=c.get("num_global_key_value_heads", 1),
129
+ head_dim=c.get("head_dim", 256), global_head_dim=c.get("global_head_dim", 512),
130
+ attention_k_eq_v=c.get("attention_k_eq_v", True),
131
+ attention_bias=c.get("attention_bias", False),
132
+ rope_theta=rope.get("rope_theta", 1_000_000.0),
133
+ partial_rotary_factor=rope.get("partial_rotary_factor", 0.25),
134
+ rope_type=rope.get("rope_type", "proportional"),
135
+ block_size=c["block_size"], mask_token_id=c["mask_token_id"],
136
+ target_layer_ids=list(c["target_layer_ids"]),
137
+ num_target_layers=c.get("num_target_layers", 48),
138
+ markov_rank=c.get("markov_rank", 256),
139
+ markov_head_type=c.get("markov_head_type", "vanilla"),
140
+ enable_confidence_head=c.get("enable_confidence_head", True),
141
+ confidence_head_with_markov=c.get("confidence_head_with_markov", True),
142
+ final_logit_softcapping=c.get("final_logit_softcapping", 30.0),
143
+ pad_token_id=c.get("pad_token_id", 0),
144
+ mlp_activation="gelu_tanh", norm_style="gemma", use_v_norm=True,
145
+ )
mlx_dspark/generate.py ADDED
@@ -0,0 +1,289 @@
1
+ """DSpark speculative decoding loop (greedy, batch=1) for Apple Silicon.
2
+
3
+ Per round:
4
+ 1. draft a block of K tokens from the parallel backbone + Markov head,
5
+ 2. verify them in one target forward,
6
+ 3. accept the matching prefix + 1 bonus token (so >=1 token/round always),
7
+ 4. trim the target KV cache and grow the fused-hidden context buffer.
8
+
9
+ Because the target verifies every token, the *output is exactly greedy target
10
+ decoding* regardless of drafter quality — drafter quality only shows up as the
11
+ acceptance length (tokens committed per target forward).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import time
17
+ from dataclasses import dataclass
18
+
19
+ import mlx.core as mx
20
+
21
+ TAP = None # set from drafter config at call time
22
+
23
+
24
+ @dataclass
25
+ class GenResult:
26
+ text: str
27
+ token_ids: list[int]
28
+ num_tokens: int
29
+ num_rounds: int
30
+ accept_lengths: list[int]
31
+ target_forwards: int
32
+ seconds: float
33
+
34
+ @property
35
+ def mean_accept_len(self) -> float:
36
+ return self.num_tokens / max(self.num_rounds, 1)
37
+
38
+ @property
39
+ def tokens_per_sec(self) -> float:
40
+ return self.num_tokens / max(self.seconds, 1e-9)
41
+
42
+
43
+ def encode_prompt(tokenizer, prompt: str, use_chat: bool = True) -> list[int]:
44
+ """Token ids for a user prompt, using the model's chat template when present.
45
+
46
+ Gemma-4 uses `<|turn>` / `<channel|>` markers (NOT Gemma-3's `<start_of_turn>`),
47
+ so the template must be applied via the tokenizer — hand-formatting breaks the
48
+ instruct model. apply_chat_template may return list[int] or a BatchEncoding.
49
+ """
50
+ if use_chat and getattr(tokenizer, "chat_template", None):
51
+ r = tokenizer.apply_chat_template(
52
+ [{"role": "user", "content": prompt}], add_generation_prompt=True
53
+ )
54
+ if isinstance(r, (list, tuple)):
55
+ if r and isinstance(r[0], int):
56
+ return list(r)
57
+ if r and isinstance(r[0], (list, tuple)):
58
+ return list(r[0])
59
+ ii = None
60
+ if hasattr(r, "__contains__") and "input_ids" in r:
61
+ ii = r["input_ids"]
62
+ elif hasattr(r, "input_ids"):
63
+ ii = r.input_ids
64
+ if ii is not None:
65
+ ii = list(ii)
66
+ return list(ii[0]) if ii and isinstance(ii[0], (list, tuple)) else ii
67
+ if hasattr(r, "ids"):
68
+ return list(r.ids)
69
+ return list(tokenizer.encode(prompt))
70
+
71
+
72
+ def eos_token_ids(tokenizer) -> set[int]:
73
+ """Collect stop-token ids: eos + Gemma turn-end markers (Gemma-4 uses <turn|>=106;
74
+ note <end_of_turn> is the UNK id in Gemma-4, so it must be filtered out)."""
75
+ ids: set[int] = set()
76
+ e = getattr(tokenizer, "eos_token_ids", None)
77
+ if isinstance(e, int):
78
+ ids.add(e)
79
+ elif e:
80
+ ids.update(int(x) for x in e)
81
+ e1 = getattr(tokenizer, "eos_token_id", None)
82
+ if isinstance(e1, int):
83
+ ids.add(e1)
84
+ unk = getattr(tokenizer, "unk_token_id", None)
85
+ # Gemma-4 (<turn|>), Gemma-3 (<end_of_turn>), Qwen (<|im_end|>), raw eos
86
+ for t in ("<turn|>", "<end_of_turn>", "<|im_end|>", "<|endoftext|>", "<eos>"):
87
+ try:
88
+ i = tokenizer.convert_tokens_to_ids(t)
89
+ except Exception:
90
+ continue
91
+ if isinstance(i, int) and i >= 0 and i != unk:
92
+ ids.add(i)
93
+ return ids
94
+
95
+
96
+ def greedy_generate(
97
+ target_model,
98
+ tokenizer,
99
+ prompt: str,
100
+ *,
101
+ max_new_tokens: int = 128,
102
+ apply_chat_template: bool = True,
103
+ on_text=None,
104
+ ) -> GenResult:
105
+ """Plain greedy decoding of the target (no drafter, no hidden-state capture) —
106
+ the fair 'run the model normally' baseline. Streams via on_text."""
107
+ eos_ids = eos_token_ids(tokenizer)
108
+ ids = encode_prompt(tokenizer, prompt, use_chat=apply_chat_template)
109
+ cache = target_model.make_cache()
110
+
111
+ t0 = time.time()
112
+ logits = target_model.plain(mx.array([ids]), cache)
113
+ nxt = int(mx.argmax(logits[0, -1]).item())
114
+ out_ids = [nxt]
115
+ streamed = 0
116
+
117
+ def _stream():
118
+ nonlocal streamed
119
+ if on_text is None:
120
+ return
121
+ disp = [t for t in out_ids if t not in eos_ids]
122
+ full = tokenizer.decode(disp)
123
+ if len(full) > streamed:
124
+ on_text(full[streamed:])
125
+ streamed = len(full)
126
+
127
+ _stream()
128
+ while len(out_ids) < max_new_tokens and nxt not in eos_ids:
129
+ logits = target_model.plain(mx.array([[nxt]]), cache)
130
+ nxt = int(mx.argmax(logits[0, -1]).item())
131
+ out_ids.append(nxt)
132
+ _stream()
133
+
134
+ secs = time.time() - t0
135
+ disp = [t for t in out_ids if t not in eos_ids]
136
+ return GenResult(
137
+ text=tokenizer.decode(disp),
138
+ token_ids=out_ids,
139
+ num_tokens=len(out_ids),
140
+ num_rounds=len(out_ids),
141
+ accept_lengths=[1] * len(out_ids),
142
+ target_forwards=len(out_ids),
143
+ seconds=secs,
144
+ )
145
+
146
+
147
+ def _make_target_cache(target):
148
+ return target.make_cache()
149
+
150
+
151
+ def _run_target(target, ids: mx.array, cache, tap: list[int]):
152
+ """ids: [1, L]. Returns (logits[1,L,V], fused_hidden[1,L,n_tap*H])."""
153
+ return target.run(ids, cache, tap)
154
+
155
+
156
+ def speculative_generate(
157
+ target_model,
158
+ tokenizer,
159
+ drafter,
160
+ prompt: str,
161
+ *,
162
+ max_new_tokens: int = 128,
163
+ confidence_threshold: float = 0.0,
164
+ max_draft_tokens: int | None = 4,
165
+ apply_chat_template: bool = True,
166
+ on_text=None,
167
+ verbose: bool = False,
168
+ ) -> GenResult:
169
+ """Greedy speculative decoding. Output is target-greedy by construction (up to
170
+ fp tie-breaking on near-ties). ``max_draft_tokens`` caps how many of the 7-token
171
+ block are verified per round; on Apple Silicon the target verify cost grows with
172
+ tokens, so the optimum is ~= acceptance length (default 4). ``None`` = full block
173
+ (faithful but slower on M-series). ``confidence_threshold`` > 0 instead truncates
174
+ the block adaptively using the drafter's confidence head."""
175
+ cfg = drafter.config
176
+ tap = list(cfg.target_layer_ids)
177
+ k = cfg.block_size
178
+ mask_id = cfg.mask_token_id
179
+ cap = k if max_draft_tokens is None else max(1, min(max_draft_tokens, k))
180
+
181
+ eos_ids = eos_token_ids(tokenizer)
182
+
183
+ # --- tokenize prompt ---
184
+ ids = encode_prompt(tokenizer, prompt, use_chat=apply_chat_template)
185
+ prompt_ids = mx.array([ids])
186
+
187
+ cache = _make_target_cache(target_model)
188
+ ctx_caches = drafter.make_ctx_cache()
189
+ t0 = time.time()
190
+
191
+ # --- prefill ---
192
+ logits, fused = _run_target(target_model, prompt_ids, cache, tap)
193
+ n_cached = prompt_ids.shape[1]
194
+ drafter.update_context(fused, ctx_offset=0, ctx_caches=ctx_caches)
195
+ pending = int(mx.argmax(logits[0, -1]).item()) # first committed token
196
+ mx.eval([c.k for c in ctx_caches])
197
+
198
+ out_ids: list[int] = [pending]
199
+ accept_lengths: list[int] = []
200
+ target_forwards = 1
201
+ streamed = 0
202
+
203
+ def _stream():
204
+ nonlocal streamed
205
+ if on_text is None:
206
+ return
207
+ disp = [t for t in out_ids if t not in eos_ids]
208
+ full = tokenizer.decode(disp)
209
+ if len(full) > streamed:
210
+ on_text(full[streamed:])
211
+ streamed = len(full)
212
+
213
+ _stream()
214
+ while len(out_ids) < max_new_tokens and pending not in eos_ids:
215
+ # ---- 1. draft a block ----
216
+ block_ids = [pending] + [mask_id] * (k - 1)
217
+ noise = drafter.embed(mx.array([block_ids])) # [1, k, H]
218
+ block_hidden = drafter.backbone(noise, n_cached, ctx_caches)
219
+ base_logits = drafter.compute_logits(block_hidden)[0] # [k, V]
220
+ draft = drafter.sample_block(base_logits, first_prev_token=pending)
221
+ mx.eval(draft)
222
+ draft = [int(x) for x in draft.tolist()]
223
+
224
+ # optional confidence-based truncation (adaptive block length)
225
+ if confidence_threshold > 0.0 and drafter.confidence_head is not None:
226
+ prev_tokens = mx.array([pending] + draft[:-1])
227
+ conf = mx.sigmoid(drafter.confidence_logits(block_hidden[0], prev_tokens))
228
+ mx.eval(conf)
229
+ below = [i for i, c in enumerate(conf.tolist()) if c < confidence_threshold]
230
+ if below:
231
+ draft = draft[: below[0]]
232
+ if cap < len(draft):
233
+ draft = draft[:cap]
234
+ if not draft:
235
+ draft = [int(mx.argmax(base_logits[0]).item())] # always propose >=1
236
+
237
+ # ---- 2. verify with the target ----
238
+ verify_ids = mx.array([[pending] + draft]) # [1, 1+len(draft)]
239
+ v_logits, v_fused = _run_target(target_model, verify_ids, cache, tap)
240
+ mx.eval(v_logits, v_fused)
241
+ target_forwards += 1
242
+ tt = mx.argmax(v_logits[0], axis=-1) # [1+len(draft)]
243
+ tt = [int(x) for x in tt.tolist()]
244
+
245
+ # ---- 3. accept matching prefix + bonus ----
246
+ n = 0
247
+ while n < len(draft) and draft[n] == tt[n]:
248
+ n += 1
249
+ bonus = tt[n] # correction / continuation
250
+ committed = draft[:n] + [bonus]
251
+ accept_lengths.append(len(committed))
252
+
253
+ # ---- 4. update caches/context ----
254
+ trim = len(draft) - n
255
+ if trim > 0:
256
+ for c in cache:
257
+ if c is not None and hasattr(c, "trim"):
258
+ c.trim(trim)
259
+ # commit [pending, accepted drafts] (positions n_cached..n_cached+n) as context
260
+ drafter.update_context(
261
+ v_fused[:, : n + 1, :], ctx_offset=n_cached, ctx_caches=ctx_caches
262
+ )
263
+ n_cached = n_cached + n + 1
264
+ mx.eval([c.k for c in ctx_caches])
265
+
266
+ for tok in committed:
267
+ out_ids.append(tok)
268
+ if tok in eos_ids:
269
+ break
270
+ pending = committed[-1]
271
+ _stream()
272
+
273
+ if verbose:
274
+ print(f" round {len(accept_lengths):3d}: drafted {len(draft)}, "
275
+ f"accepted {n}, committed {len(committed)}")
276
+
277
+ secs = time.time() - t0
278
+ # strip trailing eos for display
279
+ disp = [t for t in out_ids if t not in eos_ids]
280
+ text = tokenizer.decode(disp)
281
+ return GenResult(
282
+ text=text,
283
+ token_ids=out_ids,
284
+ num_tokens=len(out_ids),
285
+ num_rounds=len(accept_lengths),
286
+ accept_lengths=accept_lengths,
287
+ target_forwards=target_forwards,
288
+ seconds=secs,
289
+ )
mlx_dspark/load.py ADDED
@@ -0,0 +1,117 @@
1
+ """Loaders for the target (Gemma-4 via mlx-vlm) and the DSpark drafter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import glob
6
+ import os
7
+
8
+ import mlx.core as mx
9
+ import mlx.nn as nn
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from .config import DSparkConfig
13
+ from .model import DSparkDrafter
14
+ from .target import Target
15
+
16
+ # The drafter must be paired with the *instruct* target it was trained against, at decent
17
+ # precision. Presets below; pick with load_pair("gemma4") or load_pair("qwen3").
18
+ PRESETS = {
19
+ "gemma4": {
20
+ "target": "mlx-community/gemma-4-12B-it-8bit",
21
+ "drafter": "deepseek-ai/dspark_gemma4_12b_block7",
22
+ },
23
+ "qwen3": {
24
+ "target": "mlx-community/Qwen3-4B-8bit",
25
+ "drafter": "deepseek-ai/dspark_qwen3_4b_block7",
26
+ },
27
+ }
28
+ DEFAULT_TARGET = PRESETS["gemma4"]["target"]
29
+ DEFAULT_DRAFTER = PRESETS["gemma4"]["drafter"]
30
+
31
+
32
+ def _resolve(repo_or_path: str) -> str:
33
+ if os.path.isdir(repo_or_path):
34
+ return repo_or_path
35
+ return snapshot_download(repo_or_path)
36
+
37
+
38
+ def load_drafter(
39
+ repo_or_path: str = DEFAULT_DRAFTER,
40
+ *,
41
+ quantize: bool = True,
42
+ bits: int = 4,
43
+ group_size: int = 64,
44
+ ):
45
+ """Return (drafter, config). Loads bf16 weights 1:1 by matching key names.
46
+
47
+ The drafter is ~6.86 GB in bf16 and runs every speculative round, so by
48
+ default it is quantized to 4-bit (~1.8 GB) — this is what makes spec
49
+ decoding a net speedup on Apple Silicon. Output correctness is unaffected
50
+ (the target verifies every token); only acceptance length may change.
51
+ """
52
+ path = _resolve(repo_or_path)
53
+ config = DSparkConfig.from_json(os.path.join(path, "config.json"))
54
+ drafter = DSparkDrafter(config)
55
+
56
+ weights: dict[str, mx.array] = {}
57
+ for st in glob.glob(os.path.join(path, "*.safetensors")):
58
+ weights.update(mx.load(st))
59
+
60
+ # Diagnose name mismatches before loading.
61
+ model_keys = {k for k, _ in _flatten_params(drafter)}
62
+ ckpt_keys = set(weights.keys())
63
+ missing = sorted(model_keys - ckpt_keys)
64
+ unexpected = sorted(ckpt_keys - model_keys)
65
+ if missing or unexpected:
66
+ print(f"[load_drafter] WARNING key mismatch:")
67
+ if missing:
68
+ print(f" missing in checkpoint ({len(missing)}): {missing[:8]}")
69
+ if unexpected:
70
+ print(f" unexpected in checkpoint ({len(unexpected)}): {unexpected[:8]}")
71
+
72
+ drafter.load_weights(list(weights.items()), strict=not (missing or unexpected))
73
+
74
+ if quantize:
75
+ # Quantize Linear/Embedding weights; norms/scalars stay full precision.
76
+ nn.quantize(drafter, group_size=group_size, bits=bits)
77
+
78
+ mx.eval(drafter.parameters())
79
+ return drafter, config
80
+
81
+
82
+ def _flatten_params(module) -> list[tuple[str, mx.array]]:
83
+ from mlx.utils import tree_flatten
84
+
85
+ return tree_flatten(module.parameters())
86
+
87
+
88
+ def load_target(repo_or_path: str = DEFAULT_TARGET):
89
+ """Return (Target, tokenizer). Routes text models (Qwen3) to mlx-lm and VLM/unified
90
+ models (Gemma-4) to mlx-vlm, then wraps in a family-aware Target (hidden-state tap)."""
91
+ import json
92
+
93
+ path = _resolve(repo_or_path)
94
+ model_type = ""
95
+ cfg_path = os.path.join(path, "config.json")
96
+ if os.path.exists(cfg_path):
97
+ with open(cfg_path) as f:
98
+ model_type = json.load(f).get("model_type", "")
99
+
100
+ if "qwen3" in model_type and "moe" not in model_type:
101
+ from mlx_lm import load as lm_load
102
+
103
+ model, tokenizer = lm_load(path)
104
+ else:
105
+ from mlx_vlm import load as vlm_load
106
+
107
+ model, processor = vlm_load(path)
108
+ tokenizer = getattr(processor, "tokenizer", processor)
109
+ return Target(model, tokenizer), tokenizer
110
+
111
+
112
+ def load_pair(family: str = "gemma4", *, target_bits: str | None = None):
113
+ """Convenience: load (target, tokenizer, drafter, cfg) for a preset family."""
114
+ p = PRESETS[family]
115
+ target, tok = load_target(p["target"])
116
+ drafter, cfg = load_drafter(p["drafter"])
117
+ return target, tok, drafter, cfg
mlx_dspark/model.py ADDED
@@ -0,0 +1,277 @@
1
+ """DSpark drafter in MLX — Gemma-4 and Qwen3 families.
2
+
3
+ Faithful port of the DeepSpec inference path. The EAGLE-style cross-attention is shared:
4
+ Q comes from the draft block, K/V from concat([fused_target_context, block]), with the
5
+ context K/V cached per layer (CtxCache). Family differences (norm layout, rope, MLP act,
6
+ v handling, logit softcap) are config-driven. Module attribute names match the HF
7
+ checkpoints so weights load 1:1.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import mlx.core as mx
13
+ import mlx.nn as nn
14
+
15
+ from mlx_vlm.models.gemma4.rope_utils import initialize_rope
16
+
17
+ from .config import DSparkConfig
18
+
19
+
20
+ class RMSNormNoScale(nn.Module):
21
+ """RMSNorm with no learnable weight (Gemma-4 v_norm)."""
22
+
23
+ def __init__(self, eps: float = 1e-6):
24
+ super().__init__()
25
+ self.eps = eps
26
+
27
+ def __call__(self, x: mx.array) -> mx.array:
28
+ return mx.fast.rms_norm(x, None, self.eps)
29
+
30
+
31
+ def _act(name: str):
32
+ return nn.silu if name == "silu" else nn.gelu_approx
33
+
34
+
35
+ class MLP(nn.Module):
36
+ def __init__(self, config: DSparkConfig):
37
+ super().__init__()
38
+ h, i = config.hidden_size, config.intermediate_size
39
+ self.gate_proj = nn.Linear(h, i, bias=False)
40
+ self.up_proj = nn.Linear(h, i, bias=False)
41
+ self.down_proj = nn.Linear(i, h, bias=False)
42
+ self.act = _act(config.mlp_activation)
43
+
44
+ def __call__(self, x: mx.array) -> mx.array:
45
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
46
+
47
+
48
+ def _repeat_kv(x: mx.array, n_rep: int) -> mx.array:
49
+ if n_rep == 1:
50
+ return x
51
+ b, n_kv, s, d = x.shape
52
+ x = mx.expand_dims(x, 2)
53
+ x = mx.broadcast_to(x, (b, n_kv, n_rep, s, d))
54
+ return x.reshape(b, n_kv * n_rep, s, d)
55
+
56
+
57
+ class CtxCache:
58
+ """Per-layer cache of the target context's projected K/V (roped K, normed/raw V)."""
59
+
60
+ __slots__ = ("k", "v")
61
+
62
+ def __init__(self):
63
+ self.k = None
64
+ self.v = None
65
+
66
+ def append(self, k: mx.array, v: mx.array) -> None:
67
+ if self.k is None:
68
+ self.k, self.v = k, v
69
+ else:
70
+ self.k = mx.concatenate([self.k, k], axis=2)
71
+ self.v = mx.concatenate([self.v, v], axis=2)
72
+
73
+
74
+ class DSparkAttention(nn.Module):
75
+ """Cross-attention: Q from the draft block, K/V from [target_context, block]."""
76
+
77
+ def __init__(self, config: DSparkConfig):
78
+ super().__init__()
79
+ self.n_heads = config.num_attention_heads
80
+ self.head_dim = config.attn_head_dim
81
+ self.k_eq_v = config.attention_k_eq_v
82
+ self.n_kv_heads = config.n_kv_heads
83
+ self.n_rep = self.n_heads // self.n_kv_heads
84
+ self.scale = config.scaling
85
+ self.use_v_norm = config.use_v_norm
86
+
87
+ h = config.hidden_size
88
+ b = config.attention_bias
89
+ self.q_proj = nn.Linear(h, self.n_heads * self.head_dim, bias=b)
90
+ self.k_proj = nn.Linear(h, self.n_kv_heads * self.head_dim, bias=b)
91
+ if not self.k_eq_v:
92
+ self.v_proj = nn.Linear(h, self.n_kv_heads * self.head_dim, bias=b)
93
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, h, bias=b)
94
+
95
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
96
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
97
+ if self.use_v_norm:
98
+ self.v_norm = RMSNormNoScale(eps=config.rms_norm_eps)
99
+
100
+ self.rope = initialize_rope(
101
+ dims=self.head_dim, base=config.rope_theta, traditional=False,
102
+ scaling_config=config.rope_parameters,
103
+ )
104
+
105
+ def _kv(self, x: mx.array):
106
+ """Project x -> (roped+normed K, V). k_eq_v shares k_proj for V."""
107
+ B, S, _ = x.shape
108
+ kp = self.k_proj(x).reshape(B, S, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
109
+ k = self.k_norm(kp)
110
+ if self.k_eq_v:
111
+ v = self.v_norm(kp)
112
+ else:
113
+ v = self.v_proj(x).reshape(B, S, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
114
+ if self.use_v_norm:
115
+ v = self.v_norm(v)
116
+ return k, v
117
+
118
+ def update_ctx(self, fused_new: mx.array, ctx_offset: int, cache: CtxCache) -> None:
119
+ k, v = self._kv(fused_new)
120
+ cache.append(self.rope(k, offset=ctx_offset), v) # V is not roped
121
+
122
+ def attend(self, hidden: mx.array, block_offset: int, cache: CtxCache) -> mx.array:
123
+ B, q_len, _ = hidden.shape
124
+ q = self.q_proj(hidden).reshape(B, q_len, self.n_heads, self.head_dim)
125
+ q = self.rope(self.q_norm(q).transpose(0, 2, 1, 3), offset=block_offset)
126
+
127
+ k_blk, v_blk = self._kv(hidden)
128
+ k_blk = self.rope(k_blk, offset=block_offset)
129
+ k = mx.concatenate([cache.k, k_blk], axis=2)
130
+ v = mx.concatenate([cache.v, v_blk], axis=2)
131
+
132
+ k = _repeat_kv(k, self.n_rep)
133
+ v = _repeat_kv(v, self.n_rep)
134
+ out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
135
+ out = out.transpose(0, 2, 1, 3).reshape(B, q_len, -1)
136
+ return self.o_proj(out)
137
+
138
+
139
+ class DSparkDecoderLayer(nn.Module):
140
+ def __init__(self, config: DSparkConfig):
141
+ super().__init__()
142
+ eps = config.rms_norm_eps
143
+ self.norm_style = config.norm_style
144
+ self.self_attn = DSparkAttention(config)
145
+ self.mlp = MLP(config)
146
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
147
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
148
+ if self.norm_style == "gemma":
149
+ self.pre_feedforward_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
150
+ self.post_feedforward_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
151
+ self.layer_scalar = mx.ones((1,))
152
+
153
+ def __call__(self, hidden, block_offset, cache: CtxCache):
154
+ if self.norm_style == "gemma":
155
+ residual = hidden
156
+ h = self.input_layernorm(hidden)
157
+ h = self.self_attn.attend(h, block_offset, cache)
158
+ h = self.post_attention_layernorm(h)
159
+ h = residual + h
160
+ residual = h
161
+ h = self.pre_feedforward_layernorm(h)
162
+ h = self.mlp(h)
163
+ h = self.post_feedforward_layernorm(h)
164
+ h = residual + h
165
+ return h * self.layer_scalar
166
+ # qwen / llama 2-norm
167
+ residual = hidden
168
+ h = self.input_layernorm(hidden)
169
+ h = self.self_attn.attend(h, block_offset, cache)
170
+ h = residual + h
171
+ residual = h
172
+ h = self.post_attention_layernorm(h)
173
+ h = self.mlp(h)
174
+ return residual + h
175
+
176
+
177
+ class VanillaMarkov(nn.Module):
178
+ """Rank-256 previous-token correction: logits += w2(w1[prev_token])."""
179
+
180
+ def __init__(self, config: DSparkConfig):
181
+ super().__init__()
182
+ self.markov_w1 = nn.Embedding(config.vocab_size, config.markov_rank)
183
+ self.markov_w2 = nn.Linear(config.markov_rank, config.vocab_size, bias=False)
184
+
185
+ def prev_embeddings(self, token_ids: mx.array) -> mx.array:
186
+ return self.markov_w1(token_ids)
187
+
188
+ def step_bias(self, token_ids: mx.array) -> mx.array:
189
+ return self.markov_w2(self.markov_w1(token_ids))
190
+
191
+
192
+ class ConfidenceHead(nn.Module):
193
+ def __init__(self, input_dim: int):
194
+ super().__init__()
195
+ self.proj = nn.Linear(input_dim, 1)
196
+
197
+ def __call__(self, features: mx.array) -> mx.array:
198
+ return self.proj(features).squeeze(-1)
199
+
200
+
201
+ class DSparkDrafter(nn.Module):
202
+ def __init__(self, config: DSparkConfig):
203
+ super().__init__()
204
+ self.config = config
205
+ self.block_size = config.block_size
206
+ self.mask_token_id = config.mask_token_id
207
+ self.embed_scale = (float(config.hidden_size) ** 0.5) if config.family == "gemma4" else 1.0
208
+ self.softcap = config.final_logit_softcapping
209
+
210
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
211
+ self.fc = nn.Linear(
212
+ len(config.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False
213
+ )
214
+ self.hidden_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
215
+ self.layers = [DSparkDecoderLayer(config) for _ in range(config.num_hidden_layers)]
216
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
218
+
219
+ self.markov_head = VanillaMarkov(config) if config.markov_rank > 0 else None
220
+ self.confidence_head = None
221
+ if config.enable_confidence_head:
222
+ in_dim = config.hidden_size + (config.markov_rank if config.confidence_head_with_markov else 0)
223
+ self.confidence_head = ConfidenceHead(in_dim)
224
+
225
+ def embed(self, ids: mx.array) -> mx.array:
226
+ return self.embed_tokens(ids) * self.embed_scale
227
+
228
+ def fuse_target(self, target_hidden_cat: mx.array) -> mx.array:
229
+ return self.hidden_norm(self.fc(target_hidden_cat))
230
+
231
+ def make_ctx_cache(self) -> list[CtxCache]:
232
+ return [CtxCache() for _ in self.layers]
233
+
234
+ def update_context(self, target_hidden_cat, ctx_offset, ctx_caches) -> None:
235
+ fused = self.fuse_target(target_hidden_cat)
236
+ for layer, cache in zip(self.layers, ctx_caches):
237
+ layer.self_attn.update_ctx(fused, ctx_offset, cache)
238
+
239
+ def backbone(self, noise_embedding, block_offset, ctx_caches) -> mx.array:
240
+ h = noise_embedding
241
+ for layer, cache in zip(self.layers, ctx_caches):
242
+ h = layer(h, block_offset, cache)
243
+ return self.norm(h)
244
+
245
+ def compute_logits(self, hidden: mx.array) -> mx.array:
246
+ logits = self.lm_head(hidden)
247
+ if self.softcap is not None:
248
+ logits = mx.tanh(logits / self.softcap) * self.softcap
249
+ return logits
250
+
251
+ def sample_block(self, base_logits: mx.array, first_prev_token: int) -> mx.array:
252
+ k = base_logits.shape[0]
253
+ if self.markov_head is None:
254
+ return mx.argmax(base_logits, axis=-1)
255
+ tokens = []
256
+ prev = mx.array([first_prev_token])
257
+ for i in range(k):
258
+ step = base_logits[i] + self.markov_head.step_bias(prev)[0]
259
+ nxt = mx.argmax(step, axis=-1, keepdims=True)
260
+ tokens.append(nxt)
261
+ prev = nxt
262
+ return mx.concatenate(tokens)
263
+
264
+ def confidence_logits(self, block_hidden, prev_token_ids):
265
+ if self.confidence_head is None:
266
+ return None
267
+ if self.config.confidence_head_with_markov:
268
+ feats = mx.concatenate(
269
+ [block_hidden, self.markov_head.prev_embeddings(prev_token_ids)], axis=-1
270
+ )
271
+ else:
272
+ feats = block_hidden
273
+ return self.confidence_head(feats)
274
+
275
+
276
+ # Backwards-compatible alias
277
+ Gemma4DSparkDrafter = DSparkDrafter
mlx_dspark/target.py ADDED
@@ -0,0 +1,64 @@
1
+ """Family-aware target wrapper: KV cache + a hidden-state tap at given layers.
2
+
3
+ - gemma4 (mlx-vlm): uses the built-in ``capture_layer_ids`` / ``hidden_sink`` hook.
4
+ - qwen3 (mlx-lm): no hook exists, so we replicate the model's forward loop and
5
+ capture the residual stream after the tapped layers.
6
+
7
+ Both expose: make_cache(), run(ids, cache, tap)->(logits, fused_hidden), and
8
+ plain(ids, cache)->logits (no capture, for the greedy baseline).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import mlx.core as mx
14
+
15
+
16
+ class Target:
17
+ def __init__(self, model, tokenizer):
18
+ self.model = model
19
+ self.tokenizer = tokenizer
20
+ # mlx-vlm models expose .language_model; mlx-lm models expose .model + (lm_head|tied)
21
+ self.is_vlm = hasattr(model, "language_model")
22
+ self.family = "gemma4" if self.is_vlm else "qwen3"
23
+ if not self.is_vlm:
24
+ self._tied = bool(getattr(getattr(model, "args", None), "tie_word_embeddings", False))
25
+
26
+ # -- cache --
27
+ def make_cache(self):
28
+ if self.is_vlm:
29
+ return self.model.language_model.make_cache()
30
+ from mlx_lm.models.cache import make_prompt_cache
31
+ return make_prompt_cache(self.model)
32
+
33
+ # -- forward with hidden-state tap --
34
+ def run(self, ids: mx.array, cache, tap: list[int]):
35
+ """ids [1,L] -> (logits [1,L,V], fused_hidden [1,L,n_tap*H])."""
36
+ if self.is_vlm:
37
+ out = self.model.language_model(inputs=ids, cache=cache, capture_layer_ids=tap)
38
+ return out.logits, mx.concatenate(out.hidden_states, axis=-1)
39
+ return self._run_mlxlm(ids, cache, tap)
40
+
41
+ def _run_mlxlm(self, ids, cache, tap):
42
+ from mlx_lm.models.base import create_attention_mask
43
+
44
+ mm = self.model.model
45
+ tapset = set(tap)
46
+ h = mm.embed_tokens(ids)
47
+ mask = create_attention_mask(h, cache[0])
48
+ captured = []
49
+ for i, (layer, c) in enumerate(zip(mm.layers, cache)):
50
+ h = layer(h, mask, c)
51
+ if i in tapset:
52
+ captured.append(h)
53
+ hn = mm.norm(h)
54
+ if self._tied:
55
+ logits = mm.embed_tokens.as_linear(hn)
56
+ else:
57
+ logits = self.model.lm_head(hn)
58
+ return logits, mx.concatenate(captured, axis=-1)
59
+
60
+ # -- plain forward (no capture) for the greedy baseline --
61
+ def plain(self, ids: mx.array, cache):
62
+ if self.is_vlm:
63
+ return self.model.language_model(inputs=ids, cache=cache).logits
64
+ return self.model(ids, cache=cache)
@@ -0,0 +1,115 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-dspark
3
+ Version: 0.0.1
4
+ Summary: DSpark speculative decoding (semi-autoregressive drafting) for Apple Silicon via MLX
5
+ Project-URL: Homepage, https://github.com/ARahim3/mlx-dspark
6
+ Project-URL: Repository, https://github.com/ARahim3/mlx-dspark
7
+ Project-URL: Issues, https://github.com/ARahim3/mlx-dspark/issues
8
+ Author-email: erahim3 <erahim3@gmail.com>
9
+ License: MIT
10
+ License-File: LICENSE
11
+ License-File: NOTICE
12
+ Keywords: apple-silicon,deepspec,dspark,llm-inference,mlx,speculative-decoding
13
+ Classifier: Environment :: GPU
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: MacOS
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.10
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: mlx-lm>=0.31.3
21
+ Requires-Dist: mlx-vlm>=0.6.3
22
+ Requires-Dist: mlx>=0.31.2
23
+ Requires-Dist: numpy
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest; extra == 'dev'
26
+ Requires-Dist: ruff; extra == 'dev'
27
+ Description-Content-Type: text/markdown
28
+
29
+ # mlx-dspark
30
+
31
+ **DSpark speculative decoding for Apple Silicon**, built on [MLX](https://github.com/ml-explore/mlx).
32
+
33
+ DSpark is DeepSeek's semi-autoregressive, EAGLE-family speculative-decoding drafter,
34
+ open-sourced in the [DeepSpec](https://github.com/deepseek-ai/DeepSpec) codebase and used to
35
+ accelerate DeepSeek-V4. This project ports the **inference path** to MLX so the published
36
+ drafter checkpoints run natively on a Mac.
37
+
38
+ **Supported families** (auto-detected from the drafter config):
39
+
40
+ | family | target | drafter | RAM |
41
+ |---|---|---|---|
42
+ | `gemma4` | `gemma-4-12B-it-8bit` | `deepseek-ai/dspark_gemma4_12b_block7` | ~32 GB+ |
43
+ | `qwen3` | `Qwen3-4B-8bit` | `deepseek-ai/dspark_qwen3_4b_block7` | ~16 GB |
44
+
45
+ ## How it works
46
+
47
+ - A **parallel backbone** (5 Gemma-4 layers) consumes the target model's hidden states
48
+ (tapped at layers `[5,17,29,41,46]`, EAGLE3-style) and proposes a 7-token block at once.
49
+ - A **rank-256 Markov head** adds a previous-token correction, sampled sequentially — the only
50
+ sequential cost, which kills "suffix decay" cheaply.
51
+ - A **confidence head** scores each draft position (optional adaptive block length).
52
+ - The target **verifies** every token, so output is **greedy-correct by construction**
53
+ (identical to plain greedy decoding, up to floating-point tie-breaking on near-ties).
54
+
55
+ The drafter is loaded 1:1 from the HF checkpoint and **quantized to 4-bit** by default
56
+ (~1.8 GB) so it's cheap to run every round.
57
+
58
+ ## Install
59
+
60
+ ```bash
61
+ uv venv --python 3.12
62
+ source .venv/bin/activate
63
+ uv pip install -e .
64
+ ```
65
+
66
+ ## Use
67
+
68
+ ```bash
69
+ # CLI — pick a family (downloads drafter + instruct target on first run)
70
+ python -m mlx_dspark --family qwen3 --prompt "Explain how rainbows form."
71
+ python -m mlx_dspark --family gemma4 --prompt "Explain how rainbows form." --max-new-tokens 256
72
+
73
+ # side-by-side demo: baseline (plain target) vs dspark (record each, stack)
74
+ python -m mlx_dspark --family qwen3 --mode baseline --prompt "..." --max-new-tokens 400
75
+ python -m mlx_dspark --family qwen3 --mode dspark --prompt "..." --max-new-tokens 400
76
+ ```
77
+
78
+ ```python
79
+ from mlx_dspark import load_pair, speculative_generate
80
+
81
+ target, tok, drafter, cfg = load_pair("qwen3") # or "gemma4"
82
+ res = speculative_generate(target, tok, drafter, "Explain how rainbows form.")
83
+ print(res.text, res.mean_accept_len, res.tokens_per_sec)
84
+ ```
85
+
86
+ ## Results (M4 Pro, 48 GB; 8-bit instruct target, 4-bit drafter; warm — `python benchmark.py`)
87
+
88
+ | family | drafter `d_0` | accept len | greedy (baseline) | dspark (this project) | speedup |
89
+ |---|---|---|---|---|---|
90
+ | **Gemma-4 12B** | ~82% | 2.5–3.6 | ~17 tok/s | ~28 tok/s | **~1.5–1.6×** |
91
+ | **Qwen3-4B** | ~85% | 2.1–2.8 | ~49 tok/s | ~66 tok/s | **~1.3–1.4×** |
92
+
93
+ "greedy" = the plain target model decoding one token per forward (no drafter); "dspark" =
94
+ speculative decoding with the DSpark drafter. Both produce **identical** output — DSpark is just
95
+ faster (it diverges from sequential greedy only at logit-margin≈0 ties). Smaller/faster targets (Qwen3-4B) have a lower
96
+ per-token verify cost, so the optimal `--max-draft` is smaller (~2).
97
+
98
+ ### Target choice matters
99
+
100
+ The drafter is trained against a specific **instruct** model in **bf16** — use the matching
101
+ instruct target (`gemma-4-12B-it` / `Qwen3-4B`), not the base model (base gave `d_0` ~47% vs
102
+ 82%). Higher precision raises acceptance; 8-bit is the sweet spot. `-bf16` maximizes acceptance;
103
+ 4-bit verifies faster.
104
+
105
+ ### Tuning
106
+
107
+ On Apple Silicon the target verify cost grows with the number of tokens verified, so the
108
+ optimum is to verify ~= the acceptance length: `--max-draft 4` (default). `--max-draft <block>`
109
+ (full 7) is faithful but *slower* on M-series. `--confidence-threshold 0.6` instead truncates
110
+ the block adaptively via the drafter's confidence head.
111
+
112
+ ## License
113
+
114
+ MIT — see [`LICENSE`](LICENSE). This is an independent MLX port of the inference path of
115
+ DeepSeek's DSpark drafter; see [`NOTICE`](NOTICE) for attribution. No model weights are bundled.
@@ -0,0 +1,12 @@
1
+ mlx_dspark/__init__.py,sha256=fclV1VbqP6UzILCu35xa6hTDYV_BLzeH9Fr_32hbDwk,1156
2
+ mlx_dspark/__main__.py,sha256=YEfk3teljcUJrc17eMMz7WIILwGFvnEB4mNGmQqnKpQ,3417
3
+ mlx_dspark/config.py,sha256=Ey-tRJ3w6rAXSbe6LhuVHiyLdgbqNvZDGW8blq7WKS8,6257
4
+ mlx_dspark/generate.py,sha256=sOdrRg2GA8so3JH9MBKAMaJ3q6BYuop7lxTmdTs7xkc,10146
5
+ mlx_dspark/load.py,sha256=xLr0QLCFzQa7Qw0LkubT0Jpz0Mt58bvMtjdRBQxBAeQ,3946
6
+ mlx_dspark/model.py,sha256=mnI6Y_ZWZbeohF4FGsLRcvPFamUoy16_0gVFlIkDZew,10607
7
+ mlx_dspark/target.py,sha256=8SlmnMM4UO1qL36HZwaqGfniIEDoupU1T5G_wcKEpXw,2494
8
+ mlx_dspark-0.0.1.dist-info/METADATA,sha256=Wea6DUzWjUOBKFkwXY7D31Ss6TqYBUam6mXXIXfUxZc,5047
9
+ mlx_dspark-0.0.1.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
10
+ mlx_dspark-0.0.1.dist-info/licenses/LICENSE,sha256=7fTQ89BHhdRKOLqXpVQFZxTnwQUUtg8EeamGIrOz7eU,1064
11
+ mlx_dspark-0.0.1.dist-info/licenses/NOTICE,sha256=4sSSawyNMWFGbhUevnc0ZKPgOS3auWOdm77K_BqLRcs,908
12
+ mlx_dspark-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 erahim3
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,22 @@
1
+ mlx-dspark
2
+ ==========
3
+
4
+ This project is an independent MLX/Apple-Silicon port of the *inference path* of
5
+ DSpark, a speculative-decoding drafter open-sourced by DeepSeek as part of the
6
+ DeepSpec codebase:
7
+
8
+ - DeepSpec: https://github.com/deepseek-ai/DeepSpec (MIT License)
9
+
10
+ It loads the published DSpark drafter checkpoints (also released by DeepSeek under
11
+ their respective licenses):
12
+
13
+ - deepseek-ai/dspark_gemma4_12b_block7
14
+ - deepseek-ai/dspark_qwen3_4b_block7
15
+
16
+ The drafter architecture, training, and checkpoints are the work of DeepSeek and
17
+ the DeepSpec authors. This repository reimplements only the forward/verification
18
+ path for MLX and contains no DeepSpec source code. Target models (Gemma-4, Qwen3)
19
+ are downloaded at runtime from their respective publishers and are subject to
20
+ their own licenses (e.g. the Gemma Terms of Use and the Qwen license).
21
+
22
+ No model weights are bundled with this package.