cui-llama.rn 1.6.1 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
  3. package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +38 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
package/cpp/ggml-metal.m CHANGED
@@ -149,6 +149,8 @@ enum lm_ggml_metal_kernel_type {
149
149
  LM_GGML_METAL_KERNEL_TYPE_SIGMOID,
150
150
  LM_GGML_METAL_KERNEL_TYPE_GELU,
151
151
  LM_GGML_METAL_KERNEL_TYPE_GELU_4,
152
+ LM_GGML_METAL_KERNEL_TYPE_GELU_ERF,
153
+ LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
152
154
  LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
153
155
  LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
154
156
  LM_GGML_METAL_KERNEL_TYPE_SILU,
@@ -306,30 +308,36 @@ enum lm_ggml_metal_kernel_type {
306
308
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
307
309
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
308
310
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
309
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
310
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
311
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
312
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
313
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
314
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
315
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
316
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
317
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
318
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
319
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
320
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
321
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
322
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
323
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
324
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
325
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
326
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
327
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
328
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
329
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
330
- LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
311
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
312
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
313
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
314
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
315
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
316
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
317
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
318
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
319
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
320
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
321
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
322
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
323
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
324
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
325
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
326
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
327
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
328
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
329
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
330
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
331
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
332
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
333
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
334
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
331
335
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
332
336
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
337
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
338
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
339
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
340
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
333
341
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
334
342
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
335
343
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -409,6 +417,13 @@ enum lm_ggml_metal_kernel_type {
409
417
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
410
418
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
411
419
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
420
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
421
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
422
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
423
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
424
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
425
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
426
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
412
427
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
413
428
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
414
429
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@@ -650,7 +665,8 @@ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_poo
650
665
  }
651
666
 
652
667
  if (mem_pool->heaps_to_remove.count > 0) {
653
- for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
668
+ // remove in reverse order
669
+ for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
654
670
  NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
655
671
  lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
656
672
 
@@ -659,6 +675,10 @@ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_poo
659
675
 
660
676
  [mem_pool->heaps removeObjectAtIndex:index];
661
677
  [ptr release];
678
+
679
+ if (i == 0) {
680
+ break;
681
+ }
662
682
  }
663
683
 
664
684
  [mem_pool->heaps_to_remove removeAllObjects];
@@ -672,7 +692,7 @@ static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_poo
672
692
  }
673
693
 
674
694
  static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
675
- const size_t alignment = 32;
695
+ const size_t alignment = 256;
676
696
 
677
697
  const size_t size_aligned = LM_GGML_PAD(size, alignment);
678
698
 
@@ -834,11 +854,7 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
834
854
  NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
835
855
  #endif
836
856
 
837
- #if TARGET_OS_SIMULATOR
838
- NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
839
- #else
840
- NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
841
- #endif
857
+ NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
842
858
  if (path_lib == nil) {
843
859
  // Try to find the resource in the directory where the current binary located.
844
860
  NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
@@ -1089,6 +1105,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1089
1105
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
1090
1106
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
1091
1107
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
1108
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
1109
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
1092
1110
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
1093
1111
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
1094
1112
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -1246,30 +1264,36 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1246
1264
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
1247
1265
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
1248
1266
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1249
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
1250
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
1251
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
1252
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
1253
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
1254
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
1255
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
1256
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
1257
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
1258
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
1259
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
1260
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
1261
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
1262
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
1263
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
1264
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
1265
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
1266
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
1267
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
1268
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
1269
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
1270
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
1267
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1268
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1269
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
1270
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
1271
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
1272
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
1273
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
1274
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1275
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1276
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
1277
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1278
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1279
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
1280
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
1281
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
1282
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
1283
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
1284
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
1285
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
1286
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
1287
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
1288
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
1289
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
1290
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
1271
1291
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1272
1292
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1293
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
1294
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
1295
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
1296
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
1273
1297
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
1274
1298
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
1275
1299
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@@ -1349,6 +1373,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1349
1373
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1350
1374
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1351
1375
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1376
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
1377
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
1378
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
1379
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
1380
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
1381
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
1382
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
1352
1383
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
1353
1384
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
1354
1385
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
@@ -1586,6 +1617,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1586
1617
  case LM_GGML_UNARY_OP_RELU:
1587
1618
  case LM_GGML_UNARY_OP_SIGMOID:
1588
1619
  case LM_GGML_UNARY_OP_GELU:
1620
+ case LM_GGML_UNARY_OP_GELU_ERF:
1589
1621
  case LM_GGML_UNARY_OP_GELU_QUICK:
1590
1622
  case LM_GGML_UNARY_OP_SILU:
1591
1623
  case LM_GGML_UNARY_OP_ELU:
@@ -1632,16 +1664,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1632
1664
  case LM_GGML_OP_NORM:
1633
1665
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
1634
1666
  case LM_GGML_OP_ROPE:
1635
- {
1636
- const int mode = ((const int32_t *) op->op_params)[2];
1637
- if (mode & LM_GGML_ROPE_TYPE_MROPE) {
1638
- return false;
1639
- }
1640
- if (mode & LM_GGML_ROPE_TYPE_VISION) {
1641
- return false;
1642
- }
1643
- return true;
1644
- }
1667
+ return true;
1645
1668
  case LM_GGML_OP_IM2COL:
1646
1669
  return op->src[0]->type == LM_GGML_TYPE_F16;
1647
1670
  case LM_GGML_OP_POOL_1D:
@@ -2233,6 +2256,25 @@ static bool lm_ggml_metal_encode_node(
2233
2256
 
2234
2257
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2235
2258
  } break;
2259
+ case LM_GGML_UNARY_OP_GELU_ERF:
2260
+ {
2261
+ int64_t n = lm_ggml_nelements(dst);
2262
+
2263
+ id<MTLComputePipelineState> pipeline = nil;
2264
+
2265
+ if (n % 4 == 0) {
2266
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
2267
+ n /= 4;
2268
+ } else {
2269
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
2270
+ }
2271
+
2272
+ [encoder setComputePipelineState:pipeline];
2273
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2274
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2275
+
2276
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2277
+ } break;
2236
2278
  case LM_GGML_UNARY_OP_GELU_QUICK:
2237
2279
  {
2238
2280
  int64_t n = lm_ggml_nelements(dst);
@@ -3003,7 +3045,7 @@ static bool lm_ggml_metal_encode_node(
3003
3045
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3004
3046
 
3005
3047
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
3006
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3048
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3007
3049
  } else {
3008
3050
  id<MTLComputePipelineState> pipeline = nil;
3009
3051
 
@@ -3223,8 +3265,6 @@ static bool lm_ggml_metal_encode_node(
3223
3265
  } break;
3224
3266
  case LM_GGML_OP_MUL_MAT_ID:
3225
3267
  {
3226
- const int n_as = src0->ne[2];
3227
-
3228
3268
  // src2 = ids
3229
3269
  const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t);
3230
3270
 
@@ -3238,24 +3278,21 @@ static bool lm_ggml_metal_encode_node(
3238
3278
  LM_GGML_ASSERT(ne03 == 1);
3239
3279
  LM_GGML_ASSERT(ne13 == 1);
3240
3280
 
3281
+ const uint32_t r2 = 1;
3282
+ const uint32_t r3 = 1;
3283
+
3241
3284
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
3242
3285
  // to the matrix-vector kernel
3243
3286
  // ne20 = n_used_experts
3244
- // ne21 = n_rows
3245
- const int dst_rows = ne20*ne21;
3246
- const int dst_rows_min = n_as;
3247
- const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
3248
-
3249
- // max size of the rowids array in the kernel shared buffer
3250
- //LM_GGML_ASSERT(dst_rows <= dst_rows_max);
3287
+ // ne21 = n_rows (batch size)
3288
+ const int ne21_mm_id_min = 32;
3251
3289
 
3252
3290
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
3253
3291
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
3254
3292
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
3255
3293
  ne00 % 32 == 0 && ne00 >= 64 &&
3256
- //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
3257
- dst_rows > dst_rows_min &&
3258
- dst_rows <= dst_rows_max) {
3294
+ (ne21 >= ne21_mm_id_min)) {
3295
+ LM_GGML_ASSERT(ne00 % 4 == 0);
3259
3296
 
3260
3297
  // some Metal matrix data types require aligned pointers
3261
3298
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -3266,62 +3303,169 @@ static bool lm_ggml_metal_encode_node(
3266
3303
  default: break;
3267
3304
  }
3268
3305
 
3269
- id<MTLComputePipelineState> pipeline = nil;
3306
+ const int64_t neh10 = ne10; // n_embd
3307
+ const int64_t neh11 = ne21; // n_tokens
3308
+ const int64_t neh12 = ne02; // n_expert
3270
3309
 
3271
- switch (src0->type) {
3272
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
3273
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
3274
- case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
3275
- case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
3276
- case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
3277
- case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
3278
- case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
3279
- case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
3280
- case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
3281
- case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
3282
- case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
3283
- case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
3284
- case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
3285
- case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
3286
- case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
3287
- case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
3288
- case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
3289
- case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
3290
- case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
3291
- case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
3292
- case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
3293
- case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
3294
- default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
3310
+ const uint64_t nbh10 = lm_ggml_type_size(LM_GGML_TYPE_F16);
3311
+ const uint64_t nbh11 = nbh10*neh10;
3312
+ const uint64_t nbh12 = nbh11*neh11;
3313
+ const uint64_t nbh13 = nbh12*neh12;
3314
+
3315
+ const size_t s_src1 = lm_ggml_type_size(LM_GGML_TYPE_F16)*neh10*neh11*neh12;
3316
+ id<MTLBuffer> h_src1 = lm_ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3317
+ if (!h_src1) {
3318
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3319
+ return false;
3295
3320
  }
3296
3321
 
3297
- lm_ggml_metal_kargs_mul_mm_id args = {
3298
- /*.nei0 =*/ ne20,
3299
- /*.nei1 =*/ ne21,
3300
- /*.nbi1 =*/ nb21,
3301
- /*.ne00 =*/ ne00,
3302
- /*.ne02 =*/ ne02,
3303
- /*.nb01 =*/ nb01,
3304
- /*.nb02 =*/ nb02,
3305
- /*.ne11 =*/ ne11,
3306
- /*.ne12 =*/ ne12,
3307
- /*.ne13 =*/ ne13,
3308
- /*.nb10 =*/ nb10,
3309
- /*.nb11 =*/ nb11,
3310
- /*.nb12 =*/ nb12,
3311
- /*.ne0 =*/ ne0,
3312
- /*.ne1 =*/ ne1,
3313
- };
3322
+ const int64_t neh0 = ne0;
3323
+ const int64_t neh1 = ne21;
3324
+ const int64_t neh2 = ne02;
3314
3325
 
3315
- [encoder setComputePipelineState:pipeline];
3316
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
3317
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3318
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3319
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3320
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3326
+ const uint64_t nbh0 = lm_ggml_type_size(LM_GGML_TYPE_F32);
3327
+ const uint64_t nbh1 = nbh0*neh0;
3328
+ const uint64_t nbh2 = nbh1*neh1;
3329
+ //const uint64_t nbh3 = nbh2*neh2;
3330
+
3331
+ const size_t s_dst = lm_ggml_type_size(LM_GGML_TYPE_F32)*neh0*neh1*neh2;
3332
+ id<MTLBuffer> h_dst = lm_ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3333
+ if (!h_dst) {
3334
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3335
+ return false;
3336
+ }
3337
+
3338
+ // tokens per expert
3339
+ const size_t s_tpe = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne02;
3340
+ id<MTLBuffer> h_tpe = lm_ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3341
+ if (!h_tpe) {
3342
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3343
+ return false;
3344
+ }
3345
+
3346
+ // id map
3347
+ // [n_expert_used, n_tokens]
3348
+ const size_t s_ids = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne20*ne21;
3349
+ id<MTLBuffer> h_ids = lm_ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3350
+ if (!h_ids) {
3351
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3352
+ return false;
3353
+ }
3321
3354
 
3322
- [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
3355
+ {
3356
+ const int nth = MIN(1024, ne10/4);
3357
+
3358
+ lm_ggml_metal_kargs_mul_mm_id_map0 args = {
3359
+ ne10,
3360
+ ne11, // n_expert_used (bcast)
3361
+ nb11,
3362
+ nb12,
3363
+ neh11, // n_tokens
3364
+ nbh11,
3365
+ ne20, // n_expert_used
3366
+ nb21,
3367
+ };
3368
+
3369
+ id<MTLComputePipelineState> pipeline = nil;
3370
+
3371
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3372
+
3373
+ [encoder setComputePipelineState:pipeline];
3374
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3375
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3376
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3377
+ [encoder setBuffer: h_src1 offset:0 atIndex:3];
3378
+ [encoder setBuffer: h_tpe offset:0 atIndex:4];
3379
+ [encoder setBuffer: h_ids offset:0 atIndex:5];
3380
+
3381
+ [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3382
+ }
3383
+
3384
+ {
3385
+ id<MTLComputePipelineState> pipeline = nil;
3386
+
3387
+ switch (src0->type) {
3388
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
3389
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
3390
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
3391
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
3392
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
3393
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3394
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3395
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
3396
+ case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3397
+ case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3398
+ case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
3399
+ case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
3400
+ case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
3401
+ case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
3402
+ case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
3403
+ case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
3404
+ case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
3405
+ case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
3406
+ case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
3407
+ case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
3408
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
3409
+ case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
3410
+ default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
3411
+ }
3412
+
3413
+ lm_ggml_metal_kargs_mul_mm_id args = {
3414
+ /*.ne00 =*/ ne00,
3415
+ /*.ne02 =*/ ne02,
3416
+ /*.nb01 =*/ nb01,
3417
+ /*.nb02 =*/ nb02,
3418
+ /*.nb03 =*/ nb03,
3419
+ /*.neh12 =*/ neh12,
3420
+ /*.nbh10 =*/ nbh10,
3421
+ /*.nbh11 =*/ nbh11,
3422
+ /*.nbh12 =*/ nbh12,
3423
+ /*.nbh13 =*/ nbh13,
3424
+ /*.neh0 =*/ neh0,
3425
+ /*.neh1 =*/ neh1,
3426
+ /*.r2 =*/ r2,
3427
+ /*.r3 =*/ r3,
3428
+ };
3429
+
3430
+ [encoder setComputePipelineState:pipeline];
3431
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3432
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3433
+ [encoder setBuffer: h_src1 offset:0 atIndex:2];
3434
+ [encoder setBuffer: h_tpe offset:0 atIndex:3];
3435
+ [encoder setBuffer: h_dst offset:0 atIndex:4];
3436
+
3437
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
3438
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3439
+ }
3440
+
3441
+ {
3442
+ LM_GGML_ASSERT(ne0 % 4 == 0);
3443
+
3444
+ const int nth = MIN(1024, ne0/4);
3323
3445
 
3324
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
3446
+ lm_ggml_metal_kargs_mul_mm_id_map1 args = {
3447
+ ne20, // n_expert_used
3448
+ neh0,
3449
+ neh1,
3450
+ nbh1,
3451
+ nbh2,
3452
+ ne0,
3453
+ nb1,
3454
+ nb2,
3455
+ };
3456
+
3457
+ id<MTLComputePipelineState> pipeline = nil;
3458
+
3459
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
3460
+
3461
+ [encoder setComputePipelineState:pipeline];
3462
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3463
+ [encoder setBuffer: h_dst offset:0 atIndex:1];
3464
+ [encoder setBuffer: h_ids offset:0 atIndex:2];
3465
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3466
+
3467
+ [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3468
+ }
3325
3469
  } else {
3326
3470
  id<MTLComputePipelineState> pipeline = nil;
3327
3471
 
@@ -3515,7 +3659,7 @@ static bool lm_ggml_metal_encode_node(
3515
3659
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
3516
3660
 
3517
3661
  const int64_t _ne1 = 1;
3518
- const int64_t ne123 = dst_rows;
3662
+ const int64_t ne123 = ne20*ne21;
3519
3663
 
3520
3664
  if (smem > 0) {
3521
3665
  [encoder setThreadgroupMemoryLength:smem atIndex:0];
@@ -3719,6 +3863,7 @@ static bool lm_ggml_metal_encode_node(
3719
3863
  } break;
3720
3864
  case LM_GGML_OP_ROPE:
3721
3865
  {
3866
+
3722
3867
  // make sure we have one or more position id(ne10) per token(ne02)
3723
3868
  LM_GGML_ASSERT(ne10 % ne02 == 0);
3724
3869
  LM_GGML_ASSERT(ne10 >= ne02);
@@ -3745,20 +3890,42 @@ static bool lm_ggml_metal_encode_node(
3745
3890
  memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
3746
3891
  memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
3747
3892
 
3748
- const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX;
3893
+ const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX;
3894
+ const bool is_mrope = mode & LM_GGML_ROPE_TYPE_MROPE;
3895
+ const bool is_vision = mode == LM_GGML_ROPE_TYPE_VISION;
3896
+
3897
+ // mrope
3898
+ const int sect_0 = ((const int32_t *) dst->op_params)[11];
3899
+ const int sect_1 = ((const int32_t *) dst->op_params)[12];
3900
+ const int sect_2 = ((const int32_t *) dst->op_params)[13];
3901
+ const int sect_3 = ((const int32_t *) dst->op_params)[14];
3749
3902
 
3750
3903
  id<MTLComputePipelineState> pipeline = nil;
3751
3904
 
3752
- if (!is_neox) {
3905
+ if (is_neox) {
3753
3906
  switch (src0->type) {
3754
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3755
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3907
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3908
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3909
+ default: LM_GGML_ABORT("fatal error");
3910
+ };
3911
+ } else if (is_mrope && !is_vision) {
3912
+ LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3913
+ switch (src0->type) {
3914
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
3915
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
3916
+ default: LM_GGML_ABORT("fatal error");
3917
+ };
3918
+ } else if (is_vision) {
3919
+ LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3920
+ switch (src0->type) {
3921
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
3922
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
3756
3923
  default: LM_GGML_ABORT("fatal error");
3757
3924
  };
3758
3925
  } else {
3759
3926
  switch (src0->type) {
3760
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3761
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3927
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3928
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3762
3929
  default: LM_GGML_ABORT("fatal error");
3763
3930
  };
3764
3931
  }
@@ -3789,6 +3956,10 @@ static bool lm_ggml_metal_encode_node(
3789
3956
  /*.attn_factor =*/ attn_factor,
3790
3957
  /*.beta_fast =*/ beta_fast,
3791
3958
  /*.beta_slow =*/ beta_slow,
3959
+ /* sect_0 =*/ sect_0,
3960
+ /* sect_1 =*/ sect_1,
3961
+ /* sect_2 =*/ sect_2,
3962
+ /* sect_3 =*/ sect_3,
3792
3963
  };
3793
3964
 
3794
3965
  [encoder setComputePipelineState:pipeline];
@@ -4225,7 +4396,7 @@ static bool lm_ggml_metal_encode_node(
4225
4396
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
4226
4397
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
4227
4398
  // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
4228
- if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
4399
+ if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
4229
4400
  switch (src1->type) {
4230
4401
  case LM_GGML_TYPE_F16:
4231
4402
  {
@@ -4406,6 +4577,24 @@ static bool lm_ggml_metal_encode_node(
4406
4577
  use_vec_kernel = true;
4407
4578
 
4408
4579
  switch (ne00) {
4580
+ case 64:
4581
+ {
4582
+ switch (src1->type) {
4583
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
4584
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
4585
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
4586
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
4587
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
4588
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
4589
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
4590
+ default:
4591
+ {
4592
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4593
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4594
+ LM_GGML_ABORT("add template specialization for this type");
4595
+ }
4596
+ }
4597
+ } break;
4409
4598
  case 96:
4410
4599
  {
4411
4600
  switch (src1->type) {