cui-llama.rn 1.6.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (285) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +22 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
  4. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  5. package/android/src/main/jni.cpp +173 -18
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  16. package/cpp/LICENSE +21 -0
  17. package/cpp/chat.cpp +129 -107
  18. package/cpp/chat.h +2 -0
  19. package/cpp/common.cpp +58 -78
  20. package/cpp/common.h +29 -21
  21. package/cpp/ggml-alloc.c +4 -1
  22. package/cpp/ggml-backend.cpp +9 -5
  23. package/cpp/ggml-backend.h +4 -4
  24. package/cpp/ggml-cpp.h +1 -1
  25. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  26. package/cpp/ggml-cpu/amx/amx.h +8 -0
  27. package/cpp/ggml-cpu/amx/common.h +91 -0
  28. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  29. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  31. package/cpp/ggml-cpu/common.h +72 -0
  32. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
  33. package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
  34. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
  35. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
  36. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
  37. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  38. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  39. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  40. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  41. package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
  42. package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
  43. package/cpp/ggml-cpu.h +5 -0
  44. package/cpp/ggml-impl.h +16 -9
  45. package/cpp/ggml-llama-sim.metallib +0 -0
  46. package/cpp/ggml-llama.metallib +0 -0
  47. package/cpp/ggml-metal-impl.h +36 -11
  48. package/cpp/ggml-metal.m +810 -176
  49. package/cpp/ggml-opt.cpp +373 -190
  50. package/cpp/ggml-opt.h +49 -28
  51. package/cpp/ggml-quants.c +0 -6
  52. package/cpp/ggml.c +227 -282
  53. package/cpp/ggml.h +82 -101
  54. package/cpp/gguf.cpp +33 -33
  55. package/cpp/json-schema-to-grammar.cpp +3 -0
  56. package/cpp/llama-adapter.cpp +6 -0
  57. package/cpp/llama-arch.cpp +49 -17
  58. package/cpp/llama-arch.h +9 -0
  59. package/cpp/llama-batch.cpp +8 -2
  60. package/cpp/llama-batch.h +2 -1
  61. package/cpp/llama-chat.cpp +39 -16
  62. package/cpp/llama-chat.h +4 -2
  63. package/cpp/llama-context.cpp +440 -611
  64. package/cpp/llama-context.h +44 -33
  65. package/cpp/llama-cparams.h +1 -0
  66. package/cpp/llama-graph.cpp +214 -291
  67. package/cpp/llama-graph.h +69 -21
  68. package/cpp/llama-hparams.cpp +17 -1
  69. package/cpp/llama-hparams.h +39 -5
  70. package/cpp/llama-kv-cache.cpp +2067 -620
  71. package/cpp/llama-kv-cache.h +410 -108
  72. package/cpp/llama-memory.h +12 -1
  73. package/cpp/llama-model-loader.cpp +24 -15
  74. package/cpp/llama-model-saver.cpp +281 -0
  75. package/cpp/llama-model-saver.h +37 -0
  76. package/cpp/llama-model.cpp +1089 -359
  77. package/cpp/llama-model.h +19 -3
  78. package/cpp/llama-sampling.cpp +20 -7
  79. package/cpp/llama-vocab.cpp +54 -9
  80. package/cpp/llama-vocab.h +6 -0
  81. package/cpp/llama.cpp +14 -0
  82. package/cpp/llama.h +86 -142
  83. package/cpp/minja/chat-template.hpp +9 -5
  84. package/cpp/minja/minja.hpp +69 -36
  85. package/cpp/rn-llama.cpp +602 -190
  86. package/cpp/rn-llama.h +34 -8
  87. package/cpp/sampling.cpp +57 -50
  88. package/cpp/tools/mtmd/clip-impl.h +462 -0
  89. package/cpp/tools/mtmd/clip.cpp +4024 -0
  90. package/cpp/tools/mtmd/clip.h +101 -0
  91. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  92. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  93. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  94. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  95. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  96. package/cpp/tools/mtmd/mtmd.h +362 -0
  97. package/cpp/tools/mtmd/stb_image.h +7988 -0
  98. package/ios/CMakeLists.txt +20 -10
  99. package/ios/RNLlama.h +6 -0
  100. package/ios/RNLlama.mm +82 -3
  101. package/ios/RNLlamaContext.h +5 -1
  102. package/ios/RNLlamaContext.mm +131 -38
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  131. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  132. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  133. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  134. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  160. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  161. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
  162. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  163. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  164. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  175. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  176. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  177. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  178. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  179. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  180. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  181. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
  182. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  183. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  184. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  185. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  188. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  189. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  190. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  191. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  192. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  193. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  194. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  195. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  196. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  197. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  198. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  199. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  200. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  201. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  202. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  203. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  204. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  205. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  206. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  207. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  208. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  209. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  210. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  211. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  212. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  213. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  214. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  215. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  216. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  217. package/jest/mock.js +33 -7
  218. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  219. package/lib/commonjs/index.js +153 -21
  220. package/lib/commonjs/index.js.map +1 -1
  221. package/lib/module/NativeRNLlama.js.map +1 -1
  222. package/lib/module/index.js +152 -20
  223. package/lib/module/index.js.map +1 -1
  224. package/lib/typescript/NativeRNLlama.d.ts +54 -4
  225. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  226. package/lib/typescript/index.d.ts +72 -6
  227. package/lib/typescript/index.d.ts.map +1 -1
  228. package/package.json +1 -1
  229. package/src/NativeRNLlama.ts +72 -4
  230. package/src/index.ts +212 -38
  231. package/cpp/binary-ops.h +0 -16
  232. package/cpp/ops.h +0 -128
  233. package/cpp/simd-mappings.h +0 -888
  234. package/cpp/unary-ops.h +0 -28
  235. package/cpp/vec.h +0 -802
  236. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  237. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  238. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  239. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  240. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  241. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  242. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  243. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  244. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  245. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  246. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  247. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  248. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  249. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  250. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  251. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  271. package/lib/commonjs/chat.js +0 -37
  272. package/lib/commonjs/chat.js.map +0 -1
  273. package/lib/module/chat.js +0 -33
  274. package/lib/module/chat.js.map +0 -1
  275. package/lib/typescript/chat.d.ts +0 -10
  276. package/lib/typescript/chat.d.ts.map +0 -1
  277. package/src/chat.ts +0 -44
  278. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  279. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  280. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  281. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  282. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  283. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  284. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  285. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
package/cpp/ggml-metal.m CHANGED
@@ -44,8 +44,8 @@ static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device;
44
44
  // note: assumes single GPU device - the default one
45
45
  // TODO: support multiple GPU devices
46
46
  static struct lm_ggml_backend_metal_device_context {
47
- id<MTLDevice> mtl_device;
48
- int mtl_device_ref_count;
47
+ id<MTLDevice> mtl_device;
48
+ int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
51
  bool has_simdgroup_reduction;
@@ -149,6 +149,8 @@ enum lm_ggml_metal_kernel_type {
149
149
  LM_GGML_METAL_KERNEL_TYPE_SIGMOID,
150
150
  LM_GGML_METAL_KERNEL_TYPE_GELU,
151
151
  LM_GGML_METAL_KERNEL_TYPE_GELU_4,
152
+ LM_GGML_METAL_KERNEL_TYPE_GELU_ERF,
153
+ LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
152
154
  LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
153
155
  LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
154
156
  LM_GGML_METAL_KERNEL_TYPE_SILU,
@@ -306,30 +308,36 @@ enum lm_ggml_metal_kernel_type {
306
308
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
307
309
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
308
310
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
309
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
310
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
311
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
312
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
313
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
314
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
315
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
316
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
317
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
318
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
319
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
320
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
321
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
322
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
323
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
324
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
325
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
326
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
327
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
328
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
329
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
330
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
311
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
312
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
313
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
314
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
315
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
316
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
317
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
318
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
319
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
320
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
321
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
322
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
323
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
324
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
325
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
326
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
327
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
328
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
329
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
330
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
331
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
332
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
333
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
334
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
331
335
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
332
336
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
337
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
338
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
339
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
340
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
333
341
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
334
342
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
335
343
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -354,6 +362,7 @@ enum lm_ggml_metal_kernel_type {
354
362
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
355
363
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
356
364
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
365
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
357
366
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
358
367
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
359
368
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -362,6 +371,7 @@ enum lm_ggml_metal_kernel_type {
362
371
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
363
372
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
364
373
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
374
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
365
375
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
366
376
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
367
377
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -370,6 +380,7 @@ enum lm_ggml_metal_kernel_type {
370
380
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
371
381
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
372
382
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
383
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
373
384
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
374
385
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
375
386
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -378,6 +389,7 @@ enum lm_ggml_metal_kernel_type {
378
389
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
379
390
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
380
391
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
392
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
381
393
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
382
394
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
383
395
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -386,6 +398,7 @@ enum lm_ggml_metal_kernel_type {
386
398
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
387
399
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
388
400
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
401
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
389
402
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
390
403
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
391
404
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -394,6 +407,7 @@ enum lm_ggml_metal_kernel_type {
394
407
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
395
408
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
396
409
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
410
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
397
411
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
398
412
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
399
413
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -402,6 +416,21 @@ enum lm_ggml_metal_kernel_type {
402
416
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
403
417
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
404
418
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
419
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
420
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
421
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
422
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
423
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
424
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
425
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
426
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
427
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
428
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
429
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
430
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
431
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
432
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
433
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
405
434
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
406
435
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
407
436
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
@@ -430,6 +459,13 @@ enum lm_ggml_metal_kernel_type {
430
459
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
431
460
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
432
461
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
462
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
463
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
464
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
465
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
466
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
467
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
468
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
433
469
  LM_GGML_METAL_KERNEL_TYPE_SET_I32,
434
470
  LM_GGML_METAL_KERNEL_TYPE_SET_F32,
435
471
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -460,6 +496,7 @@ enum lm_ggml_metal_kernel_type {
460
496
  LM_GGML_METAL_KERNEL_TYPE_SQRT,
461
497
  LM_GGML_METAL_KERNEL_TYPE_SIN,
462
498
  LM_GGML_METAL_KERNEL_TYPE_COS,
499
+ LM_GGML_METAL_KERNEL_TYPE_NEG,
463
500
  LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
464
501
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
465
502
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -468,7 +505,264 @@ enum lm_ggml_metal_kernel_type {
468
505
  LM_GGML_METAL_KERNEL_TYPE_COUNT
469
506
  };
470
507
 
508
+ //
509
+ // lm_ggml_metal_heap
510
+ //
511
+
512
+ struct lm_ggml_metal_heap {
513
+ // number of times the heap was unused
514
+ int n_unused;
515
+
516
+ // total number of buffer allocations in this heap across all computes
517
+ int64_t n_alloc;
518
+
519
+ // current offset in the heap - we reset this after each node in order to reuse the memory
520
+ size_t offs;
521
+
522
+ // the currently allocated MTLBuffer objects in this heap
523
+ id<MTLHeap> obj;
524
+
525
+ NSMutableArray * bufs;
526
+ };
527
+
528
+ static struct lm_ggml_metal_heap * lm_ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
529
+ struct lm_ggml_metal_heap * heap = calloc(1, sizeof(struct lm_ggml_metal_heap));
530
+
531
+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
532
+ desc.storageMode = MTLStorageModePrivate;
533
+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
534
+ desc.type = MTLHeapTypePlacement;
535
+ desc.size = size;
536
+
537
+ heap->n_unused = 0;
538
+ heap->n_alloc = 0;
539
+
540
+ heap->obj = [device newHeapWithDescriptor:desc];
541
+ if (!heap->obj) {
542
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
543
+
544
+ free(heap);
545
+
546
+ return false;
547
+ }
548
+
549
+ [desc release];
550
+
551
+ heap->bufs = [[NSMutableArray alloc] init];
552
+
553
+ return heap;
554
+ }
555
+
556
+ static void lm_ggml_metal_heap_reset(struct lm_ggml_metal_heap * heap) {
557
+ heap->offs = 0;
558
+
559
+ // count how many graph computes the heap ended up being unused
560
+ if ([heap->bufs count] > 0) {
561
+ heap->n_unused = 0;
562
+ } else {
563
+ heap->n_unused++;
564
+ }
565
+
566
+ for (id<MTLBuffer> buf in heap->bufs) {
567
+ [buf release];
568
+ }
569
+ [heap->bufs removeAllObjects];
570
+
571
+ // tell the OS that it can reuse this memory if needed
572
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
573
+ [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
574
+ }
575
+
576
+ static void lm_ggml_metal_heap_free(struct lm_ggml_metal_heap * heap) {
577
+ if (heap == nil) {
578
+ return;
579
+ }
580
+
581
+ lm_ggml_metal_heap_reset(heap);
582
+
583
+ [heap->obj release];
584
+ [heap->bufs release];
585
+
586
+ free(heap);
587
+ }
588
+
589
+ @interface lm_ggml_metal_heap_ptr : NSObject
590
+
591
+ @property (nonatomic, assign) struct lm_ggml_metal_heap * data;
592
+
593
+ @end
594
+
595
+ @implementation lm_ggml_metal_heap_ptr
596
+ @end
597
+
598
+ //
599
+ // lm_ggml_metal_mem_pool
600
+ //
601
+
602
+ struct lm_ggml_metal_mem_pool {
603
+ id<MTLDevice> device;
604
+
605
+ int n_heaps; // total number of heaps ever created (including those that were removed)
606
+
607
+ NSMutableArray * heaps;
608
+ NSMutableArray * heaps_to_remove;
609
+ };
610
+
611
+ static struct lm_ggml_metal_mem_pool * lm_ggml_metal_mem_pool_init(void) {
612
+ struct lm_ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct lm_ggml_metal_mem_pool));
613
+
614
+ mem_pool->n_heaps = 0;
615
+
616
+ mem_pool->heaps = [[NSMutableArray alloc] init];
617
+ mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
618
+
619
+ return mem_pool;
620
+ }
621
+
622
+ static void lm_ggml_metal_mem_pool_free(struct lm_ggml_metal_mem_pool * mem_pool) {
623
+ LM_GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
624
+
625
+ size_t size_all = 0;
626
+ size_t size_cur = 0;
627
+
628
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
629
+ LM_GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
630
+ LM_GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
631
+ LM_GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
632
+ LM_GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
633
+ LM_GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
634
+
635
+ if ([ptr.data->bufs count] > 0) {
636
+ size_cur += [ptr.data->obj size];
637
+ }
638
+ size_all += [ptr.data->obj size];
639
+
640
+ lm_ggml_metal_heap_free(ptr.data);
641
+ [ptr release];
642
+ }
643
+ [mem_pool->heaps release];
644
+ [mem_pool->heaps_to_remove release];
645
+
646
+ if (size_all > 0) {
647
+ LM_GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
648
+ LM_GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
649
+ }
650
+
651
+ free(mem_pool);
652
+ }
653
+
654
+ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_pool) {
655
+ for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
656
+ lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
657
+
658
+ struct lm_ggml_metal_heap * heap = ptr.data;
659
+ lm_ggml_metal_heap_reset(heap);
660
+
661
+ // if the heap hasn't been used for a while, remove it
662
+ if (heap->n_unused >= 128) {
663
+ [mem_pool->heaps_to_remove addObject:@(i)];
664
+ }
665
+ }
666
+
667
+ if (mem_pool->heaps_to_remove.count > 0) {
668
+ // remove in reverse order
669
+ for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
670
+ NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
671
+ lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
672
+
673
+ struct lm_ggml_metal_heap * heap = ptr.data;
674
+ lm_ggml_metal_heap_free(heap);
675
+
676
+ [mem_pool->heaps removeObjectAtIndex:index];
677
+ [ptr release];
678
+
679
+ if (i == 0) {
680
+ break;
681
+ }
682
+ }
683
+
684
+ [mem_pool->heaps_to_remove removeAllObjects];
685
+ }
686
+ }
687
+
688
+ static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_pool) {
689
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
690
+ ptr.data->offs = 0;
691
+ }
692
+ }
693
+
694
+ static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
695
+ const size_t alignment = 256;
696
+
697
+ const size_t size_aligned = LM_GGML_PAD(size, alignment);
698
+
699
+ // try one of the existing heaps
700
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
701
+ struct lm_ggml_metal_heap * heap = ptr.data;
702
+ if (heap->offs + size_aligned <= [heap->obj size]) {
703
+ // if this is the first buffer in the heap for the current command buffer, tell the OS that
704
+ // it cannot free the memory used by the heap
705
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
706
+ if ([heap->bufs count] == 0) {
707
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
708
+ }
709
+
710
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
711
+ if (buf == nil) {
712
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
713
+ return nil;
714
+ }
715
+
716
+ heap->n_alloc++;
717
+ heap->offs += size_aligned;
718
+
719
+ [heap->bufs addObject:buf];
720
+
721
+ return buf;
722
+ }
723
+ }
724
+
725
+ // create a new heap that can fit this buffer
726
+ lm_ggml_metal_heap_ptr * heap_ptr = [lm_ggml_metal_heap_ptr new];
727
+
728
+ struct lm_ggml_metal_heap * heap = lm_ggml_metal_heap_init(mem_pool->device, size_aligned);
729
+ if (heap == NULL) {
730
+ LM_GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
731
+ return NULL;
732
+ }
733
+
734
+ //LM_GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
735
+
736
+ heap_ptr.data = heap;
737
+ lm_ggml_metal_heap_reset(heap);
738
+
739
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
740
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
741
+ if (buf == nil) {
742
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
743
+ return NULL;
744
+ }
745
+
746
+ heap->n_alloc++;
747
+ heap->offs += size_aligned;
748
+
749
+ [heap->bufs addObject:buf];
750
+
751
+ [mem_pool->heaps addObject:heap_ptr];
752
+ mem_pool->n_heaps++;
753
+
754
+ return buf;
755
+ }
756
+
757
+ struct lm_ggml_metal_command_buffer {
758
+ id<MTLCommandBuffer> obj;
759
+
760
+ // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
761
+ struct lm_ggml_metal_mem_pool * mem_pool;
762
+ };
763
+
471
764
  struct lm_ggml_backend_metal_context {
765
+ id<MTLDevice> device;
472
766
  id<MTLCommandQueue> queue;
473
767
 
474
768
  dispatch_queue_t d_queue;
@@ -493,7 +787,7 @@ struct lm_ggml_backend_metal_context {
493
787
  void (^encode_async)(size_t ith);
494
788
 
495
789
  // n_cb command buffers + 1 used by the main thread
496
- id<MTLCommandBuffer> command_buffers[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
790
+ struct lm_ggml_metal_command_buffer cmd_bufs[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
497
791
 
498
792
  // abort lm_ggml_metal_graph_compute if callback returns true
499
793
  lm_ggml_abort_callback abort_callback;
@@ -560,11 +854,7 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
560
854
  NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
561
855
  #endif
562
856
 
563
- #if TARGET_OS_SIMULATOR
564
- NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
565
- #else
566
- NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
567
- #endif
857
+ NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
568
858
  if (path_lib == nil) {
569
859
  // Try to find the resource in the directory where the current binary located.
570
860
  NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
@@ -687,9 +977,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
687
977
  struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
688
978
 
689
979
  id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
980
+
690
981
  LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
691
982
 
692
- ctx->queue = [device newCommandQueue];
983
+ ctx->device = device;
984
+ ctx->queue = [device newCommandQueue];
693
985
  if (ctx->queue == nil) {
694
986
  LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
695
987
  return NULL;
@@ -750,7 +1042,10 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
750
1042
  ctx->gf = nil;
751
1043
  ctx->encode_async = nil;
752
1044
  for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
753
- ctx->command_buffers[i] = nil;
1045
+ ctx->cmd_bufs[i].obj = nil;
1046
+
1047
+ ctx->cmd_bufs[i].mem_pool = lm_ggml_metal_mem_pool_init();
1048
+ ctx->cmd_bufs[i].mem_pool->device = device;
754
1049
  }
755
1050
 
756
1051
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -810,6 +1105,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
810
1105
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
811
1106
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
812
1107
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
1108
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
1109
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
813
1110
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
814
1111
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
815
1112
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -967,30 +1264,36 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
967
1264
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
968
1265
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
969
1266
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
970
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
971
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
972
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
973
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
974
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
975
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
976
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
977
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
978
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
979
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
980
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
981
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
982
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
983
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
984
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
985
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
986
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
987
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
988
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
989
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
990
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
991
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
1267
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1268
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1269
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
1270
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
1271
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
1272
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
1273
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
1274
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1275
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1276
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
1277
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1278
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1279
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
1280
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
1281
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
1282
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
1283
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
1284
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
1285
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
1286
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
1287
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
1288
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
1289
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
1290
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
992
1291
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
993
1292
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1293
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
1294
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
1295
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
1296
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
994
1297
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
995
1298
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
996
1299
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@@ -1015,6 +1318,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1015
1318
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
1016
1319
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
1017
1320
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
1321
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
1018
1322
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
1019
1323
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
1020
1324
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@@ -1023,6 +1327,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1023
1327
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
1024
1328
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
1025
1329
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
1330
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
1026
1331
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
1027
1332
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
1028
1333
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -1031,6 +1336,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1031
1336
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
1032
1337
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
1033
1338
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
1339
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
1034
1340
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
1035
1341
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
1036
1342
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@@ -1039,6 +1345,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1039
1345
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
1040
1346
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
1041
1347
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
1348
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
1042
1349
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
1043
1350
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
1044
1351
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@@ -1047,6 +1354,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1047
1354
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
1048
1355
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
1049
1356
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1357
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
1050
1358
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1051
1359
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1052
1360
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@@ -1055,6 +1363,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1055
1363
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
1056
1364
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
1057
1365
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1366
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
1058
1367
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1059
1368
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1060
1369
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@@ -1063,6 +1372,21 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1063
1372
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
1064
1373
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1065
1374
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1375
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1376
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
1377
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
1378
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
1379
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
1380
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
1381
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
1382
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
1383
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
1384
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
1385
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
1386
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
1387
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
1388
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
1389
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
1066
1390
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1067
1391
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
1068
1392
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
@@ -1091,6 +1415,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1091
1415
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
1092
1416
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
1093
1417
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1418
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
1419
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
1420
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
1421
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
1422
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
1423
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
1424
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1094
1425
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1095
1426
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
1096
1427
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@@ -1121,6 +1452,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1121
1452
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
1122
1453
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1123
1454
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1124
1456
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1125
1457
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1126
1458
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1141,6 +1473,12 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) {
1141
1473
 
1142
1474
  [ctx->queue release];
1143
1475
 
1476
+ for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1477
+ // ctx->cmd_bufs[i].obj is auto released
1478
+
1479
+ lm_ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
1480
+ }
1481
+
1144
1482
  dispatch_release(ctx->d_queue);
1145
1483
 
1146
1484
  free(ctx);
@@ -1279,9 +1617,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1279
1617
  case LM_GGML_UNARY_OP_RELU:
1280
1618
  case LM_GGML_UNARY_OP_SIGMOID:
1281
1619
  case LM_GGML_UNARY_OP_GELU:
1620
+ case LM_GGML_UNARY_OP_GELU_ERF:
1282
1621
  case LM_GGML_UNARY_OP_GELU_QUICK:
1283
1622
  case LM_GGML_UNARY_OP_SILU:
1284
1623
  case LM_GGML_UNARY_OP_ELU:
1624
+ case LM_GGML_UNARY_OP_NEG:
1285
1625
  return lm_ggml_is_contiguous(op->src[0]) && op->src[0]->type == LM_GGML_TYPE_F32;
1286
1626
  default:
1287
1627
  return false;
@@ -1324,22 +1664,14 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1324
1664
  case LM_GGML_OP_NORM:
1325
1665
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
1326
1666
  case LM_GGML_OP_ROPE:
1327
- {
1328
- const int mode = ((const int32_t *) op->op_params)[2];
1329
- if (mode & LM_GGML_ROPE_TYPE_MROPE) {
1330
- return false;
1331
- }
1332
- if (mode & LM_GGML_ROPE_TYPE_VISION) {
1333
- return false;
1334
- }
1335
- return true;
1336
- }
1667
+ return true;
1337
1668
  case LM_GGML_OP_IM2COL:
1338
1669
  return op->src[0]->type == LM_GGML_TYPE_F16;
1339
1670
  case LM_GGML_OP_POOL_1D:
1340
1671
  return false;
1341
- case LM_GGML_OP_POOL_2D:
1342
1672
  case LM_GGML_OP_UPSCALE:
1673
+ return op->src[0]->type == LM_GGML_TYPE_F32 && op->op_params[0] == LM_GGML_SCALE_MODE_NEAREST;
1674
+ case LM_GGML_OP_POOL_2D:
1343
1675
  case LM_GGML_OP_PAD:
1344
1676
  case LM_GGML_OP_PAD_REFLECT_1D:
1345
1677
  case LM_GGML_OP_TIMESTEP_EMBEDDING:
@@ -1354,6 +1686,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1354
1686
  // TODO: not sure if it is worth adding kernels for this size
1355
1687
  return false;
1356
1688
  }
1689
+ if (op->src[0]->ne[0] == 576) {
1690
+ // DeepSeek sizes
1691
+ // TODO: disabled for now, until optmized
1692
+ return false;
1693
+ }
1357
1694
  if (op->src[1]->type != op->src[2]->type) {
1358
1695
  return false;
1359
1696
  }
@@ -1439,10 +1776,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1439
1776
  }
1440
1777
  }
1441
1778
 
1442
- static void lm_ggml_metal_encode_node(
1779
+ static bool lm_ggml_metal_encode_node(
1443
1780
  lm_ggml_backend_t backend,
1444
1781
  int idx,
1445
- id<MTLComputeCommandEncoder> encoder) {
1782
+ id<MTLComputeCommandEncoder> encoder,
1783
+ struct lm_ggml_metal_mem_pool * mem_pool) {
1446
1784
  struct lm_ggml_backend_metal_context * ctx = backend->context;
1447
1785
  struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1448
1786
 
@@ -1458,7 +1796,7 @@ static void lm_ggml_metal_encode_node(
1458
1796
  struct lm_ggml_tensor * dst = node;
1459
1797
 
1460
1798
  if (lm_ggml_is_empty(dst)) {
1461
- return;
1799
+ return true;
1462
1800
  }
1463
1801
 
1464
1802
  switch (dst->op) {
@@ -1469,7 +1807,7 @@ static void lm_ggml_metal_encode_node(
1469
1807
  case LM_GGML_OP_PERMUTE:
1470
1808
  {
1471
1809
  // noop -> next node
1472
- } return;
1810
+ } return true;
1473
1811
  default:
1474
1812
  {
1475
1813
  } break;
@@ -1480,6 +1818,8 @@ static void lm_ggml_metal_encode_node(
1480
1818
  LM_GGML_ABORT("unsupported op");
1481
1819
  }
1482
1820
 
1821
+ lm_ggml_metal_mem_pool_clear(mem_pool);
1822
+
1483
1823
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1484
1824
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
1485
1825
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -1916,6 +2256,25 @@ static void lm_ggml_metal_encode_node(
1916
2256
 
1917
2257
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1918
2258
  } break;
2259
+ case LM_GGML_UNARY_OP_GELU_ERF:
2260
+ {
2261
+ int64_t n = lm_ggml_nelements(dst);
2262
+
2263
+ id<MTLComputePipelineState> pipeline = nil;
2264
+
2265
+ if (n % 4 == 0) {
2266
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
2267
+ n /= 4;
2268
+ } else {
2269
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
2270
+ }
2271
+
2272
+ [encoder setComputePipelineState:pipeline];
2273
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2274
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2275
+
2276
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2277
+ } break;
1919
2278
  case LM_GGML_UNARY_OP_GELU_QUICK:
1920
2279
  {
1921
2280
  int64_t n = lm_ggml_nelements(dst);
@@ -1966,6 +2325,18 @@ static void lm_ggml_metal_encode_node(
1966
2325
 
1967
2326
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1968
2327
  } break;
2328
+ case LM_GGML_UNARY_OP_NEG:
2329
+ {
2330
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NEG].pipeline;
2331
+
2332
+ [encoder setComputePipelineState:pipeline];
2333
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2334
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2335
+
2336
+ const int64_t n = lm_ggml_nelements(dst);
2337
+
2338
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2339
+ } break;
1969
2340
  default:
1970
2341
  {
1971
2342
  LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
@@ -2114,26 +2485,76 @@ static void lm_ggml_metal_encode_node(
2114
2485
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2115
2486
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2116
2487
 
2117
- lm_ggml_metal_kargs_soft_max args = {
2488
+ // use this branch to test the lm_ggml_metal_mem_pool functionality
2489
+ #if 0
2490
+ // cpy to tmp buffer in MTLHeap
2491
+
2492
+ id<MTLBuffer> h_src0 = h_src0 = lm_ggml_metal_mem_pool_alloc(mem_pool, lm_ggml_nbytes(src0));
2493
+ if (!h_src0) {
2494
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, lm_ggml_nbytes(src0));
2495
+ return false;
2496
+ }
2497
+
2498
+ offs_src0 = 0;
2499
+
2500
+ lm_ggml_metal_kargs_cpy args_cpy = {
2118
2501
  /*.ne00 =*/ ne00,
2119
2502
  /*.ne01 =*/ ne01,
2120
2503
  /*.ne02 =*/ ne02,
2121
- /*.scale =*/ scale,
2122
- /*.max_bias =*/ max_bias,
2123
- /*.m0 =*/ m0,
2124
- /*.m1 =*/ m1,
2504
+ /*.ne03 =*/ ne03,
2505
+ /*.nb00 =*/ nb00,
2506
+ /*.nb01 =*/ nb01,
2507
+ /*.nb02 =*/ nb02,
2508
+ /*.nb03 =*/ nb03,
2509
+ /*.ne0 =*/ ne00,
2510
+ /*.ne1 =*/ ne01,
2511
+ /*.ne2 =*/ ne02,
2512
+ /*.ne3 =*/ ne03,
2513
+ /*.nb0 =*/ nb00,
2514
+ /*.nb1 =*/ nb01,
2515
+ /*.nb2 =*/ nb02,
2516
+ /*.nb3 =*/ nb03,
2517
+ };
2518
+
2519
+ if (src0->type == LM_GGML_TYPE_F16) {
2520
+ [encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2521
+ } else {
2522
+ [encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2523
+ }
2524
+ [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2525
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2526
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
2527
+
2528
+ LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
2529
+ int nth_cpy = MIN(1024, ne00 / lm_ggml_blck_size(src0->type));
2530
+
2531
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
2532
+
2533
+ #else
2534
+ id<MTLBuffer> h_src0 = id_src0;
2535
+ #endif
2536
+ // softmax
2537
+
2538
+ lm_ggml_metal_kargs_soft_max args = {
2539
+ /*.ne00 =*/ ne00,
2540
+ /*.ne01 =*/ ne01,
2541
+ /*.ne02 =*/ ne02,
2542
+ /*.scale =*/ scale,
2543
+ /*.max_bias =*/ max_bias,
2544
+ /*.m0 =*/ m0,
2545
+ /*.m1 =*/ m1,
2125
2546
  /*.n_head_log2 =*/ n_head_log2,
2126
2547
  };
2127
2548
 
2128
2549
  [encoder setComputePipelineState:pipeline];
2129
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2550
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
2130
2551
  if (id_src1) {
2131
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2552
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2132
2553
  } else {
2133
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2554
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
2134
2555
  }
2135
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2136
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
2556
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2557
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2137
2558
 
2138
2559
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2139
2560
 
@@ -2624,7 +3045,7 @@ static void lm_ggml_metal_encode_node(
2624
3045
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2625
3046
 
2626
3047
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2627
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3048
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2628
3049
  } else {
2629
3050
  id<MTLComputePipelineState> pipeline = nil;
2630
3051
 
@@ -2844,8 +3265,6 @@ static void lm_ggml_metal_encode_node(
2844
3265
  } break;
2845
3266
  case LM_GGML_OP_MUL_MAT_ID:
2846
3267
  {
2847
- const int n_as = src0->ne[2];
2848
-
2849
3268
  // src2 = ids
2850
3269
  const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t);
2851
3270
 
@@ -2859,24 +3278,21 @@ static void lm_ggml_metal_encode_node(
2859
3278
  LM_GGML_ASSERT(ne03 == 1);
2860
3279
  LM_GGML_ASSERT(ne13 == 1);
2861
3280
 
3281
+ const uint32_t r2 = 1;
3282
+ const uint32_t r3 = 1;
3283
+
2862
3284
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2863
3285
  // to the matrix-vector kernel
2864
3286
  // ne20 = n_used_experts
2865
- // ne21 = n_rows
2866
- const int dst_rows = ne20*ne21;
2867
- const int dst_rows_min = n_as;
2868
- const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
2869
-
2870
- // max size of the rowids array in the kernel shared buffer
2871
- //LM_GGML_ASSERT(dst_rows <= dst_rows_max);
3287
+ // ne21 = n_rows (batch size)
3288
+ const int ne21_mm_id_min = 32;
2872
3289
 
2873
3290
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
2874
3291
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
2875
3292
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2876
3293
  ne00 % 32 == 0 && ne00 >= 64 &&
2877
- //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
2878
- dst_rows > dst_rows_min &&
2879
- dst_rows <= dst_rows_max) {
3294
+ (ne21 >= ne21_mm_id_min)) {
3295
+ LM_GGML_ASSERT(ne00 % 4 == 0);
2880
3296
 
2881
3297
  // some Metal matrix data types require aligned pointers
2882
3298
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -2887,62 +3303,169 @@ static void lm_ggml_metal_encode_node(
2887
3303
  default: break;
2888
3304
  }
2889
3305
 
2890
- id<MTLComputePipelineState> pipeline = nil;
3306
+ const int64_t neh10 = ne10; // n_embd
3307
+ const int64_t neh11 = ne21; // n_tokens
3308
+ const int64_t neh12 = ne02; // n_expert
2891
3309
 
2892
- switch (src0->type) {
2893
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
2894
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
2895
- case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
2896
- case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
2897
- case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
2898
- case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
2899
- case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
2900
- case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
2901
- case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
2902
- case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
2903
- case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
2904
- case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
2905
- case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
2906
- case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
2907
- case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
2908
- case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
2909
- case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
2910
- case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
2911
- case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
2912
- case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
2913
- case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
2914
- case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
2915
- default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
3310
+ const uint64_t nbh10 = lm_ggml_type_size(LM_GGML_TYPE_F16);
3311
+ const uint64_t nbh11 = nbh10*neh10;
3312
+ const uint64_t nbh12 = nbh11*neh11;
3313
+ const uint64_t nbh13 = nbh12*neh12;
3314
+
3315
+ const size_t s_src1 = lm_ggml_type_size(LM_GGML_TYPE_F16)*neh10*neh11*neh12;
3316
+ id<MTLBuffer> h_src1 = lm_ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3317
+ if (!h_src1) {
3318
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3319
+ return false;
2916
3320
  }
2917
3321
 
2918
- lm_ggml_metal_kargs_mul_mm_id args = {
2919
- /*.nei0 =*/ ne20,
2920
- /*.nei1 =*/ ne21,
2921
- /*.nbi1 =*/ nb21,
2922
- /*.ne00 =*/ ne00,
2923
- /*.ne02 =*/ ne02,
2924
- /*.nb01 =*/ nb01,
2925
- /*.nb02 =*/ nb02,
2926
- /*.ne11 =*/ ne11,
2927
- /*.ne12 =*/ ne12,
2928
- /*.ne13 =*/ ne13,
2929
- /*.nb10 =*/ nb10,
2930
- /*.nb11 =*/ nb11,
2931
- /*.nb12 =*/ nb12,
2932
- /*.ne0 =*/ ne0,
2933
- /*.ne1 =*/ ne1,
2934
- };
3322
+ const int64_t neh0 = ne0;
3323
+ const int64_t neh1 = ne21;
3324
+ const int64_t neh2 = ne02;
2935
3325
 
2936
- [encoder setComputePipelineState:pipeline];
2937
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2938
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2939
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2940
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2941
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3326
+ const uint64_t nbh0 = lm_ggml_type_size(LM_GGML_TYPE_F32);
3327
+ const uint64_t nbh1 = nbh0*neh0;
3328
+ const uint64_t nbh2 = nbh1*neh1;
3329
+ //const uint64_t nbh3 = nbh2*neh2;
3330
+
3331
+ const size_t s_dst = lm_ggml_type_size(LM_GGML_TYPE_F32)*neh0*neh1*neh2;
3332
+ id<MTLBuffer> h_dst = lm_ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3333
+ if (!h_dst) {
3334
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3335
+ return false;
3336
+ }
3337
+
3338
+ // tokens per expert
3339
+ const size_t s_tpe = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne02;
3340
+ id<MTLBuffer> h_tpe = lm_ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3341
+ if (!h_tpe) {
3342
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3343
+ return false;
3344
+ }
3345
+
3346
+ // id map
3347
+ // [n_expert_used, n_tokens]
3348
+ const size_t s_ids = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne20*ne21;
3349
+ id<MTLBuffer> h_ids = lm_ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3350
+ if (!h_ids) {
3351
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3352
+ return false;
3353
+ }
3354
+
3355
+ {
3356
+ const int nth = MIN(1024, ne10/4);
3357
+
3358
+ lm_ggml_metal_kargs_mul_mm_id_map0 args = {
3359
+ ne10,
3360
+ ne11, // n_expert_used (bcast)
3361
+ nb11,
3362
+ nb12,
3363
+ neh11, // n_tokens
3364
+ nbh11,
3365
+ ne20, // n_expert_used
3366
+ nb21,
3367
+ };
3368
+
3369
+ id<MTLComputePipelineState> pipeline = nil;
3370
+
3371
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3372
+
3373
+ [encoder setComputePipelineState:pipeline];
3374
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3375
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3376
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3377
+ [encoder setBuffer: h_src1 offset:0 atIndex:3];
3378
+ [encoder setBuffer: h_tpe offset:0 atIndex:4];
3379
+ [encoder setBuffer: h_ids offset:0 atIndex:5];
3380
+
3381
+ [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3382
+ }
3383
+
3384
+ {
3385
+ id<MTLComputePipelineState> pipeline = nil;
3386
+
3387
+ switch (src0->type) {
3388
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
3389
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
3390
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
3391
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
3392
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
3393
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3394
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3395
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
3396
+ case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3397
+ case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3398
+ case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
3399
+ case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
3400
+ case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
3401
+ case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
3402
+ case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
3403
+ case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
3404
+ case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
3405
+ case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
3406
+ case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
3407
+ case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
3408
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
3409
+ case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
3410
+ default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
3411
+ }
3412
+
3413
+ lm_ggml_metal_kargs_mul_mm_id args = {
3414
+ /*.ne00 =*/ ne00,
3415
+ /*.ne02 =*/ ne02,
3416
+ /*.nb01 =*/ nb01,
3417
+ /*.nb02 =*/ nb02,
3418
+ /*.nb03 =*/ nb03,
3419
+ /*.neh12 =*/ neh12,
3420
+ /*.nbh10 =*/ nbh10,
3421
+ /*.nbh11 =*/ nbh11,
3422
+ /*.nbh12 =*/ nbh12,
3423
+ /*.nbh13 =*/ nbh13,
3424
+ /*.neh0 =*/ neh0,
3425
+ /*.neh1 =*/ neh1,
3426
+ /*.r2 =*/ r2,
3427
+ /*.r3 =*/ r3,
3428
+ };
3429
+
3430
+ [encoder setComputePipelineState:pipeline];
3431
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3432
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3433
+ [encoder setBuffer: h_src1 offset:0 atIndex:2];
3434
+ [encoder setBuffer: h_tpe offset:0 atIndex:3];
3435
+ [encoder setBuffer: h_dst offset:0 atIndex:4];
3436
+
3437
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
3438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3439
+ }
2942
3440
 
2943
- [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
3441
+ {
3442
+ LM_GGML_ASSERT(ne0 % 4 == 0);
2944
3443
 
2945
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3444
+ const int nth = MIN(1024, ne0/4);
3445
+
3446
+ lm_ggml_metal_kargs_mul_mm_id_map1 args = {
3447
+ ne20, // n_expert_used
3448
+ neh0,
3449
+ neh1,
3450
+ nbh1,
3451
+ nbh2,
3452
+ ne0,
3453
+ nb1,
3454
+ nb2,
3455
+ };
3456
+
3457
+ id<MTLComputePipelineState> pipeline = nil;
3458
+
3459
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
3460
+
3461
+ [encoder setComputePipelineState:pipeline];
3462
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3463
+ [encoder setBuffer: h_dst offset:0 atIndex:1];
3464
+ [encoder setBuffer: h_ids offset:0 atIndex:2];
3465
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3466
+
3467
+ [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3468
+ }
2946
3469
  } else {
2947
3470
  id<MTLComputePipelineState> pipeline = nil;
2948
3471
 
@@ -3136,7 +3659,7 @@ static void lm_ggml_metal_encode_node(
3136
3659
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3137
3660
 
3138
3661
  const int64_t _ne1 = 1;
3139
- const int64_t ne123 = dst_rows;
3662
+ const int64_t ne123 = ne20*ne21;
3140
3663
 
3141
3664
  if (smem > 0) {
3142
3665
  [encoder setThreadgroupMemoryLength:smem atIndex:0];
@@ -3340,6 +3863,7 @@ static void lm_ggml_metal_encode_node(
3340
3863
  } break;
3341
3864
  case LM_GGML_OP_ROPE:
3342
3865
  {
3866
+
3343
3867
  // make sure we have one or more position id(ne10) per token(ne02)
3344
3868
  LM_GGML_ASSERT(ne10 % ne02 == 0);
3345
3869
  LM_GGML_ASSERT(ne10 >= ne02);
@@ -3366,20 +3890,42 @@ static void lm_ggml_metal_encode_node(
3366
3890
  memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
3367
3891
  memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
3368
3892
 
3369
- const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX;
3893
+ const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX;
3894
+ const bool is_mrope = mode & LM_GGML_ROPE_TYPE_MROPE;
3895
+ const bool is_vision = mode == LM_GGML_ROPE_TYPE_VISION;
3896
+
3897
+ // mrope
3898
+ const int sect_0 = ((const int32_t *) dst->op_params)[11];
3899
+ const int sect_1 = ((const int32_t *) dst->op_params)[12];
3900
+ const int sect_2 = ((const int32_t *) dst->op_params)[13];
3901
+ const int sect_3 = ((const int32_t *) dst->op_params)[14];
3370
3902
 
3371
3903
  id<MTLComputePipelineState> pipeline = nil;
3372
3904
 
3373
- if (!is_neox) {
3905
+ if (is_neox) {
3374
3906
  switch (src0->type) {
3375
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3376
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3907
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3908
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3909
+ default: LM_GGML_ABORT("fatal error");
3910
+ };
3911
+ } else if (is_mrope && !is_vision) {
3912
+ LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3913
+ switch (src0->type) {
3914
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
3915
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
3916
+ default: LM_GGML_ABORT("fatal error");
3917
+ };
3918
+ } else if (is_vision) {
3919
+ LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3920
+ switch (src0->type) {
3921
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
3922
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
3377
3923
  default: LM_GGML_ABORT("fatal error");
3378
3924
  };
3379
3925
  } else {
3380
3926
  switch (src0->type) {
3381
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3382
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3927
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3928
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3383
3929
  default: LM_GGML_ABORT("fatal error");
3384
3930
  };
3385
3931
  }
@@ -3410,6 +3956,10 @@ static void lm_ggml_metal_encode_node(
3410
3956
  /*.attn_factor =*/ attn_factor,
3411
3957
  /*.beta_fast =*/ beta_fast,
3412
3958
  /*.beta_slow =*/ beta_slow,
3959
+ /* sect_0 =*/ sect_0,
3960
+ /* sect_1 =*/ sect_1,
3961
+ /* sect_2 =*/ sect_2,
3962
+ /* sect_3 =*/ sect_3,
3413
3963
  };
3414
3964
 
3415
3965
  [encoder setComputePipelineState:pipeline];
@@ -3846,12 +4396,14 @@ static void lm_ggml_metal_encode_node(
3846
4396
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3847
4397
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
3848
4398
  // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
3849
- if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
4399
+ if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
3850
4400
  switch (src1->type) {
3851
4401
  case LM_GGML_TYPE_F16:
3852
4402
  {
3853
4403
  if (ne00 == 192 && ne20 == 128) {
3854
4404
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
4405
+ } else if (ne00 == 576 && ne20 == 512) {
4406
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
3855
4407
  } else {
3856
4408
  switch (ne00) {
3857
4409
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
@@ -3874,6 +4426,8 @@ static void lm_ggml_metal_encode_node(
3874
4426
  {
3875
4427
  if (ne00 == 192 && ne20 == 128) {
3876
4428
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
4429
+ } else if (ne00 == 576 && ne20 == 512) {
4430
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
3877
4431
  } else {
3878
4432
  switch (ne00) {
3879
4433
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
@@ -3896,6 +4450,8 @@ static void lm_ggml_metal_encode_node(
3896
4450
  {
3897
4451
  if (ne00 == 192 && ne20 == 128) {
3898
4452
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
4453
+ } else if (ne00 == 576 && ne20 == 512) {
4454
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
3899
4455
  } else {
3900
4456
  switch (ne00) {
3901
4457
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
@@ -3918,6 +4474,8 @@ static void lm_ggml_metal_encode_node(
3918
4474
  {
3919
4475
  if (ne00 == 192 && ne20 == 128) {
3920
4476
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
4477
+ } else if (ne00 == 576 && ne20 == 512) {
4478
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
3921
4479
  } else {
3922
4480
  switch (ne00) {
3923
4481
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
@@ -3940,6 +4498,8 @@ static void lm_ggml_metal_encode_node(
3940
4498
  {
3941
4499
  if (ne00 == 192 && ne20 == 128) {
3942
4500
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
4501
+ } else if (ne00 == 576 && ne20 == 512) {
4502
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
3943
4503
  } else {
3944
4504
  switch (ne00) {
3945
4505
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
@@ -3962,6 +4522,8 @@ static void lm_ggml_metal_encode_node(
3962
4522
  {
3963
4523
  if (ne00 == 192 && ne20 == 128) {
3964
4524
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
4525
+ } else if (ne00 == 576 && ne20 == 512) {
4526
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
3965
4527
  } else {
3966
4528
  switch (ne00) {
3967
4529
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
@@ -3984,6 +4546,8 @@ static void lm_ggml_metal_encode_node(
3984
4546
  {
3985
4547
  if (ne00 == 192 && ne20 == 128) {
3986
4548
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
4549
+ } else if (ne00 == 576 && ne20 == 512) {
4550
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
3987
4551
  } else {
3988
4552
  switch (ne00) {
3989
4553
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
@@ -4013,6 +4577,42 @@ static void lm_ggml_metal_encode_node(
4013
4577
  use_vec_kernel = true;
4014
4578
 
4015
4579
  switch (ne00) {
4580
+ case 64:
4581
+ {
4582
+ switch (src1->type) {
4583
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
4584
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
4585
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
4586
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
4587
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
4588
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
4589
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
4590
+ default:
4591
+ {
4592
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4593
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4594
+ LM_GGML_ABORT("add template specialization for this type");
4595
+ }
4596
+ }
4597
+ } break;
4598
+ case 96:
4599
+ {
4600
+ switch (src1->type) {
4601
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
4602
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
4603
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
4604
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
4605
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
4606
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
4607
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
4608
+ default:
4609
+ {
4610
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4611
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4612
+ LM_GGML_ABORT("add template specialization for this type");
4613
+ }
4614
+ }
4615
+ } break;
4016
4616
  case 128:
4017
4617
  {
4018
4618
  switch (src1->type) {
@@ -4085,12 +4685,36 @@ static void lm_ggml_metal_encode_node(
4085
4685
  }
4086
4686
  }
4087
4687
  } break;
4688
+ case 576:
4689
+ {
4690
+ if (ne20 == 512) {
4691
+ switch (src1->type) {
4692
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
4693
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
4694
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
4695
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
4696
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
4697
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
4698
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
4699
+ default:
4700
+ {
4701
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4702
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4703
+ LM_GGML_ABORT("add template specialization for this type");
4704
+ }
4705
+ }
4706
+ } else {
4707
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
4708
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4709
+ LM_GGML_ABORT("add template specialization for this size");
4710
+ }
4711
+ } break;
4088
4712
  default:
4089
- {
4090
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4091
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
4092
- LM_GGML_ABORT("add template specialization for this size");
4093
- }
4713
+ {
4714
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4715
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4716
+ LM_GGML_ABORT("add template specialization for this size");
4717
+ }
4094
4718
  }
4095
4719
  }
4096
4720
 
@@ -4486,6 +5110,8 @@ static void lm_ggml_metal_encode_node(
4486
5110
  LM_GGML_ABORT("fatal error");
4487
5111
  }
4488
5112
  }
5113
+
5114
+ return true;
4489
5115
  }
4490
5116
 
4491
5117
  static enum lm_ggml_status lm_ggml_metal_graph_compute(
@@ -4539,25 +5165,25 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4539
5165
  }
4540
5166
 
4541
5167
  // the main thread commits the first few commands immediately
4542
- // command_buffer[n_cb]
5168
+ // cmd_buf[n_cb]
4543
5169
  {
4544
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4545
- ctx->command_buffers[n_cb] = command_buffer;
5170
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5171
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
4546
5172
 
4547
- [command_buffer enqueue];
5173
+ [cmd_buf enqueue];
4548
5174
  ctx->encode_async(n_cb);
4549
5175
  }
4550
5176
 
4551
5177
  // prepare the rest of the command buffers asynchronously
4552
- // command_buffer[0.. n_cb)
5178
+ // cmd_buf[0.. n_cb)
4553
5179
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4554
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4555
- ctx->command_buffers[cb_idx] = command_buffer;
5180
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5181
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
4556
5182
 
4557
5183
  // always enqueue the first two command buffers
4558
5184
  // enqueue all of the command buffers if we don't need to abort
4559
5185
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
4560
- [command_buffer enqueue];
5186
+ [cmd_buf enqueue];
4561
5187
  }
4562
5188
  }
4563
5189
 
@@ -4566,14 +5192,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4566
5192
  // wait for completion and check status of each command buffer
4567
5193
  // needed to detect if the device ran out-of-memory for example (#1881)
4568
5194
  {
4569
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4570
- [command_buffer waitUntilCompleted];
5195
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5196
+ [cmd_buf waitUntilCompleted];
4571
5197
 
4572
- MTLCommandBufferStatus status = [command_buffer status];
5198
+ MTLCommandBufferStatus status = [cmd_buf status];
4573
5199
  if (status != MTLCommandBufferStatusCompleted) {
4574
5200
  LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
4575
5201
  if (status == MTLCommandBufferStatusError) {
4576
- LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5202
+ LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4577
5203
  }
4578
5204
 
4579
5205
  return LM_GGML_STATUS_FAILED;
@@ -4581,20 +5207,20 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4581
5207
  }
4582
5208
 
4583
5209
  for (int i = 0; i < n_cb; ++i) {
4584
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4585
- [command_buffer waitUntilCompleted];
5210
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5211
+ [cmd_buf waitUntilCompleted];
4586
5212
 
4587
- MTLCommandBufferStatus status = [command_buffer status];
5213
+ MTLCommandBufferStatus status = [cmd_buf status];
4588
5214
  if (status != MTLCommandBufferStatusCompleted) {
4589
5215
  LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
4590
5216
  if (status == MTLCommandBufferStatusError) {
4591
- LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5217
+ LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4592
5218
  }
4593
5219
 
4594
5220
  return LM_GGML_STATUS_FAILED;
4595
5221
  }
4596
5222
 
4597
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
5223
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
4598
5224
  if (!next_buffer) {
4599
5225
  continue;
4600
5226
  }
@@ -4977,8 +5603,9 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
4977
5603
 
4978
5604
  const int n_nodes_per_cb = ctx->n_nodes_per_cb;
4979
5605
 
4980
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
4981
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
5606
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5607
+
5608
+ id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
4982
5609
 
4983
5610
  int node_start = 0;
4984
5611
  int node_end = n_nodes_0;
@@ -4990,22 +5617,29 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
4990
5617
 
4991
5618
  const bool should_capture = ctx->capture_next_compute;
4992
5619
 
5620
+ struct lm_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5621
+ lm_ggml_metal_mem_pool_reset(mem_pool);
5622
+
4993
5623
  for (int idx = node_start; idx < node_end; ++idx) {
4994
5624
  if (should_capture) {
4995
5625
  [encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
4996
5626
  }
4997
5627
 
4998
- lm_ggml_metal_encode_node(backend, idx, encoder);
5628
+ const bool res = lm_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
4999
5629
 
5000
5630
  if (should_capture) {
5001
5631
  [encoder popDebugGroup];
5002
5632
  }
5633
+
5634
+ if (!res) {
5635
+ break;
5636
+ }
5003
5637
  }
5004
5638
 
5005
5639
  [encoder endEncoding];
5006
5640
 
5007
5641
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
5008
- [command_buffer commit];
5642
+ [cmd_buf commit];
5009
5643
  }
5010
5644
  });
5011
5645
  }