cui-llama.rn 1.7.4 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
  4. package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
package/cpp/ggml-metal.m CHANGED
@@ -48,22 +48,28 @@ static struct lm_ggml_backend_metal_device_context {
48
48
  int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
+ NSLock * mtl_lock;
52
+
51
53
  bool has_simdgroup_reduction;
52
54
  bool has_simdgroup_mm;
53
55
  bool has_residency_sets;
54
56
  bool has_bfloat;
55
57
  bool use_bfloat;
56
58
 
59
+ size_t max_size;
60
+
57
61
  char name[128];
58
62
  } g_lm_ggml_ctx_dev_main = {
59
63
  /*.mtl_device =*/ nil,
60
64
  /*.mtl_device_ref_count =*/ 0,
61
65
  /*.mtl_library =*/ nil,
66
+ /*.mtl_lock =*/ nil,
62
67
  /*.has_simdgroup_reduction =*/ false,
63
68
  /*.has_simdgroup_mm =*/ false,
64
69
  /*.has_residency_sets =*/ false,
65
70
  /*.has_bfloat =*/ false,
66
71
  /*.use_bfloat =*/ false,
72
+ /*.max_size =*/ 0,
67
73
  /*.name =*/ "",
68
74
  };
69
75
 
@@ -71,6 +77,10 @@ static struct lm_ggml_backend_metal_device_context {
71
77
  static id<MTLDevice> lm_ggml_backend_metal_device_acq(struct lm_ggml_backend_metal_device_context * ctx) {
72
78
  assert(ctx != NULL);
73
79
 
80
+ if (ctx->mtl_lock == nil) {
81
+ ctx->mtl_lock = [[NSLock alloc] init];
82
+ }
83
+
74
84
  if (ctx->mtl_device == nil) {
75
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
86
  }
@@ -94,6 +104,8 @@ static id<MTLDevice> lm_ggml_backend_metal_device_acq(struct lm_ggml_backend_met
94
104
  ctx->use_bfloat = false;
95
105
  #endif
96
106
 
107
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
108
+
97
109
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98
110
  }
99
111
 
@@ -110,6 +122,11 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device
110
122
  ctx->mtl_device_ref_count--;
111
123
 
112
124
  if (ctx->mtl_device_ref_count == 0) {
125
+ if (ctx->mtl_lock) {
126
+ [ctx->mtl_lock release];
127
+ ctx->mtl_lock = nil;
128
+ }
129
+
113
130
  if (ctx->mtl_library) {
114
131
  [ctx->mtl_library release];
115
132
  ctx->mtl_library = nil;
@@ -194,11 +211,14 @@ enum lm_ggml_metal_kernel_type {
194
211
  LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
212
  LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
196
213
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
214
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
197
215
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
216
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
198
217
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
199
218
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
200
219
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
201
220
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
221
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
202
222
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
203
223
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
204
224
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -498,6 +518,7 @@ enum lm_ggml_metal_kernel_type {
498
518
  LM_GGML_METAL_KERNEL_TYPE_COS,
499
519
  LM_GGML_METAL_KERNEL_TYPE_NEG,
500
520
  LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
521
+ LM_GGML_METAL_KERNEL_TYPE_MEAN,
501
522
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
523
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
524
  LM_GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -976,7 +997,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
976
997
  struct lm_ggml_backend_metal_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_context));
977
998
  struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
978
999
 
979
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
1000
+ id<MTLDevice> device = ctx_dev->mtl_device;
980
1001
 
981
1002
  LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
982
1003
 
@@ -990,9 +1011,16 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
990
1011
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
991
1012
 
992
1013
  // load library
993
- if (ctx_dev->mtl_library == nil) {
994
- ctx_dev->mtl_library = lm_ggml_metal_load_library(device, ctx_dev->use_bfloat);
1014
+ {
1015
+ [ctx_dev->mtl_lock lock];
1016
+
1017
+ if (ctx_dev->mtl_library == nil) {
1018
+ ctx_dev->mtl_library = lm_ggml_metal_load_library(device, ctx_dev->use_bfloat);
1019
+ }
1020
+
1021
+ [ctx_dev->mtl_lock unlock];
995
1022
  }
1023
+
996
1024
  id<MTLLibrary> metal_library = ctx_dev->mtl_library;
997
1025
  if (metal_library == nil) {
998
1026
  LM_GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1150,11 +1178,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1150
1178
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1151
1179
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1152
1180
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1181
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1153
1182
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1183
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1154
1184
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1155
1185
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1156
1186
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1157
1187
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1188
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1158
1189
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1159
1190
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1160
1191
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -1454,6 +1485,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1454
1485
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
1486
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1456
1487
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1488
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1457
1489
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
1490
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
1491
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1653,6 +1685,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1653
1685
  case LM_GGML_OP_LOG:
1654
1686
  return false; // TODO: implement
1655
1687
  case LM_GGML_OP_SUM_ROWS:
1688
+ case LM_GGML_OP_MEAN:
1656
1689
  case LM_GGML_OP_SOFT_MAX:
1657
1690
  case LM_GGML_OP_GROUP_NORM:
1658
1691
  return has_simdgroup_reduction && lm_ggml_is_contiguous(op->src[0]);
@@ -2400,11 +2433,31 @@ static bool lm_ggml_metal_encode_node(
2400
2433
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2401
2434
  } break;
2402
2435
  case LM_GGML_OP_SUM_ROWS:
2436
+ case LM_GGML_OP_MEAN:
2403
2437
  {
2404
2438
  LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type));
2405
2439
 
2406
- id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2440
+ id<MTLComputePipelineState> pipeline = nil;
2441
+
2442
+ switch (dst->op) {
2443
+ case LM_GGML_OP_SUM_ROWS:
2444
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2445
+ break;
2446
+ case LM_GGML_OP_MEAN:
2447
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2448
+ break;
2449
+ default:
2450
+ LM_GGML_ABORT("fatal error");
2451
+ }
2452
+
2453
+ int nth = 32; // SIMD width
2454
+
2455
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2456
+ nth *= 2;
2457
+ }
2407
2458
 
2459
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
2460
+ nth = MIN(nth, ne00);
2408
2461
 
2409
2462
  lm_ggml_metal_kargs_sum_rows args = {
2410
2463
  /*.ne00 =*/ ne00,
@@ -2434,11 +2487,12 @@ static bool lm_ggml_metal_encode_node(
2434
2487
  };
2435
2488
 
2436
2489
  [encoder setComputePipelineState:pipeline];
2437
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2438
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2439
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
2490
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2491
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2492
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2493
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2440
2494
 
2441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2495
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2442
2496
  } break;
2443
2497
  case LM_GGML_OP_SOFT_MAX:
2444
2498
  {
@@ -3063,14 +3117,23 @@ static bool lm_ggml_metal_encode_node(
3063
3117
  nsg = 1;
3064
3118
  nr0 = 1;
3065
3119
  nr1 = 4;
3066
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3120
+ if (ne00 == 4) {
3121
+ nr0 = 32;
3122
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3123
+ } else {
3124
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3125
+ }
3067
3126
  } break;
3068
3127
  case LM_GGML_TYPE_F16:
3069
3128
  {
3070
3129
  nsg = 1;
3071
3130
  nr0 = 1;
3072
3131
  if (src1t == LM_GGML_TYPE_F32) {
3073
- if (ne11 * ne12 < 4) {
3132
+ if (ne00 == 4) {
3133
+ nr0 = 32;
3134
+ nr1 = 4;
3135
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3136
+ } else if (ne11 * ne12 < 4) {
3074
3137
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3075
3138
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3076
3139
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3089,7 +3152,11 @@ static bool lm_ggml_metal_encode_node(
3089
3152
  nsg = 1;
3090
3153
  nr0 = 1;
3091
3154
  if (src1t == LM_GGML_TYPE_F32) {
3092
- if (ne11 * ne12 < 4) {
3155
+ if (ne00 == 4) {
3156
+ nr0 = 32;
3157
+ nr1 = 4;
3158
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3159
+ } else if (ne11 * ne12 < 4) {
3093
3160
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3094
3161
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3095
3162
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
@@ -3733,6 +3800,7 @@ static bool lm_ggml_metal_encode_node(
3733
3800
  nth *= 2;
3734
3801
  }
3735
3802
 
3803
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3736
3804
  nth = MIN(nth, ne00/4);
3737
3805
 
3738
3806
  lm_ggml_metal_kargs_rms_norm args = {
@@ -3769,6 +3837,7 @@ static bool lm_ggml_metal_encode_node(
3769
3837
  nth *= 2;
3770
3838
  }
3771
3839
 
3840
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3772
3841
  nth = MIN(nth, ne00/4);
3773
3842
 
3774
3843
  lm_ggml_metal_kargs_l2_norm args = {
@@ -3841,6 +3910,7 @@ static bool lm_ggml_metal_encode_node(
3841
3910
  nth *= 2;
3842
3911
  }
3843
3912
 
3913
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3844
3914
  nth = MIN(nth, ne00/4);
3845
3915
 
3846
3916
  lm_ggml_metal_kargs_norm args = {
@@ -4766,6 +4836,8 @@ static bool lm_ggml_metal_encode_node(
4766
4836
  LM_GGML_ASSERT(nqptg % 8 == 0);
4767
4837
  LM_GGML_ASSERT(ncpsg % 32 == 0);
4768
4838
 
4839
+ const int is_q = lm_ggml_is_quantized(src1->type) ? 1 : 0;
4840
+
4769
4841
  // 2*(2*ncpsg + nqptg)*(nsg)
4770
4842
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4771
4843
  //
@@ -4773,7 +4845,7 @@ static bool lm_ggml_metal_encode_node(
4773
4845
  // the shared memory needed for the simdgroups to load the KV cache
4774
4846
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4775
4847
  //
4776
- #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
4848
+ #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
4777
4849
 
4778
4850
  int64_t nsgmax = 2;
4779
4851
 
@@ -4810,9 +4882,9 @@ static bool lm_ggml_metal_encode_node(
4810
4882
  // and store the soft_max values and the mask
4811
4883
  //
4812
4884
  // ne00*(nsg)
4813
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
4885
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
4814
4886
  //
4815
- #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(LM_GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4887
+ #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(LM_GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
4816
4888
 
4817
4889
  int64_t nsgmax = 2;
4818
4890
  while (true) {
@@ -4925,8 +4997,39 @@ static bool lm_ggml_metal_encode_node(
4925
4997
  default: LM_GGML_ABORT("not implemented");
4926
4998
  }
4927
4999
 
5000
+ LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
5001
+
5002
+ // TODO: support
5003
+ //const int32_t nk00 = ne00/lm_ggml_blck_size(dst->type);
5004
+ const int32_t nk00 = ne00;
5005
+
5006
+ int nth = 32; // SIMD width
5007
+
5008
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
5009
+ nth *= 2;
5010
+ }
5011
+
5012
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
5013
+
5014
+ // when rows are small, we can batch them together in a single threadgroup
5015
+ int nrptg = 1;
5016
+
5017
+ // TODO: relax this constraint in the future
5018
+ if (lm_ggml_blck_size(src0->type) == 1 && lm_ggml_blck_size(dst->type) == 1) {
5019
+ if (nth > nk00) {
5020
+ nrptg = (nth + nk00 - 1)/nk00;
5021
+ nth = nk00;
5022
+
5023
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5024
+ nrptg--;
5025
+ }
5026
+ }
5027
+ }
5028
+
5029
+ nth = MIN(nth, nk00);
5030
+
4928
5031
  lm_ggml_metal_kargs_cpy args = {
4929
- /*.ne00 =*/ ne00,
5032
+ /*.ne00 =*/ nk00,
4930
5033
  /*.ne01 =*/ ne01,
4931
5034
  /*.ne02 =*/ ne02,
4932
5035
  /*.ne03 =*/ ne03,
@@ -4949,11 +5052,7 @@ static bool lm_ggml_metal_encode_node(
4949
5052
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4950
5053
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4951
5054
 
4952
- LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
4953
- int nth = MIN(1024, ne00/lm_ggml_blck_size(src0->type));
4954
-
4955
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4956
-
5055
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4957
5056
  } break;
4958
5057
  case LM_GGML_OP_SET:
4959
5058
  {
@@ -5259,7 +5358,6 @@ static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t bu
5259
5358
  }
5260
5359
 
5261
5360
  lm_ggml_backend_metal_buffer_rset_free(ctx);
5262
- lm_ggml_backend_metal_device_rel(buffer->buft->device->context);
5263
5361
 
5264
5362
  if (ctx->owned) {
5265
5363
  #if TARGET_OS_OSX
@@ -5368,7 +5466,10 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(l
5368
5466
  }
5369
5467
 
5370
5468
  struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)buft->device->context;
5371
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
5469
+
5470
+ LM_GGML_ASSERT(ctx_dev->mtl_device != nil);
5471
+
5472
+ id<MTLDevice> device = ctx_dev->mtl_device;
5372
5473
 
5373
5474
  ctx->all_data = lm_ggml_metal_host_malloc(size_aligned);
5374
5475
  ctx->all_size = size_aligned;
@@ -5391,14 +5492,12 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(l
5391
5492
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
5392
5493
  LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
5393
5494
  free(ctx);
5394
- lm_ggml_backend_metal_device_rel(ctx_dev);
5395
5495
  return NULL;
5396
5496
  }
5397
5497
 
5398
5498
  if (!lm_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5399
5499
  LM_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5400
5500
  free(ctx);
5401
- lm_ggml_backend_metal_device_rel(ctx_dev);
5402
5501
  return NULL;
5403
5502
  }
5404
5503
 
@@ -5409,17 +5508,14 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(l
5409
5508
 
5410
5509
  static size_t lm_ggml_backend_metal_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
5411
5510
  return 32;
5511
+
5412
5512
  LM_GGML_UNUSED(buft);
5413
5513
  }
5414
5514
 
5415
5515
  static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggml_backend_buffer_type_t buft) {
5416
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(buft->device->context);
5417
- const size_t max_size = device.maxBufferLength;
5418
- lm_ggml_backend_metal_device_rel(buft->device->context);
5516
+ const size_t max_size = ((struct lm_ggml_backend_metal_device_context *)buft->device->context)->max_size;
5419
5517
 
5420
5518
  return max_size;
5421
-
5422
- LM_GGML_UNUSED(buft);
5423
5519
  }
5424
5520
 
5425
5521
  static bool lm_ggml_backend_metal_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) {
@@ -5492,7 +5588,10 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
5492
5588
  }
5493
5589
 
5494
5590
  struct lm_ggml_backend_metal_device_context * ctx_dev = &g_lm_ggml_ctx_dev_main;
5495
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
5591
+
5592
+ LM_GGML_ASSERT(ctx_dev->mtl_device != nil);
5593
+
5594
+ id<MTLDevice> device = ctx_dev->mtl_device;
5496
5595
 
5497
5596
  // the buffer fits into the max buffer size allowed by the device
5498
5597
  if (size_aligned <= device.maxBufferLength) {
@@ -5548,7 +5647,6 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
5548
5647
  if (!lm_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5549
5648
  LM_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5550
5649
  free(ctx);
5551
- lm_ggml_backend_metal_device_rel(ctx_dev);
5552
5650
  return NULL;
5553
5651
  }
5554
5652
 
@@ -5564,10 +5662,8 @@ static const char * lm_ggml_backend_metal_name(lm_ggml_backend_t backend) {
5564
5662
  }
5565
5663
 
5566
5664
  static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) {
5567
- struct lm_ggml_backend_metal_context * ctx = backend->context;
5568
- struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5665
+ struct lm_ggml_backend_metal_context * ctx = backend->context;
5569
5666
 
5570
- lm_ggml_backend_metal_device_rel(ctx_dev);
5571
5667
  lm_ggml_metal_free(ctx);
5572
5668
 
5573
5669
  free(backend);
@@ -5707,6 +5803,8 @@ bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family
5707
5803
 
5708
5804
  struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5709
5805
 
5806
+ LM_GGML_ASSERT(ctx_dev->mtl_device != nil);
5807
+
5710
5808
  return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
5711
5809
  }
5712
5810
 
@@ -5726,10 +5824,7 @@ static const char * lm_ggml_backend_metal_device_get_name(lm_ggml_backend_dev_t
5726
5824
  }
5727
5825
 
5728
5826
  static const char * lm_ggml_backend_metal_device_get_description(lm_ggml_backend_dev_t dev) {
5729
- // acq/rel just to populate ctx->name in case it hasn't been done yet
5730
5827
  struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)dev->context;
5731
- lm_ggml_backend_metal_device_acq(ctx_dev);
5732
- lm_ggml_backend_metal_device_rel(ctx_dev);
5733
5828
 
5734
5829
  return ctx_dev->name;
5735
5830
  }
@@ -5737,12 +5832,10 @@ static const char * lm_ggml_backend_metal_device_get_description(lm_ggml_backend
5737
5832
  static void lm_ggml_backend_metal_device_get_memory(lm_ggml_backend_dev_t dev, size_t * free, size_t * total) {
5738
5833
  if (@available(macOS 10.12, iOS 16.0, *)) {
5739
5834
  struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)dev->context;
5740
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
5835
+ id<MTLDevice> device = ctx_dev->mtl_device;
5741
5836
 
5742
5837
  *total = device.recommendedMaxWorkingSetSize;
5743
5838
  *free = *total - device.currentAllocatedSize;
5744
-
5745
- lm_ggml_backend_metal_device_rel(ctx_dev);
5746
5839
  } else {
5747
5840
  *free = 1;
5748
5841
  *total = 1;
@@ -5820,7 +5913,10 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_
5820
5913
  }
5821
5914
 
5822
5915
  struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)dev->context;
5823
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
5916
+
5917
+ LM_GGML_ASSERT(ctx_dev->mtl_device != nil);
5918
+
5919
+ id<MTLDevice> device = ctx_dev->mtl_device;
5824
5920
 
5825
5921
  // the buffer fits into the max buffer size allowed by the device
5826
5922
  if (size_aligned <= device.maxBufferLength) {
@@ -5876,7 +5972,6 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_
5876
5972
  if (!lm_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5877
5973
  LM_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5878
5974
  free(ctx);
5879
- lm_ggml_backend_metal_device_rel(ctx_dev);
5880
5975
  return NULL;
5881
5976
  }
5882
5977
 
@@ -5890,8 +5985,9 @@ static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev,
5890
5985
  }
5891
5986
 
5892
5987
  static bool lm_ggml_backend_metal_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) {
5893
- return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name ||
5894
- buft->iface.get_name == lm_ggml_backend_metal_buffer_from_ptr_type_get_name;
5988
+ return
5989
+ buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name ||
5990
+ buft->iface.get_name == lm_ggml_backend_metal_buffer_from_ptr_type_get_name;
5895
5991
 
5896
5992
  LM_GGML_UNUSED(dev);
5897
5993
  }
@@ -5976,8 +6072,19 @@ static struct lm_ggml_backend_reg_i lm_ggml_backend_metal_reg_i = {
5976
6072
  /* .get_proc_address = */ lm_ggml_backend_metal_get_proc_address,
5977
6073
  };
5978
6074
 
6075
+ // called upon program exit
6076
+ static void lm_ggml_metal_cleanup(void) {
6077
+ lm_ggml_backend_metal_device_rel(&g_lm_ggml_ctx_dev_main);
6078
+ }
6079
+
6080
+ // TODO: make thread-safe
5979
6081
  lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) {
5980
- // TODO: make this thread-safe somehow?
6082
+ lm_ggml_backend_metal_device_acq(&g_lm_ggml_ctx_dev_main);
6083
+
6084
+ // register cleanup callback
6085
+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6086
+ atexit(lm_ggml_metal_cleanup);
6087
+
5981
6088
  {
5982
6089
  g_lm_ggml_backend_metal_reg = (struct lm_ggml_backend_reg) {
5983
6090
  /* .api_version = */ LM_GGML_BACKEND_API_VERSION,
package/cpp/ggml-quants.c CHANGED
@@ -2425,8 +2425,6 @@ void dequantize_row_iq1_m(const block_iq1_m * LM_GGML_RESTRICT x, float * LM_GGM
2425
2425
  }
2426
2426
  }
2427
2427
 
2428
- static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
2429
-
2430
2428
  void dequantize_row_iq4_nl(const block_iq4_nl * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k) {
2431
2429
  assert(k % QK4_NL == 0);
2432
2430
  const int64_t nb = k / QK4_NL;