@fugood/llama.node 1.3.8 → 1.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. package/lib/binding.js +25 -18
  2. package/lib/binding.ts +19 -1
  3. package/lib/index.js +3 -3
  4. package/lib/index.ts +1 -1
  5. package/package.json +17 -17
  6. package/scripts/llama.cpp.patch +53 -4
  7. package/src/LlamaCompletionWorker.cpp +2 -2
  8. package/src/LlamaContext.cpp +6 -1
  9. package/src/llama.cpp/common/arg.cpp +1 -1
  10. package/src/llama.cpp/common/chat-parser.cpp +968 -0
  11. package/src/llama.cpp/common/chat.cpp +0 -952
  12. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -2
  13. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  14. package/src/llama.cpp/ggml/include/ggml-rpc.h +1 -1
  15. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -4
  16. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +336 -3
  17. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +11 -8
  18. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +22 -0
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -1
  20. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +234 -1
  21. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +6 -0
  22. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  23. package/src/llama.cpp/src/llama-arch.cpp +48 -3
  24. package/src/llama.cpp/src/llama-arch.h +2 -0
  25. package/src/llama.cpp/src/llama-context.cpp +6 -2
  26. package/src/llama.cpp/src/llama-hparams.h +1 -1
  27. package/src/llama.cpp/src/llama-model.cpp +102 -5
  28. package/src/llama.cpp/src/llama-model.h +4 -0
  29. package/src/llama.cpp/src/llama-quant.cpp +13 -5
  30. package/src/llama.cpp/src/models/lfm2.cpp +5 -3
  31. package/src/llama.cpp/src/models/models.h +51 -1
  32. package/src/llama.cpp/src/models/qwen3next.cpp +1042 -0
@@ -0,0 +1,1042 @@
1
+ #include "ggml.h"
2
+ #include "models.h"
3
+
4
+ #define CHUNK_SIZE 64
5
+
6
+ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
7
+ llm_graph_context_mamba(params), model(model) {
8
+ ggml_tensor * cur;
9
+ ggml_tensor * inpL;
10
+
11
+ inpL = build_inp_embd(model.tok_embd);
12
+ cb(inpL, "model.embed_tokens", -1);
13
+
14
+ auto * inp = build_inp_mem_hybrid();
15
+
16
+ ggml_tensor * inp_pos = build_inp_pos();
17
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
18
+
19
+ ggml_tensor * causal_mask =
20
+ ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
21
+ GGML_TRI_TYPE_LOWER);
22
+
23
+ ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
24
+
25
+ ggml_build_forward_expand(gf, causal_mask);
26
+ ggml_build_forward_expand(gf, identity);
27
+
28
+ for (int il = 0; il < n_layer; ++il) {
29
+ ggml_tensor * inpSA = inpL;
30
+
31
+ cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
32
+ cb(cur, "attn_norm", il);
33
+
34
+ // Determine layer type and build appropriate attention mechanism
35
+ if (hparams.is_recurrent(il)) {
36
+ // Linear attention layer (gated delta net)
37
+ cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il);
38
+ } else {
39
+ // Full attention layer
40
+ cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
41
+ }
42
+
43
+ if (il == n_layer - 1 && inp_out_ids) {
44
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
45
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
46
+ }
47
+
48
+ // Residual connection
49
+ cur = ggml_add(ctx0, cur, inpSA);
50
+ cb(cur, "attn_residual", il);
51
+
52
+ // Save the tensor before post-attention norm for residual connection
53
+ ggml_tensor * ffn_residual = cur;
54
+
55
+ // Post-attention norm
56
+ ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
57
+ cb(attn_post_norm, "attn_post_norm", il);
58
+
59
+ // FFN layer (MoE or dense) - without residual connection
60
+ cur = build_layer_ffn(attn_post_norm, il);
61
+ cb(cur, "ffn_out", il);
62
+
63
+ // Residual connection for FFN - add to the tensor from before post_attention_layernorm
64
+ cur = ggml_add(ctx0, cur, ffn_residual);
65
+ cb(cur, "post_moe", il);
66
+
67
+ // Input for next layer
68
+ inpL = cur;
69
+ }
70
+ cur = inpL;
71
+
72
+ // Final norm
73
+ cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
74
+
75
+ cb(cur, "result_norm", -1);
76
+ res->t_embd = cur;
77
+
78
+ // LM head
79
+ cur = build_lora_mm(model.output, cur);
80
+
81
+ cb(cur, "result_output", -1);
82
+ res->t_logits = cur;
83
+
84
+ ggml_build_forward_expand(gf, cur);
85
+ }
86
+
87
+ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
88
+ ggml_tensor * q,
89
+ ggml_tensor * k,
90
+ ggml_tensor * v,
91
+ ggml_tensor * g,
92
+ ggml_tensor * beta,
93
+ ggml_tensor * state,
94
+ ggml_tensor * causal_mask,
95
+ ggml_tensor * identity,
96
+ int il) {
97
+ GGML_ASSERT(ggml_is_contiguous(q));
98
+ GGML_ASSERT(ggml_is_contiguous(k));
99
+ GGML_ASSERT(ggml_is_contiguous(v));
100
+ GGML_ASSERT(ggml_is_contiguous(g));
101
+ GGML_ASSERT(ggml_is_contiguous(beta));
102
+ GGML_ASSERT(ggml_is_contiguous(state));
103
+
104
+ const int64_t S_k = q->ne[0];
105
+ const int64_t H_k = q->ne[1];
106
+ const int64_t n_tokens = q->ne[2];
107
+ const int64_t n_seqs = q->ne[3];
108
+
109
+ const int64_t S_v = v->ne[0];
110
+ const int64_t H_v = v->ne[1];
111
+
112
+ GGML_ASSERT(v->ne[2] == n_tokens);
113
+ GGML_ASSERT(k->ne[2] == n_tokens);
114
+ GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
115
+ GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
116
+ GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
117
+
118
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
119
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
120
+
121
+ GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
122
+
123
+ // TODO: can this ever be false?
124
+ const bool use_qk_l2norm = true;
125
+
126
+ if (use_qk_l2norm) {
127
+ const float eps_norm = hparams.f_norm_rms_eps;
128
+
129
+ q = ggml_l2_norm(ctx0, q, eps_norm);
130
+ k = ggml_l2_norm(ctx0, k, eps_norm);
131
+ }
132
+
133
+ const float scale = 1.0f / sqrtf(S_v);
134
+
135
+ q = ggml_scale(ctx0, q, scale);
136
+
137
+ beta = ggml_sigmoid(ctx0, beta);
138
+
139
+ ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
140
+
141
+ cb(q, "q_in", il);
142
+ cb(k, "k_in", il);
143
+ cb(v, "v_in", il);
144
+ cb(beta, "beta_in", il);
145
+ cb(g, "g_in", il);
146
+
147
+ q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
148
+ k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
149
+ v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
150
+ g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
151
+
152
+ beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
153
+ state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
154
+
155
+ cb(q, "q_perm", il);
156
+ cb(k, "k_perm", il);
157
+ cb(v, "v_perm", il);
158
+ cb(beta, "beta_perm", il);
159
+ cb(g, "g_perm", il);
160
+ cb(state, "state_in", il);
161
+
162
+ GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
163
+ GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
164
+ GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
165
+ GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
166
+
167
+ // Do padding
168
+ const int64_t chunk_size = CHUNK_SIZE;
169
+
170
+ const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
171
+ const int64_t n_chunks = (n_tokens + pad) / chunk_size;
172
+
173
+ q = ggml_pad(ctx0, q, 0, pad, 0, 0);
174
+ k = ggml_pad(ctx0, k, 0, pad, 0, 0);
175
+ v = ggml_pad(ctx0, v, 0, pad, 0, 0);
176
+ g = ggml_pad(ctx0, g, pad, 0, 0, 0);
177
+ beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
178
+
179
+ cb(q, "q_pad", il);
180
+ cb(k, "k_pad", il);
181
+ cb(v, "v_pad", il);
182
+ cb(beta, "beta_pad", il);
183
+ cb(g, "g_pad", il);
184
+
185
+ ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
186
+ ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
187
+
188
+ cb(v_beta, "v_beta", il);
189
+ cb(k_beta, "k_beta", il);
190
+
191
+ ggml_tensor * chunked_mask =
192
+ ggml_view_4d(ctx0, causal_mask, chunk_size,
193
+ chunk_size, causal_mask->ne[2], causal_mask->ne[3],
194
+ causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0);
195
+
196
+ ggml_tensor * chunked_diag_mask =
197
+ ggml_view_4d(ctx0, causal_diag_mask, chunk_size,
198
+ chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3],
199
+ causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0);
200
+
201
+ ggml_tensor * chunked_identity =
202
+ ggml_view_4d(ctx0, identity, chunk_size,
203
+ chunk_size, identity->ne[2], identity->ne[3],
204
+ identity->nb[1], identity->nb[2], identity->nb[3], 0);
205
+
206
+ q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
207
+ k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
208
+ k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
209
+ v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
210
+ v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
211
+
212
+ g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
213
+ beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
214
+
215
+ ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
216
+
217
+ cb(g_cumsum, "g_cumsum", il);
218
+
219
+ ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
220
+ ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
221
+
222
+ ggml_tensor * gcs_j_broadcast =
223
+ ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
224
+
225
+ ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
226
+
227
+ cb(decay_mask, "decay_mask", il);
228
+
229
+ decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
230
+ decay_mask = ggml_exp(ctx0, decay_mask);
231
+ decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
232
+
233
+ ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
234
+
235
+ ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
236
+ ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask));
237
+
238
+ cb(attn, "attn_pre_solve", il);
239
+
240
+ ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask);
241
+ ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower);
242
+
243
+ ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
244
+ attn = ggml_mul(ctx0, lin_solve, chunked_mask);
245
+ attn = ggml_add(ctx0, attn, chunked_identity);
246
+
247
+ cb(attn, "attn_solved", il);
248
+
249
+ v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
250
+
251
+ ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
252
+ ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
253
+
254
+ ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
255
+
256
+ cb(kbeta_gexp, "kbeta_gexp", il);
257
+
258
+ ggml_tensor * k_cumdecay =
259
+ ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
260
+
261
+ cb(k_cumdecay, "k_cumdecay", il);
262
+
263
+ ggml_tensor * core_attn_out = nullptr;
264
+ ggml_tensor * new_state = ggml_dup(ctx0, state);
265
+
266
+ cb(new_state, "new_state", il);
267
+
268
+ for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
269
+ auto chunkify = [=](ggml_tensor * t) {
270
+ return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
271
+ t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
272
+ };
273
+
274
+ auto chunkify_g = [=](ggml_tensor * t) {
275
+ return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
276
+ t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
277
+ };
278
+
279
+ ggml_tensor * k_chunk = chunkify(k);
280
+ ggml_tensor * q_chunk = chunkify(q);
281
+ ggml_tensor * v_chunk = chunkify(v);
282
+
283
+ ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
284
+ ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
285
+
286
+ ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
287
+ ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
288
+
289
+ ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
290
+
291
+ // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
292
+ attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
293
+ attn = ggml_mul(ctx0, attn, decay_mask_chunk);
294
+ attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask));
295
+
296
+ ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
297
+
298
+ // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
299
+ ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
300
+
301
+ // v_new = v_i - v_prime
302
+ ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
303
+ ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
304
+
305
+ // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
306
+ ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
307
+ ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
308
+
309
+ // core_attn_out[:, :, i] = attn_inter + attn @ v_new
310
+ ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
311
+
312
+ ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
313
+
314
+ core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
315
+
316
+ // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
317
+ // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
318
+ // key_gdiff = key * g_diff.unsqueeze(-1)
319
+ // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
320
+ // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
321
+
322
+ ggml_tensor * g_cum_last =
323
+ ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
324
+ g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
325
+ g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
326
+
327
+ ggml_tensor * gexp_last =
328
+ ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
329
+
330
+ ggml_tensor * g_cum_last_3d =
331
+ ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
332
+
333
+ ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
334
+
335
+ ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
336
+
337
+ ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
338
+
339
+ ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
340
+ ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
341
+ g_diff_exp->ne[2] * g_diff_exp->ne[3]));
342
+
343
+ ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
344
+
345
+ new_state = ggml_add(ctx0,
346
+ ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
347
+ ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
348
+ }
349
+
350
+ core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
351
+
352
+ ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
353
+ cb(output_tokens, "output_tokens", il);
354
+
355
+ // flatten output
356
+ ggml_tensor * flat_output =
357
+ ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
358
+
359
+ ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
360
+
361
+ return ggml_concat(ctx0, flat_output, flat_state, 0);
362
+ }
363
+
364
+ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
365
+ ggml_tensor * q,
366
+ ggml_tensor * k,
367
+ ggml_tensor * v,
368
+ ggml_tensor * g,
369
+ ggml_tensor * beta,
370
+ ggml_tensor * state,
371
+ ggml_tensor * causal_mask,
372
+ ggml_tensor * identity,
373
+ int il) {
374
+ GGML_ASSERT(ggml_is_contiguous(q));
375
+ GGML_ASSERT(ggml_is_contiguous(k));
376
+ GGML_ASSERT(ggml_is_contiguous(v));
377
+ GGML_ASSERT(ggml_is_contiguous(g));
378
+ GGML_ASSERT(ggml_is_contiguous(beta));
379
+ GGML_ASSERT(ggml_is_contiguous(state));
380
+
381
+ const int64_t S_k = q->ne[0];
382
+ const int64_t H_k = q->ne[1];
383
+ const int64_t n_tokens = q->ne[2];
384
+ const int64_t n_seqs = q->ne[3];
385
+
386
+ const int64_t S_v = v->ne[0];
387
+ const int64_t H_v = v->ne[1];
388
+
389
+ GGML_ASSERT(v->ne[2] == n_tokens);
390
+ GGML_ASSERT(k->ne[2] == n_tokens);
391
+ GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
392
+ GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
393
+ GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
394
+
395
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
396
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
397
+
398
+ GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
399
+
400
+ // TODO: can this ever be false?
401
+ const bool use_qk_l2norm = true;
402
+
403
+ if (use_qk_l2norm) {
404
+ const float eps_norm = hparams.f_norm_rms_eps;
405
+
406
+ q = ggml_l2_norm(ctx0, q, eps_norm);
407
+ k = ggml_l2_norm(ctx0, k, eps_norm);
408
+ }
409
+
410
+ const float scale = 1.0f / sqrtf(S_v);
411
+
412
+ q = ggml_scale(ctx0, q, scale);
413
+
414
+ beta = ggml_sigmoid(ctx0, beta);
415
+
416
+ ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
417
+
418
+ cb(q, "q_in", il);
419
+ cb(k, "k_in", il);
420
+ cb(v, "v_in", il);
421
+ cb(beta, "beta_in", il);
422
+ cb(g, "g_in", il);
423
+
424
+ q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
425
+ k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
426
+ v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
427
+ g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
428
+
429
+ beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
430
+ state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
431
+
432
+ cb(q, "q_perm", il);
433
+ cb(k, "k_perm", il);
434
+ cb(v, "v_perm", il);
435
+ cb(beta, "beta_perm", il);
436
+ cb(g, "g_perm", il);
437
+ cb(state, "state_in", il);
438
+
439
+ GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
440
+ GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
441
+ GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
442
+ GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
443
+
444
+ ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
445
+ ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
446
+
447
+ ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
448
+
449
+ cb(k_beta, "k_beta", il);
450
+ cb(v_beta, "v_beta", il);
451
+ cb(g_cumsum, "g_cumsum", il);
452
+
453
+ ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
454
+ ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
455
+
456
+ // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
457
+ // ggml_tensor * gcs_i_broadcast =
458
+ // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
459
+ // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
460
+ // Don't need this, this one will get auto-broadcast
461
+ ggml_tensor * gcs_j_broadcast =
462
+ ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
463
+
464
+ ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
465
+
466
+ // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
467
+ decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
468
+ // Apply exponential to get the decay mask values
469
+ decay_mask = ggml_exp(ctx0, decay_mask);
470
+ // Apply lower triangular mask again to ensure only lower triangular values remain
471
+ decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
472
+
473
+ cb(decay_mask, "decay_mask", il);
474
+
475
+ // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
476
+ ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
477
+
478
+ cb(kmulkbeta, "kmulkbeta", il);
479
+
480
+ ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
481
+ ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
482
+
483
+ cb(attn, "attn_pre_rec", il);
484
+
485
+ // for i in range(1, chunk_size):
486
+ // row = attn[..., i, :i].clone()
487
+ // sub = attn[..., :i, :i].clone()
488
+ // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
489
+ // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
490
+ //
491
+ // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
492
+ ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
493
+ ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
494
+
495
+ ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
496
+ attn = ggml_mul(ctx0, lin_solve, causal_mask);
497
+ attn = ggml_add(ctx0, attn, identity);
498
+
499
+ // value = attn @ v_beta
500
+ v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
501
+
502
+ cb(v, "value_beta", il);
503
+
504
+ // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
505
+ ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
506
+ ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
507
+
508
+ cb(gexp, "g_cum_exp", il);
509
+
510
+ ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
511
+
512
+ cb(kbeta_gexp, "kbeta_gexp", il);
513
+
514
+ ggml_tensor * k_cumdecay =
515
+ ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
516
+
517
+ cb(k_cumdecay, "k_cumdecay", il);
518
+
519
+ // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
520
+ attn = ggml_mul_mat(ctx0, k, q);
521
+ attn = ggml_mul(ctx0, attn, decay_mask);
522
+ attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask));
523
+
524
+ cb(attn, "attn_decay_key", il);
525
+
526
+ ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
527
+
528
+ // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
529
+ ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay);
530
+
531
+ cb(v_prime, "v_prime", il);
532
+
533
+ // v_new = v_i - v_prime
534
+ ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime);
535
+
536
+ ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
537
+
538
+ cb(v_new, "v_new", il);
539
+
540
+ // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
541
+ ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp);
542
+ ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
543
+
544
+ cb(attn_inter, "attn_inter", il);
545
+
546
+ // core_attn_out[:, :, i] = attn_inter + attn @ v_new
547
+ ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
548
+
549
+ cb(v_attn, "v_attn", il);
550
+
551
+ ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn);
552
+
553
+ cb(core_attn_out, "core_attn_out", il);
554
+
555
+ // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
556
+ // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
557
+ // key_gdiff = key * g_diff.unsqueeze(-1)
558
+ // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
559
+ // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
560
+
561
+ ggml_tensor * g_cum_last =
562
+ ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3],
563
+ g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3],
564
+ g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
565
+
566
+ cb(g_cum_last, "g_cum_last", il);
567
+
568
+ ggml_tensor * gexp_last =
569
+ ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
570
+
571
+ cb(gexp_last, "gexp_last", il);
572
+
573
+ ggml_tensor * g_cum_last_3d =
574
+ ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
575
+
576
+ cb(g_cum_last_3d, "g_cum_last_3d", il);
577
+
578
+ ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
579
+
580
+ cb(g_cumsum_3d, "g_cumsum_3d", il);
581
+
582
+ ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
583
+
584
+ cb(g_diff, "g_diff", il);
585
+
586
+ ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
587
+
588
+ cb(g_diff_exp, "g_diff_exp", il);
589
+
590
+ ggml_tensor * key_gdiff = ggml_mul(ctx0, k,
591
+ ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
592
+ g_diff_exp->ne[2] * g_diff_exp->ne[3]));
593
+
594
+ cb(key_gdiff, "key_gdiff", il);
595
+
596
+ ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
597
+
598
+ cb(kgdmulvnew, "kgdmulvnew", il);
599
+
600
+ state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew);
601
+
602
+ cb(state, "new_state", il);
603
+
604
+ // flatten output
605
+ ggml_tensor * flat_output =
606
+ ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
607
+
608
+ ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
609
+
610
+ return ggml_concat(ctx0, flat_output, flat_state, 0);
611
+ }
612
+
613
+ ggml_tensor * llm_build_qwen3next::build_norm_gated(
614
+ ggml_tensor * input,
615
+ ggml_tensor * weights,
616
+ ggml_tensor * gate,
617
+ int layer) {
618
+ ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
619
+ ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
620
+
621
+ return ggml_mul(ctx0, normalized, gated_silu);
622
+ }
623
+
624
+ ggml_tensor * llm_build_qwen3next::build_layer_attn(
625
+ llm_graph_input_attn_kv * inp,
626
+ ggml_tensor * cur,
627
+ ggml_tensor * inp_pos,
628
+ int il) {
629
+ const int64_t n_embd_head = hparams.n_embd_head_v;
630
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
631
+
632
+ // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
633
+
634
+ // Qwen3Next uses a single Q projection that outputs query + gate
635
+ ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
636
+ cb(Qcur_full, "Qcur_full", il);
637
+
638
+ Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
639
+
640
+ // Split Q projection into query and gate
641
+ // The split should be along dimension 0 (the feature dimension)
642
+ ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
643
+ Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
644
+ ggml_tensor * gate =
645
+ ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
646
+ Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
647
+ cb(Qcur, "Qcur", il);
648
+ cb(gate, "gate", il);
649
+
650
+ // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
651
+ Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
652
+ cb(Qcur, "Qcur_reshaped", il);
653
+
654
+ // Apply Q normalization
655
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
656
+ cb(Qcur, "Qcur_normed", il);
657
+
658
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
659
+ cb(Kcur, "Kcur", il);
660
+
661
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
662
+ cb(Vcur, "Vcur", il);
663
+
664
+ // Apply K normalization
665
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
666
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
667
+ cb(Kcur, "Kcur_normed", il);
668
+
669
+ // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
670
+ gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
671
+ cb(gate, "gate_reshaped", il);
672
+
673
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
674
+
675
+ // Apply RoPE
676
+ Qcur = ggml_rope_ext(
677
+ ctx0, Qcur, inp_pos, nullptr,
678
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
679
+ ext_factor, attn_factor, beta_fast, beta_slow);
680
+
681
+ Kcur = ggml_rope_ext(
682
+ ctx0, Kcur, inp_pos, nullptr,
683
+ n_rot, rope_type, n_ctx_orig, freq_base,
684
+ freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
685
+
686
+ cb(Qcur, "Qcur", il);
687
+ cb(Kcur, "Kcur", il);
688
+ cb(Vcur, "Vcur", il);
689
+
690
+ // Attention computation
691
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
692
+
693
+ cur = build_attn(inp,
694
+ nullptr, nullptr,
695
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
696
+ cb(cur, "attn_pregate", il);
697
+
698
+ ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
699
+ cb(gate_sigmoid, "gate_sigmoid", il);
700
+
701
+ cur = ggml_mul(ctx0, cur, gate_sigmoid);
702
+ cb(cur, "attn_gated", il);
703
+
704
+ cur = build_lora_mm(model.layers[il].wo, cur);
705
+ cb(cur, "attn_output", il);
706
+
707
+ return cur;
708
+ }
709
+
710
+ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
711
+ llm_graph_input_rs * inp,
712
+ ggml_tensor * cur,
713
+ ggml_tensor * causal_mask,
714
+ ggml_tensor * identity,
715
+ int il) {
716
+ const auto * mctx_cur = inp->mctx;
717
+
718
+ const int64_t d_inner = hparams.ssm_d_inner;
719
+ const int64_t n_seqs = ubatch.n_seqs;
720
+ const int64_t head_k_dim = hparams.ssm_d_state;
721
+ const int64_t num_k_heads = hparams.ssm_n_group;
722
+ const int64_t num_v_heads = hparams.ssm_dt_rank;
723
+ const int64_t head_v_dim = d_inner / num_v_heads;
724
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
725
+
726
+ const auto kv_head = mctx_cur->get_head();
727
+
728
+ GGML_ASSERT(n_seqs != 0);
729
+ GGML_ASSERT(ubatch.equal_seqs());
730
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
731
+
732
+ // Input projections
733
+ ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
734
+ cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
735
+
736
+ ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
737
+ cb(mixed_ba, "linear_attn_mixed_ba", il);
738
+
739
+ int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
740
+ ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
741
+
742
+ // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
743
+ int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
744
+ ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
745
+
746
+ // Split mixed_ba into b and a (beta and alpha parameters)
747
+ int64_t split_sizes_ba[2] = {
748
+ num_v_heads / num_k_heads, // beta size
749
+ num_v_heads / num_k_heads // alpha size
750
+ };
751
+
752
+ ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
753
+ mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
754
+ cb(b, "b", il);
755
+
756
+ ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
757
+ mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
758
+ split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
759
+ cb(a, "a", il);
760
+
761
+ // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
762
+ ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
763
+ ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
764
+
765
+ GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
766
+
767
+ ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
768
+ ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
769
+ cb(alpha_softplus, "a_softplus", il);
770
+ ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
771
+ cb(gate, "gate", il);
772
+
773
+ // Split mixed_qkvz into query, key, value, z
774
+ int64_t split_sizes_qkvz[4] = {
775
+ head_k_dim, // query size
776
+ head_k_dim, // key size
777
+ head_v_dim * num_v_heads / num_k_heads, // value size
778
+ head_v_dim * num_v_heads / num_k_heads // z size
779
+ };
780
+
781
+ ggml_tensor * query =
782
+ ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
783
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
784
+ cb(query, "q", il);
785
+
786
+ ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
787
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
788
+ split_sizes_qkvz[0] * sizeof(float));
789
+ cb(key, "k", il);
790
+
791
+ ggml_tensor * value =
792
+ ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
793
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
794
+ (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
795
+ cb(value, "v", il);
796
+
797
+ ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
798
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
799
+ (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
800
+ cb(z, "z", il);
801
+
802
+ GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) ==
803
+ ggml_nelements(mixed_qkvz));
804
+
805
+ // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
806
+ // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
807
+ ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
808
+ cb(query_flat, "query_flat", il);
809
+
810
+ // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
811
+ ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
812
+ cb(key_flat, "key_flat", il);
813
+
814
+ // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
815
+ ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
816
+ cb(value_flat, "value_flat", il);
817
+
818
+ // Get convolution states from cache
819
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
820
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
821
+
822
+ // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
823
+
824
+ // Build the convolution states tensor
825
+ ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
826
+ cb(conv_states, "conv_states", il);
827
+
828
+ // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
829
+ ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
830
+ qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
831
+ cb(qkv_mixed, "qkv_mixed", il);
832
+
833
+ qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
834
+ cb(qkv_mixed, "qkv_mixed_permuted", il);
835
+
836
+ // Calculate the total conv dimension
837
+ int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
838
+
839
+ // Calculate convolution kernel size
840
+ ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
841
+ const int64_t conv_kernel_size = conv_kernel->ne[0];
842
+ const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
843
+ conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
844
+ cb(conv_states, "conv_states_reshaped", il);
845
+
846
+ ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
847
+ cb(conv_input, "conv_input", il);
848
+
849
+ // Update convolution state cache
850
+ // Extract the last (conv_kernel_size - 1) states from conv_input
851
+ ggml_tensor * last_conv_states =
852
+ ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
853
+ conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
854
+ cb(last_conv_states, "last_conv_states", il);
855
+
856
+ ggml_tensor * state_update_target =
857
+ ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
858
+ kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
859
+ cb(state_update_target, "state_update_target", il);
860
+
861
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
862
+ cb(conv_states_all, "conv_states_updated", il);
863
+
864
+ // Apply SSM convolution
865
+ ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
866
+ cb(conv_output_proper, "conv_output_raw", il);
867
+
868
+ conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
869
+ cb(conv_output_proper, "conv_output_pre_silu", il);
870
+
871
+ ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
872
+ cb(conv_output_silu, "conv_output_silu", il);
873
+
874
+ ggml_tensor * conv_qkv_mix =
875
+ ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
876
+ cb(conv_qkv_mix, "conv_qkv_mix", il);
877
+
878
+ // Extract the convolved Q, K, V from conv_output
879
+ ggml_tensor * q_conv =
880
+ ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
881
+ cb(q_conv, "q_conv", il);
882
+ ggml_tensor * k_conv =
883
+ ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
884
+ head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
885
+ cb(k_conv, "k_conv", il);
886
+ ggml_tensor * v_conv =
887
+ ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
888
+ 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
889
+ cb(v_conv, "v_conv", il);
890
+
891
+ // Unsqueeze them
892
+ q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
893
+ k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
894
+ v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
895
+
896
+ beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
897
+
898
+ ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
899
+ state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
900
+ cb(state, "state_predelta", il);
901
+
902
+ // if head keys and value keys are different, repeat to force tensors into matching shapes
903
+ if (num_k_heads != num_v_heads) {
904
+ GGML_ASSERT(num_v_heads % num_k_heads == 0);
905
+ int64_t repeat_factor = num_v_heads / num_k_heads;
906
+
907
+ // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back
908
+ ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
909
+ ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
910
+
911
+ // Repeat along the third dimension (the new dimension with size 1)
912
+ ggml_tensor * q_repeated =
913
+ ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
914
+ ggml_tensor * k_repeated =
915
+ ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
916
+
917
+ // Reshape back to merge the head and repeat dimensions
918
+ // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
919
+ // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
920
+ q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
921
+ k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
922
+ }
923
+
924
+ cb(q_conv, "q_conv_predelta", il);
925
+ cb(k_conv, "k_conv_predelta", il);
926
+ cb(v_conv, "v_conv_predelta", il);
927
+
928
+ // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
929
+ ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
930
+ build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) :
931
+ build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
932
+ cb(attn_out, "attn_out", il);
933
+
934
+ // The tensors were concatenated 1d, so we need to extract them 1d as well
935
+ const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
936
+ ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
937
+ cb(attn_out_1d, "attn_out_1d", il);
938
+
939
+ ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
940
+ cb(attn_out_final, "attn_out_reshaped", il);
941
+
942
+ // Extract the state part (second part of the concatenated tensor)
943
+ // State starts after n_tokens elements along dimension 1
944
+ const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
945
+
946
+ ggml_tensor * state_1d =
947
+ ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
948
+ cb(state_1d, "state_1d", il);
949
+
950
+ // Update the recurrent states
951
+ ggml_build_forward_expand(gf,
952
+ ggml_cpy(ctx0, state_1d,
953
+ ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
954
+ kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
955
+
956
+ GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
957
+
958
+ // Reshape both attn_out_final and z to 2D tensors for normalization
959
+ // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
960
+ ggml_tensor * attn_out_2d_final =
961
+ ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
962
+
963
+ // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
964
+ ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
965
+
966
+ // Apply gated normalization: self.norm(core_attn_out, z)
967
+ ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
968
+
969
+ // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
970
+ ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
971
+ cb(final_output, "final_output", il);
972
+
973
+ // Output projection
974
+ cur = build_lora_mm(model.layers[il].ssm_out, final_output);
975
+ cb(cur, "linear_attn_out", il);
976
+
977
+ // Reshape back to original dimensions
978
+ cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
979
+ return cur;
980
+ }
981
+
982
+ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
983
+ // Check if this is an MoE layer
984
+ if (model.layers[il].ffn_gate_inp != nullptr) {
985
+ // MoE branch
986
+ ggml_tensor * moe_out =
987
+ build_moe_ffn(cur,
988
+ model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
989
+ model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
990
+ nullptr,
991
+ n_expert, n_expert_used, LLM_FFN_SILU,
992
+ true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
993
+ cb(moe_out, "ffn_moe_out", il);
994
+
995
+ // Add shared experts if present - following Qwen3Next reference implementation
996
+ if (model.layers[il].ffn_up_shexp != nullptr) {
997
+ ggml_tensor * ffn_shexp =
998
+ build_ffn(cur,
999
+ model.layers[il].ffn_up_shexp, NULL, NULL,
1000
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
1001
+ model.layers[il].ffn_down_shexp, NULL, NULL,
1002
+ NULL,
1003
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
1004
+ cb(ffn_shexp, "ffn_shexp", il);
1005
+
1006
+ // Apply shared expert gating as in the reference implementation
1007
+ // The shared expert has its own gate that is sigmoided
1008
+ // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
1009
+ ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
1010
+ cb(shared_gate, "shared_expert_gate", il);
1011
+
1012
+ // Apply sigmoid to the gate
1013
+ shared_gate = ggml_sigmoid(ctx0, shared_gate);
1014
+ cb(shared_gate, "shared_expert_gate_sigmoid", il);
1015
+
1016
+ // The gate needs to be broadcast to match the dimensions of ffn_shexp
1017
+ // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
1018
+ // We need to repeat the gate along the feature dimension
1019
+ shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
1020
+ cb(shared_gate, "shared_expert_gate_broadcast", il);
1021
+
1022
+ // Apply the gate to the shared expert output
1023
+ ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
1024
+ cb(ffn_shexp, "ffn_shexp_gated", il);
1025
+
1026
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
1027
+ cb(cur, "ffn_out", il);
1028
+ } else {
1029
+ cur = moe_out;
1030
+ }
1031
+ } else {
1032
+ // Dense FFN branch (not currently used I believe)
1033
+ cur = build_ffn(cur,
1034
+ model.layers[il].ffn_up, NULL, NULL,
1035
+ model.layers[il].ffn_gate, NULL, NULL,
1036
+ model.layers[il].ffn_down, NULL, NULL,
1037
+ NULL,
1038
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
1039
+ cb(cur, "ffn_out", il);
1040
+ }
1041
+ return cur;
1042
+ }