cui-llama.rn 1.6.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (285) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +22 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
  4. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  5. package/android/src/main/jni.cpp +173 -18
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  16. package/cpp/LICENSE +21 -0
  17. package/cpp/chat.cpp +129 -107
  18. package/cpp/chat.h +2 -0
  19. package/cpp/common.cpp +58 -78
  20. package/cpp/common.h +29 -21
  21. package/cpp/ggml-alloc.c +4 -1
  22. package/cpp/ggml-backend.cpp +9 -5
  23. package/cpp/ggml-backend.h +4 -4
  24. package/cpp/ggml-cpp.h +1 -1
  25. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  26. package/cpp/ggml-cpu/amx/amx.h +8 -0
  27. package/cpp/ggml-cpu/amx/common.h +91 -0
  28. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  29. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  31. package/cpp/ggml-cpu/common.h +72 -0
  32. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
  33. package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
  34. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
  35. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
  36. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
  37. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  38. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  39. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  40. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  41. package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
  42. package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
  43. package/cpp/ggml-cpu.h +5 -0
  44. package/cpp/ggml-impl.h +16 -9
  45. package/cpp/ggml-llama-sim.metallib +0 -0
  46. package/cpp/ggml-llama.metallib +0 -0
  47. package/cpp/ggml-metal-impl.h +36 -11
  48. package/cpp/ggml-metal.m +810 -176
  49. package/cpp/ggml-opt.cpp +373 -190
  50. package/cpp/ggml-opt.h +49 -28
  51. package/cpp/ggml-quants.c +0 -6
  52. package/cpp/ggml.c +227 -282
  53. package/cpp/ggml.h +82 -101
  54. package/cpp/gguf.cpp +33 -33
  55. package/cpp/json-schema-to-grammar.cpp +3 -0
  56. package/cpp/llama-adapter.cpp +6 -0
  57. package/cpp/llama-arch.cpp +49 -17
  58. package/cpp/llama-arch.h +9 -0
  59. package/cpp/llama-batch.cpp +8 -2
  60. package/cpp/llama-batch.h +2 -1
  61. package/cpp/llama-chat.cpp +39 -16
  62. package/cpp/llama-chat.h +4 -2
  63. package/cpp/llama-context.cpp +440 -611
  64. package/cpp/llama-context.h +44 -33
  65. package/cpp/llama-cparams.h +1 -0
  66. package/cpp/llama-graph.cpp +214 -291
  67. package/cpp/llama-graph.h +69 -21
  68. package/cpp/llama-hparams.cpp +17 -1
  69. package/cpp/llama-hparams.h +39 -5
  70. package/cpp/llama-kv-cache.cpp +2067 -620
  71. package/cpp/llama-kv-cache.h +410 -108
  72. package/cpp/llama-memory.h +12 -1
  73. package/cpp/llama-model-loader.cpp +24 -15
  74. package/cpp/llama-model-saver.cpp +281 -0
  75. package/cpp/llama-model-saver.h +37 -0
  76. package/cpp/llama-model.cpp +1089 -359
  77. package/cpp/llama-model.h +19 -3
  78. package/cpp/llama-sampling.cpp +20 -7
  79. package/cpp/llama-vocab.cpp +54 -9
  80. package/cpp/llama-vocab.h +6 -0
  81. package/cpp/llama.cpp +14 -0
  82. package/cpp/llama.h +86 -142
  83. package/cpp/minja/chat-template.hpp +9 -5
  84. package/cpp/minja/minja.hpp +69 -36
  85. package/cpp/rn-llama.cpp +602 -190
  86. package/cpp/rn-llama.h +34 -8
  87. package/cpp/sampling.cpp +57 -50
  88. package/cpp/tools/mtmd/clip-impl.h +462 -0
  89. package/cpp/tools/mtmd/clip.cpp +4024 -0
  90. package/cpp/tools/mtmd/clip.h +101 -0
  91. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  92. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  93. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  94. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  95. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  96. package/cpp/tools/mtmd/mtmd.h +362 -0
  97. package/cpp/tools/mtmd/stb_image.h +7988 -0
  98. package/ios/CMakeLists.txt +20 -10
  99. package/ios/RNLlama.h +6 -0
  100. package/ios/RNLlama.mm +82 -3
  101. package/ios/RNLlamaContext.h +5 -1
  102. package/ios/RNLlamaContext.mm +131 -38
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  131. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  132. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  133. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  134. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  160. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  161. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
  162. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  163. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  164. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  175. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  176. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  177. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  178. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  179. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  180. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  181. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
  182. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  183. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  184. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  185. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  188. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  189. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  190. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  191. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  192. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  193. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  194. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  195. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  196. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  197. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  198. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  199. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  200. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  201. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  202. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  203. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  204. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  205. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  206. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  207. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  208. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  209. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  210. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  211. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  212. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  213. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  214. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  215. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  216. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  217. package/jest/mock.js +33 -7
  218. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  219. package/lib/commonjs/index.js +153 -21
  220. package/lib/commonjs/index.js.map +1 -1
  221. package/lib/module/NativeRNLlama.js.map +1 -1
  222. package/lib/module/index.js +152 -20
  223. package/lib/module/index.js.map +1 -1
  224. package/lib/typescript/NativeRNLlama.d.ts +54 -4
  225. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  226. package/lib/typescript/index.d.ts +72 -6
  227. package/lib/typescript/index.d.ts.map +1 -1
  228. package/package.json +1 -1
  229. package/src/NativeRNLlama.ts +72 -4
  230. package/src/index.ts +212 -38
  231. package/cpp/binary-ops.h +0 -16
  232. package/cpp/ops.h +0 -128
  233. package/cpp/simd-mappings.h +0 -888
  234. package/cpp/unary-ops.h +0 -28
  235. package/cpp/vec.h +0 -802
  236. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  237. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  238. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  239. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  240. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  241. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  242. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  243. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  244. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  245. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  246. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  247. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  248. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  249. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  250. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  251. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  271. package/lib/commonjs/chat.js +0 -37
  272. package/lib/commonjs/chat.js.map +0 -1
  273. package/lib/module/chat.js +0 -33
  274. package/lib/module/chat.js.map +0 -1
  275. package/lib/typescript/chat.d.ts +0 -10
  276. package/lib/typescript/chat.d.ts.map +0 -1
  277. package/src/chat.ts +0 -44
  278. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  279. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  280. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  281. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  282. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  283. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  284. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  285. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
@@ -23,6 +23,7 @@ enum llm_arch {
23
23
  LLM_ARCH_REFACT,
24
24
  LLM_ARCH_BERT,
25
25
  LLM_ARCH_NOMIC_BERT,
26
+ LLM_ARCH_NOMIC_BERT_MOE,
26
27
  LLM_ARCH_JINA_BERT_V2,
27
28
  LLM_ARCH_BLOOM,
28
29
  LLM_ARCH_STABLELM,
@@ -58,6 +59,7 @@ enum llm_arch {
58
59
  LLM_ARCH_DEEPSEEK,
59
60
  LLM_ARCH_DEEPSEEK2,
60
61
  LLM_ARCH_CHATGLM,
62
+ LLM_ARCH_GLM4,
61
63
  LLM_ARCH_BITNET,
62
64
  LLM_ARCH_T5,
63
65
  LLM_ARCH_T5ENCODER,
@@ -109,6 +111,7 @@ enum llm_kv {
109
111
  LLM_KV_EXPERT_WEIGHTS_SCALE,
110
112
  LLM_KV_EXPERT_WEIGHTS_NORM,
111
113
  LLM_KV_EXPERT_GATING_FUNC,
114
+ LLM_KV_MOE_EVERY_N_LAYERS,
112
115
  LLM_KV_POOLING_TYPE,
113
116
  LLM_KV_LOGIT_SCALE,
114
117
  LLM_KV_DECODER_START_TOKEN_ID,
@@ -143,6 +146,8 @@ enum llm_kv {
143
146
  LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
144
147
  LLM_KV_ATTENTION_SLIDING_WINDOW,
145
148
  LLM_KV_ATTENTION_SCALE,
149
+ LLM_KV_ATTENTION_KEY_LENGTH_MLA,
150
+ LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
146
151
 
147
152
  LLM_KV_ROPE_DIMENSION_COUNT,
148
153
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -256,6 +261,8 @@ enum llm_tensor {
256
261
  LLM_TENSOR_ATTN_Q_NORM,
257
262
  LLM_TENSOR_ATTN_K_NORM,
258
263
  LLM_TENSOR_LAYER_OUT_NORM,
264
+ LLM_TENSOR_POST_ATTN_NORM,
265
+ LLM_TENSOR_POST_MLP_NORM,
259
266
  LLM_TENSOR_SSM_IN,
260
267
  LLM_TENSOR_SSM_CONV1D,
261
268
  LLM_TENSOR_SSM_X,
@@ -303,6 +310,8 @@ enum llm_tensor {
303
310
  LLM_TENSOR_ATTN_Q_B,
304
311
  LLM_TENSOR_ATTN_KV_A_MQA,
305
312
  LLM_TENSOR_ATTN_KV_B,
313
+ LLM_TENSOR_ATTN_K_B,
314
+ LLM_TENSOR_ATTN_V_B,
306
315
  LLM_TENSOR_ATTN_Q_A_NORM,
307
316
  LLM_TENSOR_ATTN_KV_A_NORM,
308
317
  LLM_TENSOR_ATTN_SUB_NORM,
@@ -70,7 +70,8 @@ struct llama_sbatch {
70
70
  // sequence-wise split
71
71
  llama_ubatch split_seq(size_t n_ubatch);
72
72
 
73
- void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
73
+ llama_sbatch() = default;
74
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
74
75
  };
75
76
 
76
77
  // temporary allocate memory for the input batch if needed
@@ -14,6 +14,7 @@ enum llm_chat_template {
14
14
  LLM_CHAT_TEMPLATE_MISTRAL_V3,
15
15
  LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
16
16
  LLM_CHAT_TEMPLATE_MISTRAL_V7,
17
+ LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
17
18
  LLM_CHAT_TEMPLATE_PHI_3,
18
19
  LLM_CHAT_TEMPLATE_PHI_4,
19
20
  LLM_CHAT_TEMPLATE_FALCON_3,
@@ -29,8 +30,8 @@ enum llm_chat_template {
29
30
  LLM_CHAT_TEMPLATE_DEEPSEEK_3,
30
31
  LLM_CHAT_TEMPLATE_COMMAND_R,
31
32
  LLM_CHAT_TEMPLATE_LLAMA_3,
32
- LLM_CHAT_TEMPLATE_CHATGML_3,
33
- LLM_CHAT_TEMPLATE_CHATGML_4,
33
+ LLM_CHAT_TEMPLATE_CHATGLM_3,
34
+ LLM_CHAT_TEMPLATE_CHATGLM_4,
34
35
  LLM_CHAT_TEMPLATE_GLMEDGE,
35
36
  LLM_CHAT_TEMPLATE_MINICPM,
36
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
@@ -41,6 +42,7 @@ enum llm_chat_template {
41
42
  LLM_CHAT_TEMPLATE_YANDEX,
42
43
  LLM_CHAT_TEMPLATE_BAILING,
43
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
+ LLM_CHAT_TEMPLATE_SMOLVLM,
44
46
  LLM_CHAT_TEMPLATE_UNKNOWN,
45
47
  };
46
48
 
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -27,7 +28,12 @@ struct llama_context {
27
28
 
28
29
  void synchronize();
29
30
 
30
- const llama_model & get_model() const;
31
+ const llama_model & get_model() const;
32
+ const llama_cparams & get_cparams() const;
33
+
34
+ lm_ggml_backend_sched_t get_sched() const;
35
+
36
+ lm_ggml_context * get_ctx_compute() const;
31
37
 
32
38
  uint32_t n_ctx() const;
33
39
  uint32_t n_ctx_per_seq() const;
@@ -128,6 +134,32 @@ struct llama_context {
128
134
  llama_perf_context_data perf_get_data() const;
129
135
  void perf_reset();
130
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ lm_ggml_opt_dataset_t dataset,
145
+ lm_ggml_opt_result_t result_train,
146
+ lm_ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ lm_ggml_opt_epoch_callback callback_train,
149
+ lm_ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ lm_ggml_opt_dataset_t dataset,
153
+ lm_ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ lm_ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
131
163
  private:
132
164
  //
133
165
  // output
@@ -137,50 +169,30 @@ private:
137
169
  // Returns max number of outputs for which space was reserved.
138
170
  int32_t output_reserve(int32_t n_outputs);
139
171
 
140
- // make the outputs have the same order they had in the user-provided batch
141
- // TODO: maybe remove this
142
- void output_reorder();
143
-
144
172
  //
145
173
  // graph
146
174
  //
147
175
 
176
+ public:
148
177
  int32_t graph_max_nodes() const;
149
178
 
150
179
  // zero-out inputs and create the ctx_compute for the compute graph
151
180
  lm_ggml_cgraph * graph_init();
152
181
 
182
+ // returns the result of lm_ggml_backend_sched_graph_compute_async execution
183
+ lm_ggml_status graph_compute(
184
+ lm_ggml_cgraph * gf,
185
+ bool batched);
186
+
187
+ private:
153
188
  llm_graph_result_ptr graph_build(
154
189
  lm_ggml_context * ctx,
155
190
  lm_ggml_cgraph * gf,
156
191
  const llama_ubatch & ubatch,
157
192
  llm_graph_type gtype);
158
193
 
159
- // returns the result of lm_ggml_backend_sched_graph_compute_async execution
160
- lm_ggml_status graph_compute(
161
- lm_ggml_cgraph * gf,
162
- bool batched);
163
-
164
194
  llm_graph_cb graph_get_cb() const;
165
195
 
166
- // used by kv_self_update()
167
- lm_ggml_tensor * build_rope_shift(
168
- lm_ggml_context * ctx0,
169
- lm_ggml_tensor * cur,
170
- lm_ggml_tensor * shift,
171
- lm_ggml_tensor * factors,
172
- float freq_base,
173
- float freq_scale,
174
- lm_ggml_backend_buffer * bbuf) const;
175
-
176
- llm_graph_result_ptr build_kv_self_shift(
177
- lm_ggml_context * ctx0,
178
- lm_ggml_cgraph * gf) const;
179
-
180
- llm_graph_result_ptr build_kv_self_defrag(
181
- lm_ggml_context * ctx0,
182
- lm_ggml_cgraph * gf) const;
183
-
184
196
  // TODO: read/write lora adapters and cvec
185
197
  size_t state_write_data(llama_io_write_i & io);
186
198
  size_t state_read_data (llama_io_read_i & io);
@@ -197,14 +209,10 @@ private:
197
209
  llama_cparams cparams;
198
210
  llama_adapter_cvec cvec;
199
211
  llama_adapter_loras loras;
200
- llama_sbatch sbatch;
201
212
 
202
213
  llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203
214
 
204
- std::unique_ptr<llama_kv_cache_unified> kv_self;
205
-
206
- // TODO: remove
207
- bool logits_all = false;
215
+ std::unique_ptr<llama_memory_i> memory;
208
216
 
209
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
210
218
  size_t logits_size = 0; // capacity (of floats) for logits
@@ -231,6 +239,9 @@ private:
231
239
 
232
240
  lm_ggml_context_ptr ctx_compute;
233
241
 
242
+ // training
243
+ lm_ggml_opt_context_t opt_ctx = nullptr;
244
+
234
245
  lm_ggml_threadpool_t threadpool = nullptr;
235
246
  lm_ggml_threadpool_t threadpool_batch = nullptr;
236
247
 
@@ -30,6 +30,7 @@ struct llama_cparams {
30
30
  bool flash_attn;
31
31
  bool no_perf;
32
32
  bool warmup;
33
+ bool op_offload;
33
34
 
34
35
  enum llama_pooling_type pooling_type;
35
36
 
@@ -19,6 +19,8 @@ struct llama_cparams;
19
19
 
20
20
  class llama_memory_i;
21
21
  class llama_kv_cache_unified;
22
+ class llama_kv_cache_unified_iswa;
23
+ class llama_kv_cache_recurrent;
22
24
 
23
25
  // certain models (typically multi-modal) can produce different types of graphs
24
26
  enum llm_graph_type {
@@ -90,29 +92,27 @@ public:
90
92
 
91
93
  class llm_graph_input_pos : public llm_graph_input_i {
92
94
  public:
93
- llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
95
+ llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
94
96
  virtual ~llm_graph_input_pos() = default;
95
97
 
96
98
  void set_input(const llama_ubatch * ubatch) override;
97
99
 
98
100
  lm_ggml_tensor * pos = nullptr; // I32 [n_batch]
99
101
 
100
- const int64_t n_pos_per_token = 1;
102
+ const int64_t n_pos_per_embd = 1;
101
103
  };
102
104
 
103
105
  // temperature tuning, used by llama4
104
106
  class llm_graph_input_attn_temp : public llm_graph_input_i {
105
107
  public:
106
- llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107
- : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108
+ llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
109
+ : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108
110
  virtual ~llm_graph_input_attn_temp() = default;
109
111
 
110
112
  void set_input(const llama_ubatch * ubatch) override;
111
113
 
112
114
  lm_ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113
115
 
114
- const int64_t n_pos_per_token = 1;
115
-
116
116
  const uint32_t n_attn_temp_floor_scale;
117
117
  const float f_attn_temp_scale;
118
118
  };
@@ -188,26 +188,26 @@ public:
188
188
 
189
189
  class llm_graph_input_s_copy : public llm_graph_input_i {
190
190
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
191
+ llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
192
  virtual ~llm_graph_input_s_copy() = default;
193
193
 
194
194
  void set_input(const llama_ubatch * ubatch) override;
195
195
 
196
196
  lm_ggml_tensor * s_copy; // I32 [kv_size]
197
197
 
198
- const llama_kv_cache_unified * kv_self;
198
+ const llama_kv_cache_recurrent * kv_self;
199
199
  };
200
200
 
201
201
  class llm_graph_input_s_mask : public llm_graph_input_i {
202
202
  public:
203
- llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
203
+ llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
204
  virtual ~llm_graph_input_s_mask() = default;
205
205
 
206
206
  void set_input(const llama_ubatch * ubatch) override;
207
207
 
208
208
  lm_ggml_tensor * s_mask; // F32 [1, n_kv]
209
209
 
210
- const llama_kv_cache_unified * kv_self;
210
+ const llama_kv_cache_recurrent * kv_self;
211
211
  };
212
212
 
213
213
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -256,6 +256,31 @@ public:
256
256
 
257
257
  void set_input(const llama_ubatch * ubatch) override;
258
258
 
259
+ lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260
+
261
+ lm_ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262
+ lm_ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
263
+
264
+ const llama_hparams & hparams;
265
+ const llama_cparams & cparams;
266
+
267
+ const llama_kv_cache_unified * kv_self;
268
+ };
269
+
270
+ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
271
+ public:
272
+ llm_graph_input_attn_kv_unified_iswa(
273
+ const llama_hparams & hparams,
274
+ const llama_cparams & cparams,
275
+ const llama_kv_cache_unified_iswa * kv_self) :
276
+ hparams(hparams),
277
+ cparams(cparams),
278
+ kv_self(kv_self) {
279
+ }
280
+ ~llm_graph_input_attn_kv_unified_iswa() = default;
281
+
282
+ void set_input(const llama_ubatch * ubatch) override;
283
+
259
284
  lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260
285
  lm_ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
261
286
 
@@ -267,7 +292,7 @@ public:
267
292
  const llama_hparams & hparams;
268
293
  const llama_cparams & cparams;
269
294
 
270
- const llama_kv_cache_unified * kv_self;
295
+ const llama_kv_cache_unified_iswa * kv_self;
271
296
  };
272
297
 
273
298
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -299,6 +324,7 @@ class llm_graph_result_i {
299
324
  public:
300
325
  virtual ~llm_graph_result_i() = default;
301
326
 
327
+ virtual lm_ggml_tensor * get_tokens() = 0;
302
328
  virtual lm_ggml_tensor * get_logits() = 0;
303
329
  virtual lm_ggml_tensor * get_embd() = 0;
304
330
  virtual lm_ggml_tensor * get_embd_pooled() = 0;
@@ -313,6 +339,7 @@ class llm_graph_result : public llm_graph_result_i {
313
339
  public:
314
340
  virtual ~llm_graph_result() = default;
315
341
 
342
+ lm_ggml_tensor * get_tokens() override { return t_tokens; }
316
343
  lm_ggml_tensor * get_logits() override { return t_logits; }
317
344
  lm_ggml_tensor * get_embd() override { return t_embd; }
318
345
  lm_ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -329,6 +356,7 @@ public:
329
356
  }
330
357
 
331
358
  // important graph nodes
359
+ lm_ggml_tensor * t_tokens = nullptr;
332
360
  lm_ggml_tensor * t_logits = nullptr;
333
361
  lm_ggml_tensor * t_embd = nullptr;
334
362
  lm_ggml_tensor * t_embd_pooled = nullptr;
@@ -352,8 +380,8 @@ struct llm_graph_params {
352
380
  const llama_cparams & cparams;
353
381
  const llama_ubatch & ubatch;
354
382
 
355
- lm_ggml_backend_sched * sched;
356
- lm_ggml_backend * backend_cpu;
383
+ lm_ggml_backend_sched_t sched;
384
+ lm_ggml_backend_t backend_cpu;
357
385
 
358
386
  const llama_adapter_cvec * cvec;
359
387
  const llama_adapter_loras * loras;
@@ -376,7 +404,6 @@ struct llm_graph_context {
376
404
  const int64_t n_layer;
377
405
  const int64_t n_rot;
378
406
  const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
379
- const int64_t n_ctx_per_seq;
380
407
  const int64_t n_head;
381
408
  const int64_t n_head_kv;
382
409
  const int64_t n_embd_head_k;
@@ -404,9 +431,9 @@ struct llm_graph_context {
404
431
 
405
432
  lm_ggml_context * ctx0 = nullptr;
406
433
 
407
- lm_ggml_backend_sched * sched;
434
+ lm_ggml_backend_sched_t sched;
408
435
 
409
- lm_ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
436
+ lm_ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
410
437
 
411
438
  const llama_adapter_cvec * cvec;
412
439
  const llama_adapter_loras * loras;
@@ -419,7 +446,7 @@ struct llm_graph_context {
419
446
 
420
447
  llm_graph_context(const llm_graph_params & params);
421
448
 
422
- int64_t n_pos_per_token() const;
449
+ int64_t n_pos_per_embd() const;
423
450
 
424
451
  void cb(lm_ggml_tensor * cur, const char * name, int il) const;
425
452
 
@@ -505,12 +532,12 @@ struct llm_graph_context {
505
532
 
506
533
  lm_ggml_tensor * build_attn_mha(
507
534
  lm_ggml_cgraph * gf,
508
- lm_ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
509
- lm_ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
510
- lm_ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
535
+ lm_ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
536
+ lm_ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
537
+ lm_ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
511
538
  lm_ggml_tensor * kq_b,
512
539
  lm_ggml_tensor * kq_mask,
513
- bool v_trans,
540
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
514
541
  float kq_scale) const;
515
542
 
516
543
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
@@ -524,6 +551,7 @@ struct llm_graph_context {
524
551
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
525
552
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
526
553
  lm_ggml_tensor * kq_b,
554
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
527
555
  float kq_scale,
528
556
  int il) const;
529
557
 
@@ -538,6 +566,22 @@ struct llm_graph_context {
538
566
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
539
567
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
540
568
  lm_ggml_tensor * kq_b,
569
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
570
+ float kq_scale,
571
+ int il) const;
572
+
573
+ llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
574
+
575
+ lm_ggml_tensor * build_attn(
576
+ llm_graph_input_attn_kv_unified_iswa * inp,
577
+ lm_ggml_cgraph * gf,
578
+ lm_ggml_tensor * wo,
579
+ lm_ggml_tensor * wo_b,
580
+ lm_ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
+ lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
+ lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
583
+ lm_ggml_tensor * kq_b,
584
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
541
585
  float kq_scale,
542
586
  int il) const;
543
587
 
@@ -552,6 +596,7 @@ struct llm_graph_context {
552
596
  lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
553
597
  lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
554
598
  lm_ggml_tensor * kq_b,
599
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
555
600
  float kq_scale,
556
601
  int il) const;
557
602
 
@@ -590,3 +635,6 @@ struct llm_graph_context {
590
635
  lm_ggml_tensor * cls_out,
591
636
  lm_ggml_tensor * cls_out_b) const;
592
637
  };
638
+
639
+ // TODO: better name
640
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
14
14
  LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
15
15
  };
16
16
 
17
+ enum llama_swa_type {
18
+ LLAMA_SWA_TYPE_NONE = 0,
19
+ LLAMA_SWA_TYPE_STANDARD = 1,
20
+ LLAMA_SWA_TYPE_CHUNKED = 2,
21
+ };
22
+
17
23
  struct llama_hparams_posnet {
18
24
  uint32_t n_embd;
19
25
  uint32_t n_layer;
@@ -35,14 +41,16 @@ struct llama_hparams {
35
41
  uint32_t n_embd_features = 0;
36
42
  uint32_t n_layer;
37
43
  uint32_t n_rot;
38
- uint32_t n_swa = 0; // sliding window attention (SWA)
39
- uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
40
44
  uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
41
45
  uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
42
46
  uint32_t n_expert = 0;
43
47
  uint32_t n_expert_used = 0;
44
48
  uint32_t n_rel_attn_bkts = 0;
45
49
 
50
+ // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
51
+ uint32_t n_embd_head_k_mla = 0;
52
+ uint32_t n_embd_head_v_mla = 0;
53
+
46
54
  // for WavTokenizer
47
55
  struct llama_hparams_posnet posnet;
48
56
  struct llama_hparams_convnext convnext;
@@ -62,6 +70,7 @@ struct llama_hparams {
62
70
  float expert_weights_scale = 0.0;
63
71
  bool expert_weights_norm = false;
64
72
  uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
73
+ uint32_t moe_every_n_layers = 0;
65
74
 
66
75
  float f_norm_eps;
67
76
  float f_norm_rms_eps;
@@ -91,6 +100,15 @@ struct llama_hparams {
91
100
 
92
101
  std::array<int, 4> rope_sections;
93
102
 
103
+ // Sliding Window Attention (SWA)
104
+ llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105
+ // the size of the sliding window (0 - no SWA)
106
+ uint32_t n_swa = 0;
107
+ // if swa_layers[il] == true, then layer il is SWA
108
+ // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109
+ // by default, all layers are dense
110
+ std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
111
+
94
112
  // for State Space Models
95
113
  uint32_t ssm_d_conv = 0;
96
114
  uint32_t ssm_d_inner = 0;
@@ -111,11 +129,10 @@ struct llama_hparams {
111
129
  bool causal_attn = true;
112
130
  bool use_alibi = false;
113
131
  bool attn_soft_cap = false;
132
+ bool use_kq_norm = true;
114
133
 
134
+ // llama4
115
135
  uint32_t n_moe_layer_step = 0;
116
- bool use_kq_norm = true;
117
- uint32_t n_attn_chunk = 0;
118
- // values below seems to be fixed on llama4
119
136
  uint32_t n_no_rope_layer_step = 4;
120
137
  uint32_t n_attn_temp_floor_scale = 8192;
121
138
  float f_attn_temp_scale = 0.1;
@@ -128,6 +145,23 @@ struct llama_hparams {
128
145
  enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
129
146
  enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
130
147
 
148
+ // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
149
+ // note that if n_pattern == 0, all layers are SWA
150
+ // if n_pattern == 1, all layers are dense
151
+ // example: n_pattern = 3
152
+ // il == 0: swa
153
+ // il == 1: swa
154
+ // il == 2: dense
155
+ // il == 3: swa
156
+ // il == 4: swa
157
+ // il == 5: dense
158
+ // il == 6: swa
159
+ // etc ...
160
+ void set_swa_pattern(uint32_t n_pattern);
161
+
162
+ // return true if one of the layers is SWA
163
+ bool is_swa_any() const;
164
+
131
165
  uint32_t n_head(uint32_t il = 0) const;
132
166
 
133
167
  uint32_t n_head_kv(uint32_t il = 0) const;