dspark-mlx 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dspark_mlx/__init__.py +45 -0
- dspark_mlx/adapter.py +78 -0
- dspark_mlx/arch/__init__.py +6 -0
- dspark_mlx/arch/backbone.py +57 -0
- dspark_mlx/arch/deepseek_v4.py +35 -0
- dspark_mlx/arch/gemma4.py +327 -0
- dspark_mlx/arch/qwen3.py +337 -0
- dspark_mlx/cli.py +164 -0
- dspark_mlx/events.py +26 -0
- dspark_mlx/generate.py +88 -0
- dspark_mlx/hosts/__init__.py +14 -0
- dspark_mlx/hosts/gemma4_unified.py +64 -0
- dspark_mlx/hosts/mlx_lm.py +122 -0
- dspark_mlx/kernels.py +85 -0
- dspark_mlx/loader.py +92 -0
- dspark_mlx/loading.py +96 -0
- dspark_mlx/loop.py +95 -0
- dspark_mlx/model/__init__.py +6 -0
- dspark_mlx/model/attention.py +142 -0
- dspark_mlx/model/block.py +116 -0
- dspark_mlx/model/config.py +78 -0
- dspark_mlx/model/drafter.py +53 -0
- dspark_mlx/model/heads.py +49 -0
- dspark_mlx/model/hyper.py +67 -0
- dspark_mlx/model/moe.py +113 -0
- dspark_mlx/model/norm_rope.py +77 -0
- dspark_mlx/quant.py +39 -0
- dspark_mlx/recipe.py +63 -0
- dspark_mlx/registry.py +37 -0
- dspark_mlx/verify.py +118 -0
- dspark_mlx-0.1.0.dist-info/METADATA +108 -0
- dspark_mlx-0.1.0.dist-info/RECORD +36 -0
- dspark_mlx-0.1.0.dist-info/WHEEL +5 -0
- dspark_mlx-0.1.0.dist-info/entry_points.txt +2 -0
- dspark_mlx-0.1.0.dist-info/licenses/LICENSE +182 -0
- dspark_mlx-0.1.0.dist-info/top_level.txt +1 -0
dspark_mlx/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2026 popfido
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 - see LICENSE file
|
|
3
|
+
# Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
|
|
7
|
+
from .adapter import BaseModelAdapter, BlockOut, StepOut
|
|
8
|
+
from .arch.backbone import DraftArch, DraftBackbone
|
|
9
|
+
from .events import SummaryEvent, TokenEvent
|
|
10
|
+
from .generate import generate
|
|
11
|
+
from .loader import KNOWN_MODELS, load_draft, load_host, resolve_model
|
|
12
|
+
from .loading import is_dspark_checkpoint, load_drafter, map_checkpoint_key
|
|
13
|
+
from .loop import generate_eager
|
|
14
|
+
from .model.config import DSparkArgs
|
|
15
|
+
from .model.drafter import DSparkDrafter
|
|
16
|
+
from .quant import quantize_drafter
|
|
17
|
+
from .registry import ARCH_REGISTRY, resolve_arch
|
|
18
|
+
from .verify import AcceptResult, greedy_accept, speculative_sample_accept
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"BaseModelAdapter",
|
|
22
|
+
"BlockOut",
|
|
23
|
+
"StepOut",
|
|
24
|
+
"DSparkArgs",
|
|
25
|
+
"DSparkDrafter",
|
|
26
|
+
"DraftArch",
|
|
27
|
+
"DraftBackbone",
|
|
28
|
+
"resolve_arch",
|
|
29
|
+
"ARCH_REGISTRY",
|
|
30
|
+
"generate",
|
|
31
|
+
"generate_eager",
|
|
32
|
+
"load_draft",
|
|
33
|
+
"load_host",
|
|
34
|
+
"resolve_model",
|
|
35
|
+
"KNOWN_MODELS",
|
|
36
|
+
"greedy_accept",
|
|
37
|
+
"speculative_sample_accept",
|
|
38
|
+
"AcceptResult",
|
|
39
|
+
"TokenEvent",
|
|
40
|
+
"SummaryEvent",
|
|
41
|
+
"load_drafter",
|
|
42
|
+
"map_checkpoint_key",
|
|
43
|
+
"is_dspark_checkpoint",
|
|
44
|
+
"quantize_drafter",
|
|
45
|
+
]
|
dspark_mlx/adapter.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright 2026 popfido
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 - see LICENSE file
|
|
3
|
+
# Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
|
|
4
|
+
|
|
5
|
+
"""The seam between dspark-mlx (drafter + verify/accept loop) and a host base model.
|
|
6
|
+
|
|
7
|
+
dspark-mlx is target-agnostic: it owns the DSpark draft stack and the lossless
|
|
8
|
+
accept policy, but never the base model. The host (e.g. omlx over its
|
|
9
|
+
``patches/deepseek_v4`` model) implements :class:`BaseModelAdapter` so the drafter
|
|
10
|
+
can (a) read the ``main_hidden`` it conditions on, (b) get the base distribution for
|
|
11
|
+
each candidate token during verify, and (c) snapshot/roll back base KV when a block is
|
|
12
|
+
only partially accepted.
|
|
13
|
+
|
|
14
|
+
Logit conventions (one decode cycle):
|
|
15
|
+
- ``prefill`` / ``decode_step`` return ``StepOut.logits`` = ``p_1``, the base
|
|
16
|
+
distribution for the *first* drafted token. It is free — already computed by the
|
|
17
|
+
step that produced the anchor — so the verify forward never recomputes it.
|
|
18
|
+
- ``verify_forward`` runs ONE base forward over the K drafted tokens and returns the
|
|
19
|
+
K base distributions ``p_2 .. p_{K+1}`` (``p_{K+1}`` is the bonus position).
|
|
20
|
+
- The generate loop concatenates ``[p_1] + [p_2..p_{K+1}]`` into the ``[K+1, V]`` block
|
|
21
|
+
the accept policy consumes (see :mod:`dspark_mlx.verify`).
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from dataclasses import dataclass
|
|
27
|
+
from typing import Any, Protocol, Tuple, runtime_checkable
|
|
28
|
+
|
|
29
|
+
import mlx.core as mx
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class StepOut:
|
|
34
|
+
"""Output of a single base forward at the anchor position."""
|
|
35
|
+
|
|
36
|
+
logits: mx.array # [b, V] base distribution for the next (first drafted) token
|
|
37
|
+
main_hidden: mx.array # [b, dim * len(target_layer_ids)] concat of target-layer hiddens
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class BlockOut:
|
|
42
|
+
"""Output of the base verify forward over a K-token draft block."""
|
|
43
|
+
|
|
44
|
+
per_pos_logits: mx.array # [b, K, V] base distributions p_2 .. p_{K+1}
|
|
45
|
+
per_pos_main_hidden: mx.array # [b, K, D] main hidden at each verified position
|
|
46
|
+
main_hidden_last: mx.array # [b, D] convenience alias for the last verified position
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@runtime_checkable
|
|
50
|
+
class BaseModelAdapter(Protocol):
|
|
51
|
+
"""Host contract. Implementations own the base model and its KV cache."""
|
|
52
|
+
|
|
53
|
+
#: Main-model layer indices whose hidden states are concatenated into ``main_hidden``.
|
|
54
|
+
target_layer_ids: Tuple[int, ...]
|
|
55
|
+
|
|
56
|
+
def prefill(self, tokens: mx.array) -> StepOut:
|
|
57
|
+
"""Process the prompt; return logits for the first generated token + main_hidden."""
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
def decode_step(self, token: mx.array) -> StepOut:
|
|
61
|
+
"""Advance one token; return its next-token logits + main_hidden."""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
def verify_forward(self, block_tokens: mx.array) -> BlockOut:
|
|
65
|
+
"""Run one base forward over K draft tokens; return p_2..p_{K+1} + main_hidden_last.
|
|
66
|
+
|
|
67
|
+
Appends K entries to the base KV cache speculatively; the caller rolls back the
|
|
68
|
+
rejected tail via :meth:`kv_rollback`.
|
|
69
|
+
"""
|
|
70
|
+
...
|
|
71
|
+
|
|
72
|
+
def kv_snapshot(self) -> Any:
|
|
73
|
+
"""Opaque handle capturing base KV state before a speculative block."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
def kv_rollback(self, n_keep: int) -> None:
|
|
77
|
+
"""Drop speculatively-appended KV beyond ``n_keep`` accepted tokens."""
|
|
78
|
+
...
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright 2026 popfido
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 - see LICENSE file
|
|
3
|
+
# Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
|
|
4
|
+
|
|
5
|
+
"""The per-architecture seam for DSpark drafters.
|
|
6
|
+
|
|
7
|
+
DSpark ships one recipe (EAGLE-style context projection + Markov bias + confidence head +
|
|
8
|
+
block drafting) realized over different base-model decoder layers — DeepSeek-V4 (windowed
|
|
9
|
+
MLA + MoE + Hyper-Connections, bundled fp8/fp4 ``mtp.*`` checkpoint), Qwen3 and Gemma4
|
|
10
|
+
(standalone bf16 ``layers.*`` checkpoints, full-context GQA). ``generate()`` drives any of
|
|
11
|
+
them through the ``DraftBackbone`` interface; a ``DraftArch`` descriptor registers how to
|
|
12
|
+
build and load each one (see :mod:`dspark_mlx.registry`).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any, Callable, Optional, Protocol, Tuple, runtime_checkable
|
|
19
|
+
|
|
20
|
+
import mlx.core as mx
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@runtime_checkable
|
|
24
|
+
class DraftBackbone(Protocol):
|
|
25
|
+
"""A loaded DSpark drafter for one base architecture (what ``generate`` consumes)."""
|
|
26
|
+
|
|
27
|
+
block_size: int
|
|
28
|
+
|
|
29
|
+
def forward_spec(
|
|
30
|
+
self, input_ids: mx.array, main_hidden: mx.array, start_pos: int = 0
|
|
31
|
+
) -> Optional[Tuple[mx.array, mx.array, mx.array]]:
|
|
32
|
+
"""Prefill (start_pos==0) seeds context; decode drafts (ids, logits, confidence)."""
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
def advance(self, main_hidden: mx.array, position: int) -> None:
|
|
36
|
+
"""Slide the drafter's context over one committed token."""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class DraftArch:
|
|
42
|
+
"""Registry entry: how to build + load a DSpark drafter for a base architecture."""
|
|
43
|
+
|
|
44
|
+
name: str
|
|
45
|
+
model_types: Tuple[str, ...]
|
|
46
|
+
build: Callable[..., DraftBackbone] # (config: dict, *, max_seq_len) -> DraftBackbone
|
|
47
|
+
key_map: Callable[[str], Optional[str]] # checkpoint key -> drafter param path (or None)
|
|
48
|
+
|
|
49
|
+
def supports(self, model_type: Optional[str]) -> bool:
|
|
50
|
+
return model_type in self.model_types
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def config_model_type(config: Any) -> Optional[str]:
|
|
54
|
+
"""Read ``model_type`` from a dict-like or attribute-like config."""
|
|
55
|
+
if isinstance(config, dict):
|
|
56
|
+
return config.get("model_type")
|
|
57
|
+
return getattr(config, "model_type", None)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright 2026 popfido
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 - see LICENSE file
|
|
3
|
+
# Based on DeepSeek DSpark (DeepSeek-V4-Flash-DSpark, deepseek-ai/DeepSpec)
|
|
4
|
+
|
|
5
|
+
"""DeepSeek-V4-Flash-DSpark backbone descriptor.
|
|
6
|
+
|
|
7
|
+
The windowed MLA + hash-MoE + Hyper-Connections realization, drafting from the ``mtp.*``
|
|
8
|
+
namespace of the bundled fp8/fp4 checkpoint. The model code lives under ``dspark_mlx.model``
|
|
9
|
+
(its parity tests pin it); this module just registers it as a DraftArch.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
from ..loading import map_checkpoint_key
|
|
17
|
+
from ..model.config import DSparkArgs
|
|
18
|
+
from ..model.drafter import DSparkDrafter
|
|
19
|
+
from .backbone import DraftArch, DraftBackbone
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def build(config: dict, *, max_seq_len: int = 8192) -> DraftBackbone:
|
|
23
|
+
return DSparkDrafter(DSparkArgs.from_dict(config), max_seq_len=max_seq_len)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def key_map(key: str) -> Optional[str]:
|
|
27
|
+
return map_checkpoint_key(key) # mtp.N.* -> blocks.N.*, embed/head pass through
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
DEEPSEEK_V4 = DraftArch(
|
|
31
|
+
name="deepseek_v4",
|
|
32
|
+
model_types=("deepseek_v4",),
|
|
33
|
+
build=build,
|
|
34
|
+
key_map=key_map,
|
|
35
|
+
)
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
# Copyright 2026 popfido
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 - see LICENSE file
|
|
3
|
+
# Based on DeepSeek DSpark (deepseek-ai/DeepSpec: dspark/gemma4/modeling.py)
|
|
4
|
+
|
|
5
|
+
"""Gemma4 DSpark draft backbone (standalone bf16 ``layers.*`` checkpoint).
|
|
6
|
+
|
|
7
|
+
Stock Gemma4 decoder layers with the DSpark context/noise K/V split. Gemma deltas vs Qwen3:
|
|
8
|
+
K=V sharing (no v_proj; separate scaled k_norm + weightless v_norm), attention scale 1.0,
|
|
9
|
+
partial (proportional) RoPE — only ``partial_rotary_factor`` of head_dim rotates, the rest
|
|
10
|
+
pass through — four sandwich norms + a per-layer ``layer_scalar``, GeGLU (gelu-tanh) MLP, and
|
|
11
|
+
final-logit softcapping.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import re as _re
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Mapping, Tuple
|
|
20
|
+
|
|
21
|
+
import mlx.core as mx
|
|
22
|
+
import mlx.nn as nn
|
|
23
|
+
|
|
24
|
+
from ..model.heads import DSparkConfidenceHead, DSparkMarkovHead
|
|
25
|
+
from ..model.norm_rope import RMSNorm
|
|
26
|
+
from ..recipe import draft_block_decode
|
|
27
|
+
from .backbone import DraftArch
|
|
28
|
+
from .qwen3 import _apply_rope # shared NeoX rotate_half application
|
|
29
|
+
|
|
30
|
+
_GEMMA4_LAYER_RE = _re.compile(r"layers\.(\d+)\.(.+)$")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class Gemma4DSparkArgs:
|
|
35
|
+
vocab_size: int = 262144
|
|
36
|
+
hidden_size: int = 3840
|
|
37
|
+
num_hidden_layers: int = 5
|
|
38
|
+
num_attention_heads: int = 16
|
|
39
|
+
num_key_value_heads: int = 1 # global KV head count (k=v)
|
|
40
|
+
head_dim: int = 512 # global_head_dim
|
|
41
|
+
intermediate_size: int = 15360
|
|
42
|
+
rms_norm_eps: float = 1e-6
|
|
43
|
+
rope_theta: float = 1000000.0
|
|
44
|
+
partial_rotary_factor: float = 0.25
|
|
45
|
+
attention_k_eq_v: bool = True
|
|
46
|
+
final_logit_softcapping: float = 30.0
|
|
47
|
+
target_layer_ids: Tuple[int, ...] = (5, 17, 29, 41, 46)
|
|
48
|
+
num_target_layers: int = 48
|
|
49
|
+
block_size: int = 7
|
|
50
|
+
mask_token_id: int = 4
|
|
51
|
+
markov_rank: int = 256
|
|
52
|
+
temperature: float = 0.0
|
|
53
|
+
max_position_embeddings: int = 262144
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def fc_in(self) -> int:
|
|
57
|
+
return self.hidden_size * len(self.target_layer_ids)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_dict(cls, params: Mapping) -> "Gemma4DSparkArgs":
|
|
61
|
+
d = dict(params)
|
|
62
|
+
rope = (d.get("rope_parameters") or {}).get("full_attention") or {}
|
|
63
|
+
if rope.get("rope_theta"):
|
|
64
|
+
d["rope_theta"] = rope["rope_theta"]
|
|
65
|
+
if "partial_rotary_factor" in rope:
|
|
66
|
+
d["partial_rotary_factor"] = rope["partial_rotary_factor"]
|
|
67
|
+
if d.get("global_head_dim"):
|
|
68
|
+
d["head_dim"] = d["global_head_dim"]
|
|
69
|
+
if d.get("num_global_key_value_heads") is not None:
|
|
70
|
+
d["num_key_value_heads"] = d["num_global_key_value_heads"]
|
|
71
|
+
names = {f.name for f in dataclasses.fields(cls)}
|
|
72
|
+
kwargs = {k: v for k, v in d.items() if k in names}
|
|
73
|
+
if "target_layer_ids" in kwargs:
|
|
74
|
+
kwargs["target_layer_ids"] = tuple(kwargs["target_layer_ids"])
|
|
75
|
+
return cls(**kwargs)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def rope_tables(position_ids: mx.array, head_dim: int, theta: float, partial: float) -> Tuple[mx.array, mx.array]:
|
|
79
|
+
"""Proportional (partial) RoPE: first ``partial*head_dim`` dims rotate, rest are identity."""
|
|
80
|
+
rope_angles = int(partial * head_dim // 2)
|
|
81
|
+
inv_rot = 1.0 / (theta ** (mx.arange(0, 2 * rope_angles, 2).astype(mx.float32) / head_dim))
|
|
82
|
+
nope = head_dim // 2 - rope_angles
|
|
83
|
+
inv_freq = mx.concatenate([inv_rot, mx.zeros((nope,), dtype=mx.float32)]) if nope > 0 else inv_rot
|
|
84
|
+
freqs = position_ids.astype(mx.float32)[:, None] * inv_freq[None, :]
|
|
85
|
+
emb = mx.concatenate([freqs, freqs], axis=-1)
|
|
86
|
+
return mx.cos(emb), mx.sin(emb)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Gemma4DSparkAttention(nn.Module):
|
|
90
|
+
def __init__(self, args: Gemma4DSparkArgs):
|
|
91
|
+
super().__init__()
|
|
92
|
+
h, nh, nkv, hd = args.hidden_size, args.num_attention_heads, args.num_key_value_heads, args.head_dim
|
|
93
|
+
self.nh, self.nkv, self.hd = nh, nkv, hd
|
|
94
|
+
self.k_eq_v = args.attention_k_eq_v
|
|
95
|
+
self.q_proj = nn.Linear(h, nh * hd, bias=False)
|
|
96
|
+
self.k_proj = nn.Linear(h, nkv * hd, bias=False)
|
|
97
|
+
self.v_proj = None if self.k_eq_v else nn.Linear(h, nkv * hd, bias=False)
|
|
98
|
+
self.o_proj = nn.Linear(nh * hd, h, bias=False)
|
|
99
|
+
self.q_norm = RMSNorm(hd, args.rms_norm_eps)
|
|
100
|
+
self.k_norm = RMSNorm(hd, args.rms_norm_eps)
|
|
101
|
+
self.v_norm = RMSNorm(hd, args.rms_norm_eps, with_scale=False)
|
|
102
|
+
|
|
103
|
+
def __call__(self, hidden: mx.array, target_ctx: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
|
|
104
|
+
b, q, _ = hidden.shape
|
|
105
|
+
ctx = target_ctx.shape[1]
|
|
106
|
+
qh = self.q_norm(self.q_proj(hidden).reshape(b, q, self.nh, self.hd))
|
|
107
|
+
k_ctx, k_noise = self.k_proj(target_ctx), self.k_proj(hidden)
|
|
108
|
+
v_ctx, v_noise = (k_ctx, k_noise) if self.k_eq_v else (self.v_proj(target_ctx), self.v_proj(hidden))
|
|
109
|
+
k = self.k_norm(mx.concatenate([k_ctx, k_noise], axis=1).reshape(b, ctx + q, self.nkv, self.hd))
|
|
110
|
+
v = self.v_norm(mx.concatenate([v_ctx, v_noise], axis=1).reshape(b, ctx + q, self.nkv, self.hd))
|
|
111
|
+
qh = _apply_rope(qh, cos[-q:], sin[-q:])
|
|
112
|
+
k = _apply_rope(k, cos, sin) # v is not rotated
|
|
113
|
+
qh = qh.transpose(0, 2, 1, 3)
|
|
114
|
+
k = k.transpose(0, 2, 1, 3)
|
|
115
|
+
v = v.transpose(0, 2, 1, 3)
|
|
116
|
+
out = mx.fast.scaled_dot_product_attention(qh, k, v, scale=1.0, mask=None) # Gemma4 scale==1
|
|
117
|
+
out = out.transpose(0, 2, 1, 3).reshape(b, q, self.nh * self.hd)
|
|
118
|
+
return self.o_proj(out)
|
|
119
|
+
|
|
120
|
+
# --- cached path (Phase 3b): context K/V precomputed once, reused every block ---
|
|
121
|
+
def context_kv(self, proj_ctx: mx.array, cos: mx.array, sin: mx.array):
|
|
122
|
+
b, n, _ = proj_ctx.shape
|
|
123
|
+
kc = self.k_proj(proj_ctx)
|
|
124
|
+
vc = kc if self.k_eq_v else self.v_proj(proj_ctx)
|
|
125
|
+
k = self.k_norm(kc.reshape(b, n, self.nkv, self.hd))
|
|
126
|
+
v = self.v_norm(vc.reshape(b, n, self.nkv, self.hd))
|
|
127
|
+
k = _apply_rope(k, cos, sin) # v is not rotated
|
|
128
|
+
return k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3)
|
|
129
|
+
|
|
130
|
+
def attend_cached(self, noise, ctx_k, ctx_v, cos, sin):
|
|
131
|
+
b, q, _ = noise.shape
|
|
132
|
+
qh = self.q_norm(self.q_proj(noise).reshape(b, q, self.nh, self.hd))
|
|
133
|
+
kn = self.k_proj(noise)
|
|
134
|
+
vn = kn if self.k_eq_v else self.v_proj(noise)
|
|
135
|
+
nk = self.k_norm(kn.reshape(b, q, self.nkv, self.hd))
|
|
136
|
+
nv = self.v_norm(vn.reshape(b, q, self.nkv, self.hd))
|
|
137
|
+
qh = _apply_rope(qh, cos, sin).transpose(0, 2, 1, 3)
|
|
138
|
+
nk = _apply_rope(nk, cos, sin).transpose(0, 2, 1, 3)
|
|
139
|
+
nv = nv.transpose(0, 2, 1, 3)
|
|
140
|
+
k = nk if ctx_k is None else mx.concatenate([ctx_k, nk], axis=2)
|
|
141
|
+
v = nv if ctx_v is None else mx.concatenate([ctx_v, nv], axis=2)
|
|
142
|
+
out = mx.fast.scaled_dot_product_attention(qh, k, v, scale=1.0, mask=None)
|
|
143
|
+
out = out.transpose(0, 2, 1, 3).reshape(b, q, self.nh * self.hd)
|
|
144
|
+
return self.o_proj(out)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class Gemma4MLP(nn.Module):
|
|
148
|
+
def __init__(self, args: Gemma4DSparkArgs):
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
|
151
|
+
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
|
152
|
+
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
|
153
|
+
|
|
154
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
155
|
+
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class Gemma4DSparkLayer(nn.Module):
|
|
159
|
+
def __init__(self, args: Gemma4DSparkArgs):
|
|
160
|
+
super().__init__()
|
|
161
|
+
h, eps = args.hidden_size, args.rms_norm_eps
|
|
162
|
+
self.self_attn = Gemma4DSparkAttention(args)
|
|
163
|
+
self.mlp = Gemma4MLP(args)
|
|
164
|
+
self.input_layernorm = RMSNorm(h, eps)
|
|
165
|
+
self.post_attention_layernorm = RMSNorm(h, eps)
|
|
166
|
+
self.pre_feedforward_layernorm = RMSNorm(h, eps)
|
|
167
|
+
self.post_feedforward_layernorm = RMSNorm(h, eps)
|
|
168
|
+
self.layer_scalar = mx.ones((1,), dtype=mx.float32)
|
|
169
|
+
|
|
170
|
+
def __call__(self, hidden: mx.array, target_ctx: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
|
|
171
|
+
h = self.post_attention_layernorm(self.self_attn(self.input_layernorm(hidden), target_ctx, cos, sin))
|
|
172
|
+
hidden = hidden + h
|
|
173
|
+
h = self.post_feedforward_layernorm(self.mlp(self.pre_feedforward_layernorm(hidden)))
|
|
174
|
+
hidden = hidden + h
|
|
175
|
+
return hidden * self.layer_scalar
|
|
176
|
+
|
|
177
|
+
def context_kv(self, proj_ctx: mx.array, cos: mx.array, sin: mx.array):
|
|
178
|
+
return self.self_attn.context_kv(proj_ctx, cos, sin)
|
|
179
|
+
|
|
180
|
+
def forward_cached(self, hidden, ctx_k, ctx_v, cos, sin) -> mx.array:
|
|
181
|
+
h = self.post_attention_layernorm(self.self_attn.attend_cached(self.input_layernorm(hidden), ctx_k, ctx_v, cos, sin))
|
|
182
|
+
hidden = hidden + h
|
|
183
|
+
h = self.post_feedforward_layernorm(self.mlp(self.pre_feedforward_layernorm(hidden)))
|
|
184
|
+
hidden = hidden + h
|
|
185
|
+
return hidden * self.layer_scalar
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class Gemma4Backbone(nn.Module):
|
|
189
|
+
def __init__(self, args: Gemma4DSparkArgs):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.fc = nn.Linear(args.fc_in, args.hidden_size, bias=False)
|
|
192
|
+
self.hidden_norm = RMSNorm(args.hidden_size, args.rms_norm_eps)
|
|
193
|
+
self.layers = [Gemma4DSparkLayer(args) for _ in range(args.num_hidden_layers)]
|
|
194
|
+
self.norm = RMSNorm(args.hidden_size, args.rms_norm_eps)
|
|
195
|
+
|
|
196
|
+
def project_context(self, target_hidden: mx.array) -> mx.array:
|
|
197
|
+
return self.hidden_norm(self.fc(target_hidden))
|
|
198
|
+
|
|
199
|
+
def __call__(self, noise_embed: mx.array, target_ctx: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
|
|
200
|
+
h = noise_embed
|
|
201
|
+
for layer in self.layers:
|
|
202
|
+
h = layer(h, target_ctx, cos, sin)
|
|
203
|
+
return self.norm(h)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class Gemma4DSparkDrafter(nn.Module):
|
|
207
|
+
"""Gemma4 DSpark drafter (DraftBackbone). Same loop as Qwen3 + partial RoPE + softcap."""
|
|
208
|
+
|
|
209
|
+
def __init__(self, args: Gemma4DSparkArgs, max_seq_len: int = 8192):
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.args = args
|
|
212
|
+
self.block_size = args.block_size
|
|
213
|
+
self.temperature = args.temperature
|
|
214
|
+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
215
|
+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
216
|
+
self.backbone = Gemma4Backbone(args)
|
|
217
|
+
self.markov_head = DSparkMarkovHead(args.vocab_size, args.markov_rank)
|
|
218
|
+
self.confidence_head = DSparkConfidenceHead(args.hidden_size + args.markov_rank, bias=True)
|
|
219
|
+
self.reset()
|
|
220
|
+
|
|
221
|
+
def reset(self) -> None:
|
|
222
|
+
self._ctx = None
|
|
223
|
+
self._next_pos = 0
|
|
224
|
+
self._ctx_k = None # per-layer cached context K (eager path)
|
|
225
|
+
self._ctx_v = None
|
|
226
|
+
self._committed = 0
|
|
227
|
+
|
|
228
|
+
def _project_one(self, main_hidden: mx.array) -> mx.array:
|
|
229
|
+
return self.backbone.project_context(main_hidden.reshape(main_hidden.shape[0], 1, -1))
|
|
230
|
+
|
|
231
|
+
def _rope(self, length: int):
|
|
232
|
+
return rope_tables(mx.arange(length), self.args.head_dim, self.args.rope_theta, self.args.partial_rotary_factor)
|
|
233
|
+
|
|
234
|
+
def _softcap(self, logits: mx.array) -> mx.array:
|
|
235
|
+
sc = self.args.final_logit_softcapping
|
|
236
|
+
return mx.tanh(logits / sc) * sc if sc else logits
|
|
237
|
+
|
|
238
|
+
def _embed(self, ids: mx.array) -> mx.array:
|
|
239
|
+
"""Gemma scales token embeddings by sqrt(hidden) (Gemma4TextScaledWordEmbedding)."""
|
|
240
|
+
e = self.embed_tokens(ids)
|
|
241
|
+
return e * mx.array(self.args.hidden_size ** 0.5, dtype=e.dtype)
|
|
242
|
+
|
|
243
|
+
# --- reference-matched eager interface (see qwen3.py for the rationale) ---
|
|
244
|
+
|
|
245
|
+
def extend_context(self, new_hiddens: mx.array) -> None:
|
|
246
|
+
n = new_hiddens.shape[1]
|
|
247
|
+
if n == 0:
|
|
248
|
+
return
|
|
249
|
+
proj = self.backbone.project_context(new_hiddens)
|
|
250
|
+
pos = mx.arange(self._committed, self._committed + n)
|
|
251
|
+
cos, sin = rope_tables(pos, self.args.head_dim, self.args.rope_theta, self.args.partial_rotary_factor)
|
|
252
|
+
layers = self.backbone.layers
|
|
253
|
+
if self._ctx_k is None:
|
|
254
|
+
self._ctx_k = [None] * len(layers)
|
|
255
|
+
self._ctx_v = [None] * len(layers)
|
|
256
|
+
for i, layer in enumerate(layers):
|
|
257
|
+
k, v = layer.context_kv(proj, cos, sin)
|
|
258
|
+
self._ctx_k[i] = k if self._ctx_k[i] is None else mx.concatenate([self._ctx_k[i], k], axis=2)
|
|
259
|
+
self._ctx_v[i] = v if self._ctx_v[i] is None else mx.concatenate([self._ctx_v[i], v], axis=2)
|
|
260
|
+
self._committed += n
|
|
261
|
+
|
|
262
|
+
def draft(self, anchor_token: mx.array):
|
|
263
|
+
b = anchor_token.shape[0]
|
|
264
|
+
start = self._committed
|
|
265
|
+
anchor = anchor_token.astype(mx.int32).reshape(b, 1)
|
|
266
|
+
noise = mx.full((b, self.block_size - 1), self.args.mask_token_id, dtype=mx.int32)
|
|
267
|
+
noise_embed = self._embed(mx.concatenate([anchor, noise], axis=1))
|
|
268
|
+
block_pos = mx.arange(start, start + self.block_size)
|
|
269
|
+
cos, sin = rope_tables(block_pos, self.args.head_dim, self.args.rope_theta, self.args.partial_rotary_factor)
|
|
270
|
+
h = noise_embed
|
|
271
|
+
ck = self._ctx_k or [None] * len(self.backbone.layers)
|
|
272
|
+
cv = self._ctx_v or [None] * len(self.backbone.layers)
|
|
273
|
+
for i, layer in enumerate(self.backbone.layers):
|
|
274
|
+
h = layer.forward_cached(h, ck[i], cv[i], cos, sin)
|
|
275
|
+
block_hidden = self.backbone.norm(h)
|
|
276
|
+
logits = self._softcap(self.lm_head(block_hidden.astype(mx.float32)))
|
|
277
|
+
return draft_block_decode(
|
|
278
|
+
logits, block_hidden, anchor_token, self.markov_head, self.confidence_head,
|
|
279
|
+
self.block_size, self.temperature,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def forward_spec(self, input_ids: mx.array, main_hidden: mx.array, start_pos: int = 0):
|
|
283
|
+
if start_pos == 0:
|
|
284
|
+
full = self.backbone.project_context(main_hidden)
|
|
285
|
+
self._ctx = full[:, :-1]
|
|
286
|
+
self._next_pos = full.shape[1] - 1
|
|
287
|
+
return None
|
|
288
|
+
self._ctx = mx.concatenate([self._ctx, self._project_one(main_hidden)], axis=1)
|
|
289
|
+
self._next_pos = start_pos + 1
|
|
290
|
+
b = input_ids.shape[0]
|
|
291
|
+
anchor = input_ids.astype(mx.int32).reshape(b, 1)
|
|
292
|
+
noise = mx.full((b, self.block_size - 1), self.args.mask_token_id, dtype=mx.int32)
|
|
293
|
+
noise_embed = self._embed(mx.concatenate([anchor, noise], axis=1))
|
|
294
|
+
cos, sin = self._rope(self._ctx.shape[1] + self.block_size)
|
|
295
|
+
block_hidden = self.backbone(noise_embed, self._ctx, cos, sin)
|
|
296
|
+
logits = self.lm_head(block_hidden.astype(mx.float32))
|
|
297
|
+
sc = self.args.final_logit_softcapping
|
|
298
|
+
if sc:
|
|
299
|
+
logits = mx.tanh(logits / sc) * sc
|
|
300
|
+
return draft_block_decode(
|
|
301
|
+
logits, block_hidden, input_ids, self.markov_head, self.confidence_head,
|
|
302
|
+
self.block_size, self.temperature,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def advance(self, main_hidden: mx.array, position: int) -> None:
|
|
306
|
+
self._ctx = mx.concatenate([self._ctx, self._project_one(main_hidden)], axis=1)
|
|
307
|
+
self._next_pos = position + 1
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def gemma4_key_map(key):
|
|
311
|
+
if key in ("embed_tokens.weight", "lm_head.weight"):
|
|
312
|
+
return key
|
|
313
|
+
if key in ("fc.weight", "hidden_norm.weight", "norm.weight"):
|
|
314
|
+
return f"backbone.{key}"
|
|
315
|
+
if key.startswith("markov_head.") or key.startswith("confidence_head."):
|
|
316
|
+
return key
|
|
317
|
+
m = _GEMMA4_LAYER_RE.match(key)
|
|
318
|
+
if m:
|
|
319
|
+
return f"backbone.layers.{m.group(1)}.{m.group(2)}"
|
|
320
|
+
return None
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def build(config, *, max_seq_len: int = 8192) -> Gemma4DSparkDrafter:
|
|
324
|
+
return Gemma4DSparkDrafter(Gemma4DSparkArgs.from_dict(config), max_seq_len=max_seq_len)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
GEMMA4 = DraftArch(name="gemma4", model_types=("gemma4", "gemma4_text"), build=build, key_map=gemma4_key_map)
|