cui-llama.rn 1.5.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (324) hide show
  1. package/LICENSE +20 -20
  2. package/README.md +345 -319
  3. package/android/build.gradle +116 -116
  4. package/android/gradle.properties +5 -5
  5. package/android/src/main/AndroidManifest.xml +4 -4
  6. package/android/src/main/CMakeLists.txt +129 -124
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +648 -645
  8. package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
  10. package/android/src/main/jni-utils.h +100 -100
  11. package/android/src/main/jni.cpp +1279 -1263
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  14. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  15. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  16. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  17. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  20. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
  21. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
  22. package/cpp/LICENSE +21 -0
  23. package/cpp/README.md +4 -4
  24. package/cpp/chat.cpp +1 -1
  25. package/cpp/common.cpp +17 -2
  26. package/cpp/common.h +7 -3
  27. package/cpp/ggml-alloc.c +4 -1
  28. package/cpp/ggml-cpp.h +1 -1
  29. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  30. package/cpp/ggml-cpu/amx/amx.h +8 -0
  31. package/cpp/ggml-cpu/amx/common.h +91 -0
  32. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  33. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  34. package/cpp/{binary-ops.h → ggml-cpu/binary-ops.h} +1 -1
  35. package/cpp/ggml-cpu/common.h +72 -0
  36. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  37. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  38. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  39. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  40. package/cpp/{ops.h → ggml-cpu/ops.h} +2 -20
  41. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  42. package/cpp/{simd-mappings.h → ggml-cpu/simd-mappings.h} +7 -3
  43. package/cpp/{unary-ops.h → ggml-cpu/unary-ops.h} +1 -1
  44. package/cpp/ggml-cpu.h +5 -0
  45. package/cpp/ggml-impl.h +16 -9
  46. package/cpp/ggml-llama-sim.metallib +0 -0
  47. package/cpp/ggml-llama.metallib +0 -0
  48. package/cpp/ggml-metal-impl.h +597 -597
  49. package/cpp/ggml-metal.m +496 -47
  50. package/cpp/ggml.c +134 -244
  51. package/cpp/ggml.h +62 -95
  52. package/cpp/json-schema-to-grammar.cpp +3 -0
  53. package/cpp/llama-arch.cpp +46 -17
  54. package/cpp/llama-arch.h +9 -0
  55. package/cpp/llama-batch.cpp +5 -1
  56. package/cpp/llama-batch.h +2 -1
  57. package/cpp/llama-chat.cpp +31 -10
  58. package/cpp/llama-chat.h +3 -2
  59. package/cpp/llama-context.cpp +104 -489
  60. package/cpp/llama-context.h +14 -30
  61. package/cpp/llama-graph.cpp +69 -62
  62. package/cpp/llama-graph.h +21 -18
  63. package/cpp/llama-hparams.h +5 -0
  64. package/cpp/llama-kv-cache.cpp +1497 -391
  65. package/cpp/llama-kv-cache.h +272 -80
  66. package/cpp/llama-memory.h +11 -1
  67. package/cpp/llama-model.cpp +502 -176
  68. package/cpp/llama-model.h +13 -3
  69. package/cpp/llama-sampling.cpp +2 -1
  70. package/cpp/llama-vocab.cpp +8 -1
  71. package/cpp/llama.h +14 -11
  72. package/cpp/rn-llama.cpp +721 -873
  73. package/cpp/rn-llama.h +134 -138
  74. package/cpp/sampling.h +107 -107
  75. package/cpp/unicode-data.cpp +7034 -7034
  76. package/cpp/unicode-data.h +20 -20
  77. package/cpp/unicode.cpp +849 -849
  78. package/cpp/unicode.h +66 -66
  79. package/ios/CMakeLists.txt +119 -108
  80. package/ios/RNLlama.h +13 -7
  81. package/ios/RNLlama.mm +423 -405
  82. package/ios/RNLlamaContext.h +57 -57
  83. package/ios/RNLlamaContext.mm +833 -835
  84. package/ios/rnllama.xcframework/Info.plist +74 -74
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +681 -0
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +143 -0
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +601 -0
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2189 -0
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +437 -0
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +89 -0
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +57 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +249 -0
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +595 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +161 -0
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +31 -0
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +419 -0
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1437 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +134 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +681 -0
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +143 -0
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +601 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2189 -0
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +437 -0
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +89 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +57 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +249 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +595 -0
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +161 -0
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +31 -0
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +419 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1437 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +134 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  184. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +681 -0
  188. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  189. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  190. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +143 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +601 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2189 -0
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +437 -0
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +89 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +57 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +249 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +595 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +161 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  218. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +31 -0
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +419 -0
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1437 -0
  225. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
  226. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  227. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  228. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +134 -0
  229. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
  230. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
  231. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  232. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
  233. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  234. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  235. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +681 -0
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +143 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +601 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2189 -0
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +437 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +89 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +57 -0
  259. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +249 -0
  260. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +595 -0
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +161 -0
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +405 -0
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +31 -0
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  271. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +419 -0
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  274. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1437 -0
  275. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  276. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  277. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  278. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +134 -0
  279. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  280. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  281. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  282. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  283. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  284. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  285. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  286. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  287. package/jest/mock.js +203 -203
  288. package/lib/commonjs/NativeRNLlama.js +1 -2
  289. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  290. package/lib/commonjs/chat.js.map +1 -1
  291. package/lib/commonjs/grammar.js +12 -31
  292. package/lib/commonjs/grammar.js.map +1 -1
  293. package/lib/commonjs/index.js +47 -47
  294. package/lib/commonjs/index.js.map +1 -1
  295. package/lib/commonjs/package.json +1 -0
  296. package/lib/module/NativeRNLlama.js +2 -0
  297. package/lib/module/NativeRNLlama.js.map +1 -1
  298. package/lib/module/chat.js +2 -0
  299. package/lib/module/chat.js.map +1 -1
  300. package/lib/module/grammar.js +14 -31
  301. package/lib/module/grammar.js.map +1 -1
  302. package/lib/module/index.js +47 -45
  303. package/lib/module/index.js.map +1 -1
  304. package/lib/module/package.json +1 -0
  305. package/lib/typescript/NativeRNLlama.d.ts +10 -4
  306. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  307. package/lib/typescript/index.d.ts.map +1 -1
  308. package/llama-rn.podspec +48 -48
  309. package/package.json +233 -233
  310. package/src/NativeRNLlama.ts +431 -426
  311. package/src/chat.ts +44 -44
  312. package/src/grammar.ts +854 -854
  313. package/src/index.ts +495 -487
  314. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  315. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  316. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  317. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  318. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  319. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  320. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  321. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  322. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  323. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  324. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
package/cpp/ggml-metal.m CHANGED
@@ -44,8 +44,8 @@ static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device;
44
44
  // note: assumes single GPU device - the default one
45
45
  // TODO: support multiple GPU devices
46
46
  static struct lm_ggml_backend_metal_device_context {
47
- id<MTLDevice> mtl_device;
48
- int mtl_device_ref_count;
47
+ id<MTLDevice> mtl_device;
48
+ int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
51
  bool has_simdgroup_reduction;
@@ -354,6 +354,7 @@ enum lm_ggml_metal_kernel_type {
354
354
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
355
355
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
356
356
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
357
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
357
358
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
358
359
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
359
360
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -362,6 +363,7 @@ enum lm_ggml_metal_kernel_type {
362
363
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
363
364
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
364
365
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
366
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
365
367
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
366
368
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
367
369
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -370,6 +372,7 @@ enum lm_ggml_metal_kernel_type {
370
372
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
371
373
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
372
374
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
375
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
373
376
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
374
377
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
375
378
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -378,6 +381,7 @@ enum lm_ggml_metal_kernel_type {
378
381
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
379
382
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
380
383
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
384
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
381
385
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
382
386
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
383
387
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -386,6 +390,7 @@ enum lm_ggml_metal_kernel_type {
386
390
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
387
391
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
388
392
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
393
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
389
394
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
390
395
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
391
396
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -394,6 +399,7 @@ enum lm_ggml_metal_kernel_type {
394
399
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
395
400
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
396
401
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
402
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
397
403
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
398
404
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
399
405
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -402,6 +408,14 @@ enum lm_ggml_metal_kernel_type {
402
408
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
403
409
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
404
410
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
411
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
412
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
413
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
414
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
415
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
416
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
417
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
418
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
405
419
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
406
420
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
407
421
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
@@ -430,6 +444,13 @@ enum lm_ggml_metal_kernel_type {
430
444
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
431
445
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
432
446
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
447
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
448
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
449
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
450
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
451
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
452
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
453
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
433
454
  LM_GGML_METAL_KERNEL_TYPE_SET_I32,
434
455
  LM_GGML_METAL_KERNEL_TYPE_SET_F32,
435
456
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -460,6 +481,7 @@ enum lm_ggml_metal_kernel_type {
460
481
  LM_GGML_METAL_KERNEL_TYPE_SQRT,
461
482
  LM_GGML_METAL_KERNEL_TYPE_SIN,
462
483
  LM_GGML_METAL_KERNEL_TYPE_COS,
484
+ LM_GGML_METAL_KERNEL_TYPE_NEG,
463
485
  LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
464
486
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
465
487
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -468,7 +490,259 @@ enum lm_ggml_metal_kernel_type {
468
490
  LM_GGML_METAL_KERNEL_TYPE_COUNT
469
491
  };
470
492
 
493
+ //
494
+ // lm_ggml_metal_heap
495
+ //
496
+
497
+ struct lm_ggml_metal_heap {
498
+ // number of times the heap was unused
499
+ int n_unused;
500
+
501
+ // total number of buffer allocations in this heap across all computes
502
+ int64_t n_alloc;
503
+
504
+ // current offset in the heap - we reset this after each node in order to reuse the memory
505
+ size_t offs;
506
+
507
+ // the currently allocated MTLBuffer objects in this heap
508
+ id<MTLHeap> obj;
509
+
510
+ NSMutableArray * bufs;
511
+ };
512
+
513
+ static struct lm_ggml_metal_heap * lm_ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
514
+ struct lm_ggml_metal_heap * heap = calloc(1, sizeof(struct lm_ggml_metal_heap));
515
+
516
+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
517
+ desc.storageMode = MTLStorageModePrivate;
518
+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
519
+ desc.type = MTLHeapTypePlacement;
520
+ desc.size = size;
521
+
522
+ heap->n_unused = 0;
523
+ heap->n_alloc = 0;
524
+
525
+ heap->obj = [device newHeapWithDescriptor:desc];
526
+ if (!heap->obj) {
527
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
528
+
529
+ free(heap);
530
+
531
+ return false;
532
+ }
533
+
534
+ [desc release];
535
+
536
+ heap->bufs = [[NSMutableArray alloc] init];
537
+
538
+ return heap;
539
+ }
540
+
541
+ static void lm_ggml_metal_heap_reset(struct lm_ggml_metal_heap * heap) {
542
+ heap->offs = 0;
543
+
544
+ // count how many graph computes the heap ended up being unused
545
+ if ([heap->bufs count] > 0) {
546
+ heap->n_unused = 0;
547
+ } else {
548
+ heap->n_unused++;
549
+ }
550
+
551
+ for (id<MTLBuffer> buf in heap->bufs) {
552
+ [buf release];
553
+ }
554
+ [heap->bufs removeAllObjects];
555
+
556
+ // tell the OS that it can reuse this memory if needed
557
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
558
+ [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
559
+ }
560
+
561
+ static void lm_ggml_metal_heap_free(struct lm_ggml_metal_heap * heap) {
562
+ if (heap == nil) {
563
+ return;
564
+ }
565
+
566
+ lm_ggml_metal_heap_reset(heap);
567
+
568
+ [heap->obj release];
569
+ [heap->bufs release];
570
+
571
+ free(heap);
572
+ }
573
+
574
+ @interface lm_ggml_metal_heap_ptr : NSObject
575
+
576
+ @property (nonatomic, assign) struct lm_ggml_metal_heap * data;
577
+
578
+ @end
579
+
580
+ @implementation lm_ggml_metal_heap_ptr
581
+ @end
582
+
583
+ //
584
+ // lm_ggml_metal_mem_pool
585
+ //
586
+
587
+ struct lm_ggml_metal_mem_pool {
588
+ id<MTLDevice> device;
589
+
590
+ int n_heaps; // total number of heaps ever created (including those that were removed)
591
+
592
+ NSMutableArray * heaps;
593
+ NSMutableArray * heaps_to_remove;
594
+ };
595
+
596
+ static struct lm_ggml_metal_mem_pool * lm_ggml_metal_mem_pool_init(void) {
597
+ struct lm_ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct lm_ggml_metal_mem_pool));
598
+
599
+ mem_pool->n_heaps = 0;
600
+
601
+ mem_pool->heaps = [[NSMutableArray alloc] init];
602
+ mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
603
+
604
+ return mem_pool;
605
+ }
606
+
607
+ static void lm_ggml_metal_mem_pool_free(struct lm_ggml_metal_mem_pool * mem_pool) {
608
+ LM_GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
609
+
610
+ size_t size_all = 0;
611
+ size_t size_cur = 0;
612
+
613
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
614
+ LM_GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
615
+ LM_GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
616
+ LM_GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
617
+ LM_GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
618
+ LM_GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
619
+
620
+ if ([ptr.data->bufs count] > 0) {
621
+ size_cur += [ptr.data->obj size];
622
+ }
623
+ size_all += [ptr.data->obj size];
624
+
625
+ lm_ggml_metal_heap_free(ptr.data);
626
+ [ptr release];
627
+ }
628
+ [mem_pool->heaps release];
629
+ [mem_pool->heaps_to_remove release];
630
+
631
+ if (size_all > 0) {
632
+ LM_GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
633
+ LM_GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
634
+ }
635
+
636
+ free(mem_pool);
637
+ }
638
+
639
+ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_pool) {
640
+ for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
641
+ lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
642
+
643
+ struct lm_ggml_metal_heap * heap = ptr.data;
644
+ lm_ggml_metal_heap_reset(heap);
645
+
646
+ // if the heap hasn't been used for a while, remove it
647
+ if (heap->n_unused >= 128) {
648
+ [mem_pool->heaps_to_remove addObject:@(i)];
649
+ }
650
+ }
651
+
652
+ if (mem_pool->heaps_to_remove.count > 0) {
653
+ for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
654
+ NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
655
+ lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
656
+
657
+ struct lm_ggml_metal_heap * heap = ptr.data;
658
+ lm_ggml_metal_heap_free(heap);
659
+
660
+ [mem_pool->heaps removeObjectAtIndex:index];
661
+ [ptr release];
662
+ }
663
+
664
+ [mem_pool->heaps_to_remove removeAllObjects];
665
+ }
666
+ }
667
+
668
+ static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_pool) {
669
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
670
+ ptr.data->offs = 0;
671
+ }
672
+ }
673
+
674
+ static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
675
+ const size_t alignment = 32;
676
+
677
+ const size_t size_aligned = LM_GGML_PAD(size, alignment);
678
+
679
+ // try one of the existing heaps
680
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
681
+ struct lm_ggml_metal_heap * heap = ptr.data;
682
+ if (heap->offs + size_aligned <= [heap->obj size]) {
683
+ // if this is the first buffer in the heap for the current command buffer, tell the OS that
684
+ // it cannot free the memory used by the heap
685
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
686
+ if ([heap->bufs count] == 0) {
687
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
688
+ }
689
+
690
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
691
+ if (buf == nil) {
692
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
693
+ return nil;
694
+ }
695
+
696
+ heap->n_alloc++;
697
+ heap->offs += size_aligned;
698
+
699
+ [heap->bufs addObject:buf];
700
+
701
+ return buf;
702
+ }
703
+ }
704
+
705
+ // create a new heap that can fit this buffer
706
+ lm_ggml_metal_heap_ptr * heap_ptr = [lm_ggml_metal_heap_ptr new];
707
+
708
+ struct lm_ggml_metal_heap * heap = lm_ggml_metal_heap_init(mem_pool->device, size_aligned);
709
+ if (heap == NULL) {
710
+ LM_GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
711
+ return NULL;
712
+ }
713
+
714
+ //LM_GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
715
+
716
+ heap_ptr.data = heap;
717
+ lm_ggml_metal_heap_reset(heap);
718
+
719
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
720
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
721
+ if (buf == nil) {
722
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
723
+ return NULL;
724
+ }
725
+
726
+ heap->n_alloc++;
727
+ heap->offs += size_aligned;
728
+
729
+ [heap->bufs addObject:buf];
730
+
731
+ [mem_pool->heaps addObject:heap_ptr];
732
+ mem_pool->n_heaps++;
733
+
734
+ return buf;
735
+ }
736
+
737
+ struct lm_ggml_metal_command_buffer {
738
+ id<MTLCommandBuffer> obj;
739
+
740
+ // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
741
+ struct lm_ggml_metal_mem_pool * mem_pool;
742
+ };
743
+
471
744
  struct lm_ggml_backend_metal_context {
745
+ id<MTLDevice> device;
472
746
  id<MTLCommandQueue> queue;
473
747
 
474
748
  dispatch_queue_t d_queue;
@@ -493,7 +767,7 @@ struct lm_ggml_backend_metal_context {
493
767
  void (^encode_async)(size_t ith);
494
768
 
495
769
  // n_cb command buffers + 1 used by the main thread
496
- id<MTLCommandBuffer> command_buffers[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
770
+ struct lm_ggml_metal_command_buffer cmd_bufs[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
497
771
 
498
772
  // abort lm_ggml_metal_graph_compute if callback returns true
499
773
  lm_ggml_abort_callback abort_callback;
@@ -560,7 +834,11 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
560
834
  NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
561
835
  #endif
562
836
 
837
+ #if TARGET_OS_SIMULATOR
838
+ NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
839
+ #else
563
840
  NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
841
+ #endif
564
842
  if (path_lib == nil) {
565
843
  // Try to find the resource in the directory where the current binary located.
566
844
  NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
@@ -683,9 +961,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
683
961
  struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
684
962
 
685
963
  id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
964
+
686
965
  LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
687
966
 
688
- ctx->queue = [device newCommandQueue];
967
+ ctx->device = device;
968
+ ctx->queue = [device newCommandQueue];
689
969
  if (ctx->queue == nil) {
690
970
  LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
691
971
  return NULL;
@@ -746,7 +1026,10 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
746
1026
  ctx->gf = nil;
747
1027
  ctx->encode_async = nil;
748
1028
  for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
749
- ctx->command_buffers[i] = nil;
1029
+ ctx->cmd_bufs[i].obj = nil;
1030
+
1031
+ ctx->cmd_bufs[i].mem_pool = lm_ggml_metal_mem_pool_init();
1032
+ ctx->cmd_bufs[i].mem_pool->device = device;
750
1033
  }
751
1034
 
752
1035
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1011,6 +1294,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1011
1294
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
1012
1295
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
1013
1296
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
1297
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
1014
1298
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
1015
1299
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
1016
1300
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@@ -1019,6 +1303,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1019
1303
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
1020
1304
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
1021
1305
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
1306
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
1022
1307
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
1023
1308
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
1024
1309
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -1027,6 +1312,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1027
1312
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
1028
1313
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
1029
1314
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
1315
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
1030
1316
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
1031
1317
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
1032
1318
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@@ -1035,6 +1321,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1035
1321
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
1036
1322
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
1037
1323
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
1324
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
1038
1325
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
1039
1326
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
1040
1327
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@@ -1043,6 +1330,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1043
1330
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
1044
1331
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
1045
1332
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1333
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
1046
1334
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1047
1335
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1048
1336
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@@ -1051,6 +1339,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1051
1339
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
1052
1340
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
1053
1341
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1342
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
1054
1343
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1055
1344
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1056
1345
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@@ -1059,6 +1348,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1059
1348
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
1060
1349
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1061
1350
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1351
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1352
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
1353
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
1354
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
1355
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
1356
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
1357
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
1358
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
1062
1359
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1063
1360
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
1064
1361
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
@@ -1087,6 +1384,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1087
1384
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
1088
1385
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
1089
1386
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1387
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
1388
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
1389
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
1390
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
1391
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
1392
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
1393
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1090
1394
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1091
1395
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
1092
1396
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@@ -1117,6 +1421,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1117
1421
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
1118
1422
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1119
1423
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1424
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1120
1425
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1121
1426
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1122
1427
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1137,6 +1442,12 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) {
1137
1442
 
1138
1443
  [ctx->queue release];
1139
1444
 
1445
+ for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1446
+ // ctx->cmd_bufs[i].obj is auto released
1447
+
1448
+ lm_ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
1449
+ }
1450
+
1140
1451
  dispatch_release(ctx->d_queue);
1141
1452
 
1142
1453
  free(ctx);
@@ -1278,6 +1589,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1278
1589
  case LM_GGML_UNARY_OP_GELU_QUICK:
1279
1590
  case LM_GGML_UNARY_OP_SILU:
1280
1591
  case LM_GGML_UNARY_OP_ELU:
1592
+ case LM_GGML_UNARY_OP_NEG:
1281
1593
  return lm_ggml_is_contiguous(op->src[0]) && op->src[0]->type == LM_GGML_TYPE_F32;
1282
1594
  default:
1283
1595
  return false;
@@ -1334,8 +1646,9 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1334
1646
  return op->src[0]->type == LM_GGML_TYPE_F16;
1335
1647
  case LM_GGML_OP_POOL_1D:
1336
1648
  return false;
1337
- case LM_GGML_OP_POOL_2D:
1338
1649
  case LM_GGML_OP_UPSCALE:
1650
+ return op->src[0]->type == LM_GGML_TYPE_F32 && op->op_params[0] == LM_GGML_SCALE_MODE_NEAREST;
1651
+ case LM_GGML_OP_POOL_2D:
1339
1652
  case LM_GGML_OP_PAD:
1340
1653
  case LM_GGML_OP_PAD_REFLECT_1D:
1341
1654
  case LM_GGML_OP_TIMESTEP_EMBEDDING:
@@ -1350,6 +1663,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1350
1663
  // TODO: not sure if it is worth adding kernels for this size
1351
1664
  return false;
1352
1665
  }
1666
+ if (op->src[0]->ne[0] == 576) {
1667
+ // DeepSeek sizes
1668
+ // TODO: disabled for now, until optmized
1669
+ return false;
1670
+ }
1353
1671
  if (op->src[1]->type != op->src[2]->type) {
1354
1672
  return false;
1355
1673
  }
@@ -1435,10 +1753,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1435
1753
  }
1436
1754
  }
1437
1755
 
1438
- static void lm_ggml_metal_encode_node(
1756
+ static bool lm_ggml_metal_encode_node(
1439
1757
  lm_ggml_backend_t backend,
1440
1758
  int idx,
1441
- id<MTLComputeCommandEncoder> encoder) {
1759
+ id<MTLComputeCommandEncoder> encoder,
1760
+ struct lm_ggml_metal_mem_pool * mem_pool) {
1442
1761
  struct lm_ggml_backend_metal_context * ctx = backend->context;
1443
1762
  struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1444
1763
 
@@ -1454,7 +1773,7 @@ static void lm_ggml_metal_encode_node(
1454
1773
  struct lm_ggml_tensor * dst = node;
1455
1774
 
1456
1775
  if (lm_ggml_is_empty(dst)) {
1457
- return;
1776
+ return true;
1458
1777
  }
1459
1778
 
1460
1779
  switch (dst->op) {
@@ -1465,7 +1784,7 @@ static void lm_ggml_metal_encode_node(
1465
1784
  case LM_GGML_OP_PERMUTE:
1466
1785
  {
1467
1786
  // noop -> next node
1468
- } return;
1787
+ } return true;
1469
1788
  default:
1470
1789
  {
1471
1790
  } break;
@@ -1476,6 +1795,8 @@ static void lm_ggml_metal_encode_node(
1476
1795
  LM_GGML_ABORT("unsupported op");
1477
1796
  }
1478
1797
 
1798
+ lm_ggml_metal_mem_pool_clear(mem_pool);
1799
+
1479
1800
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1480
1801
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
1481
1802
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -1962,6 +2283,18 @@ static void lm_ggml_metal_encode_node(
1962
2283
 
1963
2284
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1964
2285
  } break;
2286
+ case LM_GGML_UNARY_OP_NEG:
2287
+ {
2288
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NEG].pipeline;
2289
+
2290
+ [encoder setComputePipelineState:pipeline];
2291
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2292
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2293
+
2294
+ const int64_t n = lm_ggml_nelements(dst);
2295
+
2296
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2297
+ } break;
1965
2298
  default:
1966
2299
  {
1967
2300
  LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
@@ -2110,26 +2443,76 @@ static void lm_ggml_metal_encode_node(
2110
2443
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2111
2444
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2112
2445
 
2113
- lm_ggml_metal_kargs_soft_max args = {
2446
+ // use this branch to test the lm_ggml_metal_mem_pool functionality
2447
+ #if 0
2448
+ // cpy to tmp buffer in MTLHeap
2449
+
2450
+ id<MTLBuffer> h_src0 = h_src0 = lm_ggml_metal_mem_pool_alloc(mem_pool, lm_ggml_nbytes(src0));
2451
+ if (!h_src0) {
2452
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, lm_ggml_nbytes(src0));
2453
+ return false;
2454
+ }
2455
+
2456
+ offs_src0 = 0;
2457
+
2458
+ lm_ggml_metal_kargs_cpy args_cpy = {
2114
2459
  /*.ne00 =*/ ne00,
2115
2460
  /*.ne01 =*/ ne01,
2116
2461
  /*.ne02 =*/ ne02,
2117
- /*.scale =*/ scale,
2118
- /*.max_bias =*/ max_bias,
2119
- /*.m0 =*/ m0,
2120
- /*.m1 =*/ m1,
2462
+ /*.ne03 =*/ ne03,
2463
+ /*.nb00 =*/ nb00,
2464
+ /*.nb01 =*/ nb01,
2465
+ /*.nb02 =*/ nb02,
2466
+ /*.nb03 =*/ nb03,
2467
+ /*.ne0 =*/ ne00,
2468
+ /*.ne1 =*/ ne01,
2469
+ /*.ne2 =*/ ne02,
2470
+ /*.ne3 =*/ ne03,
2471
+ /*.nb0 =*/ nb00,
2472
+ /*.nb1 =*/ nb01,
2473
+ /*.nb2 =*/ nb02,
2474
+ /*.nb3 =*/ nb03,
2475
+ };
2476
+
2477
+ if (src0->type == LM_GGML_TYPE_F16) {
2478
+ [encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2479
+ } else {
2480
+ [encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2481
+ }
2482
+ [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2483
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2484
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
2485
+
2486
+ LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
2487
+ int nth_cpy = MIN(1024, ne00 / lm_ggml_blck_size(src0->type));
2488
+
2489
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
2490
+
2491
+ #else
2492
+ id<MTLBuffer> h_src0 = id_src0;
2493
+ #endif
2494
+ // softmax
2495
+
2496
+ lm_ggml_metal_kargs_soft_max args = {
2497
+ /*.ne00 =*/ ne00,
2498
+ /*.ne01 =*/ ne01,
2499
+ /*.ne02 =*/ ne02,
2500
+ /*.scale =*/ scale,
2501
+ /*.max_bias =*/ max_bias,
2502
+ /*.m0 =*/ m0,
2503
+ /*.m1 =*/ m1,
2121
2504
  /*.n_head_log2 =*/ n_head_log2,
2122
2505
  };
2123
2506
 
2124
2507
  [encoder setComputePipelineState:pipeline];
2125
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2508
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
2126
2509
  if (id_src1) {
2127
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2510
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2128
2511
  } else {
2129
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2512
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
2130
2513
  }
2131
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2132
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
2514
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2515
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2133
2516
 
2134
2517
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2135
2518
 
@@ -3842,12 +4225,14 @@ static void lm_ggml_metal_encode_node(
3842
4225
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3843
4226
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
3844
4227
  // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
3845
- if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
4228
+ if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
3846
4229
  switch (src1->type) {
3847
4230
  case LM_GGML_TYPE_F16:
3848
4231
  {
3849
4232
  if (ne00 == 192 && ne20 == 128) {
3850
4233
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
4234
+ } else if (ne00 == 576 && ne20 == 512) {
4235
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
3851
4236
  } else {
3852
4237
  switch (ne00) {
3853
4238
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
@@ -3870,6 +4255,8 @@ static void lm_ggml_metal_encode_node(
3870
4255
  {
3871
4256
  if (ne00 == 192 && ne20 == 128) {
3872
4257
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
4258
+ } else if (ne00 == 576 && ne20 == 512) {
4259
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
3873
4260
  } else {
3874
4261
  switch (ne00) {
3875
4262
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
@@ -3892,6 +4279,8 @@ static void lm_ggml_metal_encode_node(
3892
4279
  {
3893
4280
  if (ne00 == 192 && ne20 == 128) {
3894
4281
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
4282
+ } else if (ne00 == 576 && ne20 == 512) {
4283
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
3895
4284
  } else {
3896
4285
  switch (ne00) {
3897
4286
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
@@ -3914,6 +4303,8 @@ static void lm_ggml_metal_encode_node(
3914
4303
  {
3915
4304
  if (ne00 == 192 && ne20 == 128) {
3916
4305
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
4306
+ } else if (ne00 == 576 && ne20 == 512) {
4307
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
3917
4308
  } else {
3918
4309
  switch (ne00) {
3919
4310
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
@@ -3936,6 +4327,8 @@ static void lm_ggml_metal_encode_node(
3936
4327
  {
3937
4328
  if (ne00 == 192 && ne20 == 128) {
3938
4329
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
4330
+ } else if (ne00 == 576 && ne20 == 512) {
4331
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
3939
4332
  } else {
3940
4333
  switch (ne00) {
3941
4334
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
@@ -3958,6 +4351,8 @@ static void lm_ggml_metal_encode_node(
3958
4351
  {
3959
4352
  if (ne00 == 192 && ne20 == 128) {
3960
4353
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
4354
+ } else if (ne00 == 576 && ne20 == 512) {
4355
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
3961
4356
  } else {
3962
4357
  switch (ne00) {
3963
4358
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
@@ -3980,6 +4375,8 @@ static void lm_ggml_metal_encode_node(
3980
4375
  {
3981
4376
  if (ne00 == 192 && ne20 == 128) {
3982
4377
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
4378
+ } else if (ne00 == 576 && ne20 == 512) {
4379
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
3983
4380
  } else {
3984
4381
  switch (ne00) {
3985
4382
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
@@ -4009,6 +4406,24 @@ static void lm_ggml_metal_encode_node(
4009
4406
  use_vec_kernel = true;
4010
4407
 
4011
4408
  switch (ne00) {
4409
+ case 96:
4410
+ {
4411
+ switch (src1->type) {
4412
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
4413
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
4414
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
4415
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
4416
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
4417
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
4418
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
4419
+ default:
4420
+ {
4421
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4422
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4423
+ LM_GGML_ABORT("add template specialization for this type");
4424
+ }
4425
+ }
4426
+ } break;
4012
4427
  case 128:
4013
4428
  {
4014
4429
  switch (src1->type) {
@@ -4081,12 +4496,36 @@ static void lm_ggml_metal_encode_node(
4081
4496
  }
4082
4497
  }
4083
4498
  } break;
4499
+ case 576:
4500
+ {
4501
+ if (ne20 == 512) {
4502
+ switch (src1->type) {
4503
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
4504
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
4505
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
4506
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
4507
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
4508
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
4509
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
4510
+ default:
4511
+ {
4512
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4513
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4514
+ LM_GGML_ABORT("add template specialization for this type");
4515
+ }
4516
+ }
4517
+ } else {
4518
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
4519
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4520
+ LM_GGML_ABORT("add template specialization for this size");
4521
+ }
4522
+ } break;
4084
4523
  default:
4085
- {
4086
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4087
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
4088
- LM_GGML_ABORT("add template specialization for this size");
4089
- }
4524
+ {
4525
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4526
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4527
+ LM_GGML_ABORT("add template specialization for this size");
4528
+ }
4090
4529
  }
4091
4530
  }
4092
4531
 
@@ -4482,6 +4921,8 @@ static void lm_ggml_metal_encode_node(
4482
4921
  LM_GGML_ABORT("fatal error");
4483
4922
  }
4484
4923
  }
4924
+
4925
+ return true;
4485
4926
  }
4486
4927
 
4487
4928
  static enum lm_ggml_status lm_ggml_metal_graph_compute(
@@ -4535,25 +4976,25 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4535
4976
  }
4536
4977
 
4537
4978
  // the main thread commits the first few commands immediately
4538
- // command_buffer[n_cb]
4979
+ // cmd_buf[n_cb]
4539
4980
  {
4540
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4541
- ctx->command_buffers[n_cb] = command_buffer;
4981
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4982
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
4542
4983
 
4543
- [command_buffer enqueue];
4984
+ [cmd_buf enqueue];
4544
4985
  ctx->encode_async(n_cb);
4545
4986
  }
4546
4987
 
4547
4988
  // prepare the rest of the command buffers asynchronously
4548
- // command_buffer[0.. n_cb)
4989
+ // cmd_buf[0.. n_cb)
4549
4990
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4550
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4551
- ctx->command_buffers[cb_idx] = command_buffer;
4991
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4992
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
4552
4993
 
4553
4994
  // always enqueue the first two command buffers
4554
4995
  // enqueue all of the command buffers if we don't need to abort
4555
4996
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
4556
- [command_buffer enqueue];
4997
+ [cmd_buf enqueue];
4557
4998
  }
4558
4999
  }
4559
5000
 
@@ -4562,14 +5003,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4562
5003
  // wait for completion and check status of each command buffer
4563
5004
  // needed to detect if the device ran out-of-memory for example (#1881)
4564
5005
  {
4565
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4566
- [command_buffer waitUntilCompleted];
5006
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5007
+ [cmd_buf waitUntilCompleted];
4567
5008
 
4568
- MTLCommandBufferStatus status = [command_buffer status];
5009
+ MTLCommandBufferStatus status = [cmd_buf status];
4569
5010
  if (status != MTLCommandBufferStatusCompleted) {
4570
5011
  LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
4571
5012
  if (status == MTLCommandBufferStatusError) {
4572
- LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5013
+ LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4573
5014
  }
4574
5015
 
4575
5016
  return LM_GGML_STATUS_FAILED;
@@ -4577,20 +5018,20 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4577
5018
  }
4578
5019
 
4579
5020
  for (int i = 0; i < n_cb; ++i) {
4580
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4581
- [command_buffer waitUntilCompleted];
5021
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5022
+ [cmd_buf waitUntilCompleted];
4582
5023
 
4583
- MTLCommandBufferStatus status = [command_buffer status];
5024
+ MTLCommandBufferStatus status = [cmd_buf status];
4584
5025
  if (status != MTLCommandBufferStatusCompleted) {
4585
5026
  LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
4586
5027
  if (status == MTLCommandBufferStatusError) {
4587
- LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5028
+ LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4588
5029
  }
4589
5030
 
4590
5031
  return LM_GGML_STATUS_FAILED;
4591
5032
  }
4592
5033
 
4593
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
5034
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
4594
5035
  if (!next_buffer) {
4595
5036
  continue;
4596
5037
  }
@@ -4973,8 +5414,9 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
4973
5414
 
4974
5415
  const int n_nodes_per_cb = ctx->n_nodes_per_cb;
4975
5416
 
4976
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
4977
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
5417
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5418
+
5419
+ id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
4978
5420
 
4979
5421
  int node_start = 0;
4980
5422
  int node_end = n_nodes_0;
@@ -4986,22 +5428,29 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
4986
5428
 
4987
5429
  const bool should_capture = ctx->capture_next_compute;
4988
5430
 
5431
+ struct lm_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5432
+ lm_ggml_metal_mem_pool_reset(mem_pool);
5433
+
4989
5434
  for (int idx = node_start; idx < node_end; ++idx) {
4990
5435
  if (should_capture) {
4991
5436
  [encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
4992
5437
  }
4993
5438
 
4994
- lm_ggml_metal_encode_node(backend, idx, encoder);
5439
+ const bool res = lm_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
4995
5440
 
4996
5441
  if (should_capture) {
4997
5442
  [encoder popDebugGroup];
4998
5443
  }
5444
+
5445
+ if (!res) {
5446
+ break;
5447
+ }
4999
5448
  }
5000
5449
 
5001
5450
  [encoder endEncoding];
5002
5451
 
5003
5452
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
5004
- [command_buffer commit];
5453
+ [cmd_buf commit];
5005
5454
  }
5006
5455
  });
5007
5456
  }