cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
@@ -133,6 +133,11 @@ extern "C" {
133
133
 
134
134
  LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void);
135
135
 
136
+ LM_GGML_BACKEND_API void lm_ggml_cpu_fp32_to_fp16(const float *, lm_ggml_fp16_t *, int64_t);
137
+ LM_GGML_BACKEND_API void lm_ggml_cpu_fp16_to_fp32(const lm_ggml_fp16_t *, float *, int64_t);
138
+ LM_GGML_BACKEND_API void lm_ggml_cpu_fp32_to_bf16(const float *, lm_ggml_bf16_t *, int64_t);
139
+ LM_GGML_BACKEND_API void lm_ggml_cpu_bf16_to_fp32(const lm_ggml_bf16_t *, float *, int64_t);
140
+
136
141
  #ifdef __cplusplus
137
142
  }
138
143
  #endif
@@ -16,6 +16,14 @@
16
16
  #include <arm_sve.h>
17
17
  #endif // __ARM_FEATURE_SVE
18
18
 
19
+ #if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
20
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
21
+ //
22
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
23
+ //
24
+ #include <arm_neon.h>
25
+ #endif
26
+
19
27
  #if defined(__F16C__)
20
28
  #include <immintrin.h>
21
29
  #endif
@@ -140,8 +148,14 @@ struct lm_ggml_map_custom2_op_params {
140
148
 
141
149
  struct lm_ggml_map_custom3_op_params {
142
150
  lm_ggml_custom3_op_t fun;
143
- int n_tasks;
144
- void * userdata;
151
+ int n_tasks;
152
+ void * userdata;
153
+ };
154
+
155
+ struct lm_ggml_custom_op_params {
156
+ lm_ggml_custom_op_t fun;
157
+ int n_tasks;
158
+ void * userdata;
145
159
  };
146
160
 
147
161
  // bitset
@@ -311,13 +325,6 @@ LM_GGML_API void lm_ggml_aligned_free(void * ptr, size_t size);
311
325
  // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843
312
326
  //
313
327
  #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__)
314
-
315
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
316
- //
317
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
318
- //
319
- #include <arm_neon.h>
320
-
321
328
  #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x)
322
329
  #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x)
323
330
 
@@ -394,8 +394,8 @@ extern "C" {
394
394
 
395
395
  // precision
396
396
  enum lm_ggml_prec {
397
- LM_GGML_PREC_DEFAULT,
398
- LM_GGML_PREC_F32,
397
+ LM_GGML_PREC_DEFAULT = 0, // stored as lm_ggml_tensor.op_params, 0 by default
398
+ LM_GGML_PREC_F32 = 10,
399
399
  };
400
400
 
401
401
  // model file types
@@ -482,6 +482,7 @@ extern "C" {
482
482
  LM_GGML_OP_CONV_TRANSPOSE_1D,
483
483
  LM_GGML_OP_IM2COL,
484
484
  LM_GGML_OP_IM2COL_BACK,
485
+ LM_GGML_OP_CONV_2D_DW,
485
486
  LM_GGML_OP_CONV_TRANSPOSE_2D,
486
487
  LM_GGML_OP_POOL_1D,
487
488
  LM_GGML_OP_POOL_2D,
@@ -508,17 +509,12 @@ extern "C" {
508
509
 
509
510
  LM_GGML_OP_UNARY,
510
511
 
511
- LM_GGML_OP_MAP_UNARY,
512
- LM_GGML_OP_MAP_BINARY,
513
-
514
- LM_GGML_OP_MAP_CUSTOM1_F32,
515
- LM_GGML_OP_MAP_CUSTOM2_F32,
516
- LM_GGML_OP_MAP_CUSTOM3_F32,
517
-
518
512
  LM_GGML_OP_MAP_CUSTOM1,
519
513
  LM_GGML_OP_MAP_CUSTOM2,
520
514
  LM_GGML_OP_MAP_CUSTOM3,
521
515
 
516
+ LM_GGML_OP_CUSTOM,
517
+
522
518
  LM_GGML_OP_CROSS_ENTROPY_LOSS,
523
519
  LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
524
520
  LM_GGML_OP_OPT_STEP_ADAMW,
@@ -683,6 +679,9 @@ extern "C" {
683
679
  LM_GGML_API bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1
684
680
  LM_GGML_API bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2
685
681
 
682
+ // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
683
+ LM_GGML_API bool lm_ggml_is_contiguous_channels(const struct lm_ggml_tensor * tensor);
684
+
686
685
  LM_GGML_API bool lm_ggml_are_same_shape (const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1);
687
686
  LM_GGML_API bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1);
688
687
 
@@ -1666,7 +1665,7 @@ extern "C" {
1666
1665
  struct lm_ggml_tensor * a,
1667
1666
  struct lm_ggml_tensor * b);
1668
1667
 
1669
- // depthwise
1668
+ // depthwise (via im2col and mul_mat)
1670
1669
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_2d_dw(
1671
1670
  struct lm_ggml_context * ctx,
1672
1671
  struct lm_ggml_tensor * a, // convolution kernel
@@ -1678,6 +1677,22 @@ extern "C" {
1678
1677
  int d0, // dilation dimension 0
1679
1678
  int d1); // dilation dimension 1
1680
1679
 
1680
+ // Depthwise 2D convolution
1681
+ // may be faster than lm_ggml_conv_2d_dw, but not available in all backends
1682
+ // a: KW KH 1 C convolution kernel
1683
+ // b: W H C N input data
1684
+ // res: W_out H_out C N
1685
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_2d_dw_direct(
1686
+ struct lm_ggml_context * ctx,
1687
+ struct lm_ggml_tensor * a,
1688
+ struct lm_ggml_tensor * b,
1689
+ int stride0,
1690
+ int stride1,
1691
+ int pad0,
1692
+ int pad1,
1693
+ int dilation0,
1694
+ int dilation1);
1695
+
1681
1696
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0(
1682
1697
  struct lm_ggml_context * ctx,
1683
1698
  struct lm_ggml_tensor * a,
@@ -1723,24 +1738,29 @@ extern "C" {
1723
1738
  float p0,
1724
1739
  float p1);
1725
1740
 
1726
- // nearest interpolate
1741
+ enum lm_ggml_scale_mode {
1742
+ LM_GGML_SCALE_MODE_NEAREST = 0,
1743
+ LM_GGML_SCALE_MODE_BILINEAR = 1,
1744
+ };
1745
+
1746
+ // interpolate
1727
1747
  // multiplies ne0 and ne1 by scale factor
1728
- // used in stable-diffusion
1729
1748
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_upscale(
1730
1749
  struct lm_ggml_context * ctx,
1731
1750
  struct lm_ggml_tensor * a,
1732
- int scale_factor);
1751
+ int scale_factor,
1752
+ enum lm_ggml_scale_mode mode);
1733
1753
 
1734
- // nearest interpolate
1735
- // nearest interpolate to specified dimensions
1736
- // used in tortoise.cpp
1754
+ // interpolate
1755
+ // interpolate scale to specified dimensions
1737
1756
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_upscale_ext(
1738
1757
  struct lm_ggml_context * ctx,
1739
1758
  struct lm_ggml_tensor * a,
1740
1759
  int ne0,
1741
1760
  int ne1,
1742
1761
  int ne2,
1743
- int ne3);
1762
+ int ne3,
1763
+ enum lm_ggml_scale_mode mode);
1744
1764
 
1745
1765
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1746
1766
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_pad(
@@ -1917,83 +1937,6 @@ extern "C" {
1917
1937
 
1918
1938
  // custom operators
1919
1939
 
1920
- typedef void (*lm_ggml_unary_op_f32_t) (const int, float *, const float *);
1921
- typedef void (*lm_ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
1922
-
1923
- typedef void (*lm_ggml_custom1_op_f32_t)(struct lm_ggml_tensor *, const struct lm_ggml_tensor *);
1924
- typedef void (*lm_ggml_custom2_op_f32_t)(struct lm_ggml_tensor *, const struct lm_ggml_tensor *, const struct lm_ggml_tensor *);
1925
- typedef void (*lm_ggml_custom3_op_f32_t)(struct lm_ggml_tensor *, const struct lm_ggml_tensor *, const struct lm_ggml_tensor *, const struct lm_ggml_tensor *);
1926
-
1927
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_unary_f32(
1928
- struct lm_ggml_context * ctx,
1929
- struct lm_ggml_tensor * a,
1930
- lm_ggml_unary_op_f32_t fun),
1931
- "use lm_ggml_map_custom1 instead");
1932
-
1933
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32(
1934
- struct lm_ggml_context * ctx,
1935
- struct lm_ggml_tensor * a,
1936
- lm_ggml_unary_op_f32_t fun),
1937
- "use lm_ggml_map_custom1_inplace instead");
1938
-
1939
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_binary_f32(
1940
- struct lm_ggml_context * ctx,
1941
- struct lm_ggml_tensor * a,
1942
- struct lm_ggml_tensor * b,
1943
- lm_ggml_binary_op_f32_t fun),
1944
- "use lm_ggml_map_custom2 instead");
1945
-
1946
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32(
1947
- struct lm_ggml_context * ctx,
1948
- struct lm_ggml_tensor * a,
1949
- struct lm_ggml_tensor * b,
1950
- lm_ggml_binary_op_f32_t fun),
1951
- "use lm_ggml_map_custom2_inplace instead");
1952
-
1953
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom1_f32(
1954
- struct lm_ggml_context * ctx,
1955
- struct lm_ggml_tensor * a,
1956
- lm_ggml_custom1_op_f32_t fun),
1957
- "use lm_ggml_map_custom1 instead");
1958
-
1959
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom1_inplace_f32(
1960
- struct lm_ggml_context * ctx,
1961
- struct lm_ggml_tensor * a,
1962
- lm_ggml_custom1_op_f32_t fun),
1963
- "use lm_ggml_map_custom1_inplace instead");
1964
-
1965
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom2_f32(
1966
- struct lm_ggml_context * ctx,
1967
- struct lm_ggml_tensor * a,
1968
- struct lm_ggml_tensor * b,
1969
- lm_ggml_custom2_op_f32_t fun),
1970
- "use lm_ggml_map_custom2 instead");
1971
-
1972
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom2_inplace_f32(
1973
- struct lm_ggml_context * ctx,
1974
- struct lm_ggml_tensor * a,
1975
- struct lm_ggml_tensor * b,
1976
- lm_ggml_custom2_op_f32_t fun),
1977
- "use lm_ggml_map_custom2_inplace instead");
1978
-
1979
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom3_f32(
1980
- struct lm_ggml_context * ctx,
1981
- struct lm_ggml_tensor * a,
1982
- struct lm_ggml_tensor * b,
1983
- struct lm_ggml_tensor * c,
1984
- lm_ggml_custom3_op_f32_t fun),
1985
- "use lm_ggml_map_custom3 instead");
1986
-
1987
- LM_GGML_DEPRECATED(LM_GGML_API struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32(
1988
- struct lm_ggml_context * ctx,
1989
- struct lm_ggml_tensor * a,
1990
- struct lm_ggml_tensor * b,
1991
- struct lm_ggml_tensor * c,
1992
- lm_ggml_custom3_op_f32_t fun),
1993
- "use lm_ggml_map_custom3_inplace instead");
1994
-
1995
- // custom operators v2
1996
-
1997
1940
  typedef void (*lm_ggml_custom1_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, int ith, int nth, void * userdata);
1998
1941
  typedef void (*lm_ggml_custom2_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, const struct lm_ggml_tensor * b, int ith, int nth, void * userdata);
1999
1942
  typedef void (*lm_ggml_custom3_op_t)(struct lm_ggml_tensor * dst , const struct lm_ggml_tensor * a, const struct lm_ggml_tensor * b, const struct lm_ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2049,6 +1992,30 @@ extern "C" {
2049
1992
  int n_tasks,
2050
1993
  void * userdata);
2051
1994
 
1995
+ typedef void (*lm_ggml_custom_op_t)(struct lm_ggml_tensor * dst , int ith, int nth, void * userdata);
1996
+
1997
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_custom_4d(
1998
+ struct lm_ggml_context * ctx,
1999
+ enum lm_ggml_type type,
2000
+ int64_t ne0,
2001
+ int64_t ne1,
2002
+ int64_t ne2,
2003
+ int64_t ne3,
2004
+ struct lm_ggml_tensor ** args,
2005
+ int n_args,
2006
+ lm_ggml_custom_op_t fun,
2007
+ int n_tasks,
2008
+ void * userdata);
2009
+
2010
+ LM_GGML_API struct lm_ggml_tensor * lm_ggml_custom_inplace(
2011
+ struct lm_ggml_context * ctx,
2012
+ struct lm_ggml_tensor * a,
2013
+ struct lm_ggml_tensor ** args,
2014
+ int n_args,
2015
+ lm_ggml_custom_op_t fun,
2016
+ int n_tasks,
2017
+ void * userdata);
2018
+
2052
2019
  // loss function
2053
2020
 
2054
2021
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_cross_entropy_loss(
@@ -23,6 +23,7 @@ enum llm_arch {
23
23
  LLM_ARCH_REFACT,
24
24
  LLM_ARCH_BERT,
25
25
  LLM_ARCH_NOMIC_BERT,
26
+ LLM_ARCH_NOMIC_BERT_MOE,
26
27
  LLM_ARCH_JINA_BERT_V2,
27
28
  LLM_ARCH_BLOOM,
28
29
  LLM_ARCH_STABLELM,
@@ -58,6 +59,7 @@ enum llm_arch {
58
59
  LLM_ARCH_DEEPSEEK,
59
60
  LLM_ARCH_DEEPSEEK2,
60
61
  LLM_ARCH_CHATGLM,
62
+ LLM_ARCH_GLM4,
61
63
  LLM_ARCH_BITNET,
62
64
  LLM_ARCH_T5,
63
65
  LLM_ARCH_T5ENCODER,
@@ -109,6 +111,7 @@ enum llm_kv {
109
111
  LLM_KV_EXPERT_WEIGHTS_SCALE,
110
112
  LLM_KV_EXPERT_WEIGHTS_NORM,
111
113
  LLM_KV_EXPERT_GATING_FUNC,
114
+ LLM_KV_MOE_EVERY_N_LAYERS,
112
115
  LLM_KV_POOLING_TYPE,
113
116
  LLM_KV_LOGIT_SCALE,
114
117
  LLM_KV_DECODER_START_TOKEN_ID,
@@ -143,6 +146,8 @@ enum llm_kv {
143
146
  LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
144
147
  LLM_KV_ATTENTION_SLIDING_WINDOW,
145
148
  LLM_KV_ATTENTION_SCALE,
149
+ LLM_KV_ATTENTION_KEY_LENGTH_MLA,
150
+ LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
146
151
 
147
152
  LLM_KV_ROPE_DIMENSION_COUNT,
148
153
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -256,6 +261,8 @@ enum llm_tensor {
256
261
  LLM_TENSOR_ATTN_Q_NORM,
257
262
  LLM_TENSOR_ATTN_K_NORM,
258
263
  LLM_TENSOR_LAYER_OUT_NORM,
264
+ LLM_TENSOR_POST_ATTN_NORM,
265
+ LLM_TENSOR_POST_MLP_NORM,
259
266
  LLM_TENSOR_SSM_IN,
260
267
  LLM_TENSOR_SSM_CONV1D,
261
268
  LLM_TENSOR_SSM_X,
@@ -303,6 +310,8 @@ enum llm_tensor {
303
310
  LLM_TENSOR_ATTN_Q_B,
304
311
  LLM_TENSOR_ATTN_KV_A_MQA,
305
312
  LLM_TENSOR_ATTN_KV_B,
313
+ LLM_TENSOR_ATTN_K_B,
314
+ LLM_TENSOR_ATTN_V_B,
306
315
  LLM_TENSOR_ATTN_Q_A_NORM,
307
316
  LLM_TENSOR_ATTN_KV_A_NORM,
308
317
  LLM_TENSOR_ATTN_SUB_NORM,
@@ -70,7 +70,8 @@ struct llama_sbatch {
70
70
  // sequence-wise split
71
71
  llama_ubatch split_seq(size_t n_ubatch);
72
72
 
73
- void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
73
+ llama_sbatch() = default;
74
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
74
75
  };
75
76
 
76
77
  // temporary allocate memory for the input batch if needed
@@ -29,8 +29,8 @@ enum llm_chat_template {
29
29
  LLM_CHAT_TEMPLATE_DEEPSEEK_3,
30
30
  LLM_CHAT_TEMPLATE_COMMAND_R,
31
31
  LLM_CHAT_TEMPLATE_LLAMA_3,
32
- LLM_CHAT_TEMPLATE_CHATGML_3,
33
- LLM_CHAT_TEMPLATE_CHATGML_4,
32
+ LLM_CHAT_TEMPLATE_CHATGLM_3,
33
+ LLM_CHAT_TEMPLATE_CHATGLM_4,
34
34
  LLM_CHAT_TEMPLATE_GLMEDGE,
35
35
  LLM_CHAT_TEMPLATE_MINICPM,
36
36
  LLM_CHAT_TEMPLATE_EXAONE_3,
@@ -41,6 +41,7 @@ enum llm_chat_template {
41
41
  LLM_CHAT_TEMPLATE_YANDEX,
42
42
  LLM_CHAT_TEMPLATE_BAILING,
43
43
  LLM_CHAT_TEMPLATE_LLAMA4,
44
+ LLM_CHAT_TEMPLATE_SMOLVLM,
44
45
  LLM_CHAT_TEMPLATE_UNKNOWN,
45
46
  };
46
47
 
@@ -27,7 +27,12 @@ struct llama_context {
27
27
 
28
28
  void synchronize();
29
29
 
30
- const llama_model & get_model() const;
30
+ const llama_model & get_model() const;
31
+ const llama_cparams & get_cparams() const;
32
+
33
+ lm_ggml_backend_sched_t get_sched() const;
34
+
35
+ lm_ggml_context * get_ctx_compute() const;
31
36
 
32
37
  uint32_t n_ctx() const;
33
38
  uint32_t n_ctx_per_seq() const;
@@ -137,50 +142,30 @@ private:
137
142
  // Returns max number of outputs for which space was reserved.
138
143
  int32_t output_reserve(int32_t n_outputs);
139
144
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
145
  //
145
146
  // graph
146
147
  //
147
148
 
149
+ public:
148
150
  int32_t graph_max_nodes() const;
149
151
 
150
152
  // zero-out inputs and create the ctx_compute for the compute graph
151
153
  lm_ggml_cgraph * graph_init();
152
154
 
155
+ // returns the result of lm_ggml_backend_sched_graph_compute_async execution
156
+ lm_ggml_status graph_compute(
157
+ lm_ggml_cgraph * gf,
158
+ bool batched);
159
+
160
+ private:
153
161
  llm_graph_result_ptr graph_build(
154
162
  lm_ggml_context * ctx,
155
163
  lm_ggml_cgraph * gf,
156
164
  const llama_ubatch & ubatch,
157
165
  llm_graph_type gtype);
158
166
 
159
- // returns the result of lm_ggml_backend_sched_graph_compute_async execution
160
- lm_ggml_status graph_compute(
161
- lm_ggml_cgraph * gf,
162
- bool batched);
163
-
164
167
  llm_graph_cb graph_get_cb() const;
165
168
 
166
- // used by kv_self_update()
167
- lm_ggml_tensor * build_rope_shift(
168
- lm_ggml_context * ctx0,
169
- lm_ggml_tensor * cur,
170
- lm_ggml_tensor * shift,
171
- lm_ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale,
174
- lm_ggml_backend_buffer * bbuf) const;
175
-
176
- llm_graph_result_ptr build_kv_self_shift(
177
- lm_ggml_context * ctx0,
178
- lm_ggml_cgraph * gf) const;
179
-
180
- llm_graph_result_ptr build_kv_self_defrag(
181
- lm_ggml_context * ctx0,
182
- lm_ggml_cgraph * gf) const;
183
-
184
169
  // TODO: read/write lora adapters and cvec
185
170
  size_t state_write_data(llama_io_write_i & io);
186
171
  size_t state_read_data (llama_io_read_i & io);
@@ -197,11 +182,10 @@ private:
197
182
  llama_cparams cparams;
198
183
  llama_adapter_cvec cvec;
199
184
  llama_adapter_loras loras;
200
- llama_sbatch sbatch;
201
185
 
202
186
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203
187
 
204
- std::unique_ptr<llama_kv_cache_unified> kv_self;
188
+ std::unique_ptr<llama_memory_i> memory;
205
189
 
206
190
  // TODO: remove
207
191
  bool logits_all = false;
@@ -19,6 +19,7 @@ struct llama_cparams;
19
19
 
20
20
  class llama_memory_i;
21
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_recurrent;
22
23
 
23
24
  // certain models (typically multi-modal) can produce different types of graphs
24
25
  enum llm_graph_type {
@@ -90,29 +91,27 @@ public:
90
91
 
91
92
  class llm_graph_input_pos : public llm_graph_input_i {
92
93
  public:
93
- llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
94
+ llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
94
95
  virtual ~llm_graph_input_pos() = default;
95
96
 
96
97
  void set_input(const llama_ubatch * ubatch) override;
97
98
 
98
99
  lm_ggml_tensor * pos = nullptr; // I32 [n_batch]
99
100
 
100
- const int64_t n_pos_per_token = 1;
101
+ const int64_t n_pos_per_embd = 1;
101
102
  };
102
103
 
103
104
  // temperature tuning, used by llama4
104
105
  class llm_graph_input_attn_temp : public llm_graph_input_i {
105
106
  public:
106
- llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107
- : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
107
+ llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
108
+ : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108
109
  virtual ~llm_graph_input_attn_temp() = default;
109
110
 
110
111
  void set_input(const llama_ubatch * ubatch) override;
111
112
 
112
113
  lm_ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113
114
 
114
- const int64_t n_pos_per_token = 1;
115
-
116
115
  const uint32_t n_attn_temp_floor_scale;
117
116
  const float f_attn_temp_scale;
118
117
  };
@@ -188,26 +187,26 @@ public:
188
187
 
189
188
  class llm_graph_input_s_copy : public llm_graph_input_i {
190
189
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
190
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
191
  virtual ~llm_graph_input_s_copy() = default;
193
192
 
194
193
  void set_input(const llama_ubatch * ubatch) override;
195
194
 
196
195
  lm_ggml_tensor * s_copy; // I32 [kv_size]
197
196
 
198
- const llama_kv_cache_unified * kv_self;
197
+ const llama_kv_cache_recurrent * kv_self;
199
198
  };
200
199
 
201
200
  class llm_graph_input_s_mask : public llm_graph_input_i {
202
201
  public:
203
- llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
202
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
203
  virtual ~llm_graph_input_s_mask() = default;
205
204
 
206
205
  void set_input(const llama_ubatch * ubatch) override;
207
206
 
208
207
  lm_ggml_tensor * s_mask; // F32 [1, n_kv]
209
208
 
210
- const llama_kv_cache_unified * kv_self;
209
+ const llama_kv_cache_recurrent * kv_self;
211
210
  };
212
211
 
213
212
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -352,8 +351,8 @@ struct llm_graph_params {
352
351
  const llama_cparams & cparams;
353
352
  const llama_ubatch & ubatch;
354
353
 
355
- lm_ggml_backend_sched * sched;
356
- lm_ggml_backend * backend_cpu;
354
+ lm_ggml_backend_sched_t sched;
355
+ lm_ggml_backend_t backend_cpu;
357
356
 
358
357
  const llama_adapter_cvec * cvec;
359
358
  const llama_adapter_loras * loras;
@@ -404,9 +403,9 @@ struct llm_graph_context {
404
403
 
405
404
  lm_ggml_context * ctx0 = nullptr;
406
405
 
407
- lm_ggml_backend_sched * sched;
406
+ lm_ggml_backend_sched_t sched;
408
407
 
409
- lm_ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
408
+ lm_ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
410
409
 
411
410
  const llama_adapter_cvec * cvec;
412
411
  const llama_adapter_loras * loras;
@@ -419,7 +418,7 @@ struct llm_graph_context {
419
418
 
420
419
  llm_graph_context(const llm_graph_params & params);
421
420
 
422
- int64_t n_pos_per_token() const;
421
+ int64_t n_pos_per_embd() const;
423
422
 
424
423
  void cb(lm_ggml_tensor * cur, const char * name, int il) const;
425
424
 
@@ -505,11 +504,12 @@ struct llm_graph_context {
505
504
 
506
505
  lm_ggml_tensor * build_attn_mha(
507
506
  lm_ggml_cgraph * gf,
508
- lm_ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
509
- lm_ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
510
- lm_ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
507
+ lm_ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
508
+ lm_ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
509
+ lm_ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
511
510
  lm_ggml_tensor * kq_b,
512
511
  lm_ggml_tensor * kq_mask,
512
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
513
513
  bool v_trans,
514
514
  float kq_scale) const;
515
515
 
@@ -524,6 +524,7 @@ struct llm_graph_context {
524
524
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
525
525
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
526
526
  lm_ggml_tensor * kq_b,
527
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
527
528
  float kq_scale,
528
529
  int il) const;
529
530
 
@@ -538,6 +539,7 @@ struct llm_graph_context {
538
539
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
539
540
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
540
541
  lm_ggml_tensor * kq_b,
542
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
541
543
  float kq_scale,
542
544
  int il) const;
543
545
 
@@ -552,6 +554,7 @@ struct llm_graph_context {
552
554
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
553
555
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
554
556
  lm_ggml_tensor * kq_b,
557
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
555
558
  float kq_scale,
556
559
  int il) const;
557
560
 
@@ -43,6 +43,10 @@ struct llama_hparams {
43
43
  uint32_t n_expert_used = 0;
44
44
  uint32_t n_rel_attn_bkts = 0;
45
45
 
46
+ // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
47
+ uint32_t n_embd_head_k_mla = 0;
48
+ uint32_t n_embd_head_v_mla = 0;
49
+
46
50
  // for WavTokenizer
47
51
  struct llama_hparams_posnet posnet;
48
52
  struct llama_hparams_convnext convnext;
@@ -62,6 +66,7 @@ struct llama_hparams {
62
66
  float expert_weights_scale = 0.0;
63
67
  bool expert_weights_norm = false;
64
68
  uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
69
+ uint32_t moe_every_n_layers = 0;
65
70
 
66
71
  float f_norm_eps;
67
72
  float f_norm_rms_eps;