mlx-dspark 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_dspark/__init__.py +42 -0
- mlx_dspark/__main__.py +84 -0
- mlx_dspark/config.py +145 -0
- mlx_dspark/generate.py +289 -0
- mlx_dspark/load.py +117 -0
- mlx_dspark/model.py +277 -0
- mlx_dspark/target.py +64 -0
- mlx_dspark-0.0.1.dist-info/METADATA +115 -0
- mlx_dspark-0.0.1.dist-info/RECORD +12 -0
- mlx_dspark-0.0.1.dist-info/WHEEL +4 -0
- mlx_dspark-0.0.1.dist-info/licenses/LICENSE +21 -0
- mlx_dspark-0.0.1.dist-info/licenses/NOTICE +22 -0
mlx_dspark/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""mlx-dspark: DSpark speculative decoding for Apple Silicon (MLX).
|
|
2
|
+
|
|
3
|
+
DSpark (from DeepSeek's DeepSpec codebase) is a semi-autoregressive,
|
|
4
|
+
EAGLE-family speculative-decoding drafter:
|
|
5
|
+
|
|
6
|
+
- a *parallel backbone* proposes base logits for all K draft positions at once,
|
|
7
|
+
- a tiny *sequential head* (low-rank, previous-token-conditioned) corrects
|
|
8
|
+
suffix decay,
|
|
9
|
+
- a *confidence head* scores how likely each drafted token survives
|
|
10
|
+
verification (used here for adaptive draft length instead of the
|
|
11
|
+
server-side load-aware scheduler, which is irrelevant single-user).
|
|
12
|
+
|
|
13
|
+
This package targets single-user local inference on Apple Silicon.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
__version__ = "0.0.1"
|
|
17
|
+
|
|
18
|
+
from .config import DSparkConfig
|
|
19
|
+
from .load import (
|
|
20
|
+
DEFAULT_DRAFTER,
|
|
21
|
+
DEFAULT_TARGET,
|
|
22
|
+
PRESETS,
|
|
23
|
+
load_drafter,
|
|
24
|
+
load_pair,
|
|
25
|
+
load_target,
|
|
26
|
+
)
|
|
27
|
+
from .generate import GenResult, greedy_generate, speculative_generate
|
|
28
|
+
from .target import Target
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"DSparkConfig",
|
|
32
|
+
"Target",
|
|
33
|
+
"load_drafter",
|
|
34
|
+
"load_target",
|
|
35
|
+
"load_pair",
|
|
36
|
+
"speculative_generate",
|
|
37
|
+
"greedy_generate",
|
|
38
|
+
"GenResult",
|
|
39
|
+
"PRESETS",
|
|
40
|
+
"DEFAULT_TARGET",
|
|
41
|
+
"DEFAULT_DRAFTER",
|
|
42
|
+
]
|
mlx_dspark/__main__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""CLI: run DSpark speculative decoding on Apple Silicon (streams tokens live).
|
|
2
|
+
|
|
3
|
+
Side-by-side demo (record each, then stack the two screen captures):
|
|
4
|
+
|
|
5
|
+
# left panel — plain target, no drafter
|
|
6
|
+
python -m mlx_dspark --mode baseline --prompt "Explain how rainbows form." --max-new-tokens 220
|
|
7
|
+
|
|
8
|
+
# right panel — DSpark speculative decoding (same prompt, same output, faster)
|
|
9
|
+
python -m mlx_dspark --mode dspark --prompt "Explain how rainbows form." --max-new-tokens 220
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
from .generate import greedy_generate, speculative_generate
|
|
19
|
+
from .load import PRESETS, load_drafter, load_target
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _emit(s: str) -> None:
|
|
23
|
+
sys.stdout.write(s)
|
|
24
|
+
sys.stdout.flush()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main() -> None:
|
|
28
|
+
ap = argparse.ArgumentParser(prog="mlx_dspark")
|
|
29
|
+
ap.add_argument("--mode", choices=["dspark", "baseline"], default="dspark",
|
|
30
|
+
help="dspark = speculative decoding; baseline = plain greedy target")
|
|
31
|
+
ap.add_argument("--family", choices=["gemma4", "qwen3"], default="gemma4",
|
|
32
|
+
help="model preset (target + drafter); overridden by --target/--drafter")
|
|
33
|
+
ap.add_argument("--prompt", default="Explain how rainbows form, in a few sentences.")
|
|
34
|
+
ap.add_argument("--target", default=None)
|
|
35
|
+
ap.add_argument("--drafter", default=None)
|
|
36
|
+
ap.add_argument("--max-new-tokens", type=int, default=220)
|
|
37
|
+
ap.add_argument("--max-draft", type=int, default=4)
|
|
38
|
+
ap.add_argument("--confidence-threshold", type=float, default=0.0)
|
|
39
|
+
ap.add_argument("--drafter-bits", type=int, default=4)
|
|
40
|
+
ap.add_argument("--no-chat-template", action="store_true")
|
|
41
|
+
ap.add_argument("--no-stream", action="store_true")
|
|
42
|
+
args = ap.parse_args()
|
|
43
|
+
target_repo = args.target or PRESETS[args.family]["target"]
|
|
44
|
+
drafter_repo = args.drafter or PRESETS[args.family]["drafter"]
|
|
45
|
+
|
|
46
|
+
label = "DSpark speculative" if args.mode == "dspark" else "Baseline (plain greedy)"
|
|
47
|
+
print(f"loading {args.mode}: target={target_repo}"
|
|
48
|
+
+ (f", drafter={drafter_repo}" if args.mode == "dspark" else ""))
|
|
49
|
+
target, tok = load_target(target_repo)
|
|
50
|
+
drafter = None
|
|
51
|
+
if args.mode == "dspark":
|
|
52
|
+
drafter, _ = load_drafter(drafter_repo, quantize=args.drafter_bits > 0,
|
|
53
|
+
bits=max(args.drafter_bits, 2))
|
|
54
|
+
|
|
55
|
+
on_text = None if args.no_stream else _emit
|
|
56
|
+
print("\n" + "=" * 64)
|
|
57
|
+
print(f" ▶ {label} · {target_repo.split('/')[-1]}")
|
|
58
|
+
print("=" * 64)
|
|
59
|
+
|
|
60
|
+
if args.mode == "dspark":
|
|
61
|
+
res = speculative_generate(
|
|
62
|
+
target, tok, drafter, args.prompt,
|
|
63
|
+
max_new_tokens=args.max_new_tokens, max_draft_tokens=args.max_draft,
|
|
64
|
+
confidence_threshold=args.confidence_threshold,
|
|
65
|
+
apply_chat_template=not args.no_chat_template, on_text=on_text,
|
|
66
|
+
)
|
|
67
|
+
extra = f" · accept {res.mean_accept_len:.2f}/round · {res.target_forwards} target fwds"
|
|
68
|
+
else:
|
|
69
|
+
res = greedy_generate(
|
|
70
|
+
target, tok, args.prompt, max_new_tokens=args.max_new_tokens,
|
|
71
|
+
apply_chat_template=not args.no_chat_template, on_text=on_text,
|
|
72
|
+
)
|
|
73
|
+
extra = ""
|
|
74
|
+
if args.no_stream:
|
|
75
|
+
print(res.text)
|
|
76
|
+
|
|
77
|
+
print("\n" + "-" * 64)
|
|
78
|
+
print(f" {res.num_tokens} tokens · {res.seconds:.2f}s · "
|
|
79
|
+
f"\033[1m{res.tokens_per_sec:.1f} tok/s\033[0m{extra}")
|
|
80
|
+
print("-" * 64)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
main()
|
mlx_dspark/config.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""DSpark drafter config — loaded from the HF checkpoint's config.json.
|
|
2
|
+
|
|
3
|
+
Supports two drafter families with a shared inference path:
|
|
4
|
+
- gemma4 (gemma4_text): k_eq_v attention, v_norm, partial/proportional rope,
|
|
5
|
+
sandwich norms + layer_scalar, gelu-tanh MLP, logit softcap.
|
|
6
|
+
- qwen3 (qwen3): standard GQA (separate v_proj, no v_norm), default rope,
|
|
7
|
+
Llama-style 2-norm layer, silu MLP, no softcap.
|
|
8
|
+
Only the fields the MLX inference path needs are pulled out.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class DSparkConfig:
|
|
20
|
+
family: str = "gemma4" # "gemma4" | "qwen3"
|
|
21
|
+
|
|
22
|
+
# core dims
|
|
23
|
+
hidden_size: int = 3840
|
|
24
|
+
vocab_size: int = 262144
|
|
25
|
+
num_hidden_layers: int = 5
|
|
26
|
+
intermediate_size: int = 15360
|
|
27
|
+
rms_norm_eps: float = 1e-6
|
|
28
|
+
|
|
29
|
+
# attention
|
|
30
|
+
num_attention_heads: int = 16
|
|
31
|
+
num_key_value_heads: int = 8
|
|
32
|
+
num_global_key_value_heads: int = 1
|
|
33
|
+
head_dim: int = 256
|
|
34
|
+
global_head_dim: int = 512
|
|
35
|
+
attention_k_eq_v: bool = True
|
|
36
|
+
attention_bias: bool = False
|
|
37
|
+
|
|
38
|
+
# rope
|
|
39
|
+
rope_theta: float = 1_000_000.0
|
|
40
|
+
partial_rotary_factor: float = 0.25
|
|
41
|
+
rope_type: str = "proportional"
|
|
42
|
+
|
|
43
|
+
# dspark specifics
|
|
44
|
+
block_size: int = 7
|
|
45
|
+
mask_token_id: int = 4
|
|
46
|
+
target_layer_ids: list[int] = field(default_factory=lambda: [5, 17, 29, 41, 46])
|
|
47
|
+
num_target_layers: int = 48
|
|
48
|
+
|
|
49
|
+
# markov + confidence
|
|
50
|
+
markov_rank: int = 256
|
|
51
|
+
markov_head_type: str = "vanilla"
|
|
52
|
+
enable_confidence_head: bool = True
|
|
53
|
+
confidence_head_with_markov: bool = True
|
|
54
|
+
|
|
55
|
+
# logits
|
|
56
|
+
final_logit_softcapping: float | None = 30.0
|
|
57
|
+
pad_token_id: int = 0
|
|
58
|
+
|
|
59
|
+
# ---- family-derived knobs (set in from_json) ----
|
|
60
|
+
mlp_activation: str = "gelu_tanh" # "gelu_tanh" | "silu"
|
|
61
|
+
norm_style: str = "gemma" # "gemma" (sandwich+scalar) | "qwen" (llama 2-norm)
|
|
62
|
+
use_v_norm: bool = True # gemma: RMSNormNoScale v_norm; qwen: none
|
|
63
|
+
attention_scaling: float | None = None # None -> 1/sqrt(attn_head_dim)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def attn_head_dim(self) -> int:
|
|
67
|
+
"""Head dim used by the drafter's own attention."""
|
|
68
|
+
return self.global_head_dim if self.family == "gemma4" else self.head_dim
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def n_kv_heads(self) -> int:
|
|
72
|
+
if self.family == "gemma4" and self.attention_k_eq_v:
|
|
73
|
+
return self.num_global_key_value_heads
|
|
74
|
+
return self.num_key_value_heads
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def scaling(self) -> float:
|
|
78
|
+
if self.attention_scaling is not None:
|
|
79
|
+
return self.attention_scaling
|
|
80
|
+
return self.attn_head_dim ** -0.5 if self.family == "qwen3" else 1.0
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def rope_parameters(self) -> dict:
|
|
84
|
+
return {"rope_type": self.rope_type, "partial_rotary_factor": self.partial_rotary_factor}
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_json(cls, path: str | Path) -> "DSparkConfig":
|
|
88
|
+
with open(path) as f:
|
|
89
|
+
c = json.load(f)
|
|
90
|
+
mt = c.get("model_type", "")
|
|
91
|
+
family = "qwen3" if "qwen3" in mt else "gemma4"
|
|
92
|
+
|
|
93
|
+
if family == "qwen3":
|
|
94
|
+
rp = c.get("rope_parameters") or {}
|
|
95
|
+
return cls(
|
|
96
|
+
family="qwen3",
|
|
97
|
+
hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
|
|
98
|
+
num_hidden_layers=c["num_hidden_layers"],
|
|
99
|
+
intermediate_size=c["intermediate_size"],
|
|
100
|
+
rms_norm_eps=c.get("rms_norm_eps", 1e-6),
|
|
101
|
+
num_attention_heads=c["num_attention_heads"],
|
|
102
|
+
num_key_value_heads=c.get("num_key_value_heads", 8),
|
|
103
|
+
head_dim=c.get("head_dim", c["hidden_size"] // c["num_attention_heads"]),
|
|
104
|
+
attention_k_eq_v=False, attention_bias=c.get("attention_bias", False),
|
|
105
|
+
rope_theta=rp.get("rope_theta", c.get("rope_theta", 1_000_000.0)),
|
|
106
|
+
rope_type="default",
|
|
107
|
+
block_size=c["block_size"], mask_token_id=c["mask_token_id"],
|
|
108
|
+
target_layer_ids=list(c["target_layer_ids"]),
|
|
109
|
+
num_target_layers=c.get("num_target_layers", 36),
|
|
110
|
+
markov_rank=c.get("markov_rank", 256),
|
|
111
|
+
markov_head_type=c.get("markov_head_type", "vanilla"),
|
|
112
|
+
enable_confidence_head=c.get("enable_confidence_head", True),
|
|
113
|
+
confidence_head_with_markov=c.get("confidence_head_with_markov", True),
|
|
114
|
+
final_logit_softcapping=c.get("final_logit_softcapping", None),
|
|
115
|
+
pad_token_id=c.get("pad_token_id") or 0,
|
|
116
|
+
mlp_activation="silu", norm_style="qwen", use_v_norm=False,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
rope = (c.get("rope_parameters") or {}).get("full_attention", {}) or {}
|
|
120
|
+
return cls(
|
|
121
|
+
family="gemma4",
|
|
122
|
+
hidden_size=c["hidden_size"], vocab_size=c["vocab_size"],
|
|
123
|
+
num_hidden_layers=c["num_hidden_layers"],
|
|
124
|
+
intermediate_size=c["intermediate_size"],
|
|
125
|
+
rms_norm_eps=c.get("rms_norm_eps", 1e-6),
|
|
126
|
+
num_attention_heads=c["num_attention_heads"],
|
|
127
|
+
num_key_value_heads=c.get("num_key_value_heads", 8),
|
|
128
|
+
num_global_key_value_heads=c.get("num_global_key_value_heads", 1),
|
|
129
|
+
head_dim=c.get("head_dim", 256), global_head_dim=c.get("global_head_dim", 512),
|
|
130
|
+
attention_k_eq_v=c.get("attention_k_eq_v", True),
|
|
131
|
+
attention_bias=c.get("attention_bias", False),
|
|
132
|
+
rope_theta=rope.get("rope_theta", 1_000_000.0),
|
|
133
|
+
partial_rotary_factor=rope.get("partial_rotary_factor", 0.25),
|
|
134
|
+
rope_type=rope.get("rope_type", "proportional"),
|
|
135
|
+
block_size=c["block_size"], mask_token_id=c["mask_token_id"],
|
|
136
|
+
target_layer_ids=list(c["target_layer_ids"]),
|
|
137
|
+
num_target_layers=c.get("num_target_layers", 48),
|
|
138
|
+
markov_rank=c.get("markov_rank", 256),
|
|
139
|
+
markov_head_type=c.get("markov_head_type", "vanilla"),
|
|
140
|
+
enable_confidence_head=c.get("enable_confidence_head", True),
|
|
141
|
+
confidence_head_with_markov=c.get("confidence_head_with_markov", True),
|
|
142
|
+
final_logit_softcapping=c.get("final_logit_softcapping", 30.0),
|
|
143
|
+
pad_token_id=c.get("pad_token_id", 0),
|
|
144
|
+
mlp_activation="gelu_tanh", norm_style="gemma", use_v_norm=True,
|
|
145
|
+
)
|
mlx_dspark/generate.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""DSpark speculative decoding loop (greedy, batch=1) for Apple Silicon.
|
|
2
|
+
|
|
3
|
+
Per round:
|
|
4
|
+
1. draft a block of K tokens from the parallel backbone + Markov head,
|
|
5
|
+
2. verify them in one target forward,
|
|
6
|
+
3. accept the matching prefix + 1 bonus token (so >=1 token/round always),
|
|
7
|
+
4. trim the target KV cache and grow the fused-hidden context buffer.
|
|
8
|
+
|
|
9
|
+
Because the target verifies every token, the *output is exactly greedy target
|
|
10
|
+
decoding* regardless of drafter quality — drafter quality only shows up as the
|
|
11
|
+
acceptance length (tokens committed per target forward).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
|
|
19
|
+
import mlx.core as mx
|
|
20
|
+
|
|
21
|
+
TAP = None # set from drafter config at call time
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class GenResult:
|
|
26
|
+
text: str
|
|
27
|
+
token_ids: list[int]
|
|
28
|
+
num_tokens: int
|
|
29
|
+
num_rounds: int
|
|
30
|
+
accept_lengths: list[int]
|
|
31
|
+
target_forwards: int
|
|
32
|
+
seconds: float
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def mean_accept_len(self) -> float:
|
|
36
|
+
return self.num_tokens / max(self.num_rounds, 1)
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def tokens_per_sec(self) -> float:
|
|
40
|
+
return self.num_tokens / max(self.seconds, 1e-9)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def encode_prompt(tokenizer, prompt: str, use_chat: bool = True) -> list[int]:
|
|
44
|
+
"""Token ids for a user prompt, using the model's chat template when present.
|
|
45
|
+
|
|
46
|
+
Gemma-4 uses `<|turn>` / `<channel|>` markers (NOT Gemma-3's `<start_of_turn>`),
|
|
47
|
+
so the template must be applied via the tokenizer — hand-formatting breaks the
|
|
48
|
+
instruct model. apply_chat_template may return list[int] or a BatchEncoding.
|
|
49
|
+
"""
|
|
50
|
+
if use_chat and getattr(tokenizer, "chat_template", None):
|
|
51
|
+
r = tokenizer.apply_chat_template(
|
|
52
|
+
[{"role": "user", "content": prompt}], add_generation_prompt=True
|
|
53
|
+
)
|
|
54
|
+
if isinstance(r, (list, tuple)):
|
|
55
|
+
if r and isinstance(r[0], int):
|
|
56
|
+
return list(r)
|
|
57
|
+
if r and isinstance(r[0], (list, tuple)):
|
|
58
|
+
return list(r[0])
|
|
59
|
+
ii = None
|
|
60
|
+
if hasattr(r, "__contains__") and "input_ids" in r:
|
|
61
|
+
ii = r["input_ids"]
|
|
62
|
+
elif hasattr(r, "input_ids"):
|
|
63
|
+
ii = r.input_ids
|
|
64
|
+
if ii is not None:
|
|
65
|
+
ii = list(ii)
|
|
66
|
+
return list(ii[0]) if ii and isinstance(ii[0], (list, tuple)) else ii
|
|
67
|
+
if hasattr(r, "ids"):
|
|
68
|
+
return list(r.ids)
|
|
69
|
+
return list(tokenizer.encode(prompt))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def eos_token_ids(tokenizer) -> set[int]:
|
|
73
|
+
"""Collect stop-token ids: eos + Gemma turn-end markers (Gemma-4 uses <turn|>=106;
|
|
74
|
+
note <end_of_turn> is the UNK id in Gemma-4, so it must be filtered out)."""
|
|
75
|
+
ids: set[int] = set()
|
|
76
|
+
e = getattr(tokenizer, "eos_token_ids", None)
|
|
77
|
+
if isinstance(e, int):
|
|
78
|
+
ids.add(e)
|
|
79
|
+
elif e:
|
|
80
|
+
ids.update(int(x) for x in e)
|
|
81
|
+
e1 = getattr(tokenizer, "eos_token_id", None)
|
|
82
|
+
if isinstance(e1, int):
|
|
83
|
+
ids.add(e1)
|
|
84
|
+
unk = getattr(tokenizer, "unk_token_id", None)
|
|
85
|
+
# Gemma-4 (<turn|>), Gemma-3 (<end_of_turn>), Qwen (<|im_end|>), raw eos
|
|
86
|
+
for t in ("<turn|>", "<end_of_turn>", "<|im_end|>", "<|endoftext|>", "<eos>"):
|
|
87
|
+
try:
|
|
88
|
+
i = tokenizer.convert_tokens_to_ids(t)
|
|
89
|
+
except Exception:
|
|
90
|
+
continue
|
|
91
|
+
if isinstance(i, int) and i >= 0 and i != unk:
|
|
92
|
+
ids.add(i)
|
|
93
|
+
return ids
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def greedy_generate(
|
|
97
|
+
target_model,
|
|
98
|
+
tokenizer,
|
|
99
|
+
prompt: str,
|
|
100
|
+
*,
|
|
101
|
+
max_new_tokens: int = 128,
|
|
102
|
+
apply_chat_template: bool = True,
|
|
103
|
+
on_text=None,
|
|
104
|
+
) -> GenResult:
|
|
105
|
+
"""Plain greedy decoding of the target (no drafter, no hidden-state capture) —
|
|
106
|
+
the fair 'run the model normally' baseline. Streams via on_text."""
|
|
107
|
+
eos_ids = eos_token_ids(tokenizer)
|
|
108
|
+
ids = encode_prompt(tokenizer, prompt, use_chat=apply_chat_template)
|
|
109
|
+
cache = target_model.make_cache()
|
|
110
|
+
|
|
111
|
+
t0 = time.time()
|
|
112
|
+
logits = target_model.plain(mx.array([ids]), cache)
|
|
113
|
+
nxt = int(mx.argmax(logits[0, -1]).item())
|
|
114
|
+
out_ids = [nxt]
|
|
115
|
+
streamed = 0
|
|
116
|
+
|
|
117
|
+
def _stream():
|
|
118
|
+
nonlocal streamed
|
|
119
|
+
if on_text is None:
|
|
120
|
+
return
|
|
121
|
+
disp = [t for t in out_ids if t not in eos_ids]
|
|
122
|
+
full = tokenizer.decode(disp)
|
|
123
|
+
if len(full) > streamed:
|
|
124
|
+
on_text(full[streamed:])
|
|
125
|
+
streamed = len(full)
|
|
126
|
+
|
|
127
|
+
_stream()
|
|
128
|
+
while len(out_ids) < max_new_tokens and nxt not in eos_ids:
|
|
129
|
+
logits = target_model.plain(mx.array([[nxt]]), cache)
|
|
130
|
+
nxt = int(mx.argmax(logits[0, -1]).item())
|
|
131
|
+
out_ids.append(nxt)
|
|
132
|
+
_stream()
|
|
133
|
+
|
|
134
|
+
secs = time.time() - t0
|
|
135
|
+
disp = [t for t in out_ids if t not in eos_ids]
|
|
136
|
+
return GenResult(
|
|
137
|
+
text=tokenizer.decode(disp),
|
|
138
|
+
token_ids=out_ids,
|
|
139
|
+
num_tokens=len(out_ids),
|
|
140
|
+
num_rounds=len(out_ids),
|
|
141
|
+
accept_lengths=[1] * len(out_ids),
|
|
142
|
+
target_forwards=len(out_ids),
|
|
143
|
+
seconds=secs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _make_target_cache(target):
|
|
148
|
+
return target.make_cache()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _run_target(target, ids: mx.array, cache, tap: list[int]):
|
|
152
|
+
"""ids: [1, L]. Returns (logits[1,L,V], fused_hidden[1,L,n_tap*H])."""
|
|
153
|
+
return target.run(ids, cache, tap)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def speculative_generate(
|
|
157
|
+
target_model,
|
|
158
|
+
tokenizer,
|
|
159
|
+
drafter,
|
|
160
|
+
prompt: str,
|
|
161
|
+
*,
|
|
162
|
+
max_new_tokens: int = 128,
|
|
163
|
+
confidence_threshold: float = 0.0,
|
|
164
|
+
max_draft_tokens: int | None = 4,
|
|
165
|
+
apply_chat_template: bool = True,
|
|
166
|
+
on_text=None,
|
|
167
|
+
verbose: bool = False,
|
|
168
|
+
) -> GenResult:
|
|
169
|
+
"""Greedy speculative decoding. Output is target-greedy by construction (up to
|
|
170
|
+
fp tie-breaking on near-ties). ``max_draft_tokens`` caps how many of the 7-token
|
|
171
|
+
block are verified per round; on Apple Silicon the target verify cost grows with
|
|
172
|
+
tokens, so the optimum is ~= acceptance length (default 4). ``None`` = full block
|
|
173
|
+
(faithful but slower on M-series). ``confidence_threshold`` > 0 instead truncates
|
|
174
|
+
the block adaptively using the drafter's confidence head."""
|
|
175
|
+
cfg = drafter.config
|
|
176
|
+
tap = list(cfg.target_layer_ids)
|
|
177
|
+
k = cfg.block_size
|
|
178
|
+
mask_id = cfg.mask_token_id
|
|
179
|
+
cap = k if max_draft_tokens is None else max(1, min(max_draft_tokens, k))
|
|
180
|
+
|
|
181
|
+
eos_ids = eos_token_ids(tokenizer)
|
|
182
|
+
|
|
183
|
+
# --- tokenize prompt ---
|
|
184
|
+
ids = encode_prompt(tokenizer, prompt, use_chat=apply_chat_template)
|
|
185
|
+
prompt_ids = mx.array([ids])
|
|
186
|
+
|
|
187
|
+
cache = _make_target_cache(target_model)
|
|
188
|
+
ctx_caches = drafter.make_ctx_cache()
|
|
189
|
+
t0 = time.time()
|
|
190
|
+
|
|
191
|
+
# --- prefill ---
|
|
192
|
+
logits, fused = _run_target(target_model, prompt_ids, cache, tap)
|
|
193
|
+
n_cached = prompt_ids.shape[1]
|
|
194
|
+
drafter.update_context(fused, ctx_offset=0, ctx_caches=ctx_caches)
|
|
195
|
+
pending = int(mx.argmax(logits[0, -1]).item()) # first committed token
|
|
196
|
+
mx.eval([c.k for c in ctx_caches])
|
|
197
|
+
|
|
198
|
+
out_ids: list[int] = [pending]
|
|
199
|
+
accept_lengths: list[int] = []
|
|
200
|
+
target_forwards = 1
|
|
201
|
+
streamed = 0
|
|
202
|
+
|
|
203
|
+
def _stream():
|
|
204
|
+
nonlocal streamed
|
|
205
|
+
if on_text is None:
|
|
206
|
+
return
|
|
207
|
+
disp = [t for t in out_ids if t not in eos_ids]
|
|
208
|
+
full = tokenizer.decode(disp)
|
|
209
|
+
if len(full) > streamed:
|
|
210
|
+
on_text(full[streamed:])
|
|
211
|
+
streamed = len(full)
|
|
212
|
+
|
|
213
|
+
_stream()
|
|
214
|
+
while len(out_ids) < max_new_tokens and pending not in eos_ids:
|
|
215
|
+
# ---- 1. draft a block ----
|
|
216
|
+
block_ids = [pending] + [mask_id] * (k - 1)
|
|
217
|
+
noise = drafter.embed(mx.array([block_ids])) # [1, k, H]
|
|
218
|
+
block_hidden = drafter.backbone(noise, n_cached, ctx_caches)
|
|
219
|
+
base_logits = drafter.compute_logits(block_hidden)[0] # [k, V]
|
|
220
|
+
draft = drafter.sample_block(base_logits, first_prev_token=pending)
|
|
221
|
+
mx.eval(draft)
|
|
222
|
+
draft = [int(x) for x in draft.tolist()]
|
|
223
|
+
|
|
224
|
+
# optional confidence-based truncation (adaptive block length)
|
|
225
|
+
if confidence_threshold > 0.0 and drafter.confidence_head is not None:
|
|
226
|
+
prev_tokens = mx.array([pending] + draft[:-1])
|
|
227
|
+
conf = mx.sigmoid(drafter.confidence_logits(block_hidden[0], prev_tokens))
|
|
228
|
+
mx.eval(conf)
|
|
229
|
+
below = [i for i, c in enumerate(conf.tolist()) if c < confidence_threshold]
|
|
230
|
+
if below:
|
|
231
|
+
draft = draft[: below[0]]
|
|
232
|
+
if cap < len(draft):
|
|
233
|
+
draft = draft[:cap]
|
|
234
|
+
if not draft:
|
|
235
|
+
draft = [int(mx.argmax(base_logits[0]).item())] # always propose >=1
|
|
236
|
+
|
|
237
|
+
# ---- 2. verify with the target ----
|
|
238
|
+
verify_ids = mx.array([[pending] + draft]) # [1, 1+len(draft)]
|
|
239
|
+
v_logits, v_fused = _run_target(target_model, verify_ids, cache, tap)
|
|
240
|
+
mx.eval(v_logits, v_fused)
|
|
241
|
+
target_forwards += 1
|
|
242
|
+
tt = mx.argmax(v_logits[0], axis=-1) # [1+len(draft)]
|
|
243
|
+
tt = [int(x) for x in tt.tolist()]
|
|
244
|
+
|
|
245
|
+
# ---- 3. accept matching prefix + bonus ----
|
|
246
|
+
n = 0
|
|
247
|
+
while n < len(draft) and draft[n] == tt[n]:
|
|
248
|
+
n += 1
|
|
249
|
+
bonus = tt[n] # correction / continuation
|
|
250
|
+
committed = draft[:n] + [bonus]
|
|
251
|
+
accept_lengths.append(len(committed))
|
|
252
|
+
|
|
253
|
+
# ---- 4. update caches/context ----
|
|
254
|
+
trim = len(draft) - n
|
|
255
|
+
if trim > 0:
|
|
256
|
+
for c in cache:
|
|
257
|
+
if c is not None and hasattr(c, "trim"):
|
|
258
|
+
c.trim(trim)
|
|
259
|
+
# commit [pending, accepted drafts] (positions n_cached..n_cached+n) as context
|
|
260
|
+
drafter.update_context(
|
|
261
|
+
v_fused[:, : n + 1, :], ctx_offset=n_cached, ctx_caches=ctx_caches
|
|
262
|
+
)
|
|
263
|
+
n_cached = n_cached + n + 1
|
|
264
|
+
mx.eval([c.k for c in ctx_caches])
|
|
265
|
+
|
|
266
|
+
for tok in committed:
|
|
267
|
+
out_ids.append(tok)
|
|
268
|
+
if tok in eos_ids:
|
|
269
|
+
break
|
|
270
|
+
pending = committed[-1]
|
|
271
|
+
_stream()
|
|
272
|
+
|
|
273
|
+
if verbose:
|
|
274
|
+
print(f" round {len(accept_lengths):3d}: drafted {len(draft)}, "
|
|
275
|
+
f"accepted {n}, committed {len(committed)}")
|
|
276
|
+
|
|
277
|
+
secs = time.time() - t0
|
|
278
|
+
# strip trailing eos for display
|
|
279
|
+
disp = [t for t in out_ids if t not in eos_ids]
|
|
280
|
+
text = tokenizer.decode(disp)
|
|
281
|
+
return GenResult(
|
|
282
|
+
text=text,
|
|
283
|
+
token_ids=out_ids,
|
|
284
|
+
num_tokens=len(out_ids),
|
|
285
|
+
num_rounds=len(accept_lengths),
|
|
286
|
+
accept_lengths=accept_lengths,
|
|
287
|
+
target_forwards=target_forwards,
|
|
288
|
+
seconds=secs,
|
|
289
|
+
)
|
mlx_dspark/load.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Loaders for the target (Gemma-4 via mlx-vlm) and the DSpark drafter."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import glob
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import mlx.nn as nn
|
|
10
|
+
from huggingface_hub import snapshot_download
|
|
11
|
+
|
|
12
|
+
from .config import DSparkConfig
|
|
13
|
+
from .model import DSparkDrafter
|
|
14
|
+
from .target import Target
|
|
15
|
+
|
|
16
|
+
# The drafter must be paired with the *instruct* target it was trained against, at decent
|
|
17
|
+
# precision. Presets below; pick with load_pair("gemma4") or load_pair("qwen3").
|
|
18
|
+
PRESETS = {
|
|
19
|
+
"gemma4": {
|
|
20
|
+
"target": "mlx-community/gemma-4-12B-it-8bit",
|
|
21
|
+
"drafter": "deepseek-ai/dspark_gemma4_12b_block7",
|
|
22
|
+
},
|
|
23
|
+
"qwen3": {
|
|
24
|
+
"target": "mlx-community/Qwen3-4B-8bit",
|
|
25
|
+
"drafter": "deepseek-ai/dspark_qwen3_4b_block7",
|
|
26
|
+
},
|
|
27
|
+
}
|
|
28
|
+
DEFAULT_TARGET = PRESETS["gemma4"]["target"]
|
|
29
|
+
DEFAULT_DRAFTER = PRESETS["gemma4"]["drafter"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _resolve(repo_or_path: str) -> str:
|
|
33
|
+
if os.path.isdir(repo_or_path):
|
|
34
|
+
return repo_or_path
|
|
35
|
+
return snapshot_download(repo_or_path)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def load_drafter(
|
|
39
|
+
repo_or_path: str = DEFAULT_DRAFTER,
|
|
40
|
+
*,
|
|
41
|
+
quantize: bool = True,
|
|
42
|
+
bits: int = 4,
|
|
43
|
+
group_size: int = 64,
|
|
44
|
+
):
|
|
45
|
+
"""Return (drafter, config). Loads bf16 weights 1:1 by matching key names.
|
|
46
|
+
|
|
47
|
+
The drafter is ~6.86 GB in bf16 and runs every speculative round, so by
|
|
48
|
+
default it is quantized to 4-bit (~1.8 GB) — this is what makes spec
|
|
49
|
+
decoding a net speedup on Apple Silicon. Output correctness is unaffected
|
|
50
|
+
(the target verifies every token); only acceptance length may change.
|
|
51
|
+
"""
|
|
52
|
+
path = _resolve(repo_or_path)
|
|
53
|
+
config = DSparkConfig.from_json(os.path.join(path, "config.json"))
|
|
54
|
+
drafter = DSparkDrafter(config)
|
|
55
|
+
|
|
56
|
+
weights: dict[str, mx.array] = {}
|
|
57
|
+
for st in glob.glob(os.path.join(path, "*.safetensors")):
|
|
58
|
+
weights.update(mx.load(st))
|
|
59
|
+
|
|
60
|
+
# Diagnose name mismatches before loading.
|
|
61
|
+
model_keys = {k for k, _ in _flatten_params(drafter)}
|
|
62
|
+
ckpt_keys = set(weights.keys())
|
|
63
|
+
missing = sorted(model_keys - ckpt_keys)
|
|
64
|
+
unexpected = sorted(ckpt_keys - model_keys)
|
|
65
|
+
if missing or unexpected:
|
|
66
|
+
print(f"[load_drafter] WARNING key mismatch:")
|
|
67
|
+
if missing:
|
|
68
|
+
print(f" missing in checkpoint ({len(missing)}): {missing[:8]}")
|
|
69
|
+
if unexpected:
|
|
70
|
+
print(f" unexpected in checkpoint ({len(unexpected)}): {unexpected[:8]}")
|
|
71
|
+
|
|
72
|
+
drafter.load_weights(list(weights.items()), strict=not (missing or unexpected))
|
|
73
|
+
|
|
74
|
+
if quantize:
|
|
75
|
+
# Quantize Linear/Embedding weights; norms/scalars stay full precision.
|
|
76
|
+
nn.quantize(drafter, group_size=group_size, bits=bits)
|
|
77
|
+
|
|
78
|
+
mx.eval(drafter.parameters())
|
|
79
|
+
return drafter, config
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _flatten_params(module) -> list[tuple[str, mx.array]]:
|
|
83
|
+
from mlx.utils import tree_flatten
|
|
84
|
+
|
|
85
|
+
return tree_flatten(module.parameters())
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def load_target(repo_or_path: str = DEFAULT_TARGET):
|
|
89
|
+
"""Return (Target, tokenizer). Routes text models (Qwen3) to mlx-lm and VLM/unified
|
|
90
|
+
models (Gemma-4) to mlx-vlm, then wraps in a family-aware Target (hidden-state tap)."""
|
|
91
|
+
import json
|
|
92
|
+
|
|
93
|
+
path = _resolve(repo_or_path)
|
|
94
|
+
model_type = ""
|
|
95
|
+
cfg_path = os.path.join(path, "config.json")
|
|
96
|
+
if os.path.exists(cfg_path):
|
|
97
|
+
with open(cfg_path) as f:
|
|
98
|
+
model_type = json.load(f).get("model_type", "")
|
|
99
|
+
|
|
100
|
+
if "qwen3" in model_type and "moe" not in model_type:
|
|
101
|
+
from mlx_lm import load as lm_load
|
|
102
|
+
|
|
103
|
+
model, tokenizer = lm_load(path)
|
|
104
|
+
else:
|
|
105
|
+
from mlx_vlm import load as vlm_load
|
|
106
|
+
|
|
107
|
+
model, processor = vlm_load(path)
|
|
108
|
+
tokenizer = getattr(processor, "tokenizer", processor)
|
|
109
|
+
return Target(model, tokenizer), tokenizer
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def load_pair(family: str = "gemma4", *, target_bits: str | None = None):
|
|
113
|
+
"""Convenience: load (target, tokenizer, drafter, cfg) for a preset family."""
|
|
114
|
+
p = PRESETS[family]
|
|
115
|
+
target, tok = load_target(p["target"])
|
|
116
|
+
drafter, cfg = load_drafter(p["drafter"])
|
|
117
|
+
return target, tok, drafter, cfg
|
mlx_dspark/model.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""DSpark drafter in MLX — Gemma-4 and Qwen3 families.
|
|
2
|
+
|
|
3
|
+
Faithful port of the DeepSpec inference path. The EAGLE-style cross-attention is shared:
|
|
4
|
+
Q comes from the draft block, K/V from concat([fused_target_context, block]), with the
|
|
5
|
+
context K/V cached per layer (CtxCache). Family differences (norm layout, rope, MLP act,
|
|
6
|
+
v handling, logit softcap) are config-driven. Module attribute names match the HF
|
|
7
|
+
checkpoints so weights load 1:1.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import mlx.core as mx
|
|
13
|
+
import mlx.nn as nn
|
|
14
|
+
|
|
15
|
+
from mlx_vlm.models.gemma4.rope_utils import initialize_rope
|
|
16
|
+
|
|
17
|
+
from .config import DSparkConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RMSNormNoScale(nn.Module):
|
|
21
|
+
"""RMSNorm with no learnable weight (Gemma-4 v_norm)."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, eps: float = 1e-6):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.eps = eps
|
|
26
|
+
|
|
27
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
28
|
+
return mx.fast.rms_norm(x, None, self.eps)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _act(name: str):
|
|
32
|
+
return nn.silu if name == "silu" else nn.gelu_approx
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MLP(nn.Module):
|
|
36
|
+
def __init__(self, config: DSparkConfig):
|
|
37
|
+
super().__init__()
|
|
38
|
+
h, i = config.hidden_size, config.intermediate_size
|
|
39
|
+
self.gate_proj = nn.Linear(h, i, bias=False)
|
|
40
|
+
self.up_proj = nn.Linear(h, i, bias=False)
|
|
41
|
+
self.down_proj = nn.Linear(i, h, bias=False)
|
|
42
|
+
self.act = _act(config.mlp_activation)
|
|
43
|
+
|
|
44
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
45
|
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _repeat_kv(x: mx.array, n_rep: int) -> mx.array:
|
|
49
|
+
if n_rep == 1:
|
|
50
|
+
return x
|
|
51
|
+
b, n_kv, s, d = x.shape
|
|
52
|
+
x = mx.expand_dims(x, 2)
|
|
53
|
+
x = mx.broadcast_to(x, (b, n_kv, n_rep, s, d))
|
|
54
|
+
return x.reshape(b, n_kv * n_rep, s, d)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CtxCache:
|
|
58
|
+
"""Per-layer cache of the target context's projected K/V (roped K, normed/raw V)."""
|
|
59
|
+
|
|
60
|
+
__slots__ = ("k", "v")
|
|
61
|
+
|
|
62
|
+
def __init__(self):
|
|
63
|
+
self.k = None
|
|
64
|
+
self.v = None
|
|
65
|
+
|
|
66
|
+
def append(self, k: mx.array, v: mx.array) -> None:
|
|
67
|
+
if self.k is None:
|
|
68
|
+
self.k, self.v = k, v
|
|
69
|
+
else:
|
|
70
|
+
self.k = mx.concatenate([self.k, k], axis=2)
|
|
71
|
+
self.v = mx.concatenate([self.v, v], axis=2)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class DSparkAttention(nn.Module):
|
|
75
|
+
"""Cross-attention: Q from the draft block, K/V from [target_context, block]."""
|
|
76
|
+
|
|
77
|
+
def __init__(self, config: DSparkConfig):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.n_heads = config.num_attention_heads
|
|
80
|
+
self.head_dim = config.attn_head_dim
|
|
81
|
+
self.k_eq_v = config.attention_k_eq_v
|
|
82
|
+
self.n_kv_heads = config.n_kv_heads
|
|
83
|
+
self.n_rep = self.n_heads // self.n_kv_heads
|
|
84
|
+
self.scale = config.scaling
|
|
85
|
+
self.use_v_norm = config.use_v_norm
|
|
86
|
+
|
|
87
|
+
h = config.hidden_size
|
|
88
|
+
b = config.attention_bias
|
|
89
|
+
self.q_proj = nn.Linear(h, self.n_heads * self.head_dim, bias=b)
|
|
90
|
+
self.k_proj = nn.Linear(h, self.n_kv_heads * self.head_dim, bias=b)
|
|
91
|
+
if not self.k_eq_v:
|
|
92
|
+
self.v_proj = nn.Linear(h, self.n_kv_heads * self.head_dim, bias=b)
|
|
93
|
+
self.o_proj = nn.Linear(self.n_heads * self.head_dim, h, bias=b)
|
|
94
|
+
|
|
95
|
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
96
|
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
97
|
+
if self.use_v_norm:
|
|
98
|
+
self.v_norm = RMSNormNoScale(eps=config.rms_norm_eps)
|
|
99
|
+
|
|
100
|
+
self.rope = initialize_rope(
|
|
101
|
+
dims=self.head_dim, base=config.rope_theta, traditional=False,
|
|
102
|
+
scaling_config=config.rope_parameters,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def _kv(self, x: mx.array):
|
|
106
|
+
"""Project x -> (roped+normed K, V). k_eq_v shares k_proj for V."""
|
|
107
|
+
B, S, _ = x.shape
|
|
108
|
+
kp = self.k_proj(x).reshape(B, S, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
109
|
+
k = self.k_norm(kp)
|
|
110
|
+
if self.k_eq_v:
|
|
111
|
+
v = self.v_norm(kp)
|
|
112
|
+
else:
|
|
113
|
+
v = self.v_proj(x).reshape(B, S, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
114
|
+
if self.use_v_norm:
|
|
115
|
+
v = self.v_norm(v)
|
|
116
|
+
return k, v
|
|
117
|
+
|
|
118
|
+
def update_ctx(self, fused_new: mx.array, ctx_offset: int, cache: CtxCache) -> None:
|
|
119
|
+
k, v = self._kv(fused_new)
|
|
120
|
+
cache.append(self.rope(k, offset=ctx_offset), v) # V is not roped
|
|
121
|
+
|
|
122
|
+
def attend(self, hidden: mx.array, block_offset: int, cache: CtxCache) -> mx.array:
|
|
123
|
+
B, q_len, _ = hidden.shape
|
|
124
|
+
q = self.q_proj(hidden).reshape(B, q_len, self.n_heads, self.head_dim)
|
|
125
|
+
q = self.rope(self.q_norm(q).transpose(0, 2, 1, 3), offset=block_offset)
|
|
126
|
+
|
|
127
|
+
k_blk, v_blk = self._kv(hidden)
|
|
128
|
+
k_blk = self.rope(k_blk, offset=block_offset)
|
|
129
|
+
k = mx.concatenate([cache.k, k_blk], axis=2)
|
|
130
|
+
v = mx.concatenate([cache.v, v_blk], axis=2)
|
|
131
|
+
|
|
132
|
+
k = _repeat_kv(k, self.n_rep)
|
|
133
|
+
v = _repeat_kv(v, self.n_rep)
|
|
134
|
+
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
|
135
|
+
out = out.transpose(0, 2, 1, 3).reshape(B, q_len, -1)
|
|
136
|
+
return self.o_proj(out)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class DSparkDecoderLayer(nn.Module):
|
|
140
|
+
def __init__(self, config: DSparkConfig):
|
|
141
|
+
super().__init__()
|
|
142
|
+
eps = config.rms_norm_eps
|
|
143
|
+
self.norm_style = config.norm_style
|
|
144
|
+
self.self_attn = DSparkAttention(config)
|
|
145
|
+
self.mlp = MLP(config)
|
|
146
|
+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
|
|
147
|
+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
|
|
148
|
+
if self.norm_style == "gemma":
|
|
149
|
+
self.pre_feedforward_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
|
|
150
|
+
self.post_feedforward_layernorm = nn.RMSNorm(config.hidden_size, eps=eps)
|
|
151
|
+
self.layer_scalar = mx.ones((1,))
|
|
152
|
+
|
|
153
|
+
def __call__(self, hidden, block_offset, cache: CtxCache):
|
|
154
|
+
if self.norm_style == "gemma":
|
|
155
|
+
residual = hidden
|
|
156
|
+
h = self.input_layernorm(hidden)
|
|
157
|
+
h = self.self_attn.attend(h, block_offset, cache)
|
|
158
|
+
h = self.post_attention_layernorm(h)
|
|
159
|
+
h = residual + h
|
|
160
|
+
residual = h
|
|
161
|
+
h = self.pre_feedforward_layernorm(h)
|
|
162
|
+
h = self.mlp(h)
|
|
163
|
+
h = self.post_feedforward_layernorm(h)
|
|
164
|
+
h = residual + h
|
|
165
|
+
return h * self.layer_scalar
|
|
166
|
+
# qwen / llama 2-norm
|
|
167
|
+
residual = hidden
|
|
168
|
+
h = self.input_layernorm(hidden)
|
|
169
|
+
h = self.self_attn.attend(h, block_offset, cache)
|
|
170
|
+
h = residual + h
|
|
171
|
+
residual = h
|
|
172
|
+
h = self.post_attention_layernorm(h)
|
|
173
|
+
h = self.mlp(h)
|
|
174
|
+
return residual + h
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class VanillaMarkov(nn.Module):
|
|
178
|
+
"""Rank-256 previous-token correction: logits += w2(w1[prev_token])."""
|
|
179
|
+
|
|
180
|
+
def __init__(self, config: DSparkConfig):
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.markov_w1 = nn.Embedding(config.vocab_size, config.markov_rank)
|
|
183
|
+
self.markov_w2 = nn.Linear(config.markov_rank, config.vocab_size, bias=False)
|
|
184
|
+
|
|
185
|
+
def prev_embeddings(self, token_ids: mx.array) -> mx.array:
|
|
186
|
+
return self.markov_w1(token_ids)
|
|
187
|
+
|
|
188
|
+
def step_bias(self, token_ids: mx.array) -> mx.array:
|
|
189
|
+
return self.markov_w2(self.markov_w1(token_ids))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class ConfidenceHead(nn.Module):
|
|
193
|
+
def __init__(self, input_dim: int):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.proj = nn.Linear(input_dim, 1)
|
|
196
|
+
|
|
197
|
+
def __call__(self, features: mx.array) -> mx.array:
|
|
198
|
+
return self.proj(features).squeeze(-1)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class DSparkDrafter(nn.Module):
|
|
202
|
+
def __init__(self, config: DSparkConfig):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.config = config
|
|
205
|
+
self.block_size = config.block_size
|
|
206
|
+
self.mask_token_id = config.mask_token_id
|
|
207
|
+
self.embed_scale = (float(config.hidden_size) ** 0.5) if config.family == "gemma4" else 1.0
|
|
208
|
+
self.softcap = config.final_logit_softcapping
|
|
209
|
+
|
|
210
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
211
|
+
self.fc = nn.Linear(
|
|
212
|
+
len(config.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False
|
|
213
|
+
)
|
|
214
|
+
self.hidden_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
215
|
+
self.layers = [DSparkDecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
216
|
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
217
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
218
|
+
|
|
219
|
+
self.markov_head = VanillaMarkov(config) if config.markov_rank > 0 else None
|
|
220
|
+
self.confidence_head = None
|
|
221
|
+
if config.enable_confidence_head:
|
|
222
|
+
in_dim = config.hidden_size + (config.markov_rank if config.confidence_head_with_markov else 0)
|
|
223
|
+
self.confidence_head = ConfidenceHead(in_dim)
|
|
224
|
+
|
|
225
|
+
def embed(self, ids: mx.array) -> mx.array:
|
|
226
|
+
return self.embed_tokens(ids) * self.embed_scale
|
|
227
|
+
|
|
228
|
+
def fuse_target(self, target_hidden_cat: mx.array) -> mx.array:
|
|
229
|
+
return self.hidden_norm(self.fc(target_hidden_cat))
|
|
230
|
+
|
|
231
|
+
def make_ctx_cache(self) -> list[CtxCache]:
|
|
232
|
+
return [CtxCache() for _ in self.layers]
|
|
233
|
+
|
|
234
|
+
def update_context(self, target_hidden_cat, ctx_offset, ctx_caches) -> None:
|
|
235
|
+
fused = self.fuse_target(target_hidden_cat)
|
|
236
|
+
for layer, cache in zip(self.layers, ctx_caches):
|
|
237
|
+
layer.self_attn.update_ctx(fused, ctx_offset, cache)
|
|
238
|
+
|
|
239
|
+
def backbone(self, noise_embedding, block_offset, ctx_caches) -> mx.array:
|
|
240
|
+
h = noise_embedding
|
|
241
|
+
for layer, cache in zip(self.layers, ctx_caches):
|
|
242
|
+
h = layer(h, block_offset, cache)
|
|
243
|
+
return self.norm(h)
|
|
244
|
+
|
|
245
|
+
def compute_logits(self, hidden: mx.array) -> mx.array:
|
|
246
|
+
logits = self.lm_head(hidden)
|
|
247
|
+
if self.softcap is not None:
|
|
248
|
+
logits = mx.tanh(logits / self.softcap) * self.softcap
|
|
249
|
+
return logits
|
|
250
|
+
|
|
251
|
+
def sample_block(self, base_logits: mx.array, first_prev_token: int) -> mx.array:
|
|
252
|
+
k = base_logits.shape[0]
|
|
253
|
+
if self.markov_head is None:
|
|
254
|
+
return mx.argmax(base_logits, axis=-1)
|
|
255
|
+
tokens = []
|
|
256
|
+
prev = mx.array([first_prev_token])
|
|
257
|
+
for i in range(k):
|
|
258
|
+
step = base_logits[i] + self.markov_head.step_bias(prev)[0]
|
|
259
|
+
nxt = mx.argmax(step, axis=-1, keepdims=True)
|
|
260
|
+
tokens.append(nxt)
|
|
261
|
+
prev = nxt
|
|
262
|
+
return mx.concatenate(tokens)
|
|
263
|
+
|
|
264
|
+
def confidence_logits(self, block_hidden, prev_token_ids):
|
|
265
|
+
if self.confidence_head is None:
|
|
266
|
+
return None
|
|
267
|
+
if self.config.confidence_head_with_markov:
|
|
268
|
+
feats = mx.concatenate(
|
|
269
|
+
[block_hidden, self.markov_head.prev_embeddings(prev_token_ids)], axis=-1
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
feats = block_hidden
|
|
273
|
+
return self.confidence_head(feats)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
# Backwards-compatible alias
|
|
277
|
+
Gemma4DSparkDrafter = DSparkDrafter
|
mlx_dspark/target.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Family-aware target wrapper: KV cache + a hidden-state tap at given layers.
|
|
2
|
+
|
|
3
|
+
- gemma4 (mlx-vlm): uses the built-in ``capture_layer_ids`` / ``hidden_sink`` hook.
|
|
4
|
+
- qwen3 (mlx-lm): no hook exists, so we replicate the model's forward loop and
|
|
5
|
+
capture the residual stream after the tapped layers.
|
|
6
|
+
|
|
7
|
+
Both expose: make_cache(), run(ids, cache, tap)->(logits, fused_hidden), and
|
|
8
|
+
plain(ids, cache)->logits (no capture, for the greedy baseline).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import mlx.core as mx
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Target:
|
|
17
|
+
def __init__(self, model, tokenizer):
|
|
18
|
+
self.model = model
|
|
19
|
+
self.tokenizer = tokenizer
|
|
20
|
+
# mlx-vlm models expose .language_model; mlx-lm models expose .model + (lm_head|tied)
|
|
21
|
+
self.is_vlm = hasattr(model, "language_model")
|
|
22
|
+
self.family = "gemma4" if self.is_vlm else "qwen3"
|
|
23
|
+
if not self.is_vlm:
|
|
24
|
+
self._tied = bool(getattr(getattr(model, "args", None), "tie_word_embeddings", False))
|
|
25
|
+
|
|
26
|
+
# -- cache --
|
|
27
|
+
def make_cache(self):
|
|
28
|
+
if self.is_vlm:
|
|
29
|
+
return self.model.language_model.make_cache()
|
|
30
|
+
from mlx_lm.models.cache import make_prompt_cache
|
|
31
|
+
return make_prompt_cache(self.model)
|
|
32
|
+
|
|
33
|
+
# -- forward with hidden-state tap --
|
|
34
|
+
def run(self, ids: mx.array, cache, tap: list[int]):
|
|
35
|
+
"""ids [1,L] -> (logits [1,L,V], fused_hidden [1,L,n_tap*H])."""
|
|
36
|
+
if self.is_vlm:
|
|
37
|
+
out = self.model.language_model(inputs=ids, cache=cache, capture_layer_ids=tap)
|
|
38
|
+
return out.logits, mx.concatenate(out.hidden_states, axis=-1)
|
|
39
|
+
return self._run_mlxlm(ids, cache, tap)
|
|
40
|
+
|
|
41
|
+
def _run_mlxlm(self, ids, cache, tap):
|
|
42
|
+
from mlx_lm.models.base import create_attention_mask
|
|
43
|
+
|
|
44
|
+
mm = self.model.model
|
|
45
|
+
tapset = set(tap)
|
|
46
|
+
h = mm.embed_tokens(ids)
|
|
47
|
+
mask = create_attention_mask(h, cache[0])
|
|
48
|
+
captured = []
|
|
49
|
+
for i, (layer, c) in enumerate(zip(mm.layers, cache)):
|
|
50
|
+
h = layer(h, mask, c)
|
|
51
|
+
if i in tapset:
|
|
52
|
+
captured.append(h)
|
|
53
|
+
hn = mm.norm(h)
|
|
54
|
+
if self._tied:
|
|
55
|
+
logits = mm.embed_tokens.as_linear(hn)
|
|
56
|
+
else:
|
|
57
|
+
logits = self.model.lm_head(hn)
|
|
58
|
+
return logits, mx.concatenate(captured, axis=-1)
|
|
59
|
+
|
|
60
|
+
# -- plain forward (no capture) for the greedy baseline --
|
|
61
|
+
def plain(self, ids: mx.array, cache):
|
|
62
|
+
if self.is_vlm:
|
|
63
|
+
return self.model.language_model(inputs=ids, cache=cache).logits
|
|
64
|
+
return self.model(ids, cache=cache)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlx-dspark
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: DSpark speculative decoding (semi-autoregressive drafting) for Apple Silicon via MLX
|
|
5
|
+
Project-URL: Homepage, https://github.com/ARahim3/mlx-dspark
|
|
6
|
+
Project-URL: Repository, https://github.com/ARahim3/mlx-dspark
|
|
7
|
+
Project-URL: Issues, https://github.com/ARahim3/mlx-dspark/issues
|
|
8
|
+
Author-email: erahim3 <erahim3@gmail.com>
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
License-File: NOTICE
|
|
12
|
+
Keywords: apple-silicon,deepspec,dspark,llm-inference,mlx,speculative-decoding
|
|
13
|
+
Classifier: Environment :: GPU
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: MacOS
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.10
|
|
19
|
+
Requires-Dist: huggingface-hub
|
|
20
|
+
Requires-Dist: mlx-lm>=0.31.3
|
|
21
|
+
Requires-Dist: mlx-vlm>=0.6.3
|
|
22
|
+
Requires-Dist: mlx>=0.31.2
|
|
23
|
+
Requires-Dist: numpy
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
26
|
+
Requires-Dist: ruff; extra == 'dev'
|
|
27
|
+
Description-Content-Type: text/markdown
|
|
28
|
+
|
|
29
|
+
# mlx-dspark
|
|
30
|
+
|
|
31
|
+
**DSpark speculative decoding for Apple Silicon**, built on [MLX](https://github.com/ml-explore/mlx).
|
|
32
|
+
|
|
33
|
+
DSpark is DeepSeek's semi-autoregressive, EAGLE-family speculative-decoding drafter,
|
|
34
|
+
open-sourced in the [DeepSpec](https://github.com/deepseek-ai/DeepSpec) codebase and used to
|
|
35
|
+
accelerate DeepSeek-V4. This project ports the **inference path** to MLX so the published
|
|
36
|
+
drafter checkpoints run natively on a Mac.
|
|
37
|
+
|
|
38
|
+
**Supported families** (auto-detected from the drafter config):
|
|
39
|
+
|
|
40
|
+
| family | target | drafter | RAM |
|
|
41
|
+
|---|---|---|---|
|
|
42
|
+
| `gemma4` | `gemma-4-12B-it-8bit` | `deepseek-ai/dspark_gemma4_12b_block7` | ~32 GB+ |
|
|
43
|
+
| `qwen3` | `Qwen3-4B-8bit` | `deepseek-ai/dspark_qwen3_4b_block7` | ~16 GB |
|
|
44
|
+
|
|
45
|
+
## How it works
|
|
46
|
+
|
|
47
|
+
- A **parallel backbone** (5 Gemma-4 layers) consumes the target model's hidden states
|
|
48
|
+
(tapped at layers `[5,17,29,41,46]`, EAGLE3-style) and proposes a 7-token block at once.
|
|
49
|
+
- A **rank-256 Markov head** adds a previous-token correction, sampled sequentially — the only
|
|
50
|
+
sequential cost, which kills "suffix decay" cheaply.
|
|
51
|
+
- A **confidence head** scores each draft position (optional adaptive block length).
|
|
52
|
+
- The target **verifies** every token, so output is **greedy-correct by construction**
|
|
53
|
+
(identical to plain greedy decoding, up to floating-point tie-breaking on near-ties).
|
|
54
|
+
|
|
55
|
+
The drafter is loaded 1:1 from the HF checkpoint and **quantized to 4-bit** by default
|
|
56
|
+
(~1.8 GB) so it's cheap to run every round.
|
|
57
|
+
|
|
58
|
+
## Install
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
uv venv --python 3.12
|
|
62
|
+
source .venv/bin/activate
|
|
63
|
+
uv pip install -e .
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Use
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# CLI — pick a family (downloads drafter + instruct target on first run)
|
|
70
|
+
python -m mlx_dspark --family qwen3 --prompt "Explain how rainbows form."
|
|
71
|
+
python -m mlx_dspark --family gemma4 --prompt "Explain how rainbows form." --max-new-tokens 256
|
|
72
|
+
|
|
73
|
+
# side-by-side demo: baseline (plain target) vs dspark (record each, stack)
|
|
74
|
+
python -m mlx_dspark --family qwen3 --mode baseline --prompt "..." --max-new-tokens 400
|
|
75
|
+
python -m mlx_dspark --family qwen3 --mode dspark --prompt "..." --max-new-tokens 400
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
from mlx_dspark import load_pair, speculative_generate
|
|
80
|
+
|
|
81
|
+
target, tok, drafter, cfg = load_pair("qwen3") # or "gemma4"
|
|
82
|
+
res = speculative_generate(target, tok, drafter, "Explain how rainbows form.")
|
|
83
|
+
print(res.text, res.mean_accept_len, res.tokens_per_sec)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Results (M4 Pro, 48 GB; 8-bit instruct target, 4-bit drafter; warm — `python benchmark.py`)
|
|
87
|
+
|
|
88
|
+
| family | drafter `d_0` | accept len | greedy (baseline) | dspark (this project) | speedup |
|
|
89
|
+
|---|---|---|---|---|---|
|
|
90
|
+
| **Gemma-4 12B** | ~82% | 2.5–3.6 | ~17 tok/s | ~28 tok/s | **~1.5–1.6×** |
|
|
91
|
+
| **Qwen3-4B** | ~85% | 2.1–2.8 | ~49 tok/s | ~66 tok/s | **~1.3–1.4×** |
|
|
92
|
+
|
|
93
|
+
"greedy" = the plain target model decoding one token per forward (no drafter); "dspark" =
|
|
94
|
+
speculative decoding with the DSpark drafter. Both produce **identical** output — DSpark is just
|
|
95
|
+
faster (it diverges from sequential greedy only at logit-margin≈0 ties). Smaller/faster targets (Qwen3-4B) have a lower
|
|
96
|
+
per-token verify cost, so the optimal `--max-draft` is smaller (~2).
|
|
97
|
+
|
|
98
|
+
### Target choice matters
|
|
99
|
+
|
|
100
|
+
The drafter is trained against a specific **instruct** model in **bf16** — use the matching
|
|
101
|
+
instruct target (`gemma-4-12B-it` / `Qwen3-4B`), not the base model (base gave `d_0` ~47% vs
|
|
102
|
+
82%). Higher precision raises acceptance; 8-bit is the sweet spot. `-bf16` maximizes acceptance;
|
|
103
|
+
4-bit verifies faster.
|
|
104
|
+
|
|
105
|
+
### Tuning
|
|
106
|
+
|
|
107
|
+
On Apple Silicon the target verify cost grows with the number of tokens verified, so the
|
|
108
|
+
optimum is to verify ~= the acceptance length: `--max-draft 4` (default). `--max-draft <block>`
|
|
109
|
+
(full 7) is faithful but *slower* on M-series. `--confidence-threshold 0.6` instead truncates
|
|
110
|
+
the block adaptively via the drafter's confidence head.
|
|
111
|
+
|
|
112
|
+
## License
|
|
113
|
+
|
|
114
|
+
MIT — see [`LICENSE`](LICENSE). This is an independent MLX port of the inference path of
|
|
115
|
+
DeepSeek's DSpark drafter; see [`NOTICE`](NOTICE) for attribution. No model weights are bundled.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
mlx_dspark/__init__.py,sha256=fclV1VbqP6UzILCu35xa6hTDYV_BLzeH9Fr_32hbDwk,1156
|
|
2
|
+
mlx_dspark/__main__.py,sha256=YEfk3teljcUJrc17eMMz7WIILwGFvnEB4mNGmQqnKpQ,3417
|
|
3
|
+
mlx_dspark/config.py,sha256=Ey-tRJ3w6rAXSbe6LhuVHiyLdgbqNvZDGW8blq7WKS8,6257
|
|
4
|
+
mlx_dspark/generate.py,sha256=sOdrRg2GA8so3JH9MBKAMaJ3q6BYuop7lxTmdTs7xkc,10146
|
|
5
|
+
mlx_dspark/load.py,sha256=xLr0QLCFzQa7Qw0LkubT0Jpz0Mt58bvMtjdRBQxBAeQ,3946
|
|
6
|
+
mlx_dspark/model.py,sha256=mnI6Y_ZWZbeohF4FGsLRcvPFamUoy16_0gVFlIkDZew,10607
|
|
7
|
+
mlx_dspark/target.py,sha256=8SlmnMM4UO1qL36HZwaqGfniIEDoupU1T5G_wcKEpXw,2494
|
|
8
|
+
mlx_dspark-0.0.1.dist-info/METADATA,sha256=Wea6DUzWjUOBKFkwXY7D31Ss6TqYBUam6mXXIXfUxZc,5047
|
|
9
|
+
mlx_dspark-0.0.1.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
10
|
+
mlx_dspark-0.0.1.dist-info/licenses/LICENSE,sha256=7fTQ89BHhdRKOLqXpVQFZxTnwQUUtg8EeamGIrOz7eU,1064
|
|
11
|
+
mlx_dspark-0.0.1.dist-info/licenses/NOTICE,sha256=4sSSawyNMWFGbhUevnc0ZKPgOS3auWOdm77K_BqLRcs,908
|
|
12
|
+
mlx_dspark-0.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 erahim3
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
mlx-dspark
|
|
2
|
+
==========
|
|
3
|
+
|
|
4
|
+
This project is an independent MLX/Apple-Silicon port of the *inference path* of
|
|
5
|
+
DSpark, a speculative-decoding drafter open-sourced by DeepSeek as part of the
|
|
6
|
+
DeepSpec codebase:
|
|
7
|
+
|
|
8
|
+
- DeepSpec: https://github.com/deepseek-ai/DeepSpec (MIT License)
|
|
9
|
+
|
|
10
|
+
It loads the published DSpark drafter checkpoints (also released by DeepSeek under
|
|
11
|
+
their respective licenses):
|
|
12
|
+
|
|
13
|
+
- deepseek-ai/dspark_gemma4_12b_block7
|
|
14
|
+
- deepseek-ai/dspark_qwen3_4b_block7
|
|
15
|
+
|
|
16
|
+
The drafter architecture, training, and checkpoints are the work of DeepSeek and
|
|
17
|
+
the DeepSpec authors. This repository reimplements only the forward/verification
|
|
18
|
+
path for MLX and contains no DeepSpec source code. Target models (Gemma-4, Qwen3)
|
|
19
|
+
are downloaded at runtime from their respective publishers and are subject to
|
|
20
|
+
their own licenses (e.g. the Gemma Terms of Use and the Qwen license).
|
|
21
|
+
|
|
22
|
+
No model weights are bundled with this package.
|