cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -27,7 +27,12 @@ struct llama_context {
27
27
 
28
28
  void synchronize();
29
29
 
30
- const llama_model & get_model() const;
30
+ const llama_model & get_model() const;
31
+ const llama_cparams & get_cparams() const;
32
+
33
+ lm_ggml_backend_sched_t get_sched() const;
34
+
35
+ lm_ggml_context * get_ctx_compute() const;
31
36
 
32
37
  uint32_t n_ctx() const;
33
38
  uint32_t n_ctx_per_seq() const;
@@ -137,50 +142,30 @@ private:
137
142
  // Returns max number of outputs for which space was reserved.
138
143
  int32_t output_reserve(int32_t n_outputs);
139
144
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
145
  //
145
146
  // graph
146
147
  //
147
148
 
149
+ public:
148
150
  int32_t graph_max_nodes() const;
149
151
 
150
152
  // zero-out inputs and create the ctx_compute for the compute graph
151
153
  lm_ggml_cgraph * graph_init();
152
154
 
155
+ // returns the result of lm_ggml_backend_sched_graph_compute_async execution
156
+ lm_ggml_status graph_compute(
157
+ lm_ggml_cgraph * gf,
158
+ bool batched);
159
+
160
+ private:
153
161
  llm_graph_result_ptr graph_build(
154
162
  lm_ggml_context * ctx,
155
163
  lm_ggml_cgraph * gf,
156
164
  const llama_ubatch & ubatch,
157
165
  llm_graph_type gtype);
158
166
 
159
- // returns the result of lm_ggml_backend_sched_graph_compute_async execution
160
- lm_ggml_status graph_compute(
161
- lm_ggml_cgraph * gf,
162
- bool batched);
163
-
164
167
  llm_graph_cb graph_get_cb() const;
165
168
 
166
- // used by kv_self_update()
167
- lm_ggml_tensor * build_rope_shift(
168
- lm_ggml_context * ctx0,
169
- lm_ggml_tensor * cur,
170
- lm_ggml_tensor * shift,
171
- lm_ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale,
174
- lm_ggml_backend_buffer * bbuf) const;
175
-
176
- llm_graph_result_ptr build_kv_self_shift(
177
- lm_ggml_context * ctx0,
178
- lm_ggml_cgraph * gf) const;
179
-
180
- llm_graph_result_ptr build_kv_self_defrag(
181
- lm_ggml_context * ctx0,
182
- lm_ggml_cgraph * gf) const;
183
-
184
169
  // TODO: read/write lora adapters and cvec
185
170
  size_t state_write_data(llama_io_write_i & io);
186
171
  size_t state_read_data (llama_io_read_i & io);
@@ -197,11 +182,10 @@ private:
197
182
  llama_cparams cparams;
198
183
  llama_adapter_cvec cvec;
199
184
  llama_adapter_loras loras;
200
- llama_sbatch sbatch;
201
185
 
202
186
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203
187
 
204
- std::unique_ptr<llama_kv_cache_unified> kv_self;
188
+ std::unique_ptr<llama_memory_i> memory;
205
189
 
206
190
  // TODO: remove
207
191
  bool logits_all = false;
@@ -55,7 +55,21 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
55
55
  if (ubatch->pos && pos) {
56
56
  const int64_t n_tokens = ubatch->n_tokens;
57
57
 
58
- lm_ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*lm_ggml_element_size(pos));
58
+ if (ubatch->token && n_pos_per_embd == 4) {
59
+ // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
60
+ // the 3 first dims are the same, and 4th dim is all 0
61
+ std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
62
+ // copy the first dimension
63
+ for (int i = 0; i < n_tokens; ++i) {
64
+ pos_data[ i] = ubatch->pos[i];
65
+ pos_data[ n_tokens + i] = ubatch->pos[i];
66
+ pos_data[2 * n_tokens + i] = ubatch->pos[i];
67
+ pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
68
+ }
69
+ lm_ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*lm_ggml_element_size(pos));
70
+ } else {
71
+ lm_ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*lm_ggml_element_size(pos));
72
+ }
59
73
  }
60
74
  }
61
75
 
@@ -71,7 +85,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
71
85
  ) * f_attn_temp_scale + 1.0;
72
86
  }
73
87
 
74
- lm_ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*lm_ggml_element_size(attn_scale));
88
+ lm_ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*lm_ggml_element_size(attn_scale));
75
89
  }
76
90
  }
77
91
 
@@ -270,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
270
284
 
271
285
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
272
286
  for (uint32_t i = 0; i < n_kv; ++i) {
273
- const uint32_t cell_id = i + kv_self->head;
274
-
275
- //////////////////////////////////////////////
276
- // TODO: this should not mutate the KV cache !
277
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
278
-
279
- // prevent out-of-bound sources
280
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
281
- kv_cell.src = cell_id;
282
- }
283
-
284
- data[i] = kv_cell.src;
285
-
286
- // TODO: do not mutate the KV cache
287
- // ensure copy only happens once
288
- if (kv_cell.src != (int32_t) cell_id) {
289
- kv_cell.src = cell_id;
290
- }
287
+ data[i] = kv_self->s_copy(i);
291
288
  }
292
289
  }
293
290
  }
@@ -303,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
303
300
 
304
301
  // clear unused states
305
302
  for (int i = 0; i < n_kv; ++i) {
306
- const uint32_t cell_id = i + kv_self->head;
307
-
308
- //////////////////////////////////////////////
309
- // TODO: this should not mutate the KV cache !
310
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
311
-
312
- data[i] = (float) (kv_cell.src >= 0);
313
-
314
- // only clear once
315
- if (kv_cell.src < 0) {
316
- kv_cell.src = cell_id;
317
- }
303
+ data[i] = kv_self->s_mask(i);
318
304
  }
319
305
  }
320
306
  }
@@ -592,7 +578,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
592
578
  res (std::make_unique<llm_graph_result>()) {
593
579
  }
594
580
 
595
- int64_t llm_graph_context::n_pos_per_token() const {
581
+ int64_t llm_graph_context::n_pos_per_embd() const {
596
582
  return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
597
583
  }
598
584
 
@@ -803,6 +789,10 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
803
789
 
804
790
  if (down) {
805
791
  cur = build_lora_mm(down, cur);
792
+ if (arch == LLM_ARCH_GLM4) {
793
+ // GLM4 seems to have numerical issues with half-precision accumulators
794
+ lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
795
+ }
806
796
  }
807
797
 
808
798
  if (down_b) {
@@ -910,28 +900,35 @@ lm_ggml_tensor * llm_graph_context::build_moe_ffn(
910
900
  lm_ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
911
901
  cb(up, "ffn_moe_up", il);
912
902
 
913
- lm_ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
914
- cb(gate, "ffn_moe_gate", il);
903
+ lm_ggml_tensor * experts = nullptr;
904
+ if (gate_exps) {
905
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
906
+ cb(cur, "ffn_moe_gate", il);
907
+ } else {
908
+ cur = up;
909
+ }
915
910
 
916
911
  switch (type_op) {
917
912
  case LLM_FFN_SILU:
918
913
  {
919
- gate = lm_ggml_silu(ctx0, gate);
920
- cb(gate, "ffn_moe_silu", il);
914
+ cur = lm_ggml_silu(ctx0, cur);
915
+ cb(cur, "ffn_moe_silu", il);
921
916
  } break;
922
917
  case LLM_FFN_GELU:
923
918
  {
924
- gate = lm_ggml_gelu(ctx0, gate);
925
- cb(gate, "ffn_moe_gelu", il);
919
+ cur = lm_ggml_gelu(ctx0, cur);
920
+ cb(cur, "ffn_moe_gelu", il);
926
921
  } break;
927
922
  default:
928
923
  LM_GGML_ABORT("fatal error");
929
924
  }
930
925
 
931
- lm_ggml_tensor * par = lm_ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
932
- cb(par, "ffn_moe_gate_par", il);
926
+ if (gate_exps) {
927
+ cur = lm_ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
928
+ cb(cur, "ffn_moe_gate_par", il);
929
+ }
933
930
 
934
- lm_ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
931
+ experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
935
932
  cb(experts, "ffn_moe_down", il);
936
933
 
937
934
  if (!weight_before_ffn) {
@@ -1014,11 +1011,11 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
1014
1011
  }
1015
1012
 
1016
1013
  lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
1017
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
1014
+ auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
1018
1015
 
1019
1016
  auto & cur = inp->pos;
1020
1017
 
1021
- cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*n_pos_per_token());
1018
+ cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*n_pos_per_embd());
1022
1019
  lm_ggml_set_input(cur);
1023
1020
 
1024
1021
  res->add_input(std::move(inp));
@@ -1027,11 +1024,12 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
1027
1024
  }
1028
1025
 
1029
1026
  lm_ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1030
- auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1027
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1031
1028
 
1032
1029
  auto & cur = inp->attn_scale;
1033
1030
 
1034
- cur = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
1031
+ // this need to be 1x1xN for broadcasting
1032
+ cur = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, 1, 1, n_tokens);
1035
1033
  lm_ggml_set_input(cur);
1036
1034
 
1037
1035
  res->add_input(std::move(inp));
@@ -1079,7 +1077,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_cls() const {
1079
1077
  }
1080
1078
 
1081
1079
  lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1082
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1080
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1083
1081
 
1084
1082
  auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1085
1083
 
@@ -1096,7 +1094,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1096
1094
  }
1097
1095
 
1098
1096
  lm_ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1099
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1097
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1100
1098
 
1101
1099
  auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1102
1100
 
@@ -1188,6 +1186,7 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1188
1186
  lm_ggml_tensor * v,
1189
1187
  lm_ggml_tensor * kq_b,
1190
1188
  lm_ggml_tensor * kq_mask,
1189
+ lm_ggml_tensor * v_mla,
1191
1190
  bool v_trans,
1192
1191
  float kq_scale) const {
1193
1192
  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1199,8 +1198,6 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1199
1198
  //const auto & n_embd_head_k = hparams.n_embd_head_k;
1200
1199
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1201
1200
 
1202
- const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
1203
-
1204
1201
  const auto n_tokens = q->ne[1];
1205
1202
  const auto n_head = q->ne[2];
1206
1203
  const auto n_kv = k->ne[1];
@@ -1229,7 +1226,12 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1229
1226
 
1230
1227
  lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
1231
1228
 
1232
- cur = lm_ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
1229
+ if (v_mla) {
1230
+ cur = lm_ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1231
+ cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
1232
+ }
1233
+
1234
+ cur = lm_ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1233
1235
  } else {
1234
1236
  lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q);
1235
1237
 
@@ -1267,9 +1269,14 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1267
1269
 
1268
1270
  lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx0, v, kq);
1269
1271
 
1270
- lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1272
+ // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1273
+ if (v_mla) {
1274
+ kqv = lm_ggml_mul_mat(ctx0, v_mla, kqv);
1275
+ }
1276
+
1277
+ cur = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1271
1278
 
1272
- cur = lm_ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1279
+ cur = lm_ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1273
1280
 
1274
1281
  if (!cparams.offload_kqv) {
1275
1282
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1304,6 +1311,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1304
1311
  lm_ggml_tensor * k_cur,
1305
1312
  lm_ggml_tensor * v_cur,
1306
1313
  lm_ggml_tensor * kq_b,
1314
+ lm_ggml_tensor * v_mla,
1307
1315
  float kq_scale,
1308
1316
  int il) const {
1309
1317
  LM_GGML_UNUSED(n_tokens);
@@ -1325,7 +1333,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1325
1333
  lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1326
1334
  //cb(k, "v", il);
1327
1335
 
1328
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1336
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1329
1337
 
1330
1338
  cb(cur, "kqv_out", il);
1331
1339
 
@@ -1379,6 +1387,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1379
1387
  lm_ggml_tensor * k_cur,
1380
1388
  lm_ggml_tensor * v_cur,
1381
1389
  lm_ggml_tensor * kq_b,
1390
+ lm_ggml_tensor * v_mla,
1382
1391
  float kq_scale,
1383
1392
  int il) const {
1384
1393
  // these nodes are added to the graph together so that they are not reordered
@@ -1399,8 +1408,6 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1399
1408
 
1400
1409
  // store to KV cache
1401
1410
  {
1402
- LM_GGML_ASSERT(!kv_self->recurrent);
1403
-
1404
1411
  const auto kv_head = kv_self->head;
1405
1412
 
1406
1413
  LM_GGML_ASSERT(kv_self->size == n_ctx);
@@ -1464,7 +1471,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1464
1471
  lm_ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1465
1472
  0);
1466
1473
 
1467
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
1474
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1468
1475
  cb(cur, "kqv_out", il);
1469
1476
 
1470
1477
  if (wo) {
@@ -1504,6 +1511,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1504
1511
  lm_ggml_tensor * k_cur,
1505
1512
  lm_ggml_tensor * v_cur,
1506
1513
  lm_ggml_tensor * kq_b,
1514
+ lm_ggml_tensor * v_mla,
1507
1515
  float kq_scale,
1508
1516
  int il) const {
1509
1517
  // these nodes are added to the graph together so that they are not reordered
@@ -1523,7 +1531,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1523
1531
  lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1524
1532
  //cb(k, "v", il);
1525
1533
 
1526
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1534
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1527
1535
 
1528
1536
  cb(cur, "kqv_out", il);
1529
1537
 
@@ -1549,7 +1557,7 @@ lm_ggml_tensor * llm_graph_context::build_copy_mask_state(
1549
1557
  lm_ggml_tensor * state_mask,
1550
1558
  int32_t n_state,
1551
1559
  int32_t n_seqs) const {
1552
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1560
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1553
1561
 
1554
1562
  const auto n_kv = kv_self->n;
1555
1563
  const auto kv_head = kv_self->head;
@@ -1581,7 +1589,7 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1581
1589
  lm_ggml_tensor * state_mask,
1582
1590
  const llama_ubatch & ubatch,
1583
1591
  int il) const {
1584
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1592
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1585
1593
 
1586
1594
  const auto token_shift_count = hparams.token_shift_count;
1587
1595
 
@@ -1602,7 +1610,7 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1602
1610
  lm_ggml_tensor * token_shift,
1603
1611
  const llama_ubatch & ubatch,
1604
1612
  int il) const {
1605
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1613
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1606
1614
 
1607
1615
  const auto token_shift_count = hparams.token_shift_count;
1608
1616
  const auto n_embd = hparams.n_embd;
@@ -1692,4 +1700,3 @@ void llm_graph_context::build_pooling(
1692
1700
 
1693
1701
  lm_ggml_build_forward_expand(gf, cur);
1694
1702
  }
1695
-
package/cpp/llama-graph.h CHANGED
@@ -19,6 +19,7 @@ struct llama_cparams;
19
19
 
20
20
  class llama_memory_i;
21
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_recurrent;
22
23
 
23
24
  // certain models (typically multi-modal) can produce different types of graphs
24
25
  enum llm_graph_type {
@@ -90,29 +91,27 @@ public:
90
91
 
91
92
  class llm_graph_input_pos : public llm_graph_input_i {
92
93
  public:
93
- llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
94
+ llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
94
95
  virtual ~llm_graph_input_pos() = default;
95
96
 
96
97
  void set_input(const llama_ubatch * ubatch) override;
97
98
 
98
99
  lm_ggml_tensor * pos = nullptr; // I32 [n_batch]
99
100
 
100
- const int64_t n_pos_per_token = 1;
101
+ const int64_t n_pos_per_embd = 1;
101
102
  };
102
103
 
103
104
  // temperature tuning, used by llama4
104
105
  class llm_graph_input_attn_temp : public llm_graph_input_i {
105
106
  public:
106
- llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107
- : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
107
+ llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
108
+ : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108
109
  virtual ~llm_graph_input_attn_temp() = default;
109
110
 
110
111
  void set_input(const llama_ubatch * ubatch) override;
111
112
 
112
113
  lm_ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113
114
 
114
- const int64_t n_pos_per_token = 1;
115
-
116
115
  const uint32_t n_attn_temp_floor_scale;
117
116
  const float f_attn_temp_scale;
118
117
  };
@@ -188,26 +187,26 @@ public:
188
187
 
189
188
  class llm_graph_input_s_copy : public llm_graph_input_i {
190
189
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
190
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
191
  virtual ~llm_graph_input_s_copy() = default;
193
192
 
194
193
  void set_input(const llama_ubatch * ubatch) override;
195
194
 
196
195
  lm_ggml_tensor * s_copy; // I32 [kv_size]
197
196
 
198
- const llama_kv_cache_unified * kv_self;
197
+ const llama_kv_cache_recurrent * kv_self;
199
198
  };
200
199
 
201
200
  class llm_graph_input_s_mask : public llm_graph_input_i {
202
201
  public:
203
- llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
202
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
203
  virtual ~llm_graph_input_s_mask() = default;
205
204
 
206
205
  void set_input(const llama_ubatch * ubatch) override;
207
206
 
208
207
  lm_ggml_tensor * s_mask; // F32 [1, n_kv]
209
208
 
210
- const llama_kv_cache_unified * kv_self;
209
+ const llama_kv_cache_recurrent * kv_self;
211
210
  };
212
211
 
213
212
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -352,8 +351,8 @@ struct llm_graph_params {
352
351
  const llama_cparams & cparams;
353
352
  const llama_ubatch & ubatch;
354
353
 
355
- lm_ggml_backend_sched * sched;
356
- lm_ggml_backend * backend_cpu;
354
+ lm_ggml_backend_sched_t sched;
355
+ lm_ggml_backend_t backend_cpu;
357
356
 
358
357
  const llama_adapter_cvec * cvec;
359
358
  const llama_adapter_loras * loras;
@@ -404,9 +403,9 @@ struct llm_graph_context {
404
403
 
405
404
  lm_ggml_context * ctx0 = nullptr;
406
405
 
407
- lm_ggml_backend_sched * sched;
406
+ lm_ggml_backend_sched_t sched;
408
407
 
409
- lm_ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
408
+ lm_ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
410
409
 
411
410
  const llama_adapter_cvec * cvec;
412
411
  const llama_adapter_loras * loras;
@@ -419,7 +418,7 @@ struct llm_graph_context {
419
418
 
420
419
  llm_graph_context(const llm_graph_params & params);
421
420
 
422
- int64_t n_pos_per_token() const;
421
+ int64_t n_pos_per_embd() const;
423
422
 
424
423
  void cb(lm_ggml_tensor * cur, const char * name, int il) const;
425
424
 
@@ -505,11 +504,12 @@ struct llm_graph_context {
505
504
 
506
505
  lm_ggml_tensor * build_attn_mha(
507
506
  lm_ggml_cgraph * gf,
508
- lm_ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
509
- lm_ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
510
- lm_ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
507
+ lm_ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
508
+ lm_ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
509
+ lm_ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
511
510
  lm_ggml_tensor * kq_b,
512
511
  lm_ggml_tensor * kq_mask,
512
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
513
513
  bool v_trans,
514
514
  float kq_scale) const;
515
515
 
@@ -524,6 +524,7 @@ struct llm_graph_context {
524
524
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
525
525
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
526
526
  lm_ggml_tensor * kq_b,
527
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
527
528
  float kq_scale,
528
529
  int il) const;
529
530
 
@@ -538,6 +539,7 @@ struct llm_graph_context {
538
539
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
539
540
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
540
541
  lm_ggml_tensor * kq_b,
542
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
541
543
  float kq_scale,
542
544
  int il) const;
543
545
 
@@ -552,6 +554,7 @@ struct llm_graph_context {
552
554
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
553
555
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
554
556
  lm_ggml_tensor * kq_b,
557
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
555
558
  float kq_scale,
556
559
  int il) const;
557
560
 
@@ -43,6 +43,10 @@ struct llama_hparams {
43
43
  uint32_t n_expert_used = 0;
44
44
  uint32_t n_rel_attn_bkts = 0;
45
45
 
46
+ // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
47
+ uint32_t n_embd_head_k_mla = 0;
48
+ uint32_t n_embd_head_v_mla = 0;
49
+
46
50
  // for WavTokenizer
47
51
  struct llama_hparams_posnet posnet;
48
52
  struct llama_hparams_convnext convnext;
@@ -62,6 +66,7 @@ struct llama_hparams {
62
66
  float expert_weights_scale = 0.0;
63
67
  bool expert_weights_norm = false;
64
68
  uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
69
+ uint32_t moe_every_n_layers = 0;
65
70
 
66
71
  float f_norm_eps;
67
72
  float f_norm_rms_eps;