@fugood/llama.node 1.4.7 → 1.4.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. package/lib/binding.ts +8 -0
  2. package/package.json +15 -15
  3. package/scripts/llama.cpp.patch +22 -23
  4. package/src/LlamaContext.cpp +2 -2
  5. package/src/llama.cpp/common/CMakeLists.txt +2 -0
  6. package/src/llama.cpp/common/arg.cpp +364 -193
  7. package/src/llama.cpp/common/arg.h +43 -2
  8. package/src/llama.cpp/common/chat-peg-parser.cpp +16 -2
  9. package/src/llama.cpp/common/chat.cpp +140 -0
  10. package/src/llama.cpp/common/common.cpp +130 -67
  11. package/src/llama.cpp/common/common.h +40 -16
  12. package/src/llama.cpp/common/console.cpp +98 -18
  13. package/src/llama.cpp/common/console.h +30 -8
  14. package/src/llama.cpp/common/download.cpp +69 -25
  15. package/src/llama.cpp/common/json-schema-to-grammar.cpp +132 -3
  16. package/src/llama.cpp/common/json-schema-to-grammar.h +20 -0
  17. package/src/llama.cpp/common/log.cpp +5 -0
  18. package/src/llama.cpp/common/log.h +1 -0
  19. package/src/llama.cpp/common/peg-parser.cpp +1 -1
  20. package/src/llama.cpp/common/preset.cpp +206 -0
  21. package/src/llama.cpp/common/preset.h +32 -0
  22. package/src/llama.cpp/common/sampling.cpp +91 -92
  23. package/src/llama.cpp/common/sampling.h +11 -6
  24. package/src/llama.cpp/common/speculative.cpp +1 -1
  25. package/src/llama.cpp/ggml/CMakeLists.txt +4 -0
  26. package/src/llama.cpp/ggml/include/ggml-alloc.h +9 -0
  27. package/src/llama.cpp/ggml/include/ggml-backend.h +1 -0
  28. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
  29. package/src/llama.cpp/ggml/include/ggml.h +7 -8
  30. package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
  31. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +60 -39
  33. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +2 -1
  35. package/src/llama.cpp/include/llama.h +18 -1
  36. package/src/llama.cpp/src/llama-arch.cpp +1890 -2248
  37. package/src/llama.cpp/src/llama-arch.h +9 -2
  38. package/src/llama.cpp/src/llama-batch.cpp +12 -2
  39. package/src/llama.cpp/src/llama-batch.h +4 -2
  40. package/src/llama.cpp/src/llama-context.cpp +93 -23
  41. package/src/llama.cpp/src/llama-context.h +8 -2
  42. package/src/llama.cpp/src/llama-graph.cpp +84 -16
  43. package/src/llama.cpp/src/llama-graph.h +17 -4
  44. package/src/llama.cpp/src/llama-hparams.cpp +6 -0
  45. package/src/llama.cpp/src/llama-hparams.h +5 -1
  46. package/src/llama.cpp/src/llama-impl.cpp +4 -0
  47. package/src/llama.cpp/src/llama-kv-cache.cpp +90 -42
  48. package/src/llama.cpp/src/llama-kv-cache.h +19 -2
  49. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -1
  50. package/src/llama.cpp/src/llama-model-loader.cpp +2 -0
  51. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  52. package/src/llama.cpp/src/llama-model.cpp +103 -44
  53. package/src/llama.cpp/src/llama-model.h +1 -0
  54. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  55. package/src/llama.cpp/src/llama-vocab.cpp +2 -1
  56. package/src/llama.cpp/src/llama.cpp +675 -1
  57. package/src/llama.cpp/src/models/deepseek2.cpp +9 -5
  58. package/src/llama.cpp/src/models/glm4-moe.cpp +28 -11
  59. package/src/llama.cpp/src/models/glm4.cpp +27 -4
  60. package/src/llama.cpp/src/models/models.h +5 -5
  61. package/src/llama.cpp/src/models/nemotron-h.cpp +35 -6
  62. package/src/llama.cpp/src/models/qwen2.cpp +12 -3
  63. package/src/llama.cpp/src/models/qwen3next.cpp +81 -266
@@ -17,13 +17,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
17
17
  ggml_tensor * inp_out_ids = build_inp_out_ids();
18
18
 
19
19
  ggml_tensor * causal_mask =
20
- ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
20
+ ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
21
21
  GGML_TRI_TYPE_LOWER);
22
22
 
23
- ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
23
+ ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
24
+ ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
24
25
 
25
26
  ggml_build_forward_expand(gf, causal_mask);
26
27
  ggml_build_forward_expand(gf, identity);
28
+ ggml_build_forward_expand(gf, diag_mask);
27
29
 
28
30
  for (int il = 0; il < n_layer; ++il) {
29
31
  ggml_tensor * inpSA = inpL;
@@ -34,7 +36,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
34
36
  // Determine layer type and build appropriate attention mechanism
35
37
  if (hparams.is_recurrent(il)) {
36
38
  // Linear attention layer (gated delta net)
37
- cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il);
39
+ cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
38
40
  } else {
39
41
  // Full attention layer
40
42
  cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
@@ -93,14 +95,8 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
93
95
  ggml_tensor * state,
94
96
  ggml_tensor * causal_mask,
95
97
  ggml_tensor * identity,
98
+ ggml_tensor * diag_mask,
96
99
  int il) {
97
- GGML_ASSERT(ggml_is_contiguous(q));
98
- GGML_ASSERT(ggml_is_contiguous(k));
99
- GGML_ASSERT(ggml_is_contiguous(v));
100
- GGML_ASSERT(ggml_is_contiguous(g));
101
- GGML_ASSERT(ggml_is_contiguous(beta));
102
- GGML_ASSERT(ggml_is_contiguous(state));
103
-
104
100
  const int64_t S_k = q->ne[0];
105
101
  const int64_t H_k = q->ne[1];
106
102
  const int64_t n_tokens = q->ne[2];
@@ -120,15 +116,10 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
120
116
 
121
117
  GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
122
118
 
123
- // TODO: can this ever be false?
124
- const bool use_qk_l2norm = true;
125
-
126
- if (use_qk_l2norm) {
127
- const float eps_norm = hparams.f_norm_rms_eps;
119
+ const float eps_norm = hparams.f_norm_rms_eps;
128
120
 
129
- q = ggml_l2_norm(ctx0, q, eps_norm);
130
- k = ggml_l2_norm(ctx0, k, eps_norm);
131
- }
121
+ q = ggml_l2_norm(ctx0, q, eps_norm);
122
+ k = ggml_l2_norm(ctx0, k, eps_norm);
132
123
 
133
124
  const float scale = 1.0f / sqrtf(S_v);
134
125
 
@@ -136,8 +127,6 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
136
127
 
137
128
  beta = ggml_sigmoid(ctx0, beta);
138
129
 
139
- ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
140
-
141
130
  cb(q, "q_in", il);
142
131
  cb(k, "k_in", il);
143
132
  cb(v, "v_in", il);
@@ -188,36 +177,21 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
188
177
  cb(v_beta, "v_beta", il);
189
178
  cb(k_beta, "k_beta", il);
190
179
 
191
- ggml_tensor * chunked_mask =
192
- ggml_view_4d(ctx0, causal_mask, chunk_size,
193
- chunk_size, causal_mask->ne[2], causal_mask->ne[3],
194
- causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0);
180
+ q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
181
+ k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
182
+ k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
183
+ v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
184
+ v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
195
185
 
196
- ggml_tensor * chunked_diag_mask =
197
- ggml_view_4d(ctx0, causal_diag_mask, chunk_size,
198
- chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3],
199
- causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0);
200
-
201
- ggml_tensor * chunked_identity =
202
- ggml_view_4d(ctx0, identity, chunk_size,
203
- chunk_size, identity->ne[2], identity->ne[3],
204
- identity->nb[1], identity->nb[2], identity->nb[3], 0);
205
-
206
- q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
207
- k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
208
- k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
209
- v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
210
- v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
211
-
212
- g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
213
- beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
186
+ g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
187
+ beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
214
188
 
215
189
  ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
216
190
 
217
191
  cb(g_cumsum, "g_cumsum", il);
218
192
 
219
- ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
220
- ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
193
+ ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
194
+ ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
221
195
 
222
196
  ggml_tensor * gcs_j_broadcast =
223
197
  ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
@@ -226,23 +200,23 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
226
200
 
227
201
  cb(decay_mask, "decay_mask", il);
228
202
 
229
- decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
203
+ decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
230
204
  decay_mask = ggml_exp(ctx0, decay_mask);
231
- decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask);
205
+ decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
232
206
 
233
207
  ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
234
208
 
235
209
  ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
236
- ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask));
210
+ ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
237
211
 
238
212
  cb(attn, "attn_pre_solve", il);
239
213
 
240
- ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask);
241
- ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower);
214
+ ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
215
+ ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
242
216
 
243
217
  ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
244
- attn = ggml_mul(ctx0, lin_solve, chunked_mask);
245
- attn = ggml_add(ctx0, attn, chunked_identity);
218
+ attn = ggml_mul(ctx0, lin_solve, causal_mask);
219
+ attn = ggml_add(ctx0, attn, identity);
246
220
 
247
221
  cb(attn, "attn_solved", il);
248
222
 
@@ -291,7 +265,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
291
265
  // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
292
266
  attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
293
267
  attn = ggml_mul(ctx0, attn, decay_mask_chunk);
294
- attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask));
268
+ attn = ggml_mul(ctx0, attn, diag_mask);
295
269
 
296
270
  ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
297
271
 
@@ -361,23 +335,14 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
361
335
  return ggml_concat(ctx0, flat_output, flat_state, 0);
362
336
  }
363
337
 
364
- ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
338
+ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
365
339
  ggml_tensor * q,
366
340
  ggml_tensor * k,
367
341
  ggml_tensor * v,
368
342
  ggml_tensor * g,
369
343
  ggml_tensor * beta,
370
344
  ggml_tensor * state,
371
- ggml_tensor * causal_mask,
372
- ggml_tensor * identity,
373
345
  int il) {
374
- GGML_ASSERT(ggml_is_contiguous(q));
375
- GGML_ASSERT(ggml_is_contiguous(k));
376
- GGML_ASSERT(ggml_is_contiguous(v));
377
- GGML_ASSERT(ggml_is_contiguous(g));
378
- GGML_ASSERT(ggml_is_contiguous(beta));
379
- GGML_ASSERT(ggml_is_contiguous(state));
380
-
381
346
  const int64_t S_k = q->ne[0];
382
347
  const int64_t H_k = q->ne[1];
383
348
  const int64_t n_tokens = q->ne[2];
@@ -386,6 +351,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
386
351
  const int64_t S_v = v->ne[0];
387
352
  const int64_t H_v = v->ne[1];
388
353
 
354
+ GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing
389
355
  GGML_ASSERT(v->ne[2] == n_tokens);
390
356
  GGML_ASSERT(k->ne[2] == n_tokens);
391
357
  GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
@@ -397,215 +363,65 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
397
363
 
398
364
  GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
399
365
 
400
- // TODO: can this ever be false?
401
- const bool use_qk_l2norm = true;
402
-
403
- if (use_qk_l2norm) {
404
- const float eps_norm = hparams.f_norm_rms_eps;
366
+ const float eps_norm = hparams.f_norm_rms_eps;
405
367
 
406
- q = ggml_l2_norm(ctx0, q, eps_norm);
407
- k = ggml_l2_norm(ctx0, k, eps_norm);
408
- }
368
+ q = ggml_l2_norm(ctx0, q, eps_norm);
369
+ k = ggml_l2_norm(ctx0, k, eps_norm);
409
370
 
410
371
  const float scale = 1.0f / sqrtf(S_v);
411
372
 
412
- q = ggml_scale(ctx0, q, scale);
413
-
373
+ q = ggml_scale(ctx0, q, scale);
414
374
  beta = ggml_sigmoid(ctx0, beta);
415
375
 
416
- ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity);
417
-
418
376
  cb(q, "q_in", il);
419
377
  cb(k, "k_in", il);
420
378
  cb(v, "v_in", il);
421
379
  cb(beta, "beta_in", il);
422
380
  cb(g, "g_in", il);
423
381
 
424
- q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
425
- k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
426
- v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
427
- g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
428
-
429
- beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
430
382
  state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
431
383
 
432
- cb(q, "q_perm", il);
433
- cb(k, "k_perm", il);
434
- cb(v, "v_perm", il);
435
- cb(beta, "beta_perm", il);
436
- cb(g, "g_perm", il);
437
- cb(state, "state_in", il);
438
-
439
- GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
440
- GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
441
- GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
442
- GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
443
-
444
- ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
445
- ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
446
-
447
- ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
448
-
449
- cb(k_beta, "k_beta", il);
450
- cb(v_beta, "v_beta", il);
451
- cb(g_cumsum, "g_cumsum", il);
452
-
453
- ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
454
- ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
455
-
456
- // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
457
- // ggml_tensor * gcs_i_broadcast =
458
- // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
459
- // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
460
- // Don't need this, this one will get auto-broadcast
461
- ggml_tensor * gcs_j_broadcast =
462
- ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
463
-
464
- ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
465
-
466
- // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
467
- decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
468
- // Apply exponential to get the decay mask values
469
- decay_mask = ggml_exp(ctx0, decay_mask);
470
- // Apply lower triangular mask again to ensure only lower triangular values remain
471
- decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask);
472
-
473
- cb(decay_mask, "decay_mask", il);
474
-
475
- // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
476
- ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
477
-
478
- cb(kmulkbeta, "kmulkbeta", il);
479
-
480
- ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
481
- ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
482
-
483
- cb(attn, "attn_pre_rec", il);
484
-
485
- // for i in range(1, chunk_size):
486
- // row = attn[..., i, :i].clone()
487
- // sub = attn[..., :i, :i].clone()
488
- // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
489
- // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
490
- //
491
- // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
492
- ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
493
- ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
494
-
495
- ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
496
- attn = ggml_mul(ctx0, lin_solve, causal_mask);
497
- attn = ggml_add(ctx0, attn, identity);
498
-
499
- // value = attn @ v_beta
500
- v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
501
-
502
- cb(v, "value_beta", il);
503
-
504
- // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
505
- ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
506
- ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
507
-
508
- cb(gexp, "g_cum_exp", il);
509
-
510
- ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
511
-
512
- cb(kbeta_gexp, "kbeta_gexp", il);
513
-
514
- ggml_tensor * k_cumdecay =
515
- ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
516
-
517
- cb(k_cumdecay, "k_cumdecay", il);
518
-
519
- // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
520
- attn = ggml_mul_mat(ctx0, k, q);
521
- attn = ggml_mul(ctx0, attn, decay_mask);
522
- attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask));
523
-
524
- cb(attn, "attn_decay_key", il);
525
-
526
- ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
527
-
528
- // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
529
- ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay);
530
-
531
- cb(v_prime, "v_prime", il);
532
-
533
- // v_new = v_i - v_prime
534
- ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime);
535
-
536
- ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
537
-
538
- cb(v_new, "v_new", il);
539
-
540
- // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
541
- ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp);
542
- ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
543
-
544
- cb(attn_inter, "attn_inter", il);
545
-
546
- // core_attn_out[:, :, i] = attn_inter + attn @ v_new
547
- ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
548
-
549
- cb(v_attn, "v_attn", il);
550
-
551
- ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn);
552
-
553
- cb(core_attn_out, "core_attn_out", il);
554
-
555
- // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
556
- // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
557
- // key_gdiff = key * g_diff.unsqueeze(-1)
558
- // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
559
- // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
560
-
561
- ggml_tensor * g_cum_last =
562
- ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3],
563
- g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3],
564
- g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
565
-
566
- cb(g_cum_last, "g_cum_last", il);
567
-
568
- ggml_tensor * gexp_last =
569
- ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
570
-
571
- cb(gexp_last, "gexp_last", il);
572
-
573
- ggml_tensor * g_cum_last_3d =
574
- ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
575
-
576
- cb(g_cum_last_3d, "g_cum_last_3d", il);
577
-
578
- ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
579
-
580
- cb(g_cumsum_3d, "g_cumsum_3d", il);
581
-
582
- ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
583
-
584
- cb(g_diff, "g_diff", il);
585
-
586
- ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
587
-
588
- cb(g_diff_exp, "g_diff_exp", il);
589
-
590
- ggml_tensor * key_gdiff = ggml_mul(ctx0, k,
591
- ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
592
- g_diff_exp->ne[2] * g_diff_exp->ne[3]));
593
-
594
- cb(key_gdiff, "key_gdiff", il);
595
-
596
- ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
597
-
598
- cb(kgdmulvnew, "kgdmulvnew", il);
599
-
600
- state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew);
601
-
384
+ ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
385
+ ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
386
+
387
+ // Apply exponential to g_t
388
+ g_t = ggml_exp(ctx0, g_t);
389
+
390
+ // Apply the gated delta rule for the single timestep
391
+ // last_recurrent_state = last_recurrent_state * g_t
392
+ state = ggml_mul(ctx0, state, g_t);
393
+
394
+ // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
395
+ ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
396
+ ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
397
+ // we need to sum over dim=-2, so we transpose, sum, then transpose again
398
+ kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
399
+
400
+ // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
401
+ ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
402
+ // delta = (v_t - kv_mem) * beta_t
403
+ ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
404
+ ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
405
+
406
+ // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
407
+ ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
408
+ state = ggml_add(ctx0, state, k_t_delta);
409
+
410
+ // Compute the attention output
411
+ // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
412
+ ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
413
+ ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
414
+ // again, since it's over dim = -2, transpose, sum, transpose back
415
+ ggml_tensor * core_attn_out =
416
+ ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
417
+
418
+ // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
419
+ cb(core_attn_out, "output_tokens", il);
602
420
  cb(state, "new_state", il);
603
421
 
604
- // flatten output
605
- ggml_tensor * flat_output =
606
- ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
607
-
608
- ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
422
+ // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
423
+ ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
424
+ ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
609
425
 
610
426
  return ggml_concat(ctx0, flat_output, flat_state, 0);
611
427
  }
@@ -712,6 +528,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
712
528
  ggml_tensor * cur,
713
529
  ggml_tensor * causal_mask,
714
530
  ggml_tensor * identity,
531
+ ggml_tensor * diag_mask,
715
532
  int il) {
716
533
  const auto * mctx_cur = inp->mctx;
717
534
 
@@ -737,11 +554,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
737
554
  cb(mixed_ba, "linear_attn_mixed_ba", il);
738
555
 
739
556
  int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
740
- ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
557
+ ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
741
558
 
742
559
  // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
743
560
  int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
744
- ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
561
+ ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
745
562
 
746
563
  // Split mixed_ba into b and a (beta and alpha parameters)
747
564
  int64_t split_sizes_ba[2] = {
@@ -762,8 +579,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
762
579
  ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
763
580
  ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
764
581
 
765
- GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
766
-
767
582
  ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
768
583
  ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
769
584
  cb(alpha_softplus, "a_softplus", il);
@@ -799,9 +614,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
799
614
  (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
800
615
  cb(z, "z", il);
801
616
 
802
- GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) ==
803
- ggml_nelements(mixed_qkvz));
804
-
805
617
  // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
806
618
  // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
807
619
  ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
@@ -925,10 +737,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
925
737
  cb(k_conv, "k_conv_predelta", il);
926
738
  cb(v_conv, "v_conv_predelta", il);
927
739
 
928
- // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
929
- ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
930
- build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) :
931
- build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
740
+ // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
741
+ ggml_tensor * attn_out;
742
+ if (n_seq_tokens == 1) {
743
+ attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
744
+ } else {
745
+ attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
746
+ }
932
747
  cb(attn_out, "attn_out", il);
933
748
 
934
749
  // The tensors were concatenated 1d, so we need to extract them 1d as well