cui-llama.rn 1.7.3 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +94 -8
  4. package/android/src/main/java/com/rnllama/RNLlama.java +247 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -108,7 +108,7 @@ static void lm_ggml_compute_forward_dup_f16(
108
108
  for (int i01 = ir0; i01 < ir1; i01++) {
109
109
  const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
110
110
  for (int i00 = 0; i00 < ne00; i00++) {
111
- dst_ptr[id] = LM_GGML_FP16_TO_FP32(src0_ptr[i00]);
111
+ dst_ptr[id] = LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
112
112
  id++;
113
113
  }
114
114
  }
@@ -130,7 +130,7 @@ static void lm_ggml_compute_forward_dup_f16(
130
130
  const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
131
131
 
132
132
  for (int i00 = 0; i00 < ne00; i00++) {
133
- src0_f32[i00] = LM_GGML_FP16_TO_FP32(src0_ptr[i00]);
133
+ src0_f32[i00] = LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
134
134
  }
135
135
 
136
136
  quantize_row_q(src0_f32, dst_ptr + id, ne00);
@@ -156,7 +156,7 @@ static void lm_ggml_compute_forward_dup_f16(
156
156
  for (int i00 = 0; i00 < ne00; i00++) {
157
157
  const lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
158
158
 
159
- dst_ptr[id] = LM_GGML_FP16_TO_FP32(*src0_ptr);
159
+ dst_ptr[id] = LM_GGML_CPU_FP16_TO_FP32(*src0_ptr);
160
160
  id++;
161
161
  }
162
162
  }
@@ -267,7 +267,7 @@ static void lm_ggml_compute_forward_dup_f16(
267
267
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
268
268
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
269
269
 
270
- *(float *) dst_ptr = LM_GGML_FP16_TO_FP32(*(const lm_ggml_fp16_t *) src0_ptr);
270
+ *(float *) dst_ptr = LM_GGML_CPU_FP16_TO_FP32(*(const lm_ggml_fp16_t *) src0_ptr);
271
271
 
272
272
  if (++i10 == ne0) {
273
273
  i10 = 0;
@@ -372,7 +372,7 @@ static void lm_ggml_compute_forward_dup_bf16(
372
372
  for (int i01 = ir0; i01 < ir1; i01++) {
373
373
  const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
374
374
  for (int i00 = 0; i00 < ne00; i00++) {
375
- dst_ptr[id] = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(src0_ptr[i00]));
375
+ dst_ptr[id] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(src0_ptr[i00]));
376
376
  id++;
377
377
  }
378
378
  }
@@ -473,7 +473,7 @@ static void lm_ggml_compute_forward_dup_bf16(
473
473
  for (int i00 = 0; i00 < ne00; i00++) {
474
474
  const lm_ggml_bf16_t * src0_ptr = (lm_ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
475
475
 
476
- dst_ptr[id] = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*src0_ptr));
476
+ dst_ptr[id] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*src0_ptr));
477
477
  id++;
478
478
  }
479
479
  }
@@ -566,7 +566,7 @@ static void lm_ggml_compute_forward_dup_bf16(
566
566
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
567
567
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
568
568
 
569
- *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr));
569
+ *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_BF16_TO_FP32(*(const lm_ggml_bf16_t *) src0_ptr));
570
570
 
571
571
  if (++i10 == ne0) {
572
572
  i10 = 0;
@@ -765,7 +765,7 @@ static void lm_ggml_compute_forward_dup_f32(
765
765
  for (int i00 = 0; i00 < ne00; i00++) {
766
766
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
767
767
 
768
- dst_ptr[id] = LM_GGML_FP32_TO_FP16(*src0_ptr);
768
+ dst_ptr[id] = LM_GGML_CPU_FP32_TO_FP16(*src0_ptr);
769
769
  id++;
770
770
  }
771
771
  }
@@ -878,7 +878,7 @@ static void lm_ggml_compute_forward_dup_f32(
878
878
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
879
879
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
880
880
 
881
- *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_FP32_TO_FP16(*(const float *) src0_ptr);
881
+ *(lm_ggml_fp16_t *) dst_ptr = LM_GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
882
882
 
883
883
  if (++i10 == ne0) {
884
884
  i10 = 0;
@@ -1419,7 +1419,7 @@ static void lm_ggml_compute_forward_add1_f16_f32(
1419
1419
  lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1420
1420
  lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1421
1421
  for (int i = 0; i < ne0; i++) {
1422
- dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + v);
1422
+ dst_ptr[i] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
1423
1423
  }
1424
1424
  }
1425
1425
  }
@@ -1435,7 +1435,7 @@ static void lm_ggml_compute_forward_add1_f16_f16(
1435
1435
  LM_GGML_ASSERT(lm_ggml_is_scalar(src1));
1436
1436
 
1437
1437
  // scalar to add
1438
- const float v = LM_GGML_FP16_TO_FP32(*(lm_ggml_fp16_t *) src1->data);
1438
+ const float v = LM_GGML_CPU_FP16_TO_FP32(*(lm_ggml_fp16_t *) src1->data);
1439
1439
 
1440
1440
  const int ith = params->ith;
1441
1441
  const int nth = params->nth;
@@ -1467,7 +1467,7 @@ static void lm_ggml_compute_forward_add1_f16_f16(
1467
1467
  lm_ggml_fp16_t * dst_ptr = (lm_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1468
1468
  lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1469
1469
  for (int i = 0; i < ne0; i++) {
1470
- dst_ptr[i] = LM_GGML_FP32_TO_FP16(LM_GGML_FP16_TO_FP32(src0_ptr[i]) + v);
1470
+ dst_ptr[i] = LM_GGML_CPU_FP32_TO_FP16(LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
1471
1471
  }
1472
1472
  }
1473
1473
  }
@@ -1889,7 +1889,7 @@ static void lm_ggml_compute_forward_sum_f16(
1889
1889
  }
1890
1890
  }
1891
1891
  }
1892
- ((lm_ggml_fp16_t *) dst->data)[0] = LM_GGML_FP32_TO_FP16(sum);
1892
+ ((lm_ggml_fp16_t *) dst->data)[0] = LM_GGML_CPU_FP32_TO_FP16(sum);
1893
1893
  }
1894
1894
 
1895
1895
  static void lm_ggml_compute_forward_sum_bf16(
@@ -2660,7 +2660,7 @@ static void lm_ggml_compute_forward_gelu_f16(
2660
2660
  #ifndef NDEBUG
2661
2661
  for (int k = 0; k < nc; k++) {
2662
2662
  const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2663
- const float v = LM_GGML_FP16_TO_FP32(x);
2663
+ const float v = LM_GGML_CPU_FP16_TO_FP32(x);
2664
2664
  LM_GGML_UNUSED(v);
2665
2665
  assert(!isnan(v));
2666
2666
  assert(!isinf(v));
@@ -2763,7 +2763,7 @@ static void lm_ggml_compute_forward_gelu_erf_f16(
2763
2763
  #ifndef NDEBUG
2764
2764
  for (int k = 0; k < nc; k++) {
2765
2765
  const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766
- const float v = LM_GGML_FP16_TO_FP32(x);
2766
+ const float v = LM_GGML_CPU_FP16_TO_FP32(x);
2767
2767
  LM_GGML_UNUSED(v);
2768
2768
  assert(!isnan(v));
2769
2769
  assert(!isinf(v));
@@ -2866,7 +2866,7 @@ static void lm_ggml_compute_forward_gelu_quick_f16(
2866
2866
  #ifndef NDEBUG
2867
2867
  for (int k = 0; k < nc; k++) {
2868
2868
  const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2869
- const float v = LM_GGML_FP16_TO_FP32(x);
2869
+ const float v = LM_GGML_CPU_FP16_TO_FP32(x);
2870
2870
  LM_GGML_UNUSED(v);
2871
2871
  assert(!isnan(v));
2872
2872
  assert(!isinf(v));
@@ -2969,7 +2969,7 @@ static void lm_ggml_compute_forward_silu_f16(
2969
2969
  #ifndef NDEBUG
2970
2970
  for (int k = 0; k < nc; k++) {
2971
2971
  const lm_ggml_fp16_t x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2972
- const float v = LM_GGML_FP16_TO_FP32(x);
2972
+ const float v = LM_GGML_CPU_FP16_TO_FP32(x);
2973
2973
  LM_GGML_UNUSED(v);
2974
2974
  assert(!isnan(v));
2975
2975
  assert(!isinf(v));
@@ -3163,7 +3163,7 @@ static void lm_ggml_compute_forward_silu_back_f16(
3163
3163
  #ifndef NDEBUG
3164
3164
  for (int k = 0; k < nc; k++) {
3165
3165
  const float x = ((lm_ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3166
- const float v = LM_GGML_FP16_TO_FP32(x);
3166
+ const float v = LM_GGML_CPU_FP16_TO_FP32(x);
3167
3167
  LM_GGML_UNUSED(v);
3168
3168
  assert(!isnan(v));
3169
3169
  assert(!isinf(v));
@@ -4500,7 +4500,7 @@ static void lm_ggml_compute_forward_get_rows_back_f32_f16(
4500
4500
 
4501
4501
  for (int j = 0; j < nc; ++j) {
4502
4502
  lm_ggml_fp16_t v = ((lm_ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4503
- ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += LM_GGML_FP16_TO_FP32(v);
4503
+ ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += LM_GGML_CPU_FP16_TO_FP32(v);
4504
4504
  }
4505
4505
  }
4506
4506
  }
@@ -4792,7 +4792,7 @@ static void lm_ggml_compute_forward_soft_max_f32(
4792
4792
  if (mp_f32) {
4793
4793
  if (use_f16) {
4794
4794
  for (int i = 0; i < nc; ++i) {
4795
- wp[i] += slope*LM_GGML_FP16_TO_FP32(mp_f16[i]);
4795
+ wp[i] += slope*LM_GGML_CPU_FP16_TO_FP32(mp_f16[i]);
4796
4796
  }
4797
4797
  } else {
4798
4798
  for (int i = 0; i < nc; ++i) {
@@ -5018,8 +5018,8 @@ static void lm_ggml_compute_forward_clamp_f16(
5018
5018
  lm_ggml_fp16_t * src0_ptr = (lm_ggml_fp16_t *) ((char *) src0->data + j*nb01);
5019
5019
 
5020
5020
  for (int i = 0; i < nc; i++) {
5021
- float v = LM_GGML_FP16_TO_FP32(src0_ptr[i]);
5022
- dst_ptr[i] = LM_GGML_FP32_TO_FP16(MAX(MIN(v, max), min));
5021
+ float v = LM_GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5022
+ dst_ptr[i] = LM_GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5023
5023
  }
5024
5024
  }
5025
5025
  }
@@ -5476,11 +5476,11 @@ static void lm_ggml_compute_forward_rope_f16(
5476
5476
  const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5477
5477
  lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5478
5478
 
5479
- const float x0 = LM_GGML_FP16_TO_FP32(src[0]);
5480
- const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims]);
5479
+ const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
5480
+ const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5481
5481
 
5482
- dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5483
- dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5482
+ dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5483
+ dst_data[n_dims] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5484
5484
  }
5485
5485
  } else {
5486
5486
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@@ -5492,11 +5492,11 @@ static void lm_ggml_compute_forward_rope_f16(
5492
5492
  const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5493
5493
  lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5494
5494
 
5495
- const float x0 = LM_GGML_FP16_TO_FP32(src[0]);
5496
- const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]);
5495
+ const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
5496
+ const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5497
5497
 
5498
- dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5499
- dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5498
+ dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5499
+ dst_data[n_dims/2] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5500
5500
  }
5501
5501
  }
5502
5502
  } else {
@@ -5507,11 +5507,11 @@ static void lm_ggml_compute_forward_rope_f16(
5507
5507
  const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5508
5508
  lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5509
5509
 
5510
- const float x0 = LM_GGML_FP16_TO_FP32(src[0]);
5511
- const float x1 = LM_GGML_FP16_TO_FP32(src[1]);
5510
+ const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
5511
+ const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[1]);
5512
5512
 
5513
- dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5514
- dst_data[1] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5513
+ dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5514
+ dst_data[1] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5515
5515
  }
5516
5516
  }
5517
5517
 
@@ -5525,11 +5525,11 @@ static void lm_ggml_compute_forward_rope_f16(
5525
5525
  const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5526
5526
  lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5527
5527
 
5528
- const float x0 = LM_GGML_FP16_TO_FP32(src[0]);
5529
- const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims]);
5528
+ const float x0 = LM_GGML_CPU_FP16_TO_FP32(src[0]);
5529
+ const float x1 = LM_GGML_CPU_FP16_TO_FP32(src[n_dims]);
5530
5530
 
5531
- dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5532
- dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5531
+ dst_data[0] = LM_GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5532
+ dst_data[n_dims] = LM_GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5533
5533
  }
5534
5534
  } else {
5535
5535
  for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
@@ -5640,7 +5640,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32(
5640
5640
  for (int64_t i11 = 0; i11 < ne11; i11++) {
5641
5641
  const float * const src = (float *)((char *) src1->data + i11*nb11);
5642
5642
  for (int64_t i10 = 0; i10 < ne10; i10++) {
5643
- dst_data[i10*ne11 + i11] = LM_GGML_FP32_TO_FP16(src[i10]);
5643
+ dst_data[i10*ne11 + i11] = LM_GGML_CPU_FP32_TO_FP16(src[i10]);
5644
5644
  }
5645
5645
  }
5646
5646
  }
@@ -5933,7 +5933,7 @@ static void lm_ggml_compute_forward_im2col_f16(
5933
5933
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
5934
5934
  dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
5935
5935
  } else {
5936
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = LM_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
5936
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = LM_GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
5937
5937
  }
5938
5938
  }
5939
5939
  }
@@ -6109,7 +6109,7 @@ void lm_ggml_compute_forward_conv_transpose_2d(
6109
6109
  const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6110
6110
  lm_ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
6111
6111
  for (int i10 = 0; i10 < ne10; i10++) {
6112
- dst_data[i10*ne12 + i12] = LM_GGML_FP32_TO_FP16(src[i10]);
6112
+ dst_data[i10*ne12 + i12] = LM_GGML_CPU_FP32_TO_FP16(src[i10]);
6113
6113
  }
6114
6114
  }
6115
6115
  }
@@ -6358,7 +6358,7 @@ static void lm_ggml_compute_forward_pool_1d_sk_p0(
6358
6358
  case LM_GGML_OP_POOL_COUNT: LM_GGML_ABORT("fatal error");
6359
6359
  }
6360
6360
  for (int ki = 0; ki < k; ++ki) {
6361
- const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]);
6361
+ const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]);
6362
6362
  switch (op) {
6363
6363
  case LM_GGML_OP_POOL_AVG: drow[i] += srow_j; break;
6364
6364
  case LM_GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
@@ -6450,7 +6450,7 @@ void lm_ggml_compute_forward_pool_2d(
6450
6450
  for (int kx = 0; kx < k0; ++kx) {
6451
6451
  int j = ix + kx;
6452
6452
  if (j < 0 || j >= src->ne[0]) continue;
6453
- const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]);
6453
+ const float srow_j = (src->type == LM_GGML_TYPE_F32) ? ((const float*)srow)[j] : LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t*)srow)[j]);
6454
6454
  switch (op) {
6455
6455
  case LM_GGML_OP_POOL_AVG: *out += srow_j; break;
6456
6456
  case LM_GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
@@ -6538,7 +6538,7 @@ void lm_ggml_compute_forward_pool_2d_back(
6538
6538
  }
6539
6539
 
6540
6540
  const float val = dst->type == LM_GGML_TYPE_F32 ?
6541
- ((const float *) drowf)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]);
6541
+ ((const float *) drowf)[j] : LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]);
6542
6542
  if (val <= maxval) {
6543
6543
  continue;
6544
6544
  }
@@ -6558,7 +6558,7 @@ void lm_ggml_compute_forward_pool_2d_back(
6558
6558
  if (dst->type == LM_GGML_TYPE_F32) {
6559
6559
  ((float *) drow)[j] += grad0;
6560
6560
  } else {
6561
- ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_FP32_TO_FP16(grad0 + LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j]));
6561
+ ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_CPU_FP32_TO_FP16(grad0 + LM_GGML_CPU_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j]));
6562
6562
  }
6563
6563
  } else if (op == LM_GGML_OP_POOL_AVG) {
6564
6564
  const float grad = grad0 / ka;
@@ -6577,7 +6577,7 @@ void lm_ggml_compute_forward_pool_2d_back(
6577
6577
  if (dst->type == LM_GGML_TYPE_F32) {
6578
6578
  ((float *) drow)[j] += grad;
6579
6579
  } else {
6580
- ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_FP32_TO_FP16(grad);
6580
+ ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_CPU_FP32_TO_FP16(grad);
6581
6581
  }
6582
6582
  }
6583
6583
  }
@@ -6793,6 +6793,73 @@ void lm_ggml_compute_forward_pad_reflect_1d(
6793
6793
  }
6794
6794
  }
6795
6795
 
6796
+ // lm_ggml_compute_forward_roll
6797
+
6798
+ static int64_t lm_ggml_wrap_index(int64_t i, int64_t ne) {
6799
+ if (i < 0) {
6800
+ return i + ne;
6801
+ } else if (i >= ne) {
6802
+ return i - ne;
6803
+ }
6804
+ return i;
6805
+ }
6806
+
6807
+ static void lm_ggml_compute_forward_roll_f32(
6808
+ const lm_ggml_compute_params * params,
6809
+ lm_ggml_tensor * dst) {
6810
+
6811
+ const lm_ggml_tensor * src0 = dst->src[0];
6812
+ const float * src_data = (const float *) src0->data;
6813
+ float * dst_data = (float *) dst->data;
6814
+
6815
+ LM_GGML_TENSOR_UNARY_OP_LOCALS
6816
+
6817
+ const int s0 = lm_ggml_get_op_params_i32(dst, 0);
6818
+ const int s1 = lm_ggml_get_op_params_i32(dst, 1);
6819
+ const int s2 = lm_ggml_get_op_params_i32(dst, 2);
6820
+ const int s3 = lm_ggml_get_op_params_i32(dst, 3);
6821
+
6822
+ const int64_t total = ne1 * ne2 * ne3;
6823
+ const int64_t per_thread = (total + params->nth) / params->nth;
6824
+ const int64_t start = params->ith * per_thread;
6825
+ const int64_t end = std::min(start + per_thread, total);
6826
+
6827
+ for (int64_t i = start; i < end; ++i) {
6828
+ const int64_t i1 = i % ne1;
6829
+ const int64_t i2 = (i / ne1) % ne2;
6830
+ const int64_t i3 = i / (ne2 * ne1);
6831
+ float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
6832
+
6833
+ const int64_t i01 = lm_ggml_wrap_index(i1 - s1, ne01);
6834
+ const int64_t i02 = lm_ggml_wrap_index(i2 - s2, ne02);
6835
+ const int64_t i03 = lm_ggml_wrap_index(i3 - s3, ne03);
6836
+ const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
6837
+
6838
+ const int64_t s = lm_ggml_wrap_index(-s0, ne00);
6839
+ const int64_t n = ne00 - s;
6840
+ lm_ggml_vec_cpy_f32(n, dst_row, src_row + s);
6841
+ lm_ggml_vec_cpy_f32(s, dst_row + n, src_row);
6842
+ }
6843
+ }
6844
+
6845
+ void lm_ggml_compute_forward_roll(
6846
+ const lm_ggml_compute_params * params,
6847
+ lm_ggml_tensor * dst) {
6848
+
6849
+ const lm_ggml_tensor * src0 = dst->src[0];
6850
+
6851
+ switch (src0->type) {
6852
+ case LM_GGML_TYPE_F32:
6853
+ {
6854
+ lm_ggml_compute_forward_roll_f32(params, dst);
6855
+ } break;
6856
+ default:
6857
+ {
6858
+ LM_GGML_ABORT("fatal error");
6859
+ }
6860
+ }
6861
+ }
6862
+
6796
6863
  // lm_ggml_compute_forward_arange
6797
6864
 
6798
6865
  static void lm_ggml_compute_forward_arange_f32(
@@ -7075,7 +7142,7 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
7075
7142
  // loop over n_kv and n_head_kv
7076
7143
  // ref: https://arxiv.org/pdf/2112.05682.pdf
7077
7144
  for (int64_t ic = 0; ic < nek1; ++ic) {
7078
- const float mv = mp ? slope*LM_GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
7145
+ const float mv = mp ? slope*LM_GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
7079
7146
  if (mv == -INFINITY) {
7080
7147
  continue;
7081
7148
  }
@@ -7143,7 +7210,7 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
7143
7210
 
7144
7211
  if (v->type == LM_GGML_TYPE_F16) {
7145
7212
  for (int64_t d = 0; d < DV; ++d) {
7146
- VKQ32[d] = LM_GGML_FP16_TO_FP32(VKQ16[d]);
7213
+ VKQ32[d] = LM_GGML_CPU_FP16_TO_FP32(VKQ16[d]);
7147
7214
  }
7148
7215
  }
7149
7216
 
@@ -7633,39 +7700,83 @@ static void lm_ggml_compute_forward_ssm_scan_f32(
7633
7700
  const int ir1 = MIN(ir0 + dr, nr);
7634
7701
  const int ir = ir1 - ir0;
7635
7702
 
7636
- for (int i3 = 0; i3 < n_s; ++i3) {
7637
- for (int i2 = 0; i2 < n_t; ++i2) {
7638
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7639
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7640
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7641
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7642
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7643
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7646
-
7647
- // use the output as the source for the next token-wise iterations
7648
- if (i2 > 0) { s0 = s; }
7649
-
7650
- // d_inner
7651
- for (int i1 = 0; i1 < ir; ++i1) {
7652
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654
- float x_dt = x[i1] * dt_soft_plus;
7655
- float sumf = 0.0f;
7656
- // d_state
7657
- for (int i0 = 0; i0 < nc; ++i0) {
7658
- int i = i0 + i1*nc;
7659
- // state = prev_state * dA + dB * x
7660
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661
- // y = rowwise_dotprod(state, C)
7662
- sumf += state * C[i0];
7663
- s[i] = state;
7703
+ #ifdef __ARM_FEATURE_SVE
7704
+ for (int i3 = 0; i3 < n_s; ++i3) {
7705
+ for (int i2 = 0; i2 < n_t; ++i2) {
7706
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7707
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7708
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7709
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7710
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7711
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7712
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7713
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7714
+
7715
+ // use the output as the source for the next token-wise iterations
7716
+ if (i2 > 0) { s0 = s; }
7717
+
7718
+ // d_inner
7719
+ for (int i1 = 0; i1 < ir; ++i1) {
7720
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7721
+ float x_dt = x[i1] * dt_soft_plus;
7722
+ svfloat32_t vx_dt = LM_GGML_F32_VEC_SET1(x_dt);
7723
+ svfloat32_t vdt_soft_plus = LM_GGML_F32_VEC_SET1(dt_soft_plus);
7724
+ svfloat32_t r1_vector = LM_GGML_F32_VEC_ZERO;
7725
+
7726
+ for (int64_t k = 0; k < nc; k += svcntw()) {
7727
+ svfloat32_t vA = LM_GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7728
+ svfloat32_t vB = LM_GGML_F32_VEC_LOAD(&B[k]);
7729
+ svfloat32_t vC = LM_GGML_F32_VEC_LOAD(&C[k]);
7730
+ svfloat32_t vs0 = LM_GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7731
+
7732
+ svfloat32_t t1 = LM_GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7733
+ t1 = exp_ps_sve(svptrue_b32(), t1);
7734
+ svfloat32_t t2 = LM_GGML_F32_VEC_MUL(vx_dt, vB);
7735
+
7736
+ vs0 = LM_GGML_F32_VEC_FMA(vs0, t1, t2);
7737
+ r1_vector = LM_GGML_F32_VEC_ADD(LM_GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7738
+
7739
+ LM_GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
7740
+ }
7741
+ y[i1] = LM_GGML_F32xt_REDUCE_ONE(r1_vector);
7664
7742
  }
7665
- y[i1] = sumf;
7666
7743
  }
7667
7744
  }
7668
- }
7745
+ #else
7746
+ for (int i3 = 0; i3 < n_s; ++i3) {
7747
+ for (int i2 = 0; i2 < n_t; ++i2) {
7748
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7749
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7750
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7751
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7752
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7753
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7754
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7755
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7756
+
7757
+ // use the output as the source for the next token-wise iterations
7758
+ if (i2 > 0) { s0 = s; }
7759
+
7760
+ // d_inner
7761
+ for (int i1 = 0; i1 < ir; ++i1) {
7762
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7763
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7764
+ float x_dt = x[i1] * dt_soft_plus;
7765
+ float sumf = 0.0f;
7766
+ // d_state
7767
+ for (int i0 = 0; i0 < nc; ++i0) {
7768
+ int i = i0 + i1*nc;
7769
+ // state = prev_state * dA + dB * x
7770
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7771
+ // y = rowwise_dotprod(state, C)
7772
+ sumf += state * C[i0];
7773
+ s[i] = state;
7774
+ }
7775
+ y[i1] = sumf;
7776
+ }
7777
+ }
7778
+ }
7779
+ #endif
7669
7780
  }
7670
7781
 
7671
7782
  void lm_ggml_compute_forward_ssm_scan(
@@ -8070,6 +8181,14 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
8070
8181
  #define LM_GGML_F32X_MUL LM_GGML_F32x16_MUL
8071
8182
  #define LM_GGML_F32X_FMA LM_GGML_F32x16_FMA
8072
8183
  #define WKV_VECTOR_SIZE 16
8184
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8185
+ #define LM_GGML_F32X LM_GGML_F32xt
8186
+ #define LM_GGML_F32X_SET1 LM_GGML_F32xt_SET1
8187
+ #define LM_GGML_F32X_LOAD LM_GGML_F32xt_LOAD
8188
+ #define LM_GGML_F32X_STORE LM_GGML_F32xt_STORE
8189
+ #define LM_GGML_F32X_MUL LM_GGML_F32xt_MUL
8190
+ #define LM_GGML_F32X_FMA LM_GGML_F32xt_FMA
8191
+ #define WKV_VECTOR_SIZE 8
8073
8192
  #elif defined(__ARM_NEON) && defined(__aarch64__)
8074
8193
  #define LM_GGML_F32X LM_GGML_F32x4
8075
8194
  #define LM_GGML_F32X_SET1 LM_GGML_F32x4_SET1
@@ -8081,7 +8200,13 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
8081
8200
  #endif
8082
8201
 
8083
8202
  #ifdef WKV_VECTOR_SIZE
8084
- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8203
+ int wkv_vector_size;
8204
+ #if defined(__ARM_FEATURE_SVE)
8205
+ wkv_vector_size = svcntw();
8206
+ #else
8207
+ wkv_vector_size = WKV_VECTOR_SIZE;
8208
+ #endif
8209
+ const int64_t vec_count = head_size / wkv_vector_size;
8085
8210
 
8086
8211
  for (int64_t t = 0; t < T; t++) {
8087
8212
  size_t t_offset = t * t_stride;
@@ -8111,7 +8236,7 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
8111
8236
  LM_GGML_F32X time_decay_vec = LM_GGML_F32X_SET1(time_decay_val);
8112
8237
 
8113
8238
  for (int64_t j = 0; j < vec_count; j++) {
8114
- size_t base_j = j * WKV_VECTOR_SIZE;
8239
+ size_t base_j = j * wkv_vector_size;
8115
8240
  size_t t_h_j_offset = t_h_offset + base_j;
8116
8241
  size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
8117
8242
 
@@ -8136,7 +8261,7 @@ static void lm_ggml_compute_forward_rwkv_wkv6_f32(
8136
8261
  }
8137
8262
 
8138
8263
  // Handle remaining elements, this will not be used.
8139
- for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
8264
+ for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
8140
8265
  size_t t_h_j_offset = t_h_offset + j;
8141
8266
  size_t h_2d_i_j_offset = h_2d_i_offset + j;
8142
8267
  float v_val = v[t_h_j_offset];
@@ -8272,6 +8397,14 @@ static void lm_ggml_compute_forward_gla_f32(
8272
8397
  #define LM_GGML_F32X_MUL LM_GGML_F32x16_MUL
8273
8398
  #define LM_GGML_F32X_FMA LM_GGML_F32x16_FMA
8274
8399
  #define GLA_VECTOR_SIZE 16
8400
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8401
+ #define LM_GGML_F32X LM_GGML_F32xt
8402
+ #define LM_GGML_F32X_SET1 LM_GGML_F32xt_SET1
8403
+ #define LM_GGML_F32X_LOAD LM_GGML_F32xt_LOAD
8404
+ #define LM_GGML_F32X_STORE LM_GGML_F32xt_STORE
8405
+ #define LM_GGML_F32X_MUL LM_GGML_F32xt_MUL
8406
+ #define LM_GGML_F32X_FMA LM_GGML_F32xt_FMA
8407
+ #define GLA_VECTOR_SIZE 8
8275
8408
  #elif defined(__ARM_NEON) && defined(__aarch64__)
8276
8409
  #define LM_GGML_F32X LM_GGML_F32x4
8277
8410
  #define LM_GGML_F32X_SET1 LM_GGML_F32x4_SET1
@@ -8283,7 +8416,13 @@ static void lm_ggml_compute_forward_gla_f32(
8283
8416
  #endif
8284
8417
 
8285
8418
  #ifdef GLA_VECTOR_SIZE
8286
- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8419
+ int gla_vector_size;
8420
+ #if defined(__ARM_FEATURE_SVE)
8421
+ gla_vector_size = svcntw();
8422
+ #else
8423
+ gla_vector_size = GLA_VECTOR_SIZE;
8424
+ #endif
8425
+ const int64_t vec_count = head_size / gla_vector_size;
8287
8426
 
8288
8427
  for (int64_t t = 0; t < T; t++) {
8289
8428
  size_t t_offset = t * t_stride;
@@ -8310,7 +8449,7 @@ static void lm_ggml_compute_forward_gla_f32(
8310
8449
  LM_GGML_F32X g_vec = LM_GGML_F32X_SET1(g_val);
8311
8450
 
8312
8451
  for (int64_t j = 0; j < vec_count; j++) {
8313
- size_t base_j = j * GLA_VECTOR_SIZE;
8452
+ size_t base_j = j * gla_vector_size;
8314
8453
  size_t t_h_j_offset = t_h_offset + base_j;
8315
8454
  size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
8316
8455
 
@@ -8334,7 +8473,7 @@ static void lm_ggml_compute_forward_gla_f32(
8334
8473
  }
8335
8474
 
8336
8475
  // Handle remaining elements, this will not be used.
8337
- for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
8476
+ for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
8338
8477
  size_t t_h_j_offset = t_h_offset + j;
8339
8478
  size_t h_2d_i_j_offset = h_2d_i_offset + j;
8340
8479
  float v_val = v[t_h_j_offset];
@@ -8443,83 +8582,126 @@ static void lm_ggml_compute_forward_rwkv_wkv7_f32(
8443
8582
  int64_t h_stride_2d = head_size * head_size;
8444
8583
 
8445
8584
  #if defined(LM_GGML_SIMD)
8446
- for (int64_t t = 0; t < T; t++) {
8447
- int64_t t_offset = t * t_stride;
8448
- int64_t state_offset = head_size * C * (t / (T / n_seqs));
8449
- float * state_cur = state + state_offset;
8450
- float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8451
-
8452
- for (int64_t h = h_start; h < h_end; h++) {
8453
- int64_t h_offset = h * h_stride;
8454
- int64_t t_h_offset = t_offset + h_offset;
8455
- int64_t h_2d_offset = h * h_stride_2d;
8456
-
8457
- for (int64_t ii = 0; ii < head_size; ii++) {
8458
- int64_t t_h_i_offset = t_h_offset + ii;
8459
- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8460
-
8461
- LM_GGML_F32_VEC v_vec = LM_GGML_F32_VEC_SET1(v[t_h_i_offset]);
8585
+ #if defined(__ARM_FEATURE_SVE)
8586
+ // scalar Route to scalar implementation //TODO: Write SVE code
8587
+ for (int64_t t = 0; t < T; t++) {
8588
+ int64_t t_offset = t * t_stride;
8589
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8590
+ float * state_cur = state + state_offset;
8591
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8592
+
8593
+ for (int64_t h = h_start; h < h_end; h++) {
8594
+ int64_t h_offset = h * h_stride;
8595
+ int64_t t_h_offset = t_offset + h_offset;
8596
+ int64_t h_2d_offset = h * h_stride_2d;
8597
+
8598
+ for (int64_t i = 0; i < head_size; i++) {
8599
+ int64_t t_h_i_offset = t_h_offset + i;
8600
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8601
+
8602
+ float v_val = v[t_h_i_offset];
8603
+
8604
+ float sa = 0, result = 0;
8605
+ for (int64_t j = 0; j < head_size; j++) {
8606
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8607
+ }
8462
8608
 
8463
- float sa = 0;
8464
- {
8465
- LM_GGML_F32_VEC sum[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
8466
- LM_GGML_F32_VEC ax[LM_GGML_F32_ARR];
8467
- LM_GGML_F32_VEC ay[LM_GGML_F32_ARR];
8468
- for (int64_t j = 0; j < head_size; j += LM_GGML_F32_STEP) {
8469
- for (int64_t kk = 0; kk < LM_GGML_F32_ARR; kk++) {
8470
- ax[kk] = LM_GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * LM_GGML_F32_EPR]);
8471
- ay[kk] = LM_GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * LM_GGML_F32_EPR]);
8472
- sum[kk] = LM_GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8473
- }
8609
+ for (int64_t j = 0; j < head_size; j++) {
8610
+ int64_t t_h_j_offset = t_h_offset + j;
8611
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8612
+
8613
+ float r_val = r[t_h_j_offset];
8614
+ float w_val = w[t_h_j_offset];
8615
+ float k_val = k[t_h_j_offset];
8616
+ float b_val = b[t_h_j_offset];
8617
+ float kv_val = v_val * k_val;
8618
+ float prev_state_val = state_prev[h_2d_i_j_offset];
8619
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8620
+ result += state_cur[h_2d_i_j_offset] * r_val;
8474
8621
  }
8475
- LM_GGML_F32_VEC_REDUCE(sa, sum);
8622
+ dst_data[t_h_i_offset] = result;
8476
8623
  }
8624
+ }
8625
+ }
8626
+ #else
8627
+ for (int64_t t = 0; t < T; t++) {
8628
+ int64_t t_offset = t * t_stride;
8629
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8630
+ float * state_cur = state + state_offset;
8631
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8632
+
8633
+ for (int64_t h = h_start; h < h_end; h++) {
8634
+ int64_t h_offset = h * h_stride;
8635
+ int64_t t_h_offset = t_offset + h_offset;
8636
+ int64_t h_2d_offset = h * h_stride_2d;
8637
+
8638
+ for (int64_t ii = 0; ii < head_size; ii++) {
8639
+ int64_t t_h_i_offset = t_h_offset + ii;
8640
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8641
+
8642
+ LM_GGML_F32_VEC v_vec = LM_GGML_F32_VEC_SET1(v[t_h_i_offset]);
8643
+
8644
+ float sa = 0;
8645
+ {
8646
+ LM_GGML_F32_VEC sum[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
8647
+ LM_GGML_F32_VEC ax[LM_GGML_F32_ARR];
8648
+ LM_GGML_F32_VEC ay[LM_GGML_F32_ARR];
8649
+ for (int64_t j = 0; j < head_size; j += LM_GGML_F32_STEP) {
8650
+ for (int64_t kk = 0; kk < LM_GGML_F32_ARR; kk++) {
8651
+ ax[kk] = LM_GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * LM_GGML_F32_EPR]);
8652
+ ay[kk] = LM_GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * LM_GGML_F32_EPR]);
8653
+ sum[kk] = LM_GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8654
+ }
8655
+ }
8656
+ LM_GGML_F32_VEC_REDUCE(sa, sum);
8657
+ }
8477
8658
 
8478
- LM_GGML_F32_VEC sa_vec = LM_GGML_F32_VEC_SET1(sa);
8659
+ LM_GGML_F32_VEC sa_vec = LM_GGML_F32_VEC_SET1(sa);
8479
8660
 
8480
- int64_t j = 0;
8481
- LM_GGML_F32_VEC result_vec[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
8482
- for (; j < head_size; j += LM_GGML_F32_STEP) {
8483
- for (int64_t kk = 0; kk < LM_GGML_F32_ARR; kk++) {
8484
- int64_t t_h_j_offset = t_h_offset + j + kk * LM_GGML_F32_EPR;
8485
- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * LM_GGML_F32_EPR;
8661
+ int64_t j = 0;
8662
+ LM_GGML_F32_VEC result_vec[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
8663
+ for (; j < head_size; j += LM_GGML_F32_STEP) {
8664
+ for (int64_t kk = 0; kk < LM_GGML_F32_ARR; kk++) {
8665
+ int64_t t_h_j_offset = t_h_offset + j + kk * LM_GGML_F32_EPR;
8666
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * LM_GGML_F32_EPR;
8486
8667
 
8487
- LM_GGML_F32_VEC r_vec = LM_GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8488
- LM_GGML_F32_VEC w_vec = LM_GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8489
- LM_GGML_F32_VEC k_vec = LM_GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8490
- LM_GGML_F32_VEC b_vec = LM_GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8668
+ LM_GGML_F32_VEC r_vec = LM_GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8669
+ LM_GGML_F32_VEC w_vec = LM_GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8670
+ LM_GGML_F32_VEC k_vec = LM_GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8671
+ LM_GGML_F32_VEC b_vec = LM_GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8491
8672
 
8492
- k_vec = LM_GGML_F32_VEC_MUL(v_vec, k_vec);
8673
+ k_vec = LM_GGML_F32_VEC_MUL(v_vec, k_vec);
8493
8674
 
8494
- LM_GGML_F32_VEC state_vec = LM_GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8495
- // kv + s * decay + sa * b
8496
- state_vec = LM_GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8497
- state_vec = LM_GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8498
- LM_GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8675
+ LM_GGML_F32_VEC state_vec = LM_GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8676
+ // kv + s * decay + sa * b
8677
+ state_vec = LM_GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8678
+ state_vec = LM_GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8679
+ LM_GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8499
8680
 
8500
- result_vec[kk] = LM_GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8681
+ result_vec[kk] = LM_GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8682
+ }
8683
+ }
8684
+ LM_GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8685
+
8686
+ // There shouldn't be left-overs though.
8687
+ for (; j < head_size; j++) {
8688
+ int64_t t_h_j_offset = t_h_offset + j;
8689
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8690
+
8691
+ float r_val = r[t_h_j_offset];
8692
+ float w_val = w[t_h_j_offset];
8693
+ float k_val = k[t_h_j_offset];
8694
+ float b_val = b[t_h_j_offset];
8695
+ float kv_val = v[t_h_i_offset] * k_val;
8696
+
8697
+ float prev_state_val = state_prev[h_2d_i_j_offset];
8698
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8699
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
8501
8700
  }
8502
- }
8503
- LM_GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8504
-
8505
- // There shouldn't be left-overs though.
8506
- for (; j < head_size; j++) {
8507
- int64_t t_h_j_offset = t_h_offset + j;
8508
- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8509
-
8510
- float r_val = r[t_h_j_offset];
8511
- float w_val = w[t_h_j_offset];
8512
- float k_val = k[t_h_j_offset];
8513
- float b_val = b[t_h_j_offset];
8514
- float kv_val = v[t_h_i_offset] * k_val;
8515
-
8516
- float prev_state_val = state_prev[h_2d_i_j_offset];
8517
- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8518
- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
8519
8701
  }
8520
8702
  }
8521
8703
  }
8522
- }
8704
+ #endif
8523
8705
  #else
8524
8706
  for (int64_t t = 0; t < T; t++) {
8525
8707
  int64_t t_offset = t * t_stride;