cui-llama.rn 1.6.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (285) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +22 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
  4. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  5. package/android/src/main/jni.cpp +173 -18
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  16. package/cpp/LICENSE +21 -0
  17. package/cpp/chat.cpp +129 -107
  18. package/cpp/chat.h +2 -0
  19. package/cpp/common.cpp +58 -78
  20. package/cpp/common.h +29 -21
  21. package/cpp/ggml-alloc.c +4 -1
  22. package/cpp/ggml-backend.cpp +9 -5
  23. package/cpp/ggml-backend.h +4 -4
  24. package/cpp/ggml-cpp.h +1 -1
  25. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  26. package/cpp/ggml-cpu/amx/amx.h +8 -0
  27. package/cpp/ggml-cpu/amx/common.h +91 -0
  28. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  29. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  31. package/cpp/ggml-cpu/common.h +72 -0
  32. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
  33. package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
  34. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
  35. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
  36. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
  37. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  38. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  39. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  40. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  41. package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
  42. package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
  43. package/cpp/ggml-cpu.h +5 -0
  44. package/cpp/ggml-impl.h +16 -9
  45. package/cpp/ggml-llama-sim.metallib +0 -0
  46. package/cpp/ggml-llama.metallib +0 -0
  47. package/cpp/ggml-metal-impl.h +36 -11
  48. package/cpp/ggml-metal.m +810 -176
  49. package/cpp/ggml-opt.cpp +373 -190
  50. package/cpp/ggml-opt.h +49 -28
  51. package/cpp/ggml-quants.c +0 -6
  52. package/cpp/ggml.c +227 -282
  53. package/cpp/ggml.h +82 -101
  54. package/cpp/gguf.cpp +33 -33
  55. package/cpp/json-schema-to-grammar.cpp +3 -0
  56. package/cpp/llama-adapter.cpp +6 -0
  57. package/cpp/llama-arch.cpp +49 -17
  58. package/cpp/llama-arch.h +9 -0
  59. package/cpp/llama-batch.cpp +8 -2
  60. package/cpp/llama-batch.h +2 -1
  61. package/cpp/llama-chat.cpp +39 -16
  62. package/cpp/llama-chat.h +4 -2
  63. package/cpp/llama-context.cpp +440 -611
  64. package/cpp/llama-context.h +44 -33
  65. package/cpp/llama-cparams.h +1 -0
  66. package/cpp/llama-graph.cpp +214 -291
  67. package/cpp/llama-graph.h +69 -21
  68. package/cpp/llama-hparams.cpp +17 -1
  69. package/cpp/llama-hparams.h +39 -5
  70. package/cpp/llama-kv-cache.cpp +2067 -620
  71. package/cpp/llama-kv-cache.h +410 -108
  72. package/cpp/llama-memory.h +12 -1
  73. package/cpp/llama-model-loader.cpp +24 -15
  74. package/cpp/llama-model-saver.cpp +281 -0
  75. package/cpp/llama-model-saver.h +37 -0
  76. package/cpp/llama-model.cpp +1089 -359
  77. package/cpp/llama-model.h +19 -3
  78. package/cpp/llama-sampling.cpp +20 -7
  79. package/cpp/llama-vocab.cpp +54 -9
  80. package/cpp/llama-vocab.h +6 -0
  81. package/cpp/llama.cpp +14 -0
  82. package/cpp/llama.h +86 -142
  83. package/cpp/minja/chat-template.hpp +9 -5
  84. package/cpp/minja/minja.hpp +69 -36
  85. package/cpp/rn-llama.cpp +602 -190
  86. package/cpp/rn-llama.h +34 -8
  87. package/cpp/sampling.cpp +57 -50
  88. package/cpp/tools/mtmd/clip-impl.h +462 -0
  89. package/cpp/tools/mtmd/clip.cpp +4024 -0
  90. package/cpp/tools/mtmd/clip.h +101 -0
  91. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  92. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  93. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  94. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  95. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  96. package/cpp/tools/mtmd/mtmd.h +362 -0
  97. package/cpp/tools/mtmd/stb_image.h +7988 -0
  98. package/ios/CMakeLists.txt +20 -10
  99. package/ios/RNLlama.h +6 -0
  100. package/ios/RNLlama.mm +82 -3
  101. package/ios/RNLlamaContext.h +5 -1
  102. package/ios/RNLlamaContext.mm +131 -38
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  131. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  132. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  133. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  134. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  160. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  161. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
  162. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  163. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  164. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  175. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  176. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  177. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  178. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  179. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  180. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  181. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
  182. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  183. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  184. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  185. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  188. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  189. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  190. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  191. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  192. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  193. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  194. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  195. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  196. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  197. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  198. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  199. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  200. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  201. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  202. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  203. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  204. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  205. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  206. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  207. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  208. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  209. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  210. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  211. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  212. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  213. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  214. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  215. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  216. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  217. package/jest/mock.js +33 -7
  218. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  219. package/lib/commonjs/index.js +153 -21
  220. package/lib/commonjs/index.js.map +1 -1
  221. package/lib/module/NativeRNLlama.js.map +1 -1
  222. package/lib/module/index.js +152 -20
  223. package/lib/module/index.js.map +1 -1
  224. package/lib/typescript/NativeRNLlama.d.ts +54 -4
  225. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  226. package/lib/typescript/index.d.ts +72 -6
  227. package/lib/typescript/index.d.ts.map +1 -1
  228. package/package.json +1 -1
  229. package/src/NativeRNLlama.ts +72 -4
  230. package/src/index.ts +212 -38
  231. package/cpp/binary-ops.h +0 -16
  232. package/cpp/ops.h +0 -128
  233. package/cpp/simd-mappings.h +0 -888
  234. package/cpp/unary-ops.h +0 -28
  235. package/cpp/vec.h +0 -802
  236. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  237. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  238. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  239. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  240. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  241. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  242. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  243. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  244. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  245. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  246. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  247. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  248. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  249. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  250. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  251. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  271. package/lib/commonjs/chat.js +0 -37
  272. package/lib/commonjs/chat.js.map +0 -1
  273. package/lib/module/chat.js +0 -33
  274. package/lib/module/chat.js.map +0 -1
  275. package/lib/typescript/chat.d.ts +0 -10
  276. package/lib/typescript/chat.d.ts.map +0 -1
  277. package/src/chat.ts +0 -44
  278. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  279. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  280. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  281. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  282. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  283. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  284. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  285. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
@@ -9,33 +9,6 @@
9
9
  #include <cmath>
10
10
  #include <cstring>
11
11
 
12
- static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13
- // TODO move to hparams if a T5 variant appears that uses a different value
14
- const int64_t max_distance = 128;
15
-
16
- if (bidirectional) {
17
- n_buckets >>= 1;
18
- }
19
-
20
- const int64_t max_exact = n_buckets >> 1;
21
-
22
- int32_t relative_position = x - y;
23
- int32_t relative_bucket = 0;
24
-
25
- if (bidirectional) {
26
- relative_bucket += (relative_position > 0) * n_buckets;
27
- relative_position = abs(relative_position);
28
- } else {
29
- relative_position = -std::min<int32_t>(relative_position, 0);
30
- }
31
-
32
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
33
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
34
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35
-
36
- return relative_bucket;
37
- }
38
-
39
12
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
40
13
  if (ubatch->token) {
41
14
  const int64_t n_tokens = ubatch->n_tokens;
@@ -55,7 +28,21 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
55
28
  if (ubatch->pos && pos) {
56
29
  const int64_t n_tokens = ubatch->n_tokens;
57
30
 
58
- lm_ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*lm_ggml_element_size(pos));
31
+ if (ubatch->token && n_pos_per_embd == 4) {
32
+ // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
33
+ // the 3 first dims are the same, and 4th dim is all 0
34
+ std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
35
+ // copy the first dimension
36
+ for (int i = 0; i < n_tokens; ++i) {
37
+ pos_data[ i] = ubatch->pos[i];
38
+ pos_data[ n_tokens + i] = ubatch->pos[i];
39
+ pos_data[2 * n_tokens + i] = ubatch->pos[i];
40
+ pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
41
+ }
42
+ lm_ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*lm_ggml_element_size(pos));
43
+ } else {
44
+ lm_ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*lm_ggml_element_size(pos));
45
+ }
59
46
  }
60
47
  }
61
48
 
@@ -71,7 +58,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
71
58
  ) * f_attn_temp_scale + 1.0;
72
59
  }
73
60
 
74
- lm_ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*lm_ggml_element_size(attn_scale));
61
+ lm_ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*lm_ggml_element_size(attn_scale));
75
62
  }
76
63
  }
77
64
 
@@ -96,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
96
83
 
97
84
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
98
85
  if (pos_bucket) {
99
- const int64_t n_tokens = ubatch->n_tokens;
100
-
101
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(pos_bucket->buffer));
102
- LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
103
-
104
- int32_t * data = (int32_t *) pos_bucket->data;
105
-
106
- const int64_t n_kv = kv_self->n;
107
-
108
- for (int h = 0; h < 1; ++h) {
109
- for (int j = 0; j < n_tokens; ++j) {
110
- for (int i = 0; i < n_kv; ++i) {
111
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
112
- }
113
- }
114
- }
86
+ kv_self->set_input_pos_bucket(pos_bucket, ubatch);
115
87
  }
116
88
  }
117
89
 
@@ -270,24 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
270
242
 
271
243
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
272
244
  for (uint32_t i = 0; i < n_kv; ++i) {
273
- const uint32_t cell_id = i + kv_self->head;
274
-
275
- //////////////////////////////////////////////
276
- // TODO: this should not mutate the KV cache !
277
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
278
-
279
- // prevent out-of-bound sources
280
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
281
- kv_cell.src = cell_id;
282
- }
283
-
284
- data[i] = kv_cell.src;
285
-
286
- // TODO: do not mutate the KV cache
287
- // ensure copy only happens once
288
- if (kv_cell.src != (int32_t) cell_id) {
289
- kv_cell.src = cell_id;
290
- }
245
+ data[i] = kv_self->s_copy(i);
291
246
  }
292
247
  }
293
248
  }
@@ -303,18 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
303
258
 
304
259
  // clear unused states
305
260
  for (int i = 0; i < n_kv; ++i) {
306
- const uint32_t cell_id = i + kv_self->head;
307
-
308
- //////////////////////////////////////////////
309
- // TODO: this should not mutate the KV cache !
310
- llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
311
-
312
- data[i] = (float) (kv_cell.src >= 0);
313
-
314
- // only clear once
315
- if (kv_cell.src < 0) {
316
- kv_cell.src = cell_id;
317
- }
261
+ data[i] = kv_self->s_mask(i);
318
262
  }
319
263
  }
320
264
  }
@@ -417,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
417
361
  }
418
362
 
419
363
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
420
- if (self_kq_mask || self_kq_mask_swa) {
421
- const int64_t n_kv = kv_self->n;
422
- const int64_t n_tokens = ubatch->n_tokens;
423
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
424
- const int64_t n_seqs = ubatch->n_seqs;
425
-
426
- float * data = nullptr;
427
- float * data_swa = nullptr;
428
-
429
- if (self_kq_mask) {
430
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask->buffer));
431
- data = (float *) self_kq_mask->data;
432
- }
433
-
434
- if (self_kq_mask_swa) {
435
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
436
- data_swa = (float *) self_kq_mask_swa->data;
437
- }
438
-
439
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
440
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
441
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
442
- // Causal mask:
443
- // xxx-------
444
- // xxxx------
445
- // xxxxx-----
446
- // Non-causal mask:
447
- // xxxxx-----
448
- // xxxxx-----
449
- // xxxxx-----
450
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
451
- for (int h = 0; h < 1; ++h) {
452
- for (int s = 0; s < n_seqs; ++s) {
453
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
454
-
455
- for (int j = 0; j < n_seq_tokens; ++j) {
456
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
457
- for (int i = 0; i < n_kv; ++i) {
458
- float f;
459
- // mask the token if:
460
- if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
461
- || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
462
- ) {
463
- f = -INFINITY;
464
- } else {
465
- if (hparams.use_alibi) {
466
- f = -std::abs(kv_self->cells[i].pos - pos);
467
- } else {
468
- f = 0.0f;
469
- }
470
- }
471
-
472
- if (data) {
473
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
474
- }
475
-
476
- // may need to cut off old tokens for sliding window
477
- // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
478
- if (data_swa) {
479
- if (hparams.n_attn_chunk) {
480
- llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
481
- if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
482
- f = -INFINITY;
483
- }
484
- } else {
485
- if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
486
- f = -INFINITY;
487
- }
488
- }
489
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
490
- }
491
- }
492
- }
493
- }
364
+ if (self_kq_mask) {
365
+ kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
+ }
367
+ }
494
368
 
495
- // mask padded tokens
496
- if (data) {
497
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
498
- for (int j = 0; j < n_kv; ++j) {
499
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
500
- }
501
- }
502
- }
369
+ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
+ if (self_kq_mask) {
371
+ kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
+ }
503
373
 
504
- // mask padded tokens
505
- if (data_swa) {
506
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
507
- for (int j = 0; j < n_kv; ++j) {
508
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
509
- }
510
- }
511
- }
512
- }
374
+ if (self_kq_mask_swa) {
375
+ kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
513
376
  }
514
377
  }
515
378
 
@@ -559,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
559
422
  n_layer (hparams.n_layer),
560
423
  n_rot (hparams.n_rot),
561
424
  n_ctx (cparams.n_ctx),
562
- n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
563
425
  n_head (hparams.n_head()),
564
426
  n_head_kv (hparams.n_head_kv()),
565
427
  n_embd_head_k (hparams.n_embd_head_k),
@@ -592,7 +454,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
592
454
  res (std::make_unique<llm_graph_result>()) {
593
455
  }
594
456
 
595
- int64_t llm_graph_context::n_pos_per_token() const {
457
+ int64_t llm_graph_context::n_pos_per_embd() const {
596
458
  return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
597
459
  }
598
460
 
@@ -796,13 +658,17 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
796
658
  } break;
797
659
  }
798
660
 
799
- if (type_gate == LLM_FFN_PAR) {
661
+ if (gate && type_gate == LLM_FFN_PAR) {
800
662
  cur = lm_ggml_mul(ctx0, cur, tmp);
801
663
  cb(cur, "ffn_gate_par", il);
802
664
  }
803
665
 
804
666
  if (down) {
805
667
  cur = build_lora_mm(down, cur);
668
+ if (arch == LLM_ARCH_GLM4) {
669
+ // GLM4 seems to have numerical issues with half-precision accumulators
670
+ lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
671
+ }
806
672
  }
807
673
 
808
674
  if (down_b) {
@@ -910,28 +776,35 @@ lm_ggml_tensor * llm_graph_context::build_moe_ffn(
910
776
  lm_ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
911
777
  cb(up, "ffn_moe_up", il);
912
778
 
913
- lm_ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
914
- cb(gate, "ffn_moe_gate", il);
779
+ lm_ggml_tensor * experts = nullptr;
780
+ if (gate_exps) {
781
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
782
+ cb(cur, "ffn_moe_gate", il);
783
+ } else {
784
+ cur = up;
785
+ }
915
786
 
916
787
  switch (type_op) {
917
788
  case LLM_FFN_SILU:
918
789
  {
919
- gate = lm_ggml_silu(ctx0, gate);
920
- cb(gate, "ffn_moe_silu", il);
790
+ cur = lm_ggml_silu(ctx0, cur);
791
+ cb(cur, "ffn_moe_silu", il);
921
792
  } break;
922
793
  case LLM_FFN_GELU:
923
794
  {
924
- gate = lm_ggml_gelu(ctx0, gate);
925
- cb(gate, "ffn_moe_gelu", il);
795
+ cur = lm_ggml_gelu(ctx0, cur);
796
+ cb(cur, "ffn_moe_gelu", il);
926
797
  } break;
927
798
  default:
928
799
  LM_GGML_ABORT("fatal error");
929
800
  }
930
801
 
931
- lm_ggml_tensor * par = lm_ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
932
- cb(par, "ffn_moe_gate_par", il);
802
+ if (gate_exps) {
803
+ cur = lm_ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
804
+ cb(cur, "ffn_moe_gate_par", il);
805
+ }
933
806
 
934
- lm_ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
807
+ experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
935
808
  cb(experts, "ffn_moe_down", il);
936
809
 
937
810
  if (!weight_before_ffn) {
@@ -974,6 +847,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
974
847
  inp->tokens = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, ubatch.n_tokens);
975
848
  //cb(inp->tokens, "inp_tokens", -1);
976
849
  lm_ggml_set_input(inp->tokens);
850
+ res->t_tokens = inp->tokens;
977
851
 
978
852
  cur = lm_ggml_get_rows(ctx0, tok_embd, inp->tokens);
979
853
 
@@ -1014,11 +888,11 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
1014
888
  }
1015
889
 
1016
890
  lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
1017
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
891
+ auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
1018
892
 
1019
893
  auto & cur = inp->pos;
1020
894
 
1021
- cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*n_pos_per_token());
895
+ cur = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens*n_pos_per_embd());
1022
896
  lm_ggml_set_input(cur);
1023
897
 
1024
898
  res->add_input(std::move(inp));
@@ -1027,11 +901,12 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos() const {
1027
901
  }
1028
902
 
1029
903
  lm_ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1030
- auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
904
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1031
905
 
1032
906
  auto & cur = inp->attn_scale;
1033
907
 
1034
- cur = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
908
+ // this need to be 1x1xN for broadcasting
909
+ cur = lm_ggml_new_tensor_3d(ctx0, LM_GGML_TYPE_F32, 1, 1, n_tokens);
1035
910
  lm_ggml_set_input(cur);
1036
911
 
1037
912
  res->add_input(std::move(inp));
@@ -1079,7 +954,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_cls() const {
1079
954
  }
1080
955
 
1081
956
  lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1082
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
957
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1083
958
 
1084
959
  auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
1085
960
 
@@ -1096,7 +971,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1096
971
  }
1097
972
 
1098
973
  lm_ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1099
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
974
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1100
975
 
1101
976
  auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
1102
977
 
@@ -1154,7 +1029,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1154
1029
 
1155
1030
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1156
1031
 
1157
- const auto n_kv = kv_self->n;
1032
+ const auto n_kv = kv_self->get_n();
1158
1033
 
1159
1034
  auto & cur = inp->pos_bucket;
1160
1035
 
@@ -1188,18 +1063,13 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1188
1063
  lm_ggml_tensor * v,
1189
1064
  lm_ggml_tensor * kq_b,
1190
1065
  lm_ggml_tensor * kq_mask,
1191
- bool v_trans,
1066
+ lm_ggml_tensor * v_mla,
1192
1067
  float kq_scale) const {
1193
- //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1194
- //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1195
-
1196
- //const int64_t n_head = hparams.n_head(il);
1197
- //const int64_t n_head_kv = hparams.n_head_kv(il);
1198
-
1199
- //const auto & n_embd_head_k = hparams.n_embd_head_k;
1200
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
1068
+ const bool v_trans = v->nb[1] > v->nb[2];
1201
1069
 
1202
- const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
1070
+ q = lm_ggml_permute(ctx0, q, 0, 2, 1, 3);
1071
+ k = lm_ggml_permute(ctx0, k, 0, 2, 1, 3);
1072
+ v = lm_ggml_permute(ctx0, v, 0, 2, 1, 3);
1203
1073
 
1204
1074
  const auto n_tokens = q->ne[1];
1205
1075
  const auto n_head = q->ne[2];
@@ -1229,7 +1099,23 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1229
1099
 
1230
1100
  lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
1231
1101
 
1232
- cur = lm_ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
1102
+ if (v_mla) {
1103
+ #if 0
1104
+ // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1105
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1106
+ cur = lm_ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1107
+ cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
1108
+ #else
1109
+ // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1110
+ // The permutations are noops and only change how the tensor data is interpreted.
1111
+ cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
1112
+ cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
1113
+ cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
1114
+ cur = lm_ggml_cont(ctx0, cur); // Needed because lm_ggml_reshape_2d expects contiguous inputs.
1115
+ #endif
1116
+ }
1117
+
1118
+ cur = lm_ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1233
1119
  } else {
1234
1120
  lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q);
1235
1121
 
@@ -1267,9 +1153,14 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1267
1153
 
1268
1154
  lm_ggml_tensor * kqv = lm_ggml_mul_mat(ctx0, v, kq);
1269
1155
 
1270
- lm_ggml_tensor * kqv_merged = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1156
+ // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1157
+ if (v_mla) {
1158
+ kqv = lm_ggml_mul_mat(ctx0, v_mla, kqv);
1159
+ }
1271
1160
 
1272
- cur = lm_ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1161
+ cur = lm_ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1162
+
1163
+ cur = lm_ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1273
1164
 
1274
1165
  if (!cparams.offload_kqv) {
1275
1166
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1304,6 +1195,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1304
1195
  lm_ggml_tensor * k_cur,
1305
1196
  lm_ggml_tensor * v_cur,
1306
1197
  lm_ggml_tensor * kq_b,
1198
+ lm_ggml_tensor * v_mla,
1307
1199
  float kq_scale,
1308
1200
  int il) const {
1309
1201
  LM_GGML_UNUSED(n_tokens);
@@ -1316,17 +1208,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1316
1208
 
1317
1209
  const auto & kq_mask = inp->get_kq_mask();
1318
1210
 
1319
- lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1320
- //cb(q, "q", il);
1321
-
1322
- lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1323
- //cb(k, "k", il);
1324
-
1325
- lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1326
- //cb(k, "v", il);
1327
-
1328
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1211
+ lm_ggml_tensor * q = q_cur;
1212
+ lm_ggml_tensor * k = k_cur;
1213
+ lm_ggml_tensor * v = v_cur;
1329
1214
 
1215
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1330
1216
  cb(cur, "kqv_out", il);
1331
1217
 
1332
1218
  if (wo) {
@@ -1349,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1349
1235
 
1350
1236
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1351
1237
 
1352
- const auto n_kv = kv_self->n;
1353
-
1354
- inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1355
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1356
- lm_ggml_set_input(inp->self_kq_mask);
1357
-
1358
- inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1238
+ {
1239
+ LM_GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1359
1240
 
1360
- if (hparams.n_swa_pattern > 1) {
1361
- LM_GGML_ASSERT(hparams.n_swa > 0);
1241
+ const auto n_kv = kv_self->get_n();
1362
1242
 
1363
- inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1364
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1365
- lm_ggml_set_input(inp->self_kq_mask_swa);
1243
+ inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1244
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1245
+ lm_ggml_set_input(inp->self_kq_mask);
1366
1246
 
1367
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
1247
+ inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1368
1248
  }
1369
1249
 
1370
1250
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@@ -1379,6 +1259,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1379
1259
  lm_ggml_tensor * k_cur,
1380
1260
  lm_ggml_tensor * v_cur,
1381
1261
  lm_ggml_tensor * kq_b,
1262
+ lm_ggml_tensor * v_mla,
1382
1263
  float kq_scale,
1383
1264
  int il) const {
1384
1265
  // these nodes are added to the graph together so that they are not reordered
@@ -1388,87 +1269,108 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1388
1269
  lm_ggml_build_forward_expand(gf, v_cur);
1389
1270
 
1390
1271
  const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1391
- const auto & n_ctx = cparams.n_ctx;
1392
1272
 
1393
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1394
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1273
+ // store to KV cache
1274
+ {
1275
+ lm_ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
+ lm_ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1277
+ }
1395
1278
 
1396
- const auto n_tokens = q_cur->ne[2];
1279
+ const auto & kq_mask = inp->get_kq_mask();
1397
1280
 
1398
- const bool v_trans = !cparams.flash_attn;
1281
+ lm_ggml_tensor * q = q_cur;
1282
+ lm_ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
+ lm_ggml_tensor * v = kv_self->get_v(ctx0, il);
1399
1284
 
1400
- // store to KV cache
1401
- {
1402
- LM_GGML_ASSERT(!kv_self->recurrent);
1285
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
+ cb(cur, "kqv_out", il);
1403
1287
 
1404
- const auto kv_head = kv_self->head;
1288
+ if (wo) {
1289
+ cur = build_lora_mm(wo, cur);
1290
+ }
1405
1291
 
1406
- LM_GGML_ASSERT(kv_self->size == n_ctx);
1292
+ if (wo_b) {
1293
+ cur = lm_ggml_add(ctx0, cur, wo_b);
1294
+ }
1407
1295
 
1408
- lm_ggml_tensor * k_cache_view = lm_ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1409
- //cb(k_cache_view, "k_cache_view", il);
1296
+ return cur;
1297
+ }
1410
1298
 
1411
- // note: storing RoPE-ed version of K in the KV cache
1412
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, k_cur, k_cache_view));
1299
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1300
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1413
1301
 
1414
- v_cur = lm_ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1302
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1415
1303
 
1416
- lm_ggml_tensor * v_cache_view = nullptr;
1304
+ {
1305
+ const auto n_kv = kv_self->get_kv_base()->get_n();
1417
1306
 
1418
- if (!v_trans) {
1419
- v_cache_view = lm_ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1420
- } else {
1421
- // note: the V cache is transposed when not using flash attention
1422
- v_cache_view = lm_ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1423
- ( n_ctx)*lm_ggml_element_size(kv_self->v_l[il]),
1424
- (kv_head)*lm_ggml_element_size(kv_self->v_l[il]));
1307
+ inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1308
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1309
+ lm_ggml_set_input(inp->self_kq_mask);
1425
1310
 
1426
- v_cur = lm_ggml_transpose(ctx0, v_cur);
1427
- }
1428
- //cb(v_cache_view, "v_cache_view", il);
1311
+ inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1312
+ }
1313
+
1314
+ {
1315
+ LM_GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1316
+
1317
+ const auto n_kv = kv_self->get_kv_swa()->get_n();
1318
+
1319
+ inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1320
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1321
+ lm_ggml_set_input(inp->self_kq_mask_swa);
1429
1322
 
1430
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, v_cur, v_cache_view));
1323
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
1431
1324
  }
1432
1325
 
1326
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1327
+ }
1328
+
1329
+ lm_ggml_tensor * llm_graph_context::build_attn(
1330
+ llm_graph_input_attn_kv_unified_iswa * inp,
1331
+ lm_ggml_cgraph * gf,
1332
+ lm_ggml_tensor * wo,
1333
+ lm_ggml_tensor * wo_b,
1334
+ lm_ggml_tensor * q_cur,
1335
+ lm_ggml_tensor * k_cur,
1336
+ lm_ggml_tensor * v_cur,
1337
+ lm_ggml_tensor * kq_b,
1338
+ lm_ggml_tensor * v_mla,
1339
+ float kq_scale,
1340
+ int il) const {
1341
+ // these nodes are added to the graph together so that they are not reordered
1342
+ // by doing so, the number of splits in the graph is reduced
1343
+ lm_ggml_build_forward_expand(gf, q_cur);
1344
+ lm_ggml_build_forward_expand(gf, k_cur);
1345
+ lm_ggml_build_forward_expand(gf, v_cur);
1346
+
1433
1347
  const bool is_swa = hparams.is_swa(il);
1434
1348
 
1349
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1350
+
1351
+ const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1352
+
1353
+ // store to KV cache
1354
+ {
1355
+ lm_ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1356
+ lm_ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1357
+ }
1358
+
1435
1359
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1436
1360
 
1437
- const auto n_kv = kv_self->n;
1361
+ lm_ggml_tensor * q = q_cur;
1362
+ lm_ggml_tensor * k = kv->get_k(ctx0, il);
1363
+ lm_ggml_tensor * v = kv->get_v(ctx0, il);
1438
1364
 
1439
- const int64_t n_head_kv = hparams.n_head_kv(il);
1440
-
1441
- const auto & n_embd_head_k = hparams.n_embd_head_k;
1442
- const auto & n_embd_head_v = hparams.n_embd_head_v;
1443
-
1444
- lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1445
- //cb(q, "q", il);
1446
-
1447
- lm_ggml_tensor * k =
1448
- lm_ggml_view_3d(ctx0, kv_self->k_l[il],
1449
- n_embd_head_k, n_kv, n_head_kv,
1450
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1451
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1452
- 0);
1453
- //cb(k, "k", il);
1454
-
1455
- lm_ggml_tensor * v = !v_trans ?
1456
- lm_ggml_view_3d(ctx0, kv_self->v_l[il],
1457
- n_embd_head_v, n_kv, n_head_kv,
1458
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1459
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1460
- 0) :
1461
- lm_ggml_view_3d(ctx0, kv_self->v_l[il],
1462
- n_kv, n_embd_head_v, n_head_kv,
1463
- lm_ggml_element_size(kv_self->v_l[il])*n_ctx,
1464
- lm_ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1465
- 0);
1466
-
1467
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
1365
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1468
1366
  cb(cur, "kqv_out", il);
1469
1367
 
1470
1368
  if (wo) {
1471
1369
  cur = build_lora_mm(wo, cur);
1370
+ if (arch == LLM_ARCH_GLM4) {
1371
+ // GLM4 seems to have numerical issues with half-precision accumulators
1372
+ lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
1373
+ }
1472
1374
  }
1473
1375
 
1474
1376
  if (wo_b) {
@@ -1504,6 +1406,7 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1504
1406
  lm_ggml_tensor * k_cur,
1505
1407
  lm_ggml_tensor * v_cur,
1506
1408
  lm_ggml_tensor * kq_b,
1409
+ lm_ggml_tensor * v_mla,
1507
1410
  float kq_scale,
1508
1411
  int il) const {
1509
1412
  // these nodes are added to the graph together so that they are not reordered
@@ -1514,17 +1417,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1514
1417
 
1515
1418
  const auto & kq_mask = inp->get_kq_mask_cross();
1516
1419
 
1517
- lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1518
- //cb(q, "q", il);
1519
-
1520
- lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1521
- //cb(k, "k", il);
1522
-
1523
- lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1524
- //cb(k, "v", il);
1525
-
1526
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
1420
+ lm_ggml_tensor * q = q_cur;
1421
+ lm_ggml_tensor * k = k_cur;
1422
+ lm_ggml_tensor * v = v_cur;
1527
1423
 
1424
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1528
1425
  cb(cur, "kqv_out", il);
1529
1426
 
1530
1427
  if (wo) {
@@ -1549,7 +1446,7 @@ lm_ggml_tensor * llm_graph_context::build_copy_mask_state(
1549
1446
  lm_ggml_tensor * state_mask,
1550
1447
  int32_t n_state,
1551
1448
  int32_t n_seqs) const {
1552
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1449
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1553
1450
 
1554
1451
  const auto n_kv = kv_self->n;
1555
1452
  const auto kv_head = kv_self->head;
@@ -1581,7 +1478,7 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1581
1478
  lm_ggml_tensor * state_mask,
1582
1479
  const llama_ubatch & ubatch,
1583
1480
  int il) const {
1584
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1481
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1585
1482
 
1586
1483
  const auto token_shift_count = hparams.token_shift_count;
1587
1484
 
@@ -1602,7 +1499,7 @@ lm_ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1602
1499
  lm_ggml_tensor * token_shift,
1603
1500
  const llama_ubatch & ubatch,
1604
1501
  int il) const {
1605
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1502
+ const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1606
1503
 
1607
1504
  const auto token_shift_count = hparams.token_shift_count;
1608
1505
  const auto n_embd = hparams.n_embd;
@@ -1693,3 +1590,29 @@ void llm_graph_context::build_pooling(
1693
1590
  lm_ggml_build_forward_expand(gf, cur);
1694
1591
  }
1695
1592
 
1593
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1594
+ // TODO move to hparams if a T5 variant appears that uses a different value
1595
+ const int64_t max_distance = 128;
1596
+
1597
+ if (bidirectional) {
1598
+ n_buckets >>= 1;
1599
+ }
1600
+
1601
+ const int64_t max_exact = n_buckets >> 1;
1602
+
1603
+ int32_t relative_position = x - y;
1604
+ int32_t relative_bucket = 0;
1605
+
1606
+ if (bidirectional) {
1607
+ relative_bucket += (relative_position > 0) * n_buckets;
1608
+ relative_position = abs(relative_position);
1609
+ } else {
1610
+ relative_position = -std::min<int32_t>(relative_position, 0);
1611
+ }
1612
+
1613
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
1614
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
1615
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1616
+
1617
+ return relative_bucket;
1618
+ }