meadow-mind 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,170 @@
1
+ Metadata-Version: 2.4
2
+ Name: meadow-mind
3
+ Version: 0.1.0
4
+ Summary: Language-rule decision mind: zero-training, ~0.4s real decisions for games and control. One install, one import.
5
+ Author: Hey-Meadow Lab
6
+ License: MIT
7
+ Project-URL: Homepage, https://meadow-mind.pages.dev
8
+ Project-URL: Demo, https://meadow-mind.pages.dev/en.html
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: mlx>=0.20
12
+ Requires-Dist: mlx-lm>=0.20
13
+ Requires-Dist: numpy
14
+ Requires-Dist: huggingface_hub
15
+ Provides-Extra: games
16
+ Requires-Dist: gymnasium[box2d,toy-text]; extra == "games"
17
+ Requires-Dist: matplotlib; extra == "games"
18
+ Requires-Dist: imageio[ffmpeg]; extra == "games"
19
+
20
+ # Meadow Mind
21
+
22
+ **Zero training. Second-level reactions (~400 ms).**
23
+ A language-rule decision mind: write the policy as one sentence, describe the state as one sentence, and a local 7B model makes a real decision every ~0.4 s. No RL, no reward engineering, no gradients, no samples.
24
+
25
+ 🌐 **Demo site**: [meadow-mind.pages.dev](https://meadow-mind.pages.dev) (中文) · [English](https://meadow-mind.pages.dev/en.html) · [繁體中文 README](README.zh-TW.md)
26
+
27
+ ```bash
28
+ pip install meadow-mind # weights auto-download on first use
29
+ ```
30
+
31
+ ```python
32
+ from meadow_mind import MeadowMind, tasks
33
+
34
+ mind = MeadowMind() # loads once, runs on-device
35
+ task = tasks.mountaincar()
36
+ mind.check(task) # sanity gate: decision-table exam
37
+ action, info = mind.decide(task, obs) # obs in, env action out (~0.4s)
38
+ ```
39
+
40
+ ## Results
41
+
42
+ All on official Gymnasium environments, untouched physics, **zero training**. Every frame below corresponds to one real model decision; no scripted policy, no edited speed-ups.
43
+
44
+ | Balance · CartPole-v1<br>**400/400 perfect** (solve bar 195) | Landing · LunarLander-v3<br>**+251 safe landing** (solve bar 200) |
45
+ |---|---|
46
+ | ![CartPole](assets/balance.gif) | ![LunarLander](assets/landing.gif) |
47
+
48
+ | Maze · FrozenLake 8×8<br>**goal in 14 steps = shortest path** | Momentum · MountainCar-v0<br>**flag in 103 steps** (limit 200) |
49
+ |---|---|
50
+ | ![Maze](assets/maze.gif) | ![MountainCar](assets/mountaincar.gif) |
51
+
52
+ The MountainCar policy is one counterintuitive sentence — `"push in the same direction the car is moving, to pump energy like a swing"` — which replaces an entire RL reward curve.
53
+
54
+ ### Real-time reflex (wall-clock, not turn-based)
55
+
56
+ The model runs in a thread while obstacles fall in real time. If it is still thinking when the obstacle lands, it really crashes.
57
+
58
+ | Parkour dodge: full-generation crashes at #1, Meadow Mind clears 5/6 | Shape+color match: 6/6, down to a 0.72 s window |
59
+ |---|---|
60
+ | ![Parkour](assets/parkour.gif) | ![Shape+color](assets/shape_color.gif) |
61
+
62
+ ### Working memory
63
+
64
+ A funnel maze forces both runs into the same dead-end pocket. Reactive (left) paces at its mouth forever; with `Task(memory=True)` (right) it struggles, backs out, and detours to the goal in 22 steps. The only difference is five words in the perception sentence.
65
+
66
+ ![Memory](assets/memory.gif)
67
+
68
+ ## Decision latency: traditional LLM vs Meadow Mind
69
+
70
+ A traditional LLM agent must **generate its full answer before acting** — and latency grows with answer length. Meadow Mind reads the rule and the situation and decides in **one fixed-latency pass**, right at human reaction speed (0.3–0.4 s):
71
+
72
+ ![Latency](assets/latency.png)
73
+
74
+ ```
75
+ Traditional LLM agent Meadow Mind
76
+ ───────────────────── ───────────
77
+ state → long prompt state → one sentence (Perceiver)
78
+ → generate the answer → one sentence rule (Policy)
79
+ token by token (1.2–3.9 s, → ONE decision pass, fixed ~0.4 s
80
+ grows with length) → action letter (Actuator)
81
+ → parse free text → act exam-gated before deployment
82
+ ```
83
+
84
+ ## Why a diffusion LLM underneath
85
+
86
+ Meadow Mind is built on a diffusion language model (MeadowCoder-7B), not an autoregressive one. The differences that matter:
87
+
88
+ | | AR-LLM | Diffusion LLM (Meadow dLLM) |
89
+ |---|---|---|
90
+ | Generation | left to right, one token at a time; words are final once written | drafts the **whole answer at once**, then refines it over multiple steps |
91
+ | Mid-course correction | cannot edit what is already written; fixing means regenerating everything | **refines while working** — any region can be re-opened and corrected in place |
92
+ | Task awareness | sees only the next word | **global**: senses the entire task and answer shape at once |
93
+ | Pre-answer self-sense | none | **Σ**: before answering, Meadow dLLM senses whether it understands the task; low Σ coherence becomes an escalation signal instead of a wrong answer |
94
+ | Decision latency | grows with answer length | **fixed**, independent of answer length |
95
+ | Long free-form prose | mature, strong ecosystem | weaker; smaller ecosystem (honest trade-off) |
96
+
97
+ Two of these are what make Meadow Mind possible: **multi-step self-correction** (it can fix its own draft while working) and **global task perception with Σ** (it knows what it is being asked — and whether it understands — before committing to an answer).
98
+
99
+ ## How it works
100
+
101
+ ```
102
+ ┌────────────────────────────────────────────────────┐
103
+ │ ① Perceiver your code: numbers -> one sentence │
104
+ │ "The pole tilts right, fast spin." │
105
+ ├────────────────────────────────────────────────────┤
106
+ │ ② Rule one sentence = the policy │
107
+ │ edit behavior by editing words │
108
+ ├────────────────────────────────────────────────────┤
109
+ │ ③ Mind 7B on-device model reads rule+state, │
110
+ │ answers an action letter in a single │
111
+ │ decision pass, fixed ~0.4 s │
112
+ ├────────────────────────────────────────────────────┤
113
+ │ ④ Actuator letter -> env action │
114
+ └────────────────────────────────────────────────────┘
115
+ ```
116
+
117
+ There is no reward in the loop. The env score is only a report card; improvement happens by **outcome feedback**: the episode trace shows which sentence was wrong, and you edit it. (LunarLander went from a +27.5 crash to a +251 landing by adding one touchdown-cushion line to the perceiver. Ten seconds.)
118
+
119
+ ## Wire up a new game (5 steps)
120
+
121
+ 1. **Understand the task, explore input-output.** Variables, actions, win/lose conditions; the reaction deadline must be looser than ~0.4 s. List every action and watch its effect.
122
+ 2. **Build perception words.** One sentence describing the current situation. Bucket continuous values (small/big, fast/slow); always include a velocity/trend term.
123
+ 3. **Imprint the rule.** Invert the effects into "on situation X do action B". Keyword → letter, one-layer mapping, multiple choice only.
124
+ 4. **Decide on memory.** Ask: *"is revisiting the same state a failure signal?"* Yes (maze, exploration, dead ends) → `Task(memory=True)`. No (balance, landing, tracking — repetition IS the job) → keep it off; annotations measurably hurt regulation tasks (CartPole sanity 7/8 → 6/8). Unsure → leave off; the runner prints a hint when it detects looping.
125
+ 5. **Take the exam.** Enumerate every situation with its expected letter; `mind.check(task)` passes with at most 1 miss. Failures mean the wording is incomplete — rephrase and re-check, no training.
126
+
127
+ Or skip all five: hand `meadow_mind.ai_prompt()` plus your game description to any code agent, and it wires the task for you. You only review the exam score.
128
+
129
+ ## API
130
+
131
+ ### `MeadowMind(model_path=None)`
132
+ Weight resolution: `MEADOW_MIND_MODEL` env → explicit path → local cache (`~/.meadow-mind/models/`) → auto-download.
133
+
134
+ | Method | |
135
+ |---|---|
136
+ | `mind.decide(task, obs) -> (action, info)` | one real decision; `info = {status, letter, lat}` |
137
+ | `mind.check(task) -> (ok, n)` | sanity gate; raises if the decision table fails |
138
+
139
+ ### `Task(...)`
140
+
141
+ | Field | |
142
+ |---|---|
143
+ | `perceive(obs) -> str` (or `perceive(obs, task)` with memory) | perception layer |
144
+ | `rule` / `option_text` / `options` / `act_text` | the one-sentence policy and its multiple-choice actions |
145
+ | `sanity` | the exam: `[(status sentence, expected letter)]` |
146
+ | `memory` / `mem_key` | working-memory switch (default off) + state key fn |
147
+ | `env_id` / `env_kwargs` / `max_steps` / `judge` | environment wiring and report card |
148
+
149
+ With `memory=True` the runner auto-tracks `task.visited`; use `task.seen(key)` inside `perceive` to annotate, e.g. `(safe, already visited)`.
150
+
151
+ ### CLI
152
+
153
+ ```bash
154
+ meadow-mind cartpole # sanity gate -> play one episode -> video + verdict
155
+ ```
156
+
157
+ ## Honest limits
158
+
159
+ - Reaction floor is one decision pass (~0.4 s ≈ 2 Hz). Tighter deadlines (1 m pole, Pong trajectory prediction) are out of reach today.
160
+ - Suited to tasks whose situations can be said in a sentence and whose policy fits a rule. Continuous high-precision control is not.
161
+ - The perceiver is human-designed (or AI-generated via `ai_prompt()`); the model's job is reading the rule and deciding.
162
+
163
+ ## Roadmap
164
+
165
+ - **v0.2** — layered perception with early action: accumulate confidence through the network and act when it crosses a threshold; easy situations should land near ~0.15 s.
166
+ - Rule-learning loop: discover rules like MountainCar's swing trick from failed episodes automatically (no gradients — the learned artifact is a readable sentence).
167
+
168
+ ## License
169
+
170
+ MIT © Hey-Meadow Lab
@@ -0,0 +1,151 @@
1
+ # Meadow Mind
2
+
3
+ **Zero training. Second-level reactions (~400 ms).**
4
+ A language-rule decision mind: write the policy as one sentence, describe the state as one sentence, and a local 7B model makes a real decision every ~0.4 s. No RL, no reward engineering, no gradients, no samples.
5
+
6
+ 🌐 **Demo site**: [meadow-mind.pages.dev](https://meadow-mind.pages.dev) (中文) · [English](https://meadow-mind.pages.dev/en.html) · [繁體中文 README](README.zh-TW.md)
7
+
8
+ ```bash
9
+ pip install meadow-mind # weights auto-download on first use
10
+ ```
11
+
12
+ ```python
13
+ from meadow_mind import MeadowMind, tasks
14
+
15
+ mind = MeadowMind() # loads once, runs on-device
16
+ task = tasks.mountaincar()
17
+ mind.check(task) # sanity gate: decision-table exam
18
+ action, info = mind.decide(task, obs) # obs in, env action out (~0.4s)
19
+ ```
20
+
21
+ ## Results
22
+
23
+ All on official Gymnasium environments, untouched physics, **zero training**. Every frame below corresponds to one real model decision; no scripted policy, no edited speed-ups.
24
+
25
+ | Balance · CartPole-v1<br>**400/400 perfect** (solve bar 195) | Landing · LunarLander-v3<br>**+251 safe landing** (solve bar 200) |
26
+ |---|---|
27
+ | ![CartPole](assets/balance.gif) | ![LunarLander](assets/landing.gif) |
28
+
29
+ | Maze · FrozenLake 8×8<br>**goal in 14 steps = shortest path** | Momentum · MountainCar-v0<br>**flag in 103 steps** (limit 200) |
30
+ |---|---|
31
+ | ![Maze](assets/maze.gif) | ![MountainCar](assets/mountaincar.gif) |
32
+
33
+ The MountainCar policy is one counterintuitive sentence — `"push in the same direction the car is moving, to pump energy like a swing"` — which replaces an entire RL reward curve.
34
+
35
+ ### Real-time reflex (wall-clock, not turn-based)
36
+
37
+ The model runs in a thread while obstacles fall in real time. If it is still thinking when the obstacle lands, it really crashes.
38
+
39
+ | Parkour dodge: full-generation crashes at #1, Meadow Mind clears 5/6 | Shape+color match: 6/6, down to a 0.72 s window |
40
+ |---|---|
41
+ | ![Parkour](assets/parkour.gif) | ![Shape+color](assets/shape_color.gif) |
42
+
43
+ ### Working memory
44
+
45
+ A funnel maze forces both runs into the same dead-end pocket. Reactive (left) paces at its mouth forever; with `Task(memory=True)` (right) it struggles, backs out, and detours to the goal in 22 steps. The only difference is five words in the perception sentence.
46
+
47
+ ![Memory](assets/memory.gif)
48
+
49
+ ## Decision latency: traditional LLM vs Meadow Mind
50
+
51
+ A traditional LLM agent must **generate its full answer before acting** — and latency grows with answer length. Meadow Mind reads the rule and the situation and decides in **one fixed-latency pass**, right at human reaction speed (0.3–0.4 s):
52
+
53
+ ![Latency](assets/latency.png)
54
+
55
+ ```
56
+ Traditional LLM agent Meadow Mind
57
+ ───────────────────── ───────────
58
+ state → long prompt state → one sentence (Perceiver)
59
+ → generate the answer → one sentence rule (Policy)
60
+ token by token (1.2–3.9 s, → ONE decision pass, fixed ~0.4 s
61
+ grows with length) → action letter (Actuator)
62
+ → parse free text → act exam-gated before deployment
63
+ ```
64
+
65
+ ## Why a diffusion LLM underneath
66
+
67
+ Meadow Mind is built on a diffusion language model (MeadowCoder-7B), not an autoregressive one. The differences that matter:
68
+
69
+ | | AR-LLM | Diffusion LLM (Meadow dLLM) |
70
+ |---|---|---|
71
+ | Generation | left to right, one token at a time; words are final once written | drafts the **whole answer at once**, then refines it over multiple steps |
72
+ | Mid-course correction | cannot edit what is already written; fixing means regenerating everything | **refines while working** — any region can be re-opened and corrected in place |
73
+ | Task awareness | sees only the next word | **global**: senses the entire task and answer shape at once |
74
+ | Pre-answer self-sense | none | **Σ**: before answering, Meadow dLLM senses whether it understands the task; low Σ coherence becomes an escalation signal instead of a wrong answer |
75
+ | Decision latency | grows with answer length | **fixed**, independent of answer length |
76
+ | Long free-form prose | mature, strong ecosystem | weaker; smaller ecosystem (honest trade-off) |
77
+
78
+ Two of these are what make Meadow Mind possible: **multi-step self-correction** (it can fix its own draft while working) and **global task perception with Σ** (it knows what it is being asked — and whether it understands — before committing to an answer).
79
+
80
+ ## How it works
81
+
82
+ ```
83
+ ┌────────────────────────────────────────────────────┐
84
+ │ ① Perceiver your code: numbers -> one sentence │
85
+ │ "The pole tilts right, fast spin." │
86
+ ├────────────────────────────────────────────────────┤
87
+ │ ② Rule one sentence = the policy │
88
+ │ edit behavior by editing words │
89
+ ├────────────────────────────────────────────────────┤
90
+ │ ③ Mind 7B on-device model reads rule+state, │
91
+ │ answers an action letter in a single │
92
+ │ decision pass, fixed ~0.4 s │
93
+ ├────────────────────────────────────────────────────┤
94
+ │ ④ Actuator letter -> env action │
95
+ └────────────────────────────────────────────────────┘
96
+ ```
97
+
98
+ There is no reward in the loop. The env score is only a report card; improvement happens by **outcome feedback**: the episode trace shows which sentence was wrong, and you edit it. (LunarLander went from a +27.5 crash to a +251 landing by adding one touchdown-cushion line to the perceiver. Ten seconds.)
99
+
100
+ ## Wire up a new game (5 steps)
101
+
102
+ 1. **Understand the task, explore input-output.** Variables, actions, win/lose conditions; the reaction deadline must be looser than ~0.4 s. List every action and watch its effect.
103
+ 2. **Build perception words.** One sentence describing the current situation. Bucket continuous values (small/big, fast/slow); always include a velocity/trend term.
104
+ 3. **Imprint the rule.** Invert the effects into "on situation X do action B". Keyword → letter, one-layer mapping, multiple choice only.
105
+ 4. **Decide on memory.** Ask: *"is revisiting the same state a failure signal?"* Yes (maze, exploration, dead ends) → `Task(memory=True)`. No (balance, landing, tracking — repetition IS the job) → keep it off; annotations measurably hurt regulation tasks (CartPole sanity 7/8 → 6/8). Unsure → leave off; the runner prints a hint when it detects looping.
106
+ 5. **Take the exam.** Enumerate every situation with its expected letter; `mind.check(task)` passes with at most 1 miss. Failures mean the wording is incomplete — rephrase and re-check, no training.
107
+
108
+ Or skip all five: hand `meadow_mind.ai_prompt()` plus your game description to any code agent, and it wires the task for you. You only review the exam score.
109
+
110
+ ## API
111
+
112
+ ### `MeadowMind(model_path=None)`
113
+ Weight resolution: `MEADOW_MIND_MODEL` env → explicit path → local cache (`~/.meadow-mind/models/`) → auto-download.
114
+
115
+ | Method | |
116
+ |---|---|
117
+ | `mind.decide(task, obs) -> (action, info)` | one real decision; `info = {status, letter, lat}` |
118
+ | `mind.check(task) -> (ok, n)` | sanity gate; raises if the decision table fails |
119
+
120
+ ### `Task(...)`
121
+
122
+ | Field | |
123
+ |---|---|
124
+ | `perceive(obs) -> str` (or `perceive(obs, task)` with memory) | perception layer |
125
+ | `rule` / `option_text` / `options` / `act_text` | the one-sentence policy and its multiple-choice actions |
126
+ | `sanity` | the exam: `[(status sentence, expected letter)]` |
127
+ | `memory` / `mem_key` | working-memory switch (default off) + state key fn |
128
+ | `env_id` / `env_kwargs` / `max_steps` / `judge` | environment wiring and report card |
129
+
130
+ With `memory=True` the runner auto-tracks `task.visited`; use `task.seen(key)` inside `perceive` to annotate, e.g. `(safe, already visited)`.
131
+
132
+ ### CLI
133
+
134
+ ```bash
135
+ meadow-mind cartpole # sanity gate -> play one episode -> video + verdict
136
+ ```
137
+
138
+ ## Honest limits
139
+
140
+ - Reaction floor is one decision pass (~0.4 s ≈ 2 Hz). Tighter deadlines (1 m pole, Pong trajectory prediction) are out of reach today.
141
+ - Suited to tasks whose situations can be said in a sentence and whose policy fits a rule. Continuous high-precision control is not.
142
+ - The perceiver is human-designed (or AI-generated via `ai_prompt()`); the model's job is reading the rule and deciding.
143
+
144
+ ## Roadmap
145
+
146
+ - **v0.2** — layered perception with early action: accumulate confidence through the network and act when it crosses a threshold; easy situations should land near ~0.15 s.
147
+ - Rule-learning loop: discover rules like MountainCar's swing trick from failed episodes automatically (no gradients — the learned artifact is a readable sentence).
148
+
149
+ ## License
150
+
151
+ MIT © Hey-Meadow Lab
@@ -0,0 +1,21 @@
1
+ """Meadow Mind — language-rule decision mind. One install, one import.
2
+
3
+ pip install meadow-mind
4
+
5
+ from meadow_mind import MeadowMind, Task, tasks, ai_prompt
6
+
7
+ mind = MeadowMind() # model auto-downloads on first use
8
+ task = tasks.mountaincar()
9
+ mind.check(task) # sanity gate
10
+ action, info = mind.decide(task, obs) # obs in, env action out
11
+
12
+ Everything below this API (engine, weights, decoding) is internal.
13
+ Give ai_prompt() to any code agent to wire a NEW game automatically.
14
+ """
15
+ from .mind import MeadowMind
16
+ from .task import Task
17
+ from .prompt import ai_prompt, AI_PROMPT
18
+ from . import tasks
19
+
20
+ __version__ = "0.1.0"
21
+ __all__ = ["MeadowMind", "Task", "tasks", "ai_prompt", "AI_PROMPT", "__version__"]
@@ -0,0 +1,302 @@
1
+ """DiffuCoder MLX engine — block-wise (semi-autoregressive) masked-diffusion
2
+ generation with a per-layer K/V cache and Dream boundary-shift fix.
3
+
4
+ This is the same proven inference path as diffucoder-play/diffucoder_mlx_block.py,
5
+ packaged as a reusable engine for the OpenAI-compatible server. Speed comes from
6
+ the algorithm (block-wise KV cache) + 8-bit weights, NOT from "being MLX".
7
+ """
8
+ import time
9
+
10
+ import numpy as np
11
+ import mlx.core as mx
12
+ from mlx_lm import load
13
+
14
+ MASK = 151666 # <|mask|>
15
+ IM_END = 151645 # <|im_end|>
16
+
17
+
18
+ def _top_p_filter(logits, top_p):
19
+ sorted_desc = -mx.sort(-logits, axis=-1)
20
+ cum = mx.cumsum(mx.softmax(sorted_desc, axis=-1), axis=-1)
21
+ keep = cum <= top_p
22
+ keep = mx.concatenate([mx.ones_like(keep[:, :1]), keep[:, :-1]], axis=-1)
23
+ k = mx.sum(keep, axis=-1).astype(mx.int32)
24
+ thresh = mx.take_along_axis(sorted_desc, (k - 1)[:, None], axis=-1)
25
+ return mx.where(logits < thresh, -1e9, logits)
26
+
27
+
28
+ def _sample_tokens(logits, temperature, top_p, neg_entropy):
29
+ if temperature and temperature > 0:
30
+ logits = logits / temperature
31
+ if top_p is not None and top_p < 1:
32
+ logits = _top_p_filter(logits, top_p)
33
+ probs = mx.softmax(logits, axis=-1)
34
+ if temperature and temperature > 0:
35
+ x0 = mx.random.categorical(logits)
36
+ conf = mx.take_along_axis(probs, x0[:, None], axis=-1)[:, 0]
37
+ else:
38
+ x0 = mx.argmax(probs, axis=-1)
39
+ conf = mx.max(probs, axis=-1)
40
+ if neg_entropy:
41
+ conf = mx.sum(probs * mx.log(probs + 1e-10), axis=-1)
42
+ return conf, x0
43
+
44
+
45
+ class _Cache:
46
+ def __init__(self, n):
47
+ self.k = [None] * n
48
+ self.v = [None] * n
49
+ self.length = 0
50
+
51
+ def append(self, ks, vs):
52
+ for li in range(len(self.k)):
53
+ self.k[li] = ks[li] if self.k[li] is None else mx.concatenate([self.k[li], ks[li]], axis=2)
54
+ self.v[li] = vs[li] if self.v[li] is None else mx.concatenate([self.v[li], vs[li]], axis=2)
55
+ self.length += ks[0].shape[2]
56
+
57
+
58
+ class DiffuCoderEngine:
59
+ def __init__(self, model_path, system="You are a helpful coding assistant."):
60
+ t0 = time.time()
61
+ self.is_llada = "llada" in model_path.lower()
62
+ self.mask_token_id = 156895 if self.is_llada else 151666
63
+ tokenizer_config = {"trust_remote_code": True} if self.is_llada else None
64
+ self.model, self.tok = load(model_path, tokenizer_config=tokenizer_config)
65
+ self.tie = getattr(self.model.args, "tie_word_embeddings", False)
66
+ self.layers = self.model.model.layers
67
+ self.n_layers = len(self.layers)
68
+ self.system = system
69
+ self.model_path = model_path
70
+ self.load_time = time.time() - t0
71
+ self._pc_ids = None # prefix cache: last prompt's token ids
72
+ self._pc_kv = None # prefix cache: last prompt's per-layer K/V
73
+
74
+ def _forward(self, ids_mx, offset, cache, attend_len, mask=None):
75
+ m = self.model.model
76
+ h = m.embed_tokens(ids_mx)
77
+ ks, vs = [], []
78
+ for li, layer in enumerate(self.layers):
79
+ attn = layer.self_attn
80
+ x = layer.input_layernorm(h)
81
+ B, L, _ = x.shape
82
+ q = attn.q_proj(x).reshape(B, L, attn.n_heads, -1).transpose(0, 2, 1, 3)
83
+ k = attn.k_proj(x).reshape(B, L, attn.n_kv_heads, -1).transpose(0, 2, 1, 3)
84
+ v = attn.v_proj(x).reshape(B, L, attn.n_kv_heads, -1).transpose(0, 2, 1, 3)
85
+ q = attn.rope(q, offset=offset)
86
+ k = attn.rope(k, offset=offset)
87
+ pk, pv = cache.k[li], cache.v[li]
88
+ if pk is not None and attend_len > 0:
89
+ kk = mx.concatenate([pk[:, :, :attend_len, :], k], axis=2)
90
+ vv = mx.concatenate([pv[:, :, :attend_len, :], v], axis=2)
91
+ else:
92
+ kk, vv = k, v
93
+ amask = mask.astype(q.dtype) if mask is not None else None
94
+ out = mx.fast.scaled_dot_product_attention(q, kk, vv, scale=attn.scale, mask=amask)
95
+ out = out.transpose(0, 2, 1, 3).reshape(B, L, -1)
96
+ h = h + attn.o_proj(out)
97
+ h = h + layer.mlp(layer.post_attention_layernorm(h))
98
+ ks.append(k)
99
+ vs.append(v)
100
+ h = m.norm(h)
101
+ logits = m.embed_tokens.as_linear(h) if self.tie else self.model.lm_head(h)
102
+ return logits, ks, vs
103
+
104
+ def _prefill(self, ids, use_cache=True, causal=True):
105
+ """Build the prompt's K/V cache. With prefix caching, reuse the shared prefix
106
+ of the previous prompt and only forward the new suffix → long conversations
107
+ don't re-prefill everything (fixes the 'gets slower each turn' problem).
108
+ causal=True encodes the prompt left-to-right so the prefix K/V is stable
109
+ across turns (required for caching; bidirectional prefill is not reusable)."""
110
+ cache = _Cache(self.n_layers)
111
+ m = 0
112
+ if use_cache and causal and self._pc_ids is not None:
113
+ ci = self._pc_ids
114
+ lim = min(len(ids), len(ci))
115
+ while m < lim and ids[m] == ci[m]:
116
+ m += 1
117
+ m = min(m, len(ids) - 1) # leave >=1 token to actually forward
118
+ if m > 8: # worth reusing
119
+ for li in range(self.n_layers):
120
+ cache.k[li] = self._pc_kv.k[li][:, :, :m, :]
121
+ cache.v[li] = self._pc_kv.v[li][:, :, :m, :]
122
+ cache.length = m
123
+ else:
124
+ m = 0
125
+ Ln = len(ids) - m
126
+ # causal mask: new tokens see ALL m cached prefix + only earlier new tokens.
127
+ # This makes each prompt token's K/V depend only on its left context -> the
128
+ # prefix K/V is stable across turns -> reusable (bidirectional prefill is not).
129
+ mask = None
130
+ if causal:
131
+ tri = mx.triu(mx.full((Ln, Ln), -1e9), k=1)
132
+ mask = mx.concatenate([mx.zeros((Ln, m)), tri], axis=1) if m > 0 else tri
133
+ _, ks, vs = self._forward(mx.array(ids[m:][None]), m, cache, m, mask=mask)
134
+ cache.append(ks, vs)
135
+ if use_cache and causal:
136
+ mx.eval(cache.k[0], cache.v[0])
137
+ self._pc_ids, self._pc_kv = ids.copy(), cache
138
+ return cache, m
139
+
140
+ # ---- Σ: step-0 draft (the model's instant full-answer guess before any commit) ----
141
+ def step0_draft(self, prompt_text, n=96):
142
+ """One forward over [prompt + all-MASK]: the model already drafts the whole
143
+ answer (77-100% of the final words) before committing anything. Returns the
144
+ draft text + per-position confidence. AR has no equivalent."""
145
+ pids = np.array(self.tok.encode(prompt_text), dtype=np.int64)
146
+ x = np.concatenate([pids, np.full(n, self.mask_token_id, dtype=np.int64)])
147
+ plen = len(pids)
148
+ logits = self._forward_full(mx.array(x[None]))
149
+ ans = logits[mx.array(list(range(plen, plen + n)))]
150
+ probs = mx.softmax(ans, axis=-1)
151
+ conf = np.array(mx.max(probs, axis=-1).astype(mx.float32))
152
+ ids = np.array(mx.argmax(ans, axis=-1))
153
+ return {"ids": ids, "conf": conf, "text": self.tok.decode(ids.tolist()), "coherence": float(conf.mean())}
154
+
155
+ def route(self, user_msg, n=96):
156
+ """Prefill Gating: read the step-0 draft to pick mode / trigger RAG BEFORE generating."""
157
+ d = self.step0_draft(self.build_prompt([{"role": "user", "content": user_msg}]), n)
158
+ t = d["text"].lower()
159
+ sig = []
160
+ if any(k in t for k in ("<!doctype", "<html", "<div", "<body")): sig.append("html")
161
+ if "matplotlib" in t or "plt." in t: sig.append("chart")
162
+ if "select " in t and "from " in t: sig.append("sql")
163
+ if "def " in t or "class " in t: sig.append("code")
164
+ mode = "sectioned" if "html" in sig else "single"
165
+ # low coherence = model doesn't 'have it' -> escalate / inject RAG before generating
166
+ escalate = d["coherence"] < 0.45
167
+ return {"signals": sig, "mode": mode, "coherence": round(d["coherence"], 2),
168
+ "escalate": escalate, "draft_head": d["text"][:70].replace("\n", " ")}
169
+
170
+ # ---- infilling primitive (DiffuCoder's native strength) ----
171
+ def _forward_full(self, x_mx):
172
+ """Full bidirectional forward over the whole sequence (no block cache)."""
173
+ m = self.model.model
174
+ h = m.embed_tokens(x_mx)
175
+ for layer in self.layers:
176
+ h = layer(h, None, None)
177
+ h = m.norm(h)
178
+ logits = m.embed_tokens.as_linear(h) if self.tie else self.model.lm_head(h)
179
+ if self.is_llada:
180
+ return logits[0]
181
+ return mx.concatenate([logits[:, :1], logits[:, :-1]], axis=1)[0] # Dream shift
182
+
183
+ def infill(self, prompt_text, pre, n_slot, post, steps=8, temperature=0.2, top_p=0.95):
184
+ """Fill `n_slot` masked tokens between fixed `pre`/`post`.
185
+ Returns (full_text, slot_text) where slot_text is only the filled slot region."""
186
+ pids = np.array(self.tok.encode(prompt_text), dtype=np.int64)
187
+ pre_ids = self.tok.encode(pre, add_special_tokens=False)
188
+ answer = np.array(pre_ids + [self.mask_token_id] * n_slot
189
+ + self.tok.encode(post, add_special_tokens=False), dtype=np.int64)
190
+ x = np.concatenate([pids, answer])
191
+ plen = len(pids)
192
+ slot0 = plen + len(pre_ids)
193
+ ts = np.linspace(1, 1e-12, steps + 1)
194
+ for i in range(steps):
195
+ mi = x == self.mask_token_id
196
+ if not mi.any():
197
+ break
198
+ logits = self._forward_full(mx.array(x[None]))
199
+ mpos = np.nonzero(mi)[0]
200
+ conf, x0 = _sample_tokens(logits[mx.array(mpos)], temperature, top_p, neg_entropy=True)
201
+ mx.eval(conf, x0)
202
+ conf = np.array(conf.astype(mx.float32)); x0 = np.array(x0.astype(mx.int32))
203
+ k = int(len(mpos) * (1 - ts[i + 1] / ts[i])) if i < steps - 1 else len(mpos)
204
+ if k > 0:
205
+ order = np.argsort(-conf)[:k]
206
+ x[mpos[order]] = x0[order]
207
+ out = x[plen:]
208
+ slot_ids = x[slot0: slot0 + n_slot]
209
+ slot_text = self.tok.decode([int(t) for t in slot_ids.tolist() if t != self.mask_token_id])
210
+ full = self.tok.decode(out[out != self.mask_token_id].tolist())
211
+ return full, slot_text
212
+
213
+ # prefill dictionary: tool -> (pre, slots, post). Structure fixed, only args infilled.
214
+ TOOL_SCAFFOLDS = {
215
+ "read_file": ('{"name": "read_file", "arguments": {"path": "', 8, '"}}'),
216
+ "write_file": ('{"name": "write_file", "arguments": {"path": "', 8, '", "content": "..."}}'),
217
+ "run_bash": ('{"name": "run_bash", "arguments": {"command": "', 12, '"}}'),
218
+ "search_code": ('{"name": "search_code", "arguments": {"query": "', 10, '"}}'),
219
+ "git_commit": ('{"name": "git_commit", "arguments": {"message": "', 12, '"}}'),
220
+ }
221
+
222
+ def tool_call(self, user_msg, tool_name, steps=8):
223
+ """Structure-guaranteed tool call: infill only the arg value, rebuild from scaffold."""
224
+ pre, n, post = self.TOOL_SCAFFOLDS[tool_name]
225
+ prompt = self.build_prompt([{"role": "user", "content": user_msg}])
226
+ _, slot = self.infill(prompt, pre, n, post, steps=steps)
227
+ # keep only the value up to the first closing quote / special token / newline
228
+ val = slot.split('"')[0].split("<|")[0].split("\n")[0].strip()
229
+ return pre + val + post # guaranteed valid JSON structure
230
+
231
+ @staticmethod
232
+ def _flatten(content):
233
+ """OpenAI content can be str | list[{type,text}] | None — flatten to text."""
234
+ if content is None:
235
+ return ""
236
+ if isinstance(content, str):
237
+ return content
238
+ if isinstance(content, list):
239
+ return "".join(p.get("text", "") if isinstance(p, dict) else str(p) for p in content)
240
+ return str(content)
241
+
242
+ def build_prompt(self, messages):
243
+ """messages: list of {role, content}. Prepend the engine's default system if none given."""
244
+ has_sys = any(m.get("role") == "system" for m in messages)
245
+ parts = []
246
+ if not has_sys:
247
+ parts.append(f"<|im_start|>system\n{self.system}<|im_end|>\n")
248
+ for m in messages:
249
+ parts.append(f"<|im_start|>{m.get('role','user')}\n{self._flatten(m.get('content'))}<|im_end|>\n")
250
+ parts.append("<|im_start|>assistant\n")
251
+ return "".join(parts)
252
+
253
+ def generate(self, prompt, max_new=128, block_size=32, tokens_per_step=8,
254
+ temperature=0.2, top_p=0.95, alg="entropy", use_prefix_cache=False,
255
+ causal_prefill=False):
256
+ # DEFAULT = bidirectional prefill, no cache (DiffuCoder is bidirectional-trained;
257
+ # this preserves quality). prefix-cache REQUIRES causal_prefill to be correct, but
258
+ # causal encoding degrades quality -> it's an opt-in speed/quality tradeoff, not the
259
+ # fix for long-conversation slowdown. The clean fix is memory-backed short context.
260
+ ids = np.array(self.tok.encode(prompt), dtype=np.int64)
261
+ t0 = time.time()
262
+ cache, reused = self._prefill(ids, use_prefix_cache, causal_prefill)
263
+ prev_tok = int(ids[-1])
264
+
265
+ out = []
266
+ n_blocks = (max_new + block_size - 1) // block_size
267
+ eps = 1e-12
268
+ neg = mx.array(-1e30, dtype=mx.float32)
269
+ for _ in range(n_blocks):
270
+ off = cache.length
271
+ block = mx.full((block_size,), self.mask_token_id, dtype=mx.int32)
272
+ steps = max(1, block_size // tokens_per_step)
273
+ timesteps = np.linspace(1, eps, steps + 1)
274
+ for i in range(steps):
275
+ # unmask selection stays fully in MLX — no per-step GPU->CPU->GPU round-trip
276
+ is_mask = block == self.mask_token_id
277
+ window = mx.concatenate([mx.array([prev_tok], dtype=mx.int32), block])
278
+ logits, _, _ = self._forward(window[None], off - 1, cache, off - 1)
279
+ conf, x0 = _sample_tokens(logits[0][:-1], temperature, top_p, neg_entropy=(alg == "entropy"))
280
+ conf = mx.where(is_mask, conf.astype(mx.float32), neg) # only masked are candidates
281
+ t, s = timesteps[i], timesteps[i + 1]
282
+ frac = (1 - s / t) if i < steps - 1 else 1.0
283
+ n_transfer = mx.floor(mx.sum(is_mask.astype(mx.float32)) * frac).astype(mx.int32)
284
+ rank = mx.argsort(mx.argsort(-conf)) # rank 0 = most confident
285
+ block = mx.where(rank < n_transfer, x0.astype(mx.int32), block)
286
+ block_list = block.tolist() # one eval per block, not per step
287
+ out.extend(block_list)
288
+ _, ks, vs = self._forward(block[None], off, cache, off)
289
+ cache.append(ks, vs)
290
+ prev_tok = block_list[-1]
291
+ if IM_END in block_list:
292
+ break
293
+ mx.eval(cache.k[0])
294
+ dt = time.time() - t0
295
+
296
+ out = np.array(out)
297
+ if IM_END in out.tolist():
298
+ out = out[: out.tolist().index(IM_END)]
299
+ out = out[out != self.mask_token_id]
300
+ text = self.tok.decode(out.tolist())
301
+ return {"text": text, "time": dt, "n_tokens": int(len(out)), "tok_per_s": len(out) / dt if dt else 0.0,
302
+ "prompt_len": int(len(ids)), "prefix_reused": int(reused)}