cui-llama.rn 1.4.6 → 1.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (366) hide show
  1. package/LICENSE +20 -20
  2. package/README.md +317 -319
  3. package/android/build.gradle +116 -116
  4. package/android/gradle.properties +5 -5
  5. package/android/src/main/AndroidManifest.xml +4 -4
  6. package/android/src/main/CMakeLists.txt +124 -117
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +645 -645
  8. package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
  10. package/android/src/main/jni-utils.h +100 -100
  11. package/android/src/main/jni.cpp +1263 -1245
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  14. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  15. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  16. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  17. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  20. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
  21. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
  22. package/cpp/README.md +4 -4
  23. package/cpp/binary-ops.cpp +158 -0
  24. package/cpp/binary-ops.h +16 -0
  25. package/cpp/chat.cpp +1769 -1779
  26. package/cpp/chat.h +9 -1
  27. package/cpp/common.cpp +20 -522
  28. package/cpp/common.h +13 -36
  29. package/cpp/cpu-common.h +72 -0
  30. package/cpp/ggml-common.h +12 -6
  31. package/cpp/ggml-cpu-aarch64.cpp +1557 -80
  32. package/cpp/ggml-cpu-impl.h +2 -21
  33. package/cpp/ggml-cpu-quants.c +904 -405
  34. package/cpp/ggml-cpu.c +909 -13237
  35. package/cpp/ggml-impl.h +50 -23
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +597 -523
  39. package/cpp/ggml-metal.m +798 -580
  40. package/cpp/ggml.c +92 -3
  41. package/cpp/ggml.h +30 -6
  42. package/cpp/gguf.cpp +1 -0
  43. package/cpp/llama-adapter.cpp +55 -20
  44. package/cpp/llama-adapter.h +11 -9
  45. package/cpp/llama-arch.cpp +217 -16
  46. package/cpp/llama-arch.h +25 -0
  47. package/cpp/llama-batch.h +2 -2
  48. package/cpp/llama-chat.cpp +54 -2
  49. package/cpp/llama-chat.h +3 -0
  50. package/cpp/llama-context.cpp +2294 -1238
  51. package/cpp/llama-context.h +214 -77
  52. package/cpp/llama-cparams.h +1 -0
  53. package/cpp/llama-graph.cpp +1695 -0
  54. package/cpp/llama-graph.h +592 -0
  55. package/cpp/llama-hparams.cpp +8 -0
  56. package/cpp/llama-hparams.h +17 -0
  57. package/cpp/llama-io.cpp +15 -0
  58. package/cpp/llama-io.h +35 -0
  59. package/cpp/llama-kv-cache.cpp +965 -303
  60. package/cpp/llama-kv-cache.h +145 -151
  61. package/cpp/llama-memory.cpp +1 -0
  62. package/cpp/llama-memory.h +21 -0
  63. package/cpp/llama-mmap.cpp +1 -1
  64. package/cpp/llama-model-loader.cpp +10 -5
  65. package/cpp/llama-model-loader.h +5 -3
  66. package/cpp/llama-model.cpp +9194 -201
  67. package/cpp/llama-model.h +40 -1
  68. package/cpp/llama-sampling.cpp +5 -0
  69. package/cpp/llama-vocab.cpp +36 -5
  70. package/cpp/llama.cpp +51 -9984
  71. package/cpp/llama.h +102 -22
  72. package/cpp/log.cpp +34 -0
  73. package/cpp/minja/chat-template.hpp +15 -7
  74. package/cpp/minja/minja.hpp +120 -94
  75. package/cpp/ops.cpp +8723 -0
  76. package/cpp/ops.h +128 -0
  77. package/cpp/rn-llama.cpp +873 -882
  78. package/cpp/rn-llama.h +138 -148
  79. package/cpp/sampling.cpp +3 -0
  80. package/cpp/sampling.h +107 -107
  81. package/cpp/sgemm.cpp +533 -88
  82. package/cpp/simd-mappings.h +888 -0
  83. package/cpp/speculative.cpp +4 -4
  84. package/cpp/unary-ops.cpp +186 -0
  85. package/cpp/unary-ops.h +28 -0
  86. package/cpp/unicode-data.cpp +7034 -7034
  87. package/cpp/unicode-data.h +20 -20
  88. package/cpp/unicode.cpp +849 -849
  89. package/cpp/unicode.h +66 -66
  90. package/cpp/vec.cpp +258 -0
  91. package/cpp/vec.h +802 -0
  92. package/ios/CMakeLists.txt +116 -105
  93. package/ios/RNLlama.h +7 -7
  94. package/ios/RNLlama.mm +418 -405
  95. package/ios/RNLlamaContext.h +57 -57
  96. package/ios/RNLlamaContext.mm +835 -819
  97. package/ios/rnllama.xcframework/Info.plist +74 -74
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +677 -0
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1434 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
  143. package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/chat-template.hpp +15 -7
  144. package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/minja.hpp +120 -94
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +128 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +802 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  184. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  191. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  192. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  193. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  194. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  195. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  196. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  197. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  198. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  199. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  200. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  201. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  202. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  203. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  204. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  205. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  206. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  207. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  208. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  209. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  210. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  211. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  212. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  213. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  214. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  215. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  216. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  217. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  218. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +677 -0
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  225. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  226. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  227. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  228. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  229. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  230. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  231. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  232. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  233. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  234. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  235. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  236. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  237. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  238. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  239. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  240. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
  241. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  242. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  243. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  244. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  245. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  246. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  247. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  248. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  249. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  250. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  251. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  261. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  262. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1434 -0
  263. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
  264. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  265. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  266. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +128 -0
  267. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  268. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
  269. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  270. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
  272. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  273. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  274. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
  275. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +802 -0
  276. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  277. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  278. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  279. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  280. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  281. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  282. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  283. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  284. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  285. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  286. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  287. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  288. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  289. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  290. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  291. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  292. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  293. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  294. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  295. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  296. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  297. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  298. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  299. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  300. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  301. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  302. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  303. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  304. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  305. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  306. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  307. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  308. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  309. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  310. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  311. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  312. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  313. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  314. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  315. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  316. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  317. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  318. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  319. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  320. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  321. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  322. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  323. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  324. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  325. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  326. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  327. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  328. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  329. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  330. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  331. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  332. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  333. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  334. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  335. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  336. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  337. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  338. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  339. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  340. package/jest/mock.js +203 -203
  341. package/lib/commonjs/NativeRNLlama.js +1 -2
  342. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  343. package/lib/commonjs/chat.js.map +1 -1
  344. package/lib/commonjs/grammar.js +12 -31
  345. package/lib/commonjs/grammar.js.map +1 -1
  346. package/lib/commonjs/index.js +47 -47
  347. package/lib/commonjs/index.js.map +1 -1
  348. package/lib/commonjs/package.json +1 -0
  349. package/lib/module/NativeRNLlama.js +2 -0
  350. package/lib/module/NativeRNLlama.js.map +1 -1
  351. package/lib/module/chat.js +2 -0
  352. package/lib/module/chat.js.map +1 -1
  353. package/lib/module/grammar.js +14 -31
  354. package/lib/module/grammar.js.map +1 -1
  355. package/lib/module/index.js +47 -45
  356. package/lib/module/index.js.map +1 -1
  357. package/lib/module/package.json +1 -0
  358. package/lib/typescript/NativeRNLlama.d.ts +6 -4
  359. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  360. package/lib/typescript/index.d.ts.map +1 -1
  361. package/llama-rn.podspec +48 -48
  362. package/package.json +233 -233
  363. package/src/NativeRNLlama.ts +426 -424
  364. package/src/chat.ts +44 -44
  365. package/src/grammar.ts +854 -854
  366. package/src/index.ts +495 -485
package/cpp/ggml-metal.m CHANGED
@@ -184,10 +184,13 @@ enum lm_ggml_metal_kernel_type {
184
184
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
185
185
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
186
186
  LM_GGML_METAL_KERNEL_TYPE_RMS_NORM,
187
+ LM_GGML_METAL_KERNEL_TYPE_L2_NORM,
187
188
  LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
188
189
  LM_GGML_METAL_KERNEL_TYPE_NORM,
189
190
  LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
190
191
  LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
192
+ LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
193
+ LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
191
194
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
192
195
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
193
196
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -348,42 +351,56 @@ enum lm_ggml_metal_kernel_type {
348
351
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
349
352
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
350
353
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
354
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
355
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
351
356
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
352
357
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
353
358
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
354
359
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
355
360
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
356
361
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
362
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
363
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
357
364
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
358
365
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
359
366
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
360
367
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
361
368
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
362
369
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
370
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
371
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
363
372
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
364
373
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
365
374
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
366
375
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
367
376
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
368
377
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
378
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
379
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
369
380
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
370
381
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
371
382
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
372
383
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
373
384
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
374
385
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
386
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
387
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
375
388
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
376
389
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
377
390
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
378
391
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
379
392
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
380
393
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
394
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
395
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
381
396
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
382
397
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
383
398
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
384
399
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
385
400
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
386
401
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
402
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
403
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
387
404
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
388
405
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
389
406
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
@@ -392,6 +409,20 @@ enum lm_ggml_metal_kernel_type {
392
409
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
393
410
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
394
411
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
412
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
413
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
414
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
415
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
416
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
417
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
418
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
419
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
420
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
421
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
422
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
423
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
424
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
425
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
395
426
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
396
427
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
397
428
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
@@ -529,7 +560,11 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
529
560
  NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
530
561
  #endif
531
562
 
563
+ #if TARGET_OS_SIMULATOR
564
+ NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
565
+ #else
532
566
  NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
567
+ #endif
533
568
  if (path_lib == nil) {
534
569
  // Try to find the resource in the directory where the current binary located.
535
570
  NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
@@ -755,310 +790,341 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
755
790
 
756
791
  // simd_sum and simd_max requires MTLGPUFamilyApple7
757
792
 
758
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true);
759
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
760
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB, sub, true);
761
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
762
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
763
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
764
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true);
765
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
766
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
767
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
768
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
769
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
770
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
771
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
772
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
773
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
774
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU, relu, true);
775
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
776
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
777
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
778
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
779
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
780
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
781
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
782
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
783
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
784
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
785
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
786
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
787
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
788
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
789
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
790
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
791
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
792
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
793
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
794
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
795
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
796
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
797
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
798
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
799
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
800
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
801
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
802
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
803
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
804
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
805
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
806
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
807
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
808
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
809
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
810
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
811
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
812
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
813
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
814
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
815
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
816
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
817
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
818
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
819
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
820
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
821
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
822
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
823
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
824
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
825
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
826
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
827
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
828
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
829
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
830
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
831
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
832
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
833
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
834
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
835
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
836
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
837
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
838
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
839
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
840
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
841
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
842
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
843
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
844
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
845
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
846
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
847
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
848
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
849
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
850
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
851
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
852
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
853
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
854
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
855
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
856
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
857
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
858
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
859
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
860
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
861
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
862
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
863
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
864
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
865
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
866
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
867
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
868
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
869
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
870
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
871
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
872
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
873
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
874
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
875
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
876
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
877
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
878
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
879
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
880
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
881
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
882
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
883
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
884
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
885
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
886
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
887
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
888
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
889
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
890
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
891
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
892
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
893
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
894
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
895
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
896
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
897
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
898
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
899
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
900
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
901
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
902
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
903
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
904
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
905
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
906
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
907
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
908
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
909
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
910
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
911
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
912
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
913
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
914
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
915
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
916
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
917
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
918
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
919
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
920
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
921
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
922
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
923
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
924
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
925
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
926
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
927
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
928
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
929
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
930
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
931
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
932
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
933
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
934
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
935
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
936
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
937
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
938
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
939
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
940
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
941
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
942
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
943
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
944
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
945
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
946
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
947
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
948
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
949
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
950
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
951
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
952
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
953
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
954
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
955
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
956
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
957
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
958
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
959
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
960
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
961
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
962
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
963
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
964
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
965
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
966
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
967
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
968
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
969
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
970
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
971
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
972
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
973
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
974
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
975
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
976
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
977
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
978
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
979
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
980
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
981
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
982
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
983
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
984
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
985
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
986
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
987
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
988
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
989
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
990
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
991
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
992
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
993
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
994
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
995
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
996
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
997
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
998
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
999
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
1000
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
1001
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1002
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1003
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1004
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
1005
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
1006
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
1007
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1008
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1009
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1010
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
1011
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
1012
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
1013
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1014
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1015
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
1016
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
1017
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
1018
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
1019
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
1020
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
1021
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
1022
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
1023
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
1024
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
1025
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
1026
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
1027
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1028
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1029
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
1030
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
1031
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
1032
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
1033
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
1034
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
1035
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
1036
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
1037
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
1038
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
1039
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
1040
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
1041
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
1042
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1043
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1044
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1045
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1046
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1047
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1048
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1049
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1050
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1051
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1052
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
1053
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
1054
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
1055
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
1056
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1057
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1058
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1059
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1060
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1061
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
793
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true);
794
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
795
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB, sub, true);
796
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
797
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
798
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
799
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true);
800
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
801
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
802
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
803
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
804
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
805
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
806
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
807
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
808
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
809
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU, relu, true);
810
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
811
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
812
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
813
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
814
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
815
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
816
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
817
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
818
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
819
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
820
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
821
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
822
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
823
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
824
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
825
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
826
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
827
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
828
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
829
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
830
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
831
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
832
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
833
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
834
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
835
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
836
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
837
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
838
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
839
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
840
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
841
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
842
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
843
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
844
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
845
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
846
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
847
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
848
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
849
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
850
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
851
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
852
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
853
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
854
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
855
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
856
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
857
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
858
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
859
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
860
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
861
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
862
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
863
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
864
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
865
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
866
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
867
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
868
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
869
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
870
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
871
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
872
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
873
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
874
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
875
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
876
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
877
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
878
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
879
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
880
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
881
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
882
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
883
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
884
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
885
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
886
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
887
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
888
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
889
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
890
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
891
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
892
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
893
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
894
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
895
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
896
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
897
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
898
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
899
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
900
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
901
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
902
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
903
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
904
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
905
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
906
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
907
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
908
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
909
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
910
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
911
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
912
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
913
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
914
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
915
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
916
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
917
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
918
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
919
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
920
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
921
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
922
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
923
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
924
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
925
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
926
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
927
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
928
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
929
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
930
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
931
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
932
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
933
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
934
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
935
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
936
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
937
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
938
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
939
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
940
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
941
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
942
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
943
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
944
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
945
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
946
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
947
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
948
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
949
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
950
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
951
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
952
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
953
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
954
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
955
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
956
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
957
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
958
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
959
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
960
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
961
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
962
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
963
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
964
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
965
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
966
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
967
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
968
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
969
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
970
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
971
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
972
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
973
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
974
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
975
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
976
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
977
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
978
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
979
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
980
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
981
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
982
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
983
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
984
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
985
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
986
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
987
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
988
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
989
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
990
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
991
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
992
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
993
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
994
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
995
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
996
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
997
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
998
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
999
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
1000
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
1001
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
1002
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
1003
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
1004
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
1005
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
1006
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
1007
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
1008
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
1009
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
1010
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
1011
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
1012
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
1013
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
1014
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
1015
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
1016
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
1017
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
1018
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
1019
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
1020
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
1021
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
1022
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
1023
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
1024
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
1025
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
1026
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
1027
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
1028
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
1029
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
1030
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
1031
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
1032
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
1033
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
1034
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
1035
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
1036
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
1037
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
1038
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
1039
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
1040
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
1041
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
1042
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
1043
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
1044
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
1045
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
1046
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
1047
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
1048
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
1049
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1050
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1051
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1052
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
1053
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
1054
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
1055
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
1056
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
1057
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1058
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1059
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1060
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
1061
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
1062
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
1063
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
1064
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1065
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1066
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1067
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
1068
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
1069
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
1070
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
1071
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
1072
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
1073
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
1074
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
1075
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
1076
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
1077
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
1078
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
1079
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
1080
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
1081
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
1082
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
1083
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
1084
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
1085
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
1086
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
1087
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
1088
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
1089
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
1090
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
1091
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
1092
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
1093
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1094
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1095
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
1096
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
1097
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
1098
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
1099
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
1100
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
1101
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
1102
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
1103
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
1104
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
1105
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
1106
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
1107
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
1108
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1109
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1110
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1111
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1112
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1113
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1114
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1115
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1116
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1117
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1118
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
1119
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
1120
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
1121
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
1122
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1123
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1124
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1125
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1126
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1127
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
1062
1128
  }
1063
1129
 
1064
1130
  return ctx;
@@ -1251,6 +1317,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1251
1317
  case LM_GGML_OP_GROUP_NORM:
1252
1318
  return has_simdgroup_reduction && lm_ggml_is_contiguous(op->src[0]);
1253
1319
  case LM_GGML_OP_RMS_NORM:
1320
+ case LM_GGML_OP_L2_NORM:
1254
1321
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
1255
1322
  case LM_GGML_OP_ARGMAX:
1256
1323
  return true;
@@ -1282,12 +1349,19 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1282
1349
  case LM_GGML_OP_ARANGE:
1283
1350
  return true;
1284
1351
  case LM_GGML_OP_FLASH_ATTN_EXT:
1352
+ if (op->src[0]->ne[0] == 32) {
1353
+ // head size == 32 (e.g. bert-bge-small)
1354
+ // TODO: not sure if it is worth adding kernels for this size
1355
+ return false;
1356
+ }
1285
1357
  if (op->src[1]->type != op->src[2]->type) {
1286
1358
  return false;
1287
1359
  }
1288
1360
  return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1289
1361
  case LM_GGML_OP_SSM_CONV:
1290
1362
  case LM_GGML_OP_SSM_SCAN:
1363
+ case LM_GGML_OP_RWKV_WKV6:
1364
+ case LM_GGML_OP_RWKV_WKV7:
1291
1365
  return true;
1292
1366
  case LM_GGML_OP_MUL_MAT:
1293
1367
  case LM_GGML_OP_MUL_MAT_ID:
@@ -2216,6 +2290,83 @@ static void lm_ggml_metal_encode_node(
2216
2290
 
2217
2291
  [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2218
2292
  } break;
2293
+ case LM_GGML_OP_RWKV_WKV6:
2294
+ {
2295
+ const int64_t B = dst->src[5]->ne[1];
2296
+ const int64_t T = dst->src[0]->ne[2];
2297
+ const int64_t C = dst->ne[0];
2298
+ const int64_t H = dst->src[0]->ne[1];
2299
+
2300
+ LM_GGML_ASSERT(dst->src[5]->type == LM_GGML_TYPE_F32);
2301
+ LM_GGML_ASSERT(C % H == 0);
2302
+ LM_GGML_ASSERT(C / H == 64);
2303
+
2304
+ size_t offs_src3 = 0;
2305
+ size_t offs_src4 = 0;
2306
+ size_t offs_src5 = 0;
2307
+
2308
+ id<MTLBuffer> id_src3 = dst->src[3] ? lm_ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2309
+ id<MTLBuffer> id_src4 = dst->src[4] ? lm_ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2310
+ id<MTLBuffer> id_src5 = dst->src[5] ? lm_ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2311
+
2312
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
2313
+
2314
+ [encoder setComputePipelineState:pipeline];
2315
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2316
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2317
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2318
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2319
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2320
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2321
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2322
+
2323
+ [encoder setBytes:&B length:sizeof(B) atIndex:7];
2324
+ [encoder setBytes:&T length:sizeof(T) atIndex:8];
2325
+ [encoder setBytes:&C length:sizeof(C) atIndex:9];
2326
+ [encoder setBytes:&H length:sizeof(H) atIndex:10];
2327
+
2328
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2329
+ } break;
2330
+ case LM_GGML_OP_RWKV_WKV7:
2331
+ {
2332
+ const int64_t B = dst->src[6]->ne[1];
2333
+ const int64_t T = dst->src[0]->ne[2];
2334
+ const int64_t C = dst->ne[0];
2335
+ const int64_t H = dst->src[0]->ne[1];
2336
+
2337
+ LM_GGML_ASSERT(dst->src[6]->type == LM_GGML_TYPE_F32);
2338
+ LM_GGML_ASSERT(C % H == 0);
2339
+ LM_GGML_ASSERT(C / H == 64);
2340
+
2341
+ size_t offs_src3 = 0;
2342
+ size_t offs_src4 = 0;
2343
+ size_t offs_src5 = 0;
2344
+ size_t offs_src6 = 0;
2345
+
2346
+ id<MTLBuffer> id_src3 = dst->src[3] ? lm_ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2347
+ id<MTLBuffer> id_src4 = dst->src[4] ? lm_ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2348
+ id<MTLBuffer> id_src5 = dst->src[5] ? lm_ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2349
+ id<MTLBuffer> id_src6 = dst->src[6] ? lm_ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
2350
+
2351
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
2352
+
2353
+ [encoder setComputePipelineState:pipeline];
2354
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2355
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2356
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2357
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2358
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2359
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2360
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2361
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2362
+
2363
+ [encoder setBytes:&B length:sizeof(B) atIndex:8];
2364
+ [encoder setBytes:&T length:sizeof(T) atIndex:9];
2365
+ [encoder setBytes:&C length:sizeof(C) atIndex:10];
2366
+ [encoder setBytes:&H length:sizeof(H) atIndex:11];
2367
+
2368
+ [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2369
+ } break;
2219
2370
  case LM_GGML_OP_MUL_MAT:
2220
2371
  {
2221
2372
  LM_GGML_ASSERT(ne00 == ne10);
@@ -2475,171 +2626,180 @@ static void lm_ggml_metal_encode_node(
2475
2626
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2476
2627
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2477
2628
  } else {
2478
- int nth0 = 32;
2479
- int nth1 = 1;
2480
- int nrows = 1;
2481
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2482
-
2483
2629
  id<MTLComputePipelineState> pipeline = nil;
2484
2630
 
2631
+ int nsg = 0; // number of simdgroups
2632
+ int nr0 = 0; // number of src0 rows per simdgroup
2633
+ int nr1 = 1; // number of src1 rows per threadgroup
2634
+
2635
+ size_t smem = 0; // shared memory
2636
+
2485
2637
  // use custom matrix x vector kernel
2486
2638
  switch (src0t) {
2487
2639
  case LM_GGML_TYPE_F32:
2488
2640
  {
2489
2641
  LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2642
+ nsg = 1;
2643
+ nr0 = 1;
2644
+ nr1 = 4;
2490
2645
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2491
- nrows = 4;
2492
2646
  } break;
2493
2647
  case LM_GGML_TYPE_F16:
2494
2648
  {
2495
- nth0 = 32;
2496
- nth1 = 1;
2649
+ nsg = 1;
2650
+ nr0 = 1;
2497
2651
  if (src1t == LM_GGML_TYPE_F32) {
2498
2652
  if (ne11 * ne12 < 4) {
2499
2653
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2500
2654
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2501
2655
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2502
- nrows = ne11;
2656
+ nr1 = ne11;
2503
2657
  } else {
2504
2658
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2505
- nrows = 4;
2659
+ nr1 = 4;
2506
2660
  }
2507
2661
  } else {
2508
2662
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2509
- nrows = 4;
2663
+ nr1 = 4;
2510
2664
  }
2511
2665
  } break;
2512
2666
  case LM_GGML_TYPE_BF16:
2513
2667
  {
2514
- nth0 = 32;
2515
- nth1 = 1;
2668
+ nsg = 1;
2669
+ nr0 = 1;
2516
2670
  if (src1t == LM_GGML_TYPE_F32) {
2517
2671
  if (ne11 * ne12 < 4) {
2518
2672
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2519
2673
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2520
2674
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2521
- nrows = ne11;
2675
+ nr1 = ne11;
2522
2676
  } else {
2523
2677
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2524
- nrows = 4;
2678
+ nr1 = 4;
2525
2679
  }
2526
2680
  } else {
2527
2681
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2528
- nrows = 4;
2682
+ nr1 = 4;
2529
2683
  }
2530
2684
  } break;
2531
2685
  case LM_GGML_TYPE_Q4_0:
2532
2686
  {
2533
- nth0 = 8;
2534
- nth1 = 8;
2687
+ nsg = N_SG_Q4_0;
2688
+ nr0 = N_R0_Q4_0;
2535
2689
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2536
2690
  } break;
2537
2691
  case LM_GGML_TYPE_Q4_1:
2538
2692
  {
2539
- nth0 = 8;
2540
- nth1 = 8;
2693
+ nsg = N_SG_Q4_1;
2694
+ nr0 = N_R0_Q4_1;
2541
2695
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2542
2696
  } break;
2543
2697
  case LM_GGML_TYPE_Q5_0:
2544
2698
  {
2545
- nth0 = 8;
2546
- nth1 = 8;
2699
+ nsg = N_SG_Q5_0;
2700
+ nr0 = N_R0_Q5_0;
2547
2701
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2548
2702
  } break;
2549
2703
  case LM_GGML_TYPE_Q5_1:
2550
2704
  {
2551
- nth0 = 8;
2552
- nth1 = 8;
2705
+ nsg = N_SG_Q5_1;
2706
+ nr0 = N_R0_Q5_1;
2553
2707
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2554
2708
  } break;
2555
2709
  case LM_GGML_TYPE_Q8_0:
2556
2710
  {
2557
- nth0 = 8;
2558
- nth1 = 8;
2711
+ nsg = N_SG_Q8_0;
2712
+ nr0 = N_R0_Q8_0;
2559
2713
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2560
2714
  } break;
2561
2715
  case LM_GGML_TYPE_Q2_K:
2562
2716
  {
2563
- nth0 = 2;
2564
- nth1 = 32;
2717
+ nsg = N_SG_Q2_K;
2718
+ nr0 = N_R0_Q2_K;
2565
2719
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2566
2720
  } break;
2567
2721
  case LM_GGML_TYPE_Q3_K:
2568
2722
  {
2569
- nth0 = 2;
2570
- nth1 = 32;
2723
+ nsg = N_SG_Q3_K;
2724
+ nr0 = N_R0_Q3_K;
2571
2725
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2572
2726
  } break;
2573
2727
  case LM_GGML_TYPE_Q4_K:
2574
2728
  {
2575
- nth0 = 4; //1;
2576
- nth1 = 8; //32;
2729
+ nsg = N_SG_Q4_K;
2730
+ nr0 = N_R0_Q4_K;
2577
2731
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2578
2732
  } break;
2579
2733
  case LM_GGML_TYPE_Q5_K:
2580
2734
  {
2581
- nth0 = 2;
2582
- nth1 = 32;
2735
+ nsg = N_SG_Q5_K;
2736
+ nr0 = N_R0_Q5_K;
2583
2737
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2584
2738
  } break;
2585
2739
  case LM_GGML_TYPE_Q6_K:
2586
2740
  {
2587
- nth0 = 2;
2588
- nth1 = 32;
2741
+ nsg = N_SG_Q6_K;
2742
+ nr0 = N_R0_Q6_K;
2589
2743
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2590
2744
  } break;
2591
2745
  case LM_GGML_TYPE_IQ2_XXS:
2592
2746
  {
2593
- nth0 = 4;
2594
- nth1 = 16;
2747
+ nsg = N_SG_IQ2_XXS;
2748
+ nr0 = N_R0_IQ2_XXS;
2749
+ smem = 256*8+128;
2595
2750
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2596
2751
  } break;
2597
2752
  case LM_GGML_TYPE_IQ2_XS:
2598
2753
  {
2599
- nth0 = 4;
2600
- nth1 = 16;
2754
+ nsg = N_SG_IQ2_XS;
2755
+ nr0 = N_R0_IQ2_XS;
2756
+ smem = 512*8+128;
2601
2757
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2602
2758
  } break;
2603
2759
  case LM_GGML_TYPE_IQ3_XXS:
2604
2760
  {
2605
- nth0 = 4;
2606
- nth1 = 16;
2761
+ nsg = N_SG_IQ3_XXS;
2762
+ nr0 = N_R0_IQ3_XXS;
2763
+ smem = 256*4+128;
2607
2764
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2608
2765
  } break;
2609
2766
  case LM_GGML_TYPE_IQ3_S:
2610
2767
  {
2611
- nth0 = 4;
2612
- nth1 = 16;
2768
+ nsg = N_SG_IQ3_S;
2769
+ nr0 = N_R0_IQ3_S;
2770
+ smem = 512*4;
2613
2771
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2614
2772
  } break;
2615
2773
  case LM_GGML_TYPE_IQ2_S:
2616
2774
  {
2617
- nth0 = 4;
2618
- nth1 = 16;
2775
+ nsg = N_SG_IQ2_S;
2776
+ nr0 = N_R0_IQ2_S;
2619
2777
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2620
2778
  } break;
2621
2779
  case LM_GGML_TYPE_IQ1_S:
2622
2780
  {
2623
- nth0 = 4;
2624
- nth1 = 16;
2781
+ nsg = N_SG_IQ1_S;
2782
+ nr0 = N_R0_IQ1_S;
2625
2783
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2626
2784
  } break;
2627
2785
  case LM_GGML_TYPE_IQ1_M:
2628
2786
  {
2629
- nth0 = 4;
2630
- nth1 = 16;
2787
+ nsg = N_SG_IQ1_M;
2788
+ nr0 = N_R0_IQ1_M;
2631
2789
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2632
2790
  } break;
2633
2791
  case LM_GGML_TYPE_IQ4_NL:
2634
2792
  {
2635
- nth0 = 4;
2636
- nth1 = 16;
2793
+ nsg = N_SG_IQ4_NL;
2794
+ nr0 = N_R0_IQ4_NL;
2795
+ smem = 32*sizeof(float);
2637
2796
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2638
2797
  } break;
2639
2798
  case LM_GGML_TYPE_IQ4_XS:
2640
2799
  {
2641
- nth0 = 4;
2642
- nth1 = 16;
2800
+ nsg = N_SG_IQ4_XS;
2801
+ nr0 = N_R0_IQ4_XS;
2802
+ smem = 32*sizeof(float);
2643
2803
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2644
2804
  } break;
2645
2805
  default:
@@ -2676,41 +2836,10 @@ static void lm_ggml_metal_encode_node(
2676
2836
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2677
2837
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2678
2838
 
2679
- if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
2680
- src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
2681
- src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
2682
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2683
- }
2684
- else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
2685
- const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2686
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2687
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2688
- }
2689
- else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
2690
- const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2691
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2692
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2693
- }
2694
- else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
2695
- const int mem_size = 32*sizeof(float);
2696
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2697
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2698
- }
2699
- else if (src0t == LM_GGML_TYPE_Q4_K) {
2700
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2701
- }
2702
- else if (src0t == LM_GGML_TYPE_Q3_K) {
2703
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2704
- }
2705
- else if (src0t == LM_GGML_TYPE_Q5_K) {
2706
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2707
- }
2708
- else if (src0t == LM_GGML_TYPE_Q6_K) {
2709
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2710
- } else {
2711
- const int64_t ny = (ne11 + nrows - 1)/nrows;
2712
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2839
+ if (smem > 0) {
2840
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
2713
2841
  }
2842
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2714
2843
  }
2715
2844
  } break;
2716
2845
  case LM_GGML_OP_MUL_MAT_ID:
@@ -2736,20 +2865,19 @@ static void lm_ggml_metal_encode_node(
2736
2865
  // ne21 = n_rows
2737
2866
  const int dst_rows = ne20*ne21;
2738
2867
  const int dst_rows_min = n_as;
2739
- const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
2868
+ const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
2740
2869
 
2741
2870
  // max size of the rowids array in the kernel shared buffer
2742
- LM_GGML_ASSERT(dst_rows <= dst_rows_max);
2871
+ //LM_GGML_ASSERT(dst_rows <= dst_rows_max);
2743
2872
 
2744
2873
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
2745
2874
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
2746
- // !!!
2747
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
2748
- // indirect matrix multiplication
2749
- // !!!
2750
2875
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2751
2876
  ne00 % 32 == 0 && ne00 >= 64 &&
2752
- dst_rows > dst_rows_min) {
2877
+ //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
2878
+ dst_rows > dst_rows_min &&
2879
+ dst_rows <= dst_rows_max) {
2880
+
2753
2881
  // some Metal matrix data types require aligned pointers
2754
2882
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2755
2883
  switch (src0->type) {
@@ -2816,146 +2944,155 @@ static void lm_ggml_metal_encode_node(
2816
2944
 
2817
2945
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2818
2946
  } else {
2819
- int nth0 = 32;
2820
- int nth1 = 1;
2821
- int nrows = 1;
2822
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2823
-
2824
2947
  id<MTLComputePipelineState> pipeline = nil;
2825
2948
 
2949
+ int nsg = 0; // number of simdgroups
2950
+ int nr0 = 0; // number of src0 rows per simdgroup
2951
+ int nr1 = 1; // number of src1 rows per threadgroup
2952
+
2953
+ size_t smem = 0; // shared memory
2954
+
2826
2955
  // use custom matrix x vector kernel
2827
2956
  switch (src0t) {
2828
2957
  case LM_GGML_TYPE_F32:
2829
2958
  {
2830
2959
  LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2960
+ nsg = 1;
2961
+ nr0 = 1;
2831
2962
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
2832
2963
  } break;
2833
2964
  case LM_GGML_TYPE_F16:
2834
2965
  {
2835
2966
  LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2836
- nth0 = 32;
2837
- nth1 = 1;
2967
+ nsg = 1;
2968
+ nr0 = 1;
2838
2969
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2839
2970
  } break;
2840
2971
  case LM_GGML_TYPE_BF16:
2841
2972
  {
2842
2973
  LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2843
- nth0 = 32;
2844
- nth1 = 1;
2974
+ nsg = 1;
2975
+ nr0 = 1;
2845
2976
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2846
2977
  } break;
2847
2978
  case LM_GGML_TYPE_Q4_0:
2848
2979
  {
2849
- nth0 = 8;
2850
- nth1 = 8;
2980
+ nsg = N_SG_Q4_0;
2981
+ nr0 = N_R0_Q4_0;
2851
2982
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
2852
2983
  } break;
2853
2984
  case LM_GGML_TYPE_Q4_1:
2854
2985
  {
2855
- nth0 = 8;
2856
- nth1 = 8;
2986
+ nsg = N_SG_Q4_1;
2987
+ nr0 = N_R0_Q4_1;
2857
2988
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
2858
2989
  } break;
2859
2990
  case LM_GGML_TYPE_Q5_0:
2860
2991
  {
2861
- nth0 = 8;
2862
- nth1 = 8;
2992
+ nsg = N_SG_Q5_0;
2993
+ nr0 = N_R0_Q5_0;
2863
2994
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
2864
2995
  } break;
2865
2996
  case LM_GGML_TYPE_Q5_1:
2866
2997
  {
2867
- nth0 = 8;
2868
- nth1 = 8;
2998
+ nsg = N_SG_Q5_1;
2999
+ nr0 = N_R0_Q5_1;
2869
3000
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
2870
3001
  } break;
2871
3002
  case LM_GGML_TYPE_Q8_0:
2872
3003
  {
2873
- nth0 = 8;
2874
- nth1 = 8;
3004
+ nsg = N_SG_Q8_0;
3005
+ nr0 = N_R0_Q8_0;
2875
3006
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
2876
3007
  } break;
2877
3008
  case LM_GGML_TYPE_Q2_K:
2878
3009
  {
2879
- nth0 = 2;
2880
- nth1 = 32;
3010
+ nsg = N_SG_Q2_K;
3011
+ nr0 = N_R0_Q2_K;
2881
3012
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
2882
3013
  } break;
2883
3014
  case LM_GGML_TYPE_Q3_K:
2884
3015
  {
2885
- nth0 = 2;
2886
- nth1 = 32;
3016
+ nsg = N_SG_Q3_K;
3017
+ nr0 = N_R0_Q3_K;
2887
3018
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2888
3019
  } break;
2889
3020
  case LM_GGML_TYPE_Q4_K:
2890
3021
  {
2891
- nth0 = 4; //1;
2892
- nth1 = 8; //32;
3022
+ nsg = N_SG_Q4_K;
3023
+ nr0 = N_R0_Q4_K;
2893
3024
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2894
3025
  } break;
2895
3026
  case LM_GGML_TYPE_Q5_K:
2896
3027
  {
2897
- nth0 = 2;
2898
- nth1 = 32;
3028
+ nsg = N_SG_Q5_K;
3029
+ nr0 = N_R0_Q5_K;
2899
3030
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2900
3031
  } break;
2901
3032
  case LM_GGML_TYPE_Q6_K:
2902
3033
  {
2903
- nth0 = 2;
2904
- nth1 = 32;
3034
+ nsg = N_SG_Q6_K;
3035
+ nr0 = N_R0_Q6_K;
2905
3036
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2906
3037
  } break;
2907
3038
  case LM_GGML_TYPE_IQ2_XXS:
2908
3039
  {
2909
- nth0 = 4;
2910
- nth1 = 16;
3040
+ nsg = N_SG_IQ2_XXS;
3041
+ nr0 = N_R0_IQ2_XXS;
3042
+ smem = 256*8+128;
2911
3043
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2912
3044
  } break;
2913
3045
  case LM_GGML_TYPE_IQ2_XS:
2914
3046
  {
2915
- nth0 = 4;
2916
- nth1 = 16;
3047
+ nsg = N_SG_IQ2_XS;
3048
+ nr0 = N_R0_IQ2_XS;
3049
+ smem = 512*8+128;
2917
3050
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
2918
3051
  } break;
2919
3052
  case LM_GGML_TYPE_IQ3_XXS:
2920
3053
  {
2921
- nth0 = 4;
2922
- nth1 = 16;
3054
+ nsg = N_SG_IQ3_XXS;
3055
+ nr0 = N_R0_IQ3_XXS;
3056
+ smem = 256*4+128;
2923
3057
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
2924
3058
  } break;
2925
3059
  case LM_GGML_TYPE_IQ3_S:
2926
3060
  {
2927
- nth0 = 4;
2928
- nth1 = 16;
3061
+ nsg = N_SG_IQ3_S;
3062
+ nr0 = N_R0_IQ3_S;
3063
+ smem = 512*4;
2929
3064
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
2930
3065
  } break;
2931
3066
  case LM_GGML_TYPE_IQ2_S:
2932
3067
  {
2933
- nth0 = 4;
2934
- nth1 = 16;
3068
+ nsg = N_SG_IQ2_S;
3069
+ nr0 = N_R0_IQ2_S;
2935
3070
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
2936
3071
  } break;
2937
3072
  case LM_GGML_TYPE_IQ1_S:
2938
3073
  {
2939
- nth0 = 4;
2940
- nth1 = 16;
3074
+ nsg = N_SG_IQ1_S;
3075
+ nr0 = N_R0_IQ1_S;
2941
3076
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
2942
3077
  } break;
2943
3078
  case LM_GGML_TYPE_IQ1_M:
2944
3079
  {
2945
- nth0 = 4;
2946
- nth1 = 16;
3080
+ nsg = N_SG_IQ1_M;
3081
+ nr0 = N_R0_IQ1_M;
2947
3082
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
2948
3083
  } break;
2949
3084
  case LM_GGML_TYPE_IQ4_NL:
2950
3085
  {
2951
- nth0 = 4;
2952
- nth1 = 16;
3086
+ nsg = N_SG_IQ4_NL;
3087
+ nr0 = N_R0_IQ4_NL;
3088
+ smem = 32*sizeof(float);
2953
3089
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2954
3090
  } break;
2955
3091
  case LM_GGML_TYPE_IQ4_XS:
2956
3092
  {
2957
- nth0 = 4;
2958
- nth1 = 16;
3093
+ nsg = N_SG_IQ4_XS;
3094
+ nr0 = N_R0_IQ4_XS;
3095
+ smem = 32*sizeof(float);
2959
3096
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2960
3097
  } break;
2961
3098
  default:
@@ -2966,7 +3103,7 @@ static void lm_ggml_metal_encode_node(
2966
3103
  };
2967
3104
 
2968
3105
  if (lm_ggml_is_quantized(src0t)) {
2969
- LM_GGML_ASSERT(ne00 >= nth0*nth1);
3106
+ LM_GGML_ASSERT(ne00 >= nsg*nr0);
2970
3107
  }
2971
3108
 
2972
3109
  lm_ggml_metal_kargs_mul_mv_id args = {
@@ -2999,43 +3136,12 @@ static void lm_ggml_metal_encode_node(
2999
3136
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3000
3137
 
3001
3138
  const int64_t _ne1 = 1;
3002
- const int tgz = dst_rows;
3139
+ const int64_t ne123 = dst_rows;
3003
3140
 
3004
- if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
3005
- src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
3006
- src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
3007
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3008
- }
3009
- else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
3010
- const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
3011
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3012
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3013
- }
3014
- else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
3015
- const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
3016
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3017
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3018
- }
3019
- else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
3020
- const int mem_size = 32*sizeof(float);
3021
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
3022
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3023
- }
3024
- else if (src0t == LM_GGML_TYPE_Q4_K) {
3025
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3026
- }
3027
- else if (src0t == LM_GGML_TYPE_Q3_K) {
3028
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3029
- }
3030
- else if (src0t == LM_GGML_TYPE_Q5_K) {
3031
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3032
- }
3033
- else if (src0t == LM_GGML_TYPE_Q6_K) {
3034
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3035
- } else {
3036
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
3037
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
3141
+ if (smem > 0) {
3142
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3038
3143
  }
3144
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
3039
3145
  }
3040
3146
  } break;
3041
3147
  case LM_GGML_OP_GET_ROWS:
@@ -3122,6 +3228,42 @@ static void lm_ggml_metal_encode_node(
3122
3228
 
3123
3229
  const int64_t nrows = lm_ggml_nrows(src0);
3124
3230
 
3231
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3232
+ } break;
3233
+ case LM_GGML_OP_L2_NORM:
3234
+ {
3235
+ LM_GGML_ASSERT(ne00 % 4 == 0);
3236
+ LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
3237
+
3238
+ float eps;
3239
+ memcpy(&eps, dst->op_params, sizeof(float));
3240
+
3241
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
3242
+
3243
+ int nth = 32; // SIMD width
3244
+
3245
+ while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3246
+ nth *= 2;
3247
+ }
3248
+
3249
+ nth = MIN(nth, ne00/4);
3250
+
3251
+ lm_ggml_metal_kargs_l2_norm args = {
3252
+ /*.ne00 =*/ ne00,
3253
+ /*.ne00_4 =*/ ne00/4,
3254
+ /*.nb01 =*/ nb01,
3255
+ /*.eps =*/ eps,
3256
+ };
3257
+
3258
+ [encoder setComputePipelineState:pipeline];
3259
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3260
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3261
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3262
+
3263
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3264
+
3265
+ const int64_t nrows = lm_ggml_nrows(src0);
3266
+
3125
3267
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3126
3268
  } break;
3127
3269
  case LM_GGML_OP_GROUP_NORM:
@@ -3654,7 +3796,9 @@ static void lm_ggml_metal_encode_node(
3654
3796
  LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
3655
3797
  LM_GGML_ASSERT(src1->type == src2->type);
3656
3798
 
3657
- LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
3799
+ //LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
3800
+ LM_GGML_ASSERT(ne11 == ne21);
3801
+ LM_GGML_ASSERT(ne12 == ne22);
3658
3802
 
3659
3803
  struct lm_ggml_tensor * src3 = node->src[3];
3660
3804
 
@@ -3701,125 +3845,161 @@ static void lm_ggml_metal_encode_node(
3701
3845
 
3702
3846
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3703
3847
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
3704
- if (ne01 >= 4 || (ne00%128 != 0)) {
3848
+ // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
3849
+ if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
3705
3850
  switch (src1->type) {
3706
3851
  case LM_GGML_TYPE_F16:
3707
3852
  {
3708
- switch (ne00) {
3709
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
3710
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
3711
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
3712
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
3713
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
3714
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
3715
- default:
3716
- {
3717
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3718
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3719
- LM_GGML_ABORT("add template specialization for this size");
3720
- }
3853
+ if (ne00 == 192 && ne20 == 128) {
3854
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
3855
+ } else {
3856
+ switch (ne00) {
3857
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
3858
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
3859
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
3860
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
3861
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
3862
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
3863
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
3864
+ default:
3865
+ {
3866
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3867
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3868
+ LM_GGML_ABORT("add template specialization for this size");
3869
+ }
3870
+ }
3721
3871
  }
3722
3872
  } break;
3723
3873
  case LM_GGML_TYPE_BF16:
3724
3874
  {
3725
- switch (ne00) {
3726
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3727
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3728
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3729
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3730
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3731
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3732
- default:
3733
- {
3734
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3735
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3736
- LM_GGML_ABORT("add template specialization for this size");
3737
- }
3875
+ if (ne00 == 192 && ne20 == 128) {
3876
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
3877
+ } else {
3878
+ switch (ne00) {
3879
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3880
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3881
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3882
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3883
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3884
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
3885
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3886
+ default:
3887
+ {
3888
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3889
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3890
+ LM_GGML_ABORT("add template specialization for this size");
3891
+ }
3892
+ }
3738
3893
  }
3739
3894
  } break;
3740
3895
  case LM_GGML_TYPE_Q4_0:
3741
3896
  {
3742
- switch (ne00) {
3743
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
3744
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
3745
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
3746
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
3747
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
3748
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
3749
- default:
3750
- {
3751
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3752
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3753
- LM_GGML_ABORT("add template specialization for this size");
3754
- }
3897
+ if (ne00 == 192 && ne20 == 128) {
3898
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
3899
+ } else {
3900
+ switch (ne00) {
3901
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
3902
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
3903
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
3904
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
3905
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
3906
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
3907
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
3908
+ default:
3909
+ {
3910
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3911
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3912
+ LM_GGML_ABORT("add template specialization for this size");
3913
+ }
3914
+ }
3755
3915
  }
3756
3916
  } break;
3757
3917
  case LM_GGML_TYPE_Q4_1:
3758
3918
  {
3759
- switch (ne00) {
3760
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
3761
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
3762
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
3763
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
3764
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
3765
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
3766
- default:
3767
- {
3768
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3769
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3770
- LM_GGML_ABORT("add template specialization for this size");
3771
- }
3919
+ if (ne00 == 192 && ne20 == 128) {
3920
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
3921
+ } else {
3922
+ switch (ne00) {
3923
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
3924
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
3925
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
3926
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
3927
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
3928
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
3929
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
3930
+ default:
3931
+ {
3932
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3933
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3934
+ LM_GGML_ABORT("add template specialization for this size");
3935
+ }
3936
+ }
3772
3937
  }
3773
3938
  } break;
3774
3939
  case LM_GGML_TYPE_Q5_0:
3775
3940
  {
3776
- switch (ne00) {
3777
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
3778
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
3779
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
3780
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
3781
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
3782
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
3783
- default:
3784
- {
3785
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3786
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3787
- LM_GGML_ABORT("add template specialization for this size");
3788
- }
3941
+ if (ne00 == 192 && ne20 == 128) {
3942
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
3943
+ } else {
3944
+ switch (ne00) {
3945
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
3946
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
3947
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
3948
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
3949
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
3950
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
3951
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
3952
+ default:
3953
+ {
3954
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3955
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3956
+ LM_GGML_ABORT("add template specialization for this size");
3957
+ }
3958
+ }
3789
3959
  }
3790
3960
  } break;
3791
3961
  case LM_GGML_TYPE_Q5_1:
3792
3962
  {
3793
- switch (ne00) {
3794
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
3795
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
3796
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
3797
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
3798
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
3799
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
3800
- default:
3801
- {
3802
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3803
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3804
- LM_GGML_ABORT("add template specialization for this size");
3805
- }
3963
+ if (ne00 == 192 && ne20 == 128) {
3964
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
3965
+ } else {
3966
+ switch (ne00) {
3967
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
3968
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
3969
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
3970
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
3971
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
3972
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
3973
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
3974
+ default:
3975
+ {
3976
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3977
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3978
+ LM_GGML_ABORT("add template specialization for this size");
3979
+ }
3980
+ }
3806
3981
  }
3807
3982
  } break;
3808
3983
  case LM_GGML_TYPE_Q8_0:
3809
3984
  {
3810
- switch (ne00) {
3811
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
3812
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
3813
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
3814
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
3815
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
3816
- case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
3817
- default:
3818
- {
3819
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3820
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
3821
- LM_GGML_ABORT("add template specialization for this size");
3822
- }
3985
+ if (ne00 == 192 && ne20 == 128) {
3986
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
3987
+ } else {
3988
+ switch (ne00) {
3989
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
3990
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
3991
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
3992
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
3993
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
3994
+ case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
3995
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
3996
+ default:
3997
+ {
3998
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3999
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4000
+ LM_GGML_ABORT("add template specialization for this size");
4001
+ }
4002
+ }
3823
4003
  }
3824
4004
  } break;
3825
4005
  default:
@@ -3851,6 +4031,42 @@ static void lm_ggml_metal_encode_node(
3851
4031
  }
3852
4032
  }
3853
4033
  } break;
4034
+ case 192:
4035
+ {
4036
+ if (ne20 == 128) {
4037
+ switch (src1->type) {
4038
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
4039
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
4040
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
4041
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
4042
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
4043
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
4044
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
4045
+ default:
4046
+ {
4047
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4048
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4049
+ LM_GGML_ABORT("add template specialization for this type");
4050
+ }
4051
+ }
4052
+ } else {
4053
+ switch (src1->type) {
4054
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
4055
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
4056
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
4057
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
4058
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
4059
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
4060
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
4061
+ default:
4062
+ {
4063
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4064
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4065
+ LM_GGML_ABORT("add template specialization for this type");
4066
+ }
4067
+ }
4068
+ }
4069
+ } break;
3854
4070
  case 256:
3855
4071
  {
3856
4072
  switch (src1->type) {
@@ -3888,9 +4104,12 @@ static void lm_ggml_metal_encode_node(
3888
4104
  /*.ne11 =*/ ne11,
3889
4105
  /*.ne_12_2 =*/ ne12,
3890
4106
  /*.ne_12_3 =*/ ne13,
3891
- /*.nb_12_1 =*/ nb11,
3892
- /*.nb_12_2 =*/ nb12,
3893
- /*.nb_12_3 =*/ nb13,
4107
+ /*.nb11 =*/ nb11,
4108
+ /*.nb12 =*/ nb12,
4109
+ /*.nb13 =*/ nb13,
4110
+ /*.nb21 =*/ nb21,
4111
+ /*.nb22 =*/ nb22,
4112
+ /*.nb23 =*/ nb23,
3894
4113
  /*.nb31 =*/ nb31,
3895
4114
  /*.ne1 =*/ ne1,
3896
4115
  /*.ne2 =*/ ne2,
@@ -3969,10 +4188,9 @@ static void lm_ggml_metal_encode_node(
3969
4188
  // ne00*(nsg)
3970
4189
  // each simdgroup has a full f16 head vector in shared mem to accumulate results
3971
4190
  //
3972
- #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
4191
+ #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(LM_GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
3973
4192
 
3974
4193
  int64_t nsgmax = 2;
3975
-
3976
4194
  while (true) {
3977
4195
  const size_t smem = FATTN_SMEM(nsgmax);
3978
4196
  if (smem > device.maxThreadgroupMemoryLength) {