dspark-mlx 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dspark_mlx/__init__.py ADDED
@@ -0,0 +1,45 @@
1
+ # Copyright 2026 popfido
2
+ # Licensed under the Apache License, Version 2.0 - see LICENSE file
3
+ # Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ from .adapter import BaseModelAdapter, BlockOut, StepOut
8
+ from .arch.backbone import DraftArch, DraftBackbone
9
+ from .events import SummaryEvent, TokenEvent
10
+ from .generate import generate
11
+ from .loader import KNOWN_MODELS, load_draft, load_host, resolve_model
12
+ from .loading import is_dspark_checkpoint, load_drafter, map_checkpoint_key
13
+ from .loop import generate_eager
14
+ from .model.config import DSparkArgs
15
+ from .model.drafter import DSparkDrafter
16
+ from .quant import quantize_drafter
17
+ from .registry import ARCH_REGISTRY, resolve_arch
18
+ from .verify import AcceptResult, greedy_accept, speculative_sample_accept
19
+
20
+ __all__ = [
21
+ "BaseModelAdapter",
22
+ "BlockOut",
23
+ "StepOut",
24
+ "DSparkArgs",
25
+ "DSparkDrafter",
26
+ "DraftArch",
27
+ "DraftBackbone",
28
+ "resolve_arch",
29
+ "ARCH_REGISTRY",
30
+ "generate",
31
+ "generate_eager",
32
+ "load_draft",
33
+ "load_host",
34
+ "resolve_model",
35
+ "KNOWN_MODELS",
36
+ "greedy_accept",
37
+ "speculative_sample_accept",
38
+ "AcceptResult",
39
+ "TokenEvent",
40
+ "SummaryEvent",
41
+ "load_drafter",
42
+ "map_checkpoint_key",
43
+ "is_dspark_checkpoint",
44
+ "quantize_drafter",
45
+ ]
dspark_mlx/adapter.py ADDED
@@ -0,0 +1,78 @@
1
+ # Copyright 2026 popfido
2
+ # Licensed under the Apache License, Version 2.0 - see LICENSE file
3
+ # Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
4
+
5
+ """The seam between dspark-mlx (drafter + verify/accept loop) and a host base model.
6
+
7
+ dspark-mlx is target-agnostic: it owns the DSpark draft stack and the lossless
8
+ accept policy, but never the base model. The host (e.g. omlx over its
9
+ ``patches/deepseek_v4`` model) implements :class:`BaseModelAdapter` so the drafter
10
+ can (a) read the ``main_hidden`` it conditions on, (b) get the base distribution for
11
+ each candidate token during verify, and (c) snapshot/roll back base KV when a block is
12
+ only partially accepted.
13
+
14
+ Logit conventions (one decode cycle):
15
+ - ``prefill`` / ``decode_step`` return ``StepOut.logits`` = ``p_1``, the base
16
+ distribution for the *first* drafted token. It is free — already computed by the
17
+ step that produced the anchor — so the verify forward never recomputes it.
18
+ - ``verify_forward`` runs ONE base forward over the K drafted tokens and returns the
19
+ K base distributions ``p_2 .. p_{K+1}`` (``p_{K+1}`` is the bonus position).
20
+ - The generate loop concatenates ``[p_1] + [p_2..p_{K+1}]`` into the ``[K+1, V]`` block
21
+ the accept policy consumes (see :mod:`dspark_mlx.verify`).
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from dataclasses import dataclass
27
+ from typing import Any, Protocol, Tuple, runtime_checkable
28
+
29
+ import mlx.core as mx
30
+
31
+
32
+ @dataclass
33
+ class StepOut:
34
+ """Output of a single base forward at the anchor position."""
35
+
36
+ logits: mx.array # [b, V] base distribution for the next (first drafted) token
37
+ main_hidden: mx.array # [b, dim * len(target_layer_ids)] concat of target-layer hiddens
38
+
39
+
40
+ @dataclass
41
+ class BlockOut:
42
+ """Output of the base verify forward over a K-token draft block."""
43
+
44
+ per_pos_logits: mx.array # [b, K, V] base distributions p_2 .. p_{K+1}
45
+ per_pos_main_hidden: mx.array # [b, K, D] main hidden at each verified position
46
+ main_hidden_last: mx.array # [b, D] convenience alias for the last verified position
47
+
48
+
49
+ @runtime_checkable
50
+ class BaseModelAdapter(Protocol):
51
+ """Host contract. Implementations own the base model and its KV cache."""
52
+
53
+ #: Main-model layer indices whose hidden states are concatenated into ``main_hidden``.
54
+ target_layer_ids: Tuple[int, ...]
55
+
56
+ def prefill(self, tokens: mx.array) -> StepOut:
57
+ """Process the prompt; return logits for the first generated token + main_hidden."""
58
+ ...
59
+
60
+ def decode_step(self, token: mx.array) -> StepOut:
61
+ """Advance one token; return its next-token logits + main_hidden."""
62
+ ...
63
+
64
+ def verify_forward(self, block_tokens: mx.array) -> BlockOut:
65
+ """Run one base forward over K draft tokens; return p_2..p_{K+1} + main_hidden_last.
66
+
67
+ Appends K entries to the base KV cache speculatively; the caller rolls back the
68
+ rejected tail via :meth:`kv_rollback`.
69
+ """
70
+ ...
71
+
72
+ def kv_snapshot(self) -> Any:
73
+ """Opaque handle capturing base KV state before a speculative block."""
74
+ ...
75
+
76
+ def kv_rollback(self, n_keep: int) -> None:
77
+ """Drop speculatively-appended KV beyond ``n_keep`` accepted tokens."""
78
+ ...
@@ -0,0 +1,6 @@
1
+ # Copyright 2026 popfido
2
+ # Licensed under the Apache License, Version 2.0 - see LICENSE file
3
+
4
+ from .backbone import DraftArch, DraftBackbone
5
+
6
+ __all__ = ["DraftArch", "DraftBackbone"]
@@ -0,0 +1,57 @@
1
+ # Copyright 2026 popfido
2
+ # Licensed under the Apache License, Version 2.0 - see LICENSE file
3
+ # Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
4
+
5
+ """The per-architecture seam for DSpark drafters.
6
+
7
+ DSpark ships one recipe (EAGLE-style context projection + Markov bias + confidence head +
8
+ block drafting) realized over different base-model decoder layers — DeepSeek-V4 (windowed
9
+ MLA + MoE + Hyper-Connections, bundled fp8/fp4 ``mtp.*`` checkpoint), Qwen3 and Gemma4
10
+ (standalone bf16 ``layers.*`` checkpoints, full-context GQA). ``generate()`` drives any of
11
+ them through the ``DraftBackbone`` interface; a ``DraftArch`` descriptor registers how to
12
+ build and load each one (see :mod:`dspark_mlx.registry`).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any, Callable, Optional, Protocol, Tuple, runtime_checkable
19
+
20
+ import mlx.core as mx
21
+
22
+
23
+ @runtime_checkable
24
+ class DraftBackbone(Protocol):
25
+ """A loaded DSpark drafter for one base architecture (what ``generate`` consumes)."""
26
+
27
+ block_size: int
28
+
29
+ def forward_spec(
30
+ self, input_ids: mx.array, main_hidden: mx.array, start_pos: int = 0
31
+ ) -> Optional[Tuple[mx.array, mx.array, mx.array]]:
32
+ """Prefill (start_pos==0) seeds context; decode drafts (ids, logits, confidence)."""
33
+ ...
34
+
35
+ def advance(self, main_hidden: mx.array, position: int) -> None:
36
+ """Slide the drafter's context over one committed token."""
37
+ ...
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class DraftArch:
42
+ """Registry entry: how to build + load a DSpark drafter for a base architecture."""
43
+
44
+ name: str
45
+ model_types: Tuple[str, ...]
46
+ build: Callable[..., DraftBackbone] # (config: dict, *, max_seq_len) -> DraftBackbone
47
+ key_map: Callable[[str], Optional[str]] # checkpoint key -> drafter param path (or None)
48
+
49
+ def supports(self, model_type: Optional[str]) -> bool:
50
+ return model_type in self.model_types
51
+
52
+
53
+ def config_model_type(config: Any) -> Optional[str]:
54
+ """Read ``model_type`` from a dict-like or attribute-like config."""
55
+ if isinstance(config, dict):
56
+ return config.get("model_type")
57
+ return getattr(config, "model_type", None)
@@ -0,0 +1,35 @@
1
+ # Copyright 2026 popfido
2
+ # Licensed under the Apache License, Version 2.0 - see LICENSE file
3
+ # Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
4
+
5
+ """DeepSeek-V4-Flash-DSpark backbone descriptor.
6
+
7
+ The windowed MLA + hash-MoE + Hyper-Connections realization, drafting from the ``mtp.*``
8
+ namespace of the bundled fp8/fp4 checkpoint. The model code lives under ``dspark_mlx.model``
9
+ (its parity tests pin it); this module just registers it as a DraftArch.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Optional
15
+
16
+ from ..loading import map_checkpoint_key
17
+ from ..model.config import DSparkArgs
18
+ from ..model.drafter import DSparkDrafter
19
+ from .backbone import DraftArch, DraftBackbone
20
+
21
+
22
+ def build(config: dict, *, max_seq_len: int = 8192) -> DraftBackbone:
23
+ return DSparkDrafter(DSparkArgs.from_dict(config), max_seq_len=max_seq_len)
24
+
25
+
26
+ def key_map(key: str) -> Optional[str]:
27
+ return map_checkpoint_key(key) # mtp.N.* -> blocks.N.*, embed/head pass through
28
+
29
+
30
+ DEEPSEEK_V4 = DraftArch(
31
+ name="deepseek_v4",
32
+ model_types=("deepseek_v4",),
33
+ build=build,
34
+ key_map=key_map,
35
+ )
@@ -0,0 +1,327 @@
1
+ # Copyright 2026 popfido
2
+ # Licensed under the Apache License, Version 2.0 - see LICENSE file
3
+ # Based on DeepSeek DSpark (deepseek-ai/DeepSpec: dspark/gemma4/modeling.py)
4
+
5
+ """Gemma4 DSpark draft backbone (standalone bf16 ``layers.*`` checkpoint).
6
+
7
+ Stock Gemma4 decoder layers with the DSpark context/noise K/V split. Gemma deltas vs Qwen3:
8
+ K=V sharing (no v_proj; separate scaled k_norm + weightless v_norm), attention scale 1.0,
9
+ partial (proportional) RoPE — only ``partial_rotary_factor`` of head_dim rotates, the rest
10
+ pass through — four sandwich norms + a per-layer ``layer_scalar``, GeGLU (gelu-tanh) MLP, and
11
+ final-logit softcapping.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import dataclasses
17
+ import re as _re
18
+ from dataclasses import dataclass
19
+ from typing import Mapping, Tuple
20
+
21
+ import mlx.core as mx
22
+ import mlx.nn as nn
23
+
24
+ from ..model.heads import DSparkConfidenceHead, DSparkMarkovHead
25
+ from ..model.norm_rope import RMSNorm
26
+ from ..recipe import draft_block_decode
27
+ from .backbone import DraftArch
28
+ from .qwen3 import _apply_rope # shared NeoX rotate_half application
29
+
30
+ _GEMMA4_LAYER_RE = _re.compile(r"layers\.(\d+)\.(.+)$")
31
+
32
+
33
+ @dataclass
34
+ class Gemma4DSparkArgs:
35
+ vocab_size: int = 262144
36
+ hidden_size: int = 3840
37
+ num_hidden_layers: int = 5
38
+ num_attention_heads: int = 16
39
+ num_key_value_heads: int = 1 # global KV head count (k=v)
40
+ head_dim: int = 512 # global_head_dim
41
+ intermediate_size: int = 15360
42
+ rms_norm_eps: float = 1e-6
43
+ rope_theta: float = 1000000.0
44
+ partial_rotary_factor: float = 0.25
45
+ attention_k_eq_v: bool = True
46
+ final_logit_softcapping: float = 30.0
47
+ target_layer_ids: Tuple[int, ...] = (5, 17, 29, 41, 46)
48
+ num_target_layers: int = 48
49
+ block_size: int = 7
50
+ mask_token_id: int = 4
51
+ markov_rank: int = 256
52
+ temperature: float = 0.0
53
+ max_position_embeddings: int = 262144
54
+
55
+ @property
56
+ def fc_in(self) -> int:
57
+ return self.hidden_size * len(self.target_layer_ids)
58
+
59
+ @classmethod
60
+ def from_dict(cls, params: Mapping) -> "Gemma4DSparkArgs":
61
+ d = dict(params)
62
+ rope = (d.get("rope_parameters") or {}).get("full_attention") or {}
63
+ if rope.get("rope_theta"):
64
+ d["rope_theta"] = rope["rope_theta"]
65
+ if "partial_rotary_factor" in rope:
66
+ d["partial_rotary_factor"] = rope["partial_rotary_factor"]
67
+ if d.get("global_head_dim"):
68
+ d["head_dim"] = d["global_head_dim"]
69
+ if d.get("num_global_key_value_heads") is not None:
70
+ d["num_key_value_heads"] = d["num_global_key_value_heads"]
71
+ names = {f.name for f in dataclasses.fields(cls)}
72
+ kwargs = {k: v for k, v in d.items() if k in names}
73
+ if "target_layer_ids" in kwargs:
74
+ kwargs["target_layer_ids"] = tuple(kwargs["target_layer_ids"])
75
+ return cls(**kwargs)
76
+
77
+
78
+ def rope_tables(position_ids: mx.array, head_dim: int, theta: float, partial: float) -> Tuple[mx.array, mx.array]:
79
+ """Proportional (partial) RoPE: first ``partial*head_dim`` dims rotate, rest are identity."""
80
+ rope_angles = int(partial * head_dim // 2)
81
+ inv_rot = 1.0 / (theta ** (mx.arange(0, 2 * rope_angles, 2).astype(mx.float32) / head_dim))
82
+ nope = head_dim // 2 - rope_angles
83
+ inv_freq = mx.concatenate([inv_rot, mx.zeros((nope,), dtype=mx.float32)]) if nope > 0 else inv_rot
84
+ freqs = position_ids.astype(mx.float32)[:, None] * inv_freq[None, :]
85
+ emb = mx.concatenate([freqs, freqs], axis=-1)
86
+ return mx.cos(emb), mx.sin(emb)
87
+
88
+
89
+ class Gemma4DSparkAttention(nn.Module):
90
+ def __init__(self, args: Gemma4DSparkArgs):
91
+ super().__init__()
92
+ h, nh, nkv, hd = args.hidden_size, args.num_attention_heads, args.num_key_value_heads, args.head_dim
93
+ self.nh, self.nkv, self.hd = nh, nkv, hd
94
+ self.k_eq_v = args.attention_k_eq_v
95
+ self.q_proj = nn.Linear(h, nh * hd, bias=False)
96
+ self.k_proj = nn.Linear(h, nkv * hd, bias=False)
97
+ self.v_proj = None if self.k_eq_v else nn.Linear(h, nkv * hd, bias=False)
98
+ self.o_proj = nn.Linear(nh * hd, h, bias=False)
99
+ self.q_norm = RMSNorm(hd, args.rms_norm_eps)
100
+ self.k_norm = RMSNorm(hd, args.rms_norm_eps)
101
+ self.v_norm = RMSNorm(hd, args.rms_norm_eps, with_scale=False)
102
+
103
+ def __call__(self, hidden: mx.array, target_ctx: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
104
+ b, q, _ = hidden.shape
105
+ ctx = target_ctx.shape[1]
106
+ qh = self.q_norm(self.q_proj(hidden).reshape(b, q, self.nh, self.hd))
107
+ k_ctx, k_noise = self.k_proj(target_ctx), self.k_proj(hidden)
108
+ v_ctx, v_noise = (k_ctx, k_noise) if self.k_eq_v else (self.v_proj(target_ctx), self.v_proj(hidden))
109
+ k = self.k_norm(mx.concatenate([k_ctx, k_noise], axis=1).reshape(b, ctx + q, self.nkv, self.hd))
110
+ v = self.v_norm(mx.concatenate([v_ctx, v_noise], axis=1).reshape(b, ctx + q, self.nkv, self.hd))
111
+ qh = _apply_rope(qh, cos[-q:], sin[-q:])
112
+ k = _apply_rope(k, cos, sin) # v is not rotated
113
+ qh = qh.transpose(0, 2, 1, 3)
114
+ k = k.transpose(0, 2, 1, 3)
115
+ v = v.transpose(0, 2, 1, 3)
116
+ out = mx.fast.scaled_dot_product_attention(qh, k, v, scale=1.0, mask=None) # Gemma4 scale==1
117
+ out = out.transpose(0, 2, 1, 3).reshape(b, q, self.nh * self.hd)
118
+ return self.o_proj(out)
119
+
120
+ # --- cached path (Phase 3b): context K/V precomputed once, reused every block ---
121
+ def context_kv(self, proj_ctx: mx.array, cos: mx.array, sin: mx.array):
122
+ b, n, _ = proj_ctx.shape
123
+ kc = self.k_proj(proj_ctx)
124
+ vc = kc if self.k_eq_v else self.v_proj(proj_ctx)
125
+ k = self.k_norm(kc.reshape(b, n, self.nkv, self.hd))
126
+ v = self.v_norm(vc.reshape(b, n, self.nkv, self.hd))
127
+ k = _apply_rope(k, cos, sin) # v is not rotated
128
+ return k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3)
129
+
130
+ def attend_cached(self, noise, ctx_k, ctx_v, cos, sin):
131
+ b, q, _ = noise.shape
132
+ qh = self.q_norm(self.q_proj(noise).reshape(b, q, self.nh, self.hd))
133
+ kn = self.k_proj(noise)
134
+ vn = kn if self.k_eq_v else self.v_proj(noise)
135
+ nk = self.k_norm(kn.reshape(b, q, self.nkv, self.hd))
136
+ nv = self.v_norm(vn.reshape(b, q, self.nkv, self.hd))
137
+ qh = _apply_rope(qh, cos, sin).transpose(0, 2, 1, 3)
138
+ nk = _apply_rope(nk, cos, sin).transpose(0, 2, 1, 3)
139
+ nv = nv.transpose(0, 2, 1, 3)
140
+ k = nk if ctx_k is None else mx.concatenate([ctx_k, nk], axis=2)
141
+ v = nv if ctx_v is None else mx.concatenate([ctx_v, nv], axis=2)
142
+ out = mx.fast.scaled_dot_product_attention(qh, k, v, scale=1.0, mask=None)
143
+ out = out.transpose(0, 2, 1, 3).reshape(b, q, self.nh * self.hd)
144
+ return self.o_proj(out)
145
+
146
+
147
+ class Gemma4MLP(nn.Module):
148
+ def __init__(self, args: Gemma4DSparkArgs):
149
+ super().__init__()
150
+ self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
151
+ self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
152
+ self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
153
+
154
+ def __call__(self, x: mx.array) -> mx.array:
155
+ return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
156
+
157
+
158
+ class Gemma4DSparkLayer(nn.Module):
159
+ def __init__(self, args: Gemma4DSparkArgs):
160
+ super().__init__()
161
+ h, eps = args.hidden_size, args.rms_norm_eps
162
+ self.self_attn = Gemma4DSparkAttention(args)
163
+ self.mlp = Gemma4MLP(args)
164
+ self.input_layernorm = RMSNorm(h, eps)
165
+ self.post_attention_layernorm = RMSNorm(h, eps)
166
+ self.pre_feedforward_layernorm = RMSNorm(h, eps)
167
+ self.post_feedforward_layernorm = RMSNorm(h, eps)
168
+ self.layer_scalar = mx.ones((1,), dtype=mx.float32)
169
+
170
+ def __call__(self, hidden: mx.array, target_ctx: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
171
+ h = self.post_attention_layernorm(self.self_attn(self.input_layernorm(hidden), target_ctx, cos, sin))
172
+ hidden = hidden + h
173
+ h = self.post_feedforward_layernorm(self.mlp(self.pre_feedforward_layernorm(hidden)))
174
+ hidden = hidden + h
175
+ return hidden * self.layer_scalar
176
+
177
+ def context_kv(self, proj_ctx: mx.array, cos: mx.array, sin: mx.array):
178
+ return self.self_attn.context_kv(proj_ctx, cos, sin)
179
+
180
+ def forward_cached(self, hidden, ctx_k, ctx_v, cos, sin) -> mx.array:
181
+ h = self.post_attention_layernorm(self.self_attn.attend_cached(self.input_layernorm(hidden), ctx_k, ctx_v, cos, sin))
182
+ hidden = hidden + h
183
+ h = self.post_feedforward_layernorm(self.mlp(self.pre_feedforward_layernorm(hidden)))
184
+ hidden = hidden + h
185
+ return hidden * self.layer_scalar
186
+
187
+
188
+ class Gemma4Backbone(nn.Module):
189
+ def __init__(self, args: Gemma4DSparkArgs):
190
+ super().__init__()
191
+ self.fc = nn.Linear(args.fc_in, args.hidden_size, bias=False)
192
+ self.hidden_norm = RMSNorm(args.hidden_size, args.rms_norm_eps)
193
+ self.layers = [Gemma4DSparkLayer(args) for _ in range(args.num_hidden_layers)]
194
+ self.norm = RMSNorm(args.hidden_size, args.rms_norm_eps)
195
+
196
+ def project_context(self, target_hidden: mx.array) -> mx.array:
197
+ return self.hidden_norm(self.fc(target_hidden))
198
+
199
+ def __call__(self, noise_embed: mx.array, target_ctx: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
200
+ h = noise_embed
201
+ for layer in self.layers:
202
+ h = layer(h, target_ctx, cos, sin)
203
+ return self.norm(h)
204
+
205
+
206
+ class Gemma4DSparkDrafter(nn.Module):
207
+ """Gemma4 DSpark drafter (DraftBackbone). Same loop as Qwen3 + partial RoPE + softcap."""
208
+
209
+ def __init__(self, args: Gemma4DSparkArgs, max_seq_len: int = 8192):
210
+ super().__init__()
211
+ self.args = args
212
+ self.block_size = args.block_size
213
+ self.temperature = args.temperature
214
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
215
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
216
+ self.backbone = Gemma4Backbone(args)
217
+ self.markov_head = DSparkMarkovHead(args.vocab_size, args.markov_rank)
218
+ self.confidence_head = DSparkConfidenceHead(args.hidden_size + args.markov_rank, bias=True)
219
+ self.reset()
220
+
221
+ def reset(self) -> None:
222
+ self._ctx = None
223
+ self._next_pos = 0
224
+ self._ctx_k = None # per-layer cached context K (eager path)
225
+ self._ctx_v = None
226
+ self._committed = 0
227
+
228
+ def _project_one(self, main_hidden: mx.array) -> mx.array:
229
+ return self.backbone.project_context(main_hidden.reshape(main_hidden.shape[0], 1, -1))
230
+
231
+ def _rope(self, length: int):
232
+ return rope_tables(mx.arange(length), self.args.head_dim, self.args.rope_theta, self.args.partial_rotary_factor)
233
+
234
+ def _softcap(self, logits: mx.array) -> mx.array:
235
+ sc = self.args.final_logit_softcapping
236
+ return mx.tanh(logits / sc) * sc if sc else logits
237
+
238
+ def _embed(self, ids: mx.array) -> mx.array:
239
+ """Gemma scales token embeddings by sqrt(hidden) (Gemma4TextScaledWordEmbedding)."""
240
+ e = self.embed_tokens(ids)
241
+ return e * mx.array(self.args.hidden_size ** 0.5, dtype=e.dtype)
242
+
243
+ # --- reference-matched eager interface (see qwen3.py for the rationale) ---
244
+
245
+ def extend_context(self, new_hiddens: mx.array) -> None:
246
+ n = new_hiddens.shape[1]
247
+ if n == 0:
248
+ return
249
+ proj = self.backbone.project_context(new_hiddens)
250
+ pos = mx.arange(self._committed, self._committed + n)
251
+ cos, sin = rope_tables(pos, self.args.head_dim, self.args.rope_theta, self.args.partial_rotary_factor)
252
+ layers = self.backbone.layers
253
+ if self._ctx_k is None:
254
+ self._ctx_k = [None] * len(layers)
255
+ self._ctx_v = [None] * len(layers)
256
+ for i, layer in enumerate(layers):
257
+ k, v = layer.context_kv(proj, cos, sin)
258
+ self._ctx_k[i] = k if self._ctx_k[i] is None else mx.concatenate([self._ctx_k[i], k], axis=2)
259
+ self._ctx_v[i] = v if self._ctx_v[i] is None else mx.concatenate([self._ctx_v[i], v], axis=2)
260
+ self._committed += n
261
+
262
+ def draft(self, anchor_token: mx.array):
263
+ b = anchor_token.shape[0]
264
+ start = self._committed
265
+ anchor = anchor_token.astype(mx.int32).reshape(b, 1)
266
+ noise = mx.full((b, self.block_size - 1), self.args.mask_token_id, dtype=mx.int32)
267
+ noise_embed = self._embed(mx.concatenate([anchor, noise], axis=1))
268
+ block_pos = mx.arange(start, start + self.block_size)
269
+ cos, sin = rope_tables(block_pos, self.args.head_dim, self.args.rope_theta, self.args.partial_rotary_factor)
270
+ h = noise_embed
271
+ ck = self._ctx_k or [None] * len(self.backbone.layers)
272
+ cv = self._ctx_v or [None] * len(self.backbone.layers)
273
+ for i, layer in enumerate(self.backbone.layers):
274
+ h = layer.forward_cached(h, ck[i], cv[i], cos, sin)
275
+ block_hidden = self.backbone.norm(h)
276
+ logits = self._softcap(self.lm_head(block_hidden.astype(mx.float32)))
277
+ return draft_block_decode(
278
+ logits, block_hidden, anchor_token, self.markov_head, self.confidence_head,
279
+ self.block_size, self.temperature,
280
+ )
281
+
282
+ def forward_spec(self, input_ids: mx.array, main_hidden: mx.array, start_pos: int = 0):
283
+ if start_pos == 0:
284
+ full = self.backbone.project_context(main_hidden)
285
+ self._ctx = full[:, :-1]
286
+ self._next_pos = full.shape[1] - 1
287
+ return None
288
+ self._ctx = mx.concatenate([self._ctx, self._project_one(main_hidden)], axis=1)
289
+ self._next_pos = start_pos + 1
290
+ b = input_ids.shape[0]
291
+ anchor = input_ids.astype(mx.int32).reshape(b, 1)
292
+ noise = mx.full((b, self.block_size - 1), self.args.mask_token_id, dtype=mx.int32)
293
+ noise_embed = self._embed(mx.concatenate([anchor, noise], axis=1))
294
+ cos, sin = self._rope(self._ctx.shape[1] + self.block_size)
295
+ block_hidden = self.backbone(noise_embed, self._ctx, cos, sin)
296
+ logits = self.lm_head(block_hidden.astype(mx.float32))
297
+ sc = self.args.final_logit_softcapping
298
+ if sc:
299
+ logits = mx.tanh(logits / sc) * sc
300
+ return draft_block_decode(
301
+ logits, block_hidden, input_ids, self.markov_head, self.confidence_head,
302
+ self.block_size, self.temperature,
303
+ )
304
+
305
+ def advance(self, main_hidden: mx.array, position: int) -> None:
306
+ self._ctx = mx.concatenate([self._ctx, self._project_one(main_hidden)], axis=1)
307
+ self._next_pos = position + 1
308
+
309
+
310
+ def gemma4_key_map(key):
311
+ if key in ("embed_tokens.weight", "lm_head.weight"):
312
+ return key
313
+ if key in ("fc.weight", "hidden_norm.weight", "norm.weight"):
314
+ return f"backbone.{key}"
315
+ if key.startswith("markov_head.") or key.startswith("confidence_head."):
316
+ return key
317
+ m = _GEMMA4_LAYER_RE.match(key)
318
+ if m:
319
+ return f"backbone.layers.{m.group(1)}.{m.group(2)}"
320
+ return None
321
+
322
+
323
+ def build(config, *, max_seq_len: int = 8192) -> Gemma4DSparkDrafter:
324
+ return Gemma4DSparkDrafter(Gemma4DSparkArgs.from_dict(config), max_seq_len=max_seq_len)
325
+
326
+
327
+ GEMMA4 = DraftArch(name="gemma4", model_types=("gemma4", "gemma4_text"), build=build, key_map=gemma4_key_map)