@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -55,7 +55,37 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
55
55
  if (ubatch->pos && pos) {
56
56
  const int64_t n_tokens = ubatch->n_tokens;
57
57
 
58
- ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
58
+ if (ubatch->token && n_pos_per_embd == 4) {
59
+ // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
60
+ // the 3 first dims are the same, and 4th dim is all 0
61
+ std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
62
+ // copy the first dimension
63
+ for (int i = 0; i < n_tokens; ++i) {
64
+ pos_data[ i] = ubatch->pos[i];
65
+ pos_data[ n_tokens + i] = ubatch->pos[i];
66
+ pos_data[2 * n_tokens + i] = ubatch->pos[i];
67
+ pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
68
+ }
69
+ ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
70
+ } else {
71
+ ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
72
+ }
73
+ }
74
+ }
75
+
76
+ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
77
+ if (ubatch->pos && attn_scale) {
78
+ const int64_t n_tokens = ubatch->n_tokens;
79
+
80
+ std::vector<float> attn_scale_data(n_tokens, 0.0f);
81
+ for (int i = 0; i < n_tokens; ++i) {
82
+ const float pos = ubatch->pos[i];
83
+ attn_scale_data[i] = std::log(
84
+ std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
85
+ ) * f_attn_temp_scale + 1.0;
86
+ }
87
+
88
+ ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
59
89
  }
60
90
  }
61
91
 
@@ -402,120 +432,94 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
402
432
 
403
433
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
404
434
  if (self_kq_mask || self_kq_mask_swa) {
405
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406
- if (cparams.causal_attn) {
407
- const int64_t n_kv = kv_self->n;
408
- const int64_t n_tokens = ubatch->n_tokens;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410
- const int64_t n_seqs = ubatch->n_seqs;
411
-
412
- float * data = nullptr;
413
- float * data_swa = nullptr;
414
-
415
- if (self_kq_mask) {
416
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417
- data = (float *) self_kq_mask->data;
418
- }
435
+ const int64_t n_kv = kv_self->n;
436
+ const int64_t n_tokens = ubatch->n_tokens;
437
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
438
+ const int64_t n_seqs = ubatch->n_seqs;
419
439
 
420
- if (self_kq_mask_swa) {
421
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422
- data_swa = (float *) self_kq_mask_swa->data;
423
- }
440
+ float * data = nullptr;
441
+ float * data_swa = nullptr;
424
442
 
425
- // For causal attention, use only the previous KV cells
426
- // of the correct sequence for each token of the ubatch.
427
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
428
- for (int h = 0; h < 1; ++h) {
429
- for (int s = 0; s < n_seqs; ++s) {
430
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
443
+ if (self_kq_mask) {
444
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
445
+ data = (float *) self_kq_mask->data;
446
+ }
431
447
 
432
- for (int j = 0; j < n_seq_tokens; ++j) {
433
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
448
+ if (self_kq_mask_swa) {
449
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
450
+ data_swa = (float *) self_kq_mask_swa->data;
451
+ }
434
452
 
435
- for (int i = 0; i < n_kv; ++i) {
436
- float f;
437
- if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
438
- f = -INFINITY;
453
+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
454
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
455
+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
456
+ // Causal mask:
457
+ // xxx-------
458
+ // xxxx------
459
+ // xxxxx-----
460
+ // Non-causal mask:
461
+ // xxxxx-----
462
+ // xxxxx-----
463
+ // xxxxx-----
464
+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
465
+ for (int h = 0; h < 1; ++h) {
466
+ for (int s = 0; s < n_seqs; ++s) {
467
+ const llama_seq_id seq_id = ubatch->seq_id[s][0];
468
+
469
+ for (int j = 0; j < n_seq_tokens; ++j) {
470
+ const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
471
+ for (int i = 0; i < n_kv; ++i) {
472
+ float f;
473
+ // mask the token if:
474
+ if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
475
+ || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
476
+ ) {
477
+ f = -INFINITY;
478
+ } else {
479
+ if (hparams.use_alibi) {
480
+ f = -std::abs(kv_self->cells[i].pos - pos);
439
481
  } else {
440
- if (hparams.use_alibi) {
441
- f = -std::abs(kv_self->cells[i].pos - pos);
442
- } else {
443
- f = 0.0f;
444
- }
482
+ f = 0.0f;
445
483
  }
484
+ }
446
485
 
447
- if (data) {
448
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
449
- }
486
+ if (data) {
487
+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
488
+ }
450
489
 
451
- // may need to cut off old tokens for sliding window
452
- if (data_swa) {
490
+ // may need to cut off old tokens for sliding window
491
+ // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
492
+ if (data_swa) {
493
+ if (hparams.n_attn_chunk) {
494
+ llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
495
+ if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
496
+ f = -INFINITY;
497
+ }
498
+ } else {
453
499
  if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
454
500
  f = -INFINITY;
455
501
  }
456
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
457
502
  }
503
+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
458
504
  }
459
505
  }
460
506
  }
507
+ }
461
508
 
462
- if (data) {
463
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
464
- for (int j = 0; j < n_kv; ++j) {
465
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
466
- }
467
- }
468
- }
469
-
470
- if (data_swa) {
471
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
472
- for (int j = 0; j < n_kv; ++j) {
473
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
474
- }
509
+ // mask padded tokens
510
+ if (data) {
511
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
512
+ for (int j = 0; j < n_kv; ++j) {
513
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
475
514
  }
476
515
  }
477
516
  }
478
- } else {
479
- const int64_t n_tokens = ubatch->n_tokens;
480
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
481
- const int64_t n_seqs = ubatch->n_seqs;
482
- // when using kv cache, the mask needs to match the kv cache size
483
- const int64_t n_stride = n_tokens;
484
-
485
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
486
-
487
- float * data = (float *) self_kq_mask->data;
488
-
489
- for (int h = 0; h < 1; ++h) {
490
- for (int s1 = 0; s1 < n_seqs; ++s1) {
491
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
492
-
493
- for (int j = 0; j < n_seq_tokens; ++j) {
494
- const int32_t tj = s1*n_seq_tokens + j;
495
517
 
496
- for (int s0 = 0; s0 < n_seqs; ++s0) {
497
- for (int i = 0; i < n_seq_tokens; ++i) {
498
- const int32_t ti = s0*n_seq_tokens + i;
499
- float f = -INFINITY;
500
-
501
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
502
- if (ubatch->seq_id[s0][s] == seq_id) {
503
- if (hparams.use_alibi) {
504
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
505
- } else {
506
- f = 0.0f;
507
- }
508
- break;
509
- }
510
- }
511
-
512
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
513
- }
514
- }
515
-
516
- for (int i = n_tokens; i < n_stride; ++i) {
517
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
518
- }
518
+ // mask padded tokens
519
+ if (data_swa) {
520
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
521
+ for (int j = 0; j < n_kv; ++j) {
522
+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
519
523
  }
520
524
  }
521
525
  }
@@ -602,7 +606,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
602
606
  res (std::make_unique<llm_graph_result>()) {
603
607
  }
604
608
 
605
- int64_t llm_graph_context::n_pos_per_token() const {
609
+ int64_t llm_graph_context::n_pos_per_embd() const {
606
610
  return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
607
611
  }
608
612
 
@@ -813,6 +817,10 @@ ggml_tensor * llm_graph_context::build_ffn(
813
817
 
814
818
  if (down) {
815
819
  cur = build_lora_mm(down, cur);
820
+ if (arch == LLM_ARCH_GLM4) {
821
+ // GLM4 seems to have numerical issues with half-precision accumulators
822
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
823
+ }
816
824
  }
817
825
 
818
826
  if (down_b) {
@@ -846,8 +854,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
846
854
  float w_scale,
847
855
  llama_expert_gating_func_type gating_op,
848
856
  int il) const {
849
- int64_t n_embd = cur->ne[0];
850
- int64_t n_tokens = cur->ne[1];
857
+ const int64_t n_embd = cur->ne[0];
858
+ const int64_t n_tokens = cur->ne[1];
859
+ const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
851
860
 
852
861
  ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
853
862
  cb(logits, "ffn_moe_logits", il);
@@ -875,6 +884,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
875
884
  cb(selection_probs, "ffn_moe_probs_biased", il);
876
885
  }
877
886
 
887
+ // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
888
+ // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
889
+ if (arch == LLM_ARCH_LLAMA4) {
890
+ selection_probs = logits;
891
+ }
892
+
878
893
  // select experts
879
894
  ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
880
895
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
@@ -901,34 +916,53 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
901
916
  }
902
917
 
903
918
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
919
+
920
+ if (weight_before_ffn) {
921
+ // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
922
+ ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
923
+ repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
924
+ cur = ggml_mul(ctx0, repeated, weights);
925
+ cb(cur, "ffn_moe_weighted", il);
926
+ }
927
+
904
928
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
905
929
  cb(up, "ffn_moe_up", il);
906
930
 
907
- ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
908
- cb(gate, "ffn_moe_gate", il);
931
+ ggml_tensor * experts = nullptr;
932
+ if (gate_exps) {
933
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
934
+ cb(cur, "ffn_moe_gate", il);
935
+ } else {
936
+ cur = up;
937
+ }
909
938
 
910
939
  switch (type_op) {
911
940
  case LLM_FFN_SILU:
912
941
  {
913
- gate = ggml_silu(ctx0, gate);
914
- cb(gate, "ffn_moe_silu", il);
942
+ cur = ggml_silu(ctx0, cur);
943
+ cb(cur, "ffn_moe_silu", il);
915
944
  } break;
916
945
  case LLM_FFN_GELU:
917
946
  {
918
- gate = ggml_gelu(ctx0, gate);
919
- cb(gate, "ffn_moe_gelu", il);
947
+ cur = ggml_gelu(ctx0, cur);
948
+ cb(cur, "ffn_moe_gelu", il);
920
949
  } break;
921
950
  default:
922
951
  GGML_ABORT("fatal error");
923
952
  }
924
953
 
925
- ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
926
- cb(par, "ffn_moe_gate_par", il);
954
+ if (gate_exps) {
955
+ cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
956
+ cb(cur, "ffn_moe_gate_par", il);
957
+ }
927
958
 
928
- ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
959
+ experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
929
960
  cb(experts, "ffn_moe_down", il);
930
961
 
931
- experts = ggml_mul(ctx0, experts, weights);
962
+ if (!weight_before_ffn) {
963
+ experts = ggml_mul(ctx0, experts, weights);
964
+ cb(cur, "ffn_moe_weighted", il);
965
+ }
932
966
 
933
967
  // aggregate experts
934
968
  ggml_tensor * moe_out = nullptr;
@@ -948,6 +982,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
948
982
  moe_out = ggml_cont(ctx0, moe_out);
949
983
  }
950
984
 
985
+ cb(moe_out, "ffn_moe_out", il);
986
+
951
987
  return moe_out;
952
988
  }
953
989
 
@@ -1003,11 +1039,25 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1003
1039
  }
1004
1040
 
1005
1041
  ggml_tensor * llm_graph_context::build_inp_pos() const {
1006
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
1042
+ auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
1007
1043
 
1008
1044
  auto & cur = inp->pos;
1009
1045
 
1010
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
1046
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
1047
+ ggml_set_input(cur);
1048
+
1049
+ res->add_input(std::move(inp));
1050
+
1051
+ return cur;
1052
+ }
1053
+
1054
+ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1055
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1056
+
1057
+ auto & cur = inp->attn_scale;
1058
+
1059
+ // this need to be 1x1xN for broadcasting
1060
+ cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1011
1061
  ggml_set_input(cur);
1012
1062
 
1013
1063
  res->add_input(std::move(inp));
@@ -1164,6 +1214,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1164
1214
  ggml_tensor * v,
1165
1215
  ggml_tensor * kq_b,
1166
1216
  ggml_tensor * kq_mask,
1217
+ ggml_tensor * v_mla,
1167
1218
  bool v_trans,
1168
1219
  float kq_scale) const {
1169
1220
  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1175,8 +1226,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1175
1226
  //const auto & n_embd_head_k = hparams.n_embd_head_k;
1176
1227
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1177
1228
 
1178
- const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
1179
-
1180
1229
  const auto n_tokens = q->ne[1];
1181
1230
  const auto n_head = q->ne[2];
1182
1231
  const auto n_kv = k->ne[1];
@@ -1191,12 +1240,26 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1191
1240
  v = ggml_transpose(ctx0, v);
1192
1241
  }
1193
1242
 
1243
+ // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1244
+ if (k->type == GGML_TYPE_F32) {
1245
+ k = ggml_cast(ctx0, k, GGML_TYPE_F16);
1246
+ }
1247
+
1248
+ if (v->type == GGML_TYPE_F32) {
1249
+ v = ggml_cast(ctx0, v, GGML_TYPE_F16);
1250
+ }
1251
+
1194
1252
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1195
1253
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1196
1254
 
1197
1255
  ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1198
1256
 
1199
- cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
1257
+ if (v_mla) {
1258
+ cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1259
+ cur = ggml_mul_mat(ctx0, v_mla, cur);
1260
+ }
1261
+
1262
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1200
1263
  } else {
1201
1264
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1202
1265
 
@@ -1234,9 +1297,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1234
1297
 
1235
1298
  ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1236
1299
 
1237
- ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1300
+ // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1301
+ if (v_mla) {
1302
+ kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1303
+ }
1304
+
1305
+ cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1238
1306
 
1239
- cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1307
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1240
1308
 
1241
1309
  if (!cparams.offload_kqv) {
1242
1310
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1271,6 +1339,7 @@ ggml_tensor * llm_graph_context::build_attn(
1271
1339
  ggml_tensor * k_cur,
1272
1340
  ggml_tensor * v_cur,
1273
1341
  ggml_tensor * kq_b,
1342
+ ggml_tensor * v_mla,
1274
1343
  float kq_scale,
1275
1344
  int il) const {
1276
1345
  GGML_UNUSED(n_tokens);
@@ -1292,7 +1361,7 @@ ggml_tensor * llm_graph_context::build_attn(
1292
1361
  ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1293
1362
  //cb(k, "v", il);
1294
1363
 
1295
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1364
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1296
1365
 
1297
1366
  cb(cur, "kqv_out", il);
1298
1367
 
@@ -1346,6 +1415,7 @@ ggml_tensor * llm_graph_context::build_attn(
1346
1415
  ggml_tensor * k_cur,
1347
1416
  ggml_tensor * v_cur,
1348
1417
  ggml_tensor * kq_b,
1418
+ ggml_tensor * v_mla,
1349
1419
  float kq_scale,
1350
1420
  int il) const {
1351
1421
  // these nodes are added to the graph together so that they are not reordered
@@ -1431,7 +1501,7 @@ ggml_tensor * llm_graph_context::build_attn(
1431
1501
  ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1432
1502
  0);
1433
1503
 
1434
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
1504
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1435
1505
  cb(cur, "kqv_out", il);
1436
1506
 
1437
1507
  if (wo) {
@@ -1471,6 +1541,7 @@ ggml_tensor * llm_graph_context::build_attn(
1471
1541
  ggml_tensor * k_cur,
1472
1542
  ggml_tensor * v_cur,
1473
1543
  ggml_tensor * kq_b,
1544
+ ggml_tensor * v_mla,
1474
1545
  float kq_scale,
1475
1546
  int il) const {
1476
1547
  // these nodes are added to the graph together so that they are not reordered
@@ -1490,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn(
1490
1561
  ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1491
1562
  //cb(k, "v", il);
1492
1563
 
1493
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1564
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1494
1565
 
1495
1566
  cb(cur, "kqv_out", il);
1496
1567
 
@@ -1659,4 +1730,3 @@ void llm_graph_context::build_pooling(
1659
1730
 
1660
1731
  ggml_build_forward_expand(gf, cur);
1661
1732
  }
1662
-
@@ -90,14 +90,29 @@ public:
90
90
 
91
91
  class llm_graph_input_pos : public llm_graph_input_i {
92
92
  public:
93
- llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
93
+ llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
94
94
  virtual ~llm_graph_input_pos() = default;
95
95
 
96
96
  void set_input(const llama_ubatch * ubatch) override;
97
97
 
98
98
  ggml_tensor * pos = nullptr; // I32 [n_batch]
99
99
 
100
- const int64_t n_pos_per_token = 1;
100
+ const int64_t n_pos_per_embd = 1;
101
+ };
102
+
103
+ // temperature tuning, used by llama4
104
+ class llm_graph_input_attn_temp : public llm_graph_input_i {
105
+ public:
106
+ llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107
+ : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108
+ virtual ~llm_graph_input_attn_temp() = default;
109
+
110
+ void set_input(const llama_ubatch * ubatch) override;
111
+
112
+ ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113
+
114
+ const uint32_t n_attn_temp_floor_scale;
115
+ const float f_attn_temp_scale;
101
116
  };
102
117
 
103
118
  class llm_graph_input_pos_bucket : public llm_graph_input_i {
@@ -402,7 +417,7 @@ struct llm_graph_context {
402
417
 
403
418
  llm_graph_context(const llm_graph_params & params);
404
419
 
405
- int64_t n_pos_per_token() const;
420
+ int64_t n_pos_per_embd() const;
406
421
 
407
422
  void cb(ggml_tensor * cur, const char * name, int il) const;
408
423
 
@@ -470,6 +485,7 @@ struct llm_graph_context {
470
485
 
471
486
  ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
472
487
  ggml_tensor * build_inp_pos() const;
488
+ ggml_tensor * build_inp_attn_scale() const;
473
489
  ggml_tensor * build_inp_out_ids() const;
474
490
  ggml_tensor * build_inp_mean() const;
475
491
  ggml_tensor * build_inp_cls() const;
@@ -487,11 +503,12 @@ struct llm_graph_context {
487
503
 
488
504
  ggml_tensor * build_attn_mha(
489
505
  ggml_cgraph * gf,
490
- ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
491
- ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
492
- ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
506
+ ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
507
+ ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
508
+ ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
493
509
  ggml_tensor * kq_b,
494
510
  ggml_tensor * kq_mask,
511
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
495
512
  bool v_trans,
496
513
  float kq_scale) const;
497
514
 
@@ -506,6 +523,7 @@ struct llm_graph_context {
506
523
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
507
524
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
508
525
  ggml_tensor * kq_b,
526
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
509
527
  float kq_scale,
510
528
  int il) const;
511
529
 
@@ -520,6 +538,7 @@ struct llm_graph_context {
520
538
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
521
539
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
522
540
  ggml_tensor * kq_b,
541
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
523
542
  float kq_scale,
524
543
  int il) const;
525
544
 
@@ -534,6 +553,7 @@ struct llm_graph_context {
534
553
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
535
554
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
536
555
  ggml_tensor * kq_b,
556
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
537
557
  float kq_scale,
538
558
  int il) const;
539
559
 
@@ -43,6 +43,10 @@ struct llama_hparams {
43
43
  uint32_t n_expert_used = 0;
44
44
  uint32_t n_rel_attn_bkts = 0;
45
45
 
46
+ // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
47
+ uint32_t n_embd_head_k_mla = 0;
48
+ uint32_t n_embd_head_v_mla = 0;
49
+
46
50
  // for WavTokenizer
47
51
  struct llama_hparams_posnet posnet;
48
52
  struct llama_hparams_convnext convnext;
@@ -62,6 +66,7 @@ struct llama_hparams {
62
66
  float expert_weights_scale = 0.0;
63
67
  bool expert_weights_norm = false;
64
68
  uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
69
+ uint32_t moe_every_n_layers = 0;
65
70
 
66
71
  float f_norm_eps;
67
72
  float f_norm_rms_eps;
@@ -112,6 +117,14 @@ struct llama_hparams {
112
117
  bool use_alibi = false;
113
118
  bool attn_soft_cap = false;
114
119
 
120
+ uint32_t n_moe_layer_step = 0;
121
+ bool use_kq_norm = true;
122
+ uint32_t n_attn_chunk = 0;
123
+ // values below seems to be fixed on llama4
124
+ uint32_t n_no_rope_layer_step = 4;
125
+ uint32_t n_attn_temp_floor_scale = 8192;
126
+ float f_attn_temp_scale = 0.1;
127
+
115
128
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
116
129
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
117
130
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;