@fugood/llama.node 1.4.13 → 1.4.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/lib/binding.ts +23 -2
  2. package/lib/index.js +2 -1
  3. package/lib/index.ts +8 -1
  4. package/lib/parallel.ts +2 -2
  5. package/package.json +15 -15
  6. package/scripts/llama.cpp.patch +9 -12
  7. package/src/LlamaContext.cpp +16 -4
  8. package/src/llama.cpp/CMakeLists.txt +24 -8
  9. package/src/llama.cpp/common/CMakeLists.txt +3 -34
  10. package/src/llama.cpp/common/arg.cpp +183 -60
  11. package/src/llama.cpp/common/arg.h +0 -8
  12. package/src/llama.cpp/common/chat-parser.cpp +115 -0
  13. package/src/llama.cpp/common/chat.cpp +67 -0
  14. package/src/llama.cpp/common/chat.h +1 -0
  15. package/src/llama.cpp/common/common.cpp +2 -1
  16. package/src/llama.cpp/common/common.h +12 -7
  17. package/src/llama.cpp/common/debug.cpp +165 -0
  18. package/src/llama.cpp/common/debug.h +43 -0
  19. package/src/llama.cpp/common/download.cpp +88 -369
  20. package/src/llama.cpp/common/download.h +32 -5
  21. package/src/llama.cpp/common/preset.cpp +87 -2
  22. package/src/llama.cpp/common/preset.h +10 -1
  23. package/src/llama.cpp/ggml/include/ggml.h +5 -0
  24. package/src/llama.cpp/include/llama.h +5 -2
  25. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  26. package/src/llama.cpp/src/llama-arch.cpp +35 -0
  27. package/src/llama.cpp/src/llama-arch.h +1 -0
  28. package/src/llama.cpp/src/llama-chat.cpp +20 -0
  29. package/src/llama.cpp/src/llama-chat.h +1 -0
  30. package/src/llama.cpp/src/llama-graph.cpp +31 -43
  31. package/src/llama.cpp/src/llama-mmap.cpp +78 -42
  32. package/src/llama.cpp/src/llama-mmap.h +5 -4
  33. package/src/llama.cpp/src/llama-model-loader.cpp +17 -5
  34. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  35. package/src/llama.cpp/src/llama-model.cpp +225 -101
  36. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  37. package/src/llama.cpp/src/llama-sampling.cpp +1 -1
  38. package/src/llama.cpp/src/llama-vocab.cpp +37 -24
  39. package/src/llama.cpp/src/llama-vocab.h +1 -0
  40. package/src/llama.cpp/src/llama.cpp +63 -27
  41. package/src/llama.cpp/src/models/exaone-moe.cpp +146 -0
  42. package/src/llama.cpp/src/models/gemma3n-iswa.cpp +13 -3
  43. package/src/llama.cpp/src/models/models.h +13 -2
  44. package/src/llama.cpp/src/models/qwen3next.cpp +198 -182
@@ -86,7 +86,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
86
86
  ggml_build_forward_expand(gf, cur);
87
87
  }
88
88
 
89
- ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
89
+ // utility to get one slice from the third dimension
90
+ // input dim: [x, y, c, b]
91
+ // output dim: [x, y, 1, b]
92
+ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
93
+ return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
94
+ t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
95
+ }
96
+
97
+ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
90
98
  ggml_tensor * q,
91
99
  ggml_tensor * k,
92
100
  ggml_tensor * v,
@@ -187,18 +195,16 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
187
195
  beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
188
196
 
189
197
  ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
198
+ cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
190
199
 
191
- cb(g_cumsum, "g_cumsum", il);
192
-
193
- ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
200
+ ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
194
201
  ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
195
202
 
196
203
  ggml_tensor * gcs_j_broadcast =
197
204
  ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
198
205
 
199
206
  ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
200
-
201
- cb(decay_mask, "decay_mask", il);
207
+ cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
202
208
 
203
209
  decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
204
210
  decay_mask = ggml_exp(ctx0, decay_mask);
@@ -208,8 +214,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
208
214
 
209
215
  ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
210
216
  ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
211
-
212
- cb(attn, "attn_pre_solve", il);
217
+ cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
213
218
 
214
219
  ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
215
220
  ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
@@ -217,8 +222,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
217
222
  ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
218
223
  attn = ggml_mul(ctx0, lin_solve, causal_mask);
219
224
  attn = ggml_add(ctx0, attn, identity);
220
-
221
- cb(attn, "attn_solved", il);
225
+ cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
222
226
 
223
227
  v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
224
228
 
@@ -226,116 +230,126 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
226
230
  ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
227
231
 
228
232
  ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
229
-
230
- cb(kbeta_gexp, "kbeta_gexp", il);
233
+ cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
231
234
 
232
235
  ggml_tensor * k_cumdecay =
233
236
  ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
237
+ cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
234
238
 
235
- cb(k_cumdecay, "k_cumdecay", il);
239
+ ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
240
+ attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
241
+ attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
242
+ cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
236
243
 
237
- ggml_tensor * core_attn_out = nullptr;
238
- ggml_tensor * new_state = ggml_dup(ctx0, state);
239
244
 
240
- cb(new_state, "new_state", il);
245
+ // vectorized calculation of key_gdiff
246
+ // improved from the chunked version:
247
+ // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
248
+ // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
249
+ // key_gdiff = key * g_diff.unsqueeze(-1)
250
+ // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
251
+ // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
241
252
 
242
- for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
243
- auto chunkify = [=](ggml_tensor * t) {
244
- return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
245
- t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
246
- };
253
+ // get last element in g_cumsum along chunk_size dimension (ne0)
254
+ // example: [[x, y, z, ..., last], ...] -> [[last], ...]
255
+ ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
256
+ g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
257
+ (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
258
+ g_last = ggml_cont(ctx0, g_last);
259
+ cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
247
260
 
248
- auto chunkify_g = [=](ggml_tensor * t) {
249
- return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
250
- t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
251
- };
261
+ ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
262
+ cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
263
+
264
+ ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
265
+ cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
266
+
267
+ ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
268
+ ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
269
+ cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
270
+
271
+
272
+ // state to be updated per chunk
273
+ ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
274
+ cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
252
275
 
253
- ggml_tensor * k_chunk = chunkify(k);
254
- ggml_tensor * q_chunk = chunkify(q);
255
- ggml_tensor * v_chunk = chunkify(v);
276
+ // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
277
+ ggml_tensor * core_attn_out = nullptr;
278
+
279
+ for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
280
+ // shape: (S_k, chunk_size, 1, H_k * n_seqs)
281
+ ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
256
282
 
257
- ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
258
- ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
283
+ // shape: (S_v, chunk_size, 1, H_v * n_seqs)
284
+ ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
259
285
 
260
- ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
261
- ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
286
+ // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
287
+ ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
262
288
 
263
- ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
289
+ // shape: (chunk_size, 1, H_v * n_seqs)
290
+ ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
264
291
 
265
292
  // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
266
- attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
267
- attn = ggml_mul(ctx0, attn, decay_mask_chunk);
268
- attn = ggml_mul(ctx0, attn, diag_mask);
293
+ // replaced by precomputed attn_kq
294
+ ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
295
+ cb(attn_chunk, "attn_chunk", il);
269
296
 
270
297
  ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
271
298
 
272
299
  // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
273
300
  ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
301
+ cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
274
302
 
275
303
  // v_new = v_i - v_prime
276
304
  ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
277
305
  ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
306
+ cb(v_new, "v_new_chunk", il);
278
307
 
279
308
  // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
280
309
  ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
281
310
  ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
311
+ cb(attn_inter, "attn_inter_chunk", il);
282
312
 
283
313
  // core_attn_out[:, :, i] = attn_inter + attn @ v_new
284
- ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
314
+ ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
315
+ cb(v_attn, "v_attn_chunk", il);
285
316
 
286
317
  ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
318
+ cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
287
319
 
288
- core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
320
+ core_attn_out = core_attn_out == nullptr
321
+ ? core_attn_out_chunk
322
+ : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
289
323
 
290
- // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
291
- // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
292
- // key_gdiff = key * g_diff.unsqueeze(-1)
293
324
  // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
294
- // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
295
-
296
- ggml_tensor * g_cum_last =
297
- ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
298
- g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
299
- g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
300
-
301
- ggml_tensor * gexp_last =
302
- ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
303
-
304
- ggml_tensor * g_cum_last_3d =
305
- ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
306
-
307
- ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
308
-
309
- ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
310
-
311
- ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
312
-
313
- ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
314
- ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
315
- g_diff_exp->ne[2] * g_diff_exp->ne[3]));
316
-
317
- ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
325
+ ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
326
+ //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
327
+ ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
318
328
 
329
+ // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
330
+ ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
319
331
  new_state = ggml_add(ctx0,
320
- ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
332
+ ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
321
333
  ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
322
334
  }
323
335
 
324
- core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
325
-
326
- ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
336
+ // truncate padded tokens
337
+ ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
338
+ S_v, n_tokens, H_v, n_seqs,
339
+ ggml_row_size(core_attn_out->type, S_v),
340
+ ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
341
+ ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
342
+ output_tokens = ggml_cont(ctx0, output_tokens);
327
343
  cb(output_tokens, "output_tokens", il);
328
344
 
329
- // flatten output
330
- ggml_tensor * flat_output =
331
- ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
332
-
333
- ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
345
+ // permute back to (S_v, H_v, n_tokens, n_seqs)
346
+ output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
347
+ output_tokens = ggml_cont(ctx0, output_tokens);
334
348
 
335
- return ggml_concat(ctx0, flat_output, flat_state, 0);
349
+ return {output_tokens, new_state};
336
350
  }
337
351
 
338
- ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
352
+ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
339
353
  ggml_tensor * q,
340
354
  ggml_tensor * k,
341
355
  ggml_tensor * v,
@@ -419,11 +433,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
419
433
  cb(core_attn_out, "output_tokens", il);
420
434
  cb(state, "new_state", il);
421
435
 
422
- // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
423
- ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
424
- ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
425
-
426
- return ggml_concat(ctx0, flat_output, flat_state, 0);
436
+ return {core_attn_out, state};
427
437
  }
428
438
 
429
439
  ggml_tensor * llm_build_qwen3next::build_norm_gated(
@@ -523,6 +533,88 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
523
533
  return cur;
524
534
  }
525
535
 
536
+ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
537
+ ggml_tensor * input,
538
+ int il) {
539
+ const int64_t d_inner = hparams.ssm_d_inner;
540
+ const int64_t n_seqs = ubatch.n_seqs;
541
+ const int64_t head_k_dim = hparams.ssm_d_state;
542
+ const int64_t num_k_heads = hparams.ssm_n_group;
543
+ const int64_t num_v_heads = hparams.ssm_dt_rank;
544
+ const int64_t head_v_dim = d_inner / num_v_heads;
545
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
546
+
547
+ if (model.layers[il].wqkv) {
548
+ // optimized path
549
+ ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
550
+ qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
551
+ cb(qkv_mixed, "linear_attn_qkv_mixed", il);
552
+
553
+ ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
554
+ cb(z, "z", il);
555
+
556
+ return { qkv_mixed, z };
557
+
558
+ } else {
559
+ // legacy (slower) path
560
+ ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
561
+ cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
562
+
563
+ int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
564
+ ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
565
+
566
+ // Split mixed_qkvz into query, key, value, z
567
+ int64_t split_sizes_qkvz[4] = {
568
+ head_k_dim, // query size
569
+ head_k_dim, // key size
570
+ head_v_dim * num_v_heads / num_k_heads, // value size
571
+ head_v_dim * num_v_heads / num_k_heads // z size
572
+ };
573
+
574
+ ggml_tensor * query =
575
+ ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
576
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
577
+ cb(query, "q", il);
578
+
579
+ ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
580
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
581
+ split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
582
+ cb(key, "k", il);
583
+
584
+ ggml_tensor * value =
585
+ ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
586
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
587
+ (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
588
+ cb(value, "v", il);
589
+
590
+ ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
591
+ mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
592
+ (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
593
+ z = ggml_cont(ctx0, z);
594
+ cb(z, "z", il);
595
+
596
+ // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
597
+ // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
598
+ ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
599
+ cb(query_flat, "query_flat", il);
600
+
601
+ // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
602
+ ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
603
+ cb(key_flat, "key_flat", il);
604
+
605
+ // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
606
+ ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
607
+ cb(value_flat, "value_flat", il);
608
+
609
+ // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
610
+ ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
611
+ qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
612
+ cb(qkv_mixed, "qkv_mixed", il);
613
+
614
+ return { qkv_mixed, z };
615
+ }
616
+ }
617
+
526
618
  ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
527
619
  llm_graph_input_rs * inp,
528
620
  ggml_tensor * cur,
@@ -547,15 +639,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
547
639
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
548
640
 
549
641
  // Input projections
550
- ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
551
- cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
642
+ auto qkvz = build_qkvz(cur, il);
643
+ ggml_tensor * qkv_mixed = qkvz.first;
644
+ ggml_tensor * z = qkvz.second;
552
645
 
553
646
  ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
554
647
  cb(mixed_ba, "linear_attn_mixed_ba", il);
555
648
 
556
- int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
557
- ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
558
-
559
649
  // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
560
650
  int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
561
651
  ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
@@ -575,8 +665,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
575
665
  split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
576
666
  cb(a, "a", il);
577
667
 
578
- // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
579
- ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
668
+ ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
669
+
670
+ // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
580
671
  ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
581
672
 
582
673
  ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
@@ -585,48 +676,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
585
676
  ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
586
677
  cb(gate, "gate", il);
587
678
 
588
- // Split mixed_qkvz into query, key, value, z
589
- int64_t split_sizes_qkvz[4] = {
590
- head_k_dim, // query size
591
- head_k_dim, // key size
592
- head_v_dim * num_v_heads / num_k_heads, // value size
593
- head_v_dim * num_v_heads / num_k_heads // z size
594
- };
595
-
596
- ggml_tensor * query =
597
- ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
598
- mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
599
- cb(query, "q", il);
600
-
601
- ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
602
- mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
603
- split_sizes_qkvz[0] * sizeof(float));
604
- cb(key, "k", il);
605
-
606
- ggml_tensor * value =
607
- ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
608
- mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
609
- (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
610
- cb(value, "v", il);
611
-
612
- ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
613
- mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
614
- (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
615
- cb(z, "z", il);
616
-
617
- // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
618
- // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
619
- ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
620
- cb(query_flat, "query_flat", il);
621
-
622
- // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
623
- ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
624
- cb(key_flat, "key_flat", il);
625
-
626
- // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
627
- ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
628
- cb(value_flat, "value_flat", il);
629
-
630
679
  // Get convolution states from cache
631
680
  ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
632
681
  ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
@@ -637,17 +686,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
637
686
  ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
638
687
  cb(conv_states, "conv_states", il);
639
688
 
640
- // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
641
- ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
642
- qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
643
- cb(qkv_mixed, "qkv_mixed", il);
644
-
645
- qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
646
- cb(qkv_mixed, "qkv_mixed_permuted", il);
647
-
648
- // Calculate the total conv dimension
649
- int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
650
-
651
689
  // Calculate convolution kernel size
652
690
  ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
653
691
  const int64_t conv_kernel_size = conv_kernel->ne[0];
@@ -655,6 +693,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
655
693
  conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
656
694
  cb(conv_states, "conv_states_reshaped", il);
657
695
 
696
+ qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
697
+ cb(qkv_mixed, "qkv_mixed_permuted", il);
698
+
658
699
  ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
659
700
  cb(conv_input, "conv_input", il);
660
701
 
@@ -677,26 +718,25 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
677
718
  ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
678
719
  cb(conv_output_proper, "conv_output_raw", il);
679
720
 
680
- conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
681
- cb(conv_output_proper, "conv_output_pre_silu", il);
682
-
683
721
  ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
684
722
  cb(conv_output_silu, "conv_output_silu", il);
685
723
 
686
- ggml_tensor * conv_qkv_mix =
687
- ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
688
- cb(conv_qkv_mix, "conv_qkv_mix", il);
724
+ ggml_tensor * conv_qkv_mix = conv_output_silu;
725
+
726
+ // Calculate the total conv dimension
727
+ int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
728
+ int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
689
729
 
690
730
  // Extract the convolved Q, K, V from conv_output
691
731
  ggml_tensor * q_conv =
692
- ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
732
+ ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
693
733
  cb(q_conv, "q_conv", il);
694
734
  ggml_tensor * k_conv =
695
- ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
735
+ ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
696
736
  head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
697
737
  cb(k_conv, "k_conv", il);
698
738
  ggml_tensor * v_conv =
699
- ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
739
+ ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
700
740
  2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
701
741
  cb(v_conv, "v_conv", il);
702
742
 
@@ -705,8 +745,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
705
745
  k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
706
746
  v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
707
747
 
708
- beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
709
-
710
748
  ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
711
749
  state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
712
750
  cb(state, "state_predelta", il);
@@ -738,45 +776,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
738
776
  cb(v_conv, "v_conv_predelta", il);
739
777
 
740
778
  // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
741
- ggml_tensor * attn_out;
779
+ std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
742
780
  if (n_seq_tokens == 1) {
743
781
  attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
744
782
  } else {
745
783
  attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
746
784
  }
747
- cb(attn_out, "attn_out", il);
748
-
749
- // The tensors were concatenated 1d, so we need to extract them 1d as well
750
- const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
751
- ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
752
- cb(attn_out_1d, "attn_out_1d", il);
753
-
754
- ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
755
- cb(attn_out_final, "attn_out_reshaped", il);
756
-
757
- // Extract the state part (second part of the concatenated tensor)
758
- // State starts after n_tokens elements along dimension 1
759
- const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
760
-
761
- ggml_tensor * state_1d =
762
- ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
763
- cb(state_1d, "state_1d", il);
785
+ ggml_tensor * output = attn_out.first;
786
+ ggml_tensor * new_state = attn_out.second;
787
+ cb(output, "attn_output", il);
788
+ cb(new_state, "new_state", il);
764
789
 
765
790
  // Update the recurrent states
766
791
  ggml_build_forward_expand(gf,
767
- ggml_cpy(ctx0, state_1d,
792
+ ggml_cpy(ctx0, new_state,
768
793
  ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
769
794
  kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
770
795
 
771
- GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
772
-
773
796
  // Reshape both attn_out_final and z to 2D tensors for normalization
774
797
  // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
775
- ggml_tensor * attn_out_2d_final =
776
- ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
798
+ ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
777
799
 
778
800
  // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
779
- ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
801
+ ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
780
802
 
781
803
  // Apply gated normalization: self.norm(core_attn_out, z)
782
804
  ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
@@ -828,12 +850,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
828
850
  shared_gate = ggml_sigmoid(ctx0, shared_gate);
829
851
  cb(shared_gate, "shared_expert_gate_sigmoid", il);
830
852
 
831
- // The gate needs to be broadcast to match the dimensions of ffn_shexp
832
- // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
833
- // We need to repeat the gate along the feature dimension
834
- shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
835
- cb(shared_gate, "shared_expert_gate_broadcast", il);
836
-
837
853
  // Apply the gate to the shared expert output
838
854
  ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
839
855
  cb(ffn_shexp, "ffn_shexp_gated", il);