@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
55
55
  bool has_residency_sets;
56
56
  bool has_bfloat;
57
57
  bool use_bfloat;
58
+ bool use_fusion;
59
+
60
+ int debug_fusion;
61
+
62
+ // how many times a given op was fused
63
+ uint64_t fuse_cnt[GGML_OP_COUNT];
58
64
 
59
65
  size_t max_size;
60
66
 
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
69
75
  /*.has_residency_sets =*/ false,
70
76
  /*.has_bfloat =*/ false,
71
77
  /*.use_bfloat =*/ false,
78
+ /*.use_fusion =*/ true,
79
+ /*.debug_fusion =*/ 0,
80
+ /*.fuse_cnt =*/ { 0 },
72
81
  /*.max_size =*/ 0,
73
82
  /*.name =*/ "",
74
83
  };
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
83
92
 
84
93
  if (ctx->mtl_device == nil) {
85
94
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
86
- }
87
95
 
88
- if (ctx->mtl_device) {
89
96
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
90
97
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
91
98
 
92
99
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
93
100
 
94
101
  #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95
- ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
102
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
96
103
  #endif
97
104
 
98
105
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
103
110
  #else
104
111
  ctx->use_bfloat = false;
105
112
  #endif
113
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
114
+
115
+ {
116
+ const char * val = getenv("GGML_METAL_FUSION_DEBUG");
117
+ ctx->debug_fusion = val ? atoi(val) : 0;
118
+ }
119
+
120
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
106
121
 
107
122
  ctx->max_size = ctx->mtl_device.maxBufferLength;
108
123
 
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122
137
  ctx->mtl_device_ref_count--;
123
138
 
124
139
  if (ctx->mtl_device_ref_count == 0) {
140
+ if (ctx->debug_fusion > 0) {
141
+ fprintf(stderr, "%s: fusion stats:\n", __func__);
142
+ for (int i = 0; i < GGML_OP_COUNT; i++) {
143
+ if (ctx->fuse_cnt[i] == 0) {
144
+ continue;
145
+ }
146
+
147
+ // note: cannot use ggml_log here
148
+ fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
149
+ }
150
+ }
151
+
125
152
  if (ctx->mtl_lock) {
126
153
  [ctx->mtl_lock release];
127
154
  ctx->mtl_lock = nil;
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
147
174
 
148
175
  enum ggml_metal_kernel_type {
149
176
  GGML_METAL_KERNEL_TYPE_ADD,
150
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
177
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
178
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
179
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
180
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
181
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
182
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
183
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
184
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
185
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
186
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
187
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
188
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
189
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
190
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
191
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
151
192
  GGML_METAL_KERNEL_TYPE_SUB,
152
- GGML_METAL_KERNEL_TYPE_SUB_ROW,
193
+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
153
194
  GGML_METAL_KERNEL_TYPE_MUL,
154
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
195
+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
155
196
  GGML_METAL_KERNEL_TYPE_DIV,
156
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
197
+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
157
198
  GGML_METAL_KERNEL_TYPE_REPEAT_F32,
158
199
  GGML_METAL_KERNEL_TYPE_REPEAT_F16,
159
200
  GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -173,6 +214,12 @@ enum ggml_metal_kernel_type {
173
214
  GGML_METAL_KERNEL_TYPE_SILU,
174
215
  GGML_METAL_KERNEL_TYPE_SILU_4,
175
216
  GGML_METAL_KERNEL_TYPE_ELU,
217
+ GGML_METAL_KERNEL_TYPE_ABS,
218
+ GGML_METAL_KERNEL_TYPE_SGN,
219
+ GGML_METAL_KERNEL_TYPE_STEP,
220
+ GGML_METAL_KERNEL_TYPE_HARDSWISH,
221
+ GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
222
+ GGML_METAL_KERNEL_TYPE_EXP,
176
223
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177
224
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178
225
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -212,11 +259,14 @@ enum ggml_metal_kernel_type {
212
259
  GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
260
  GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
214
261
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
262
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
263
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
215
264
  GGML_METAL_KERNEL_TYPE_L2_NORM,
216
265
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
217
266
  GGML_METAL_KERNEL_TYPE_NORM,
218
267
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
219
268
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
269
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
220
270
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
221
271
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
222
272
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -526,6 +576,11 @@ enum ggml_metal_kernel_type {
526
576
  GGML_METAL_KERNEL_TYPE_SIN,
527
577
  GGML_METAL_KERNEL_TYPE_COS,
528
578
  GGML_METAL_KERNEL_TYPE_NEG,
579
+ GGML_METAL_KERNEL_TYPE_REGLU,
580
+ GGML_METAL_KERNEL_TYPE_GEGLU,
581
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
582
+ GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
583
+ GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
529
584
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
530
585
  GGML_METAL_KERNEL_TYPE_MEAN,
531
586
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1123,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1123
1178
  // simd_sum and simd_max requires MTLGPUFamilyApple7
1124
1179
 
1125
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1126
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
1190
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
1191
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
1192
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
1193
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
1194
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
1195
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
1127
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1128
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
1197
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
1129
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1130
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
1199
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1131
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1132
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
1201
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
1133
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1134
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1135
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1149,6 +1218,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1149
1218
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
1150
1219
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1151
1220
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1221
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
1222
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
1223
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
1224
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
1225
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
1226
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
1152
1227
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1153
1228
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1154
1229
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1188,11 +1263,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1188
1263
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1189
1264
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1190
1265
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1266
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1267
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1191
1268
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1192
1269
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1193
1270
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1194
1271
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1195
1272
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1273
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1196
1274
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1197
1275
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1198
1276
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1502,6 +1580,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1502
1580
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1503
1581
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1504
1582
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1583
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1584
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1585
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1586
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1587
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1505
1588
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1506
1589
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1507
1590
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1676,10 +1759,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1676
1759
  case GGML_UNARY_OP_SILU:
1677
1760
  case GGML_UNARY_OP_ELU:
1678
1761
  case GGML_UNARY_OP_NEG:
1762
+ case GGML_UNARY_OP_ABS:
1763
+ case GGML_UNARY_OP_SGN:
1764
+ case GGML_UNARY_OP_STEP:
1765
+ case GGML_UNARY_OP_HARDSWISH:
1766
+ case GGML_UNARY_OP_HARDSIGMOID:
1767
+ case GGML_UNARY_OP_EXP:
1679
1768
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1680
1769
  default:
1681
1770
  return false;
1682
1771
  }
1772
+ case GGML_OP_GLU:
1773
+ switch (ggml_get_glu_op(op)) {
1774
+ case GGML_GLU_OP_REGLU:
1775
+ case GGML_GLU_OP_GEGLU:
1776
+ case GGML_GLU_OP_SWIGLU:
1777
+ case GGML_GLU_OP_GEGLU_ERF:
1778
+ case GGML_GLU_OP_GEGLU_QUICK:
1779
+ return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1780
+ default:
1781
+ return false;
1782
+ }
1683
1783
  case GGML_OP_NONE:
1684
1784
  case GGML_OP_RESHAPE:
1685
1785
  case GGML_OP_VIEW:
@@ -1710,7 +1810,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1710
1810
  case GGML_OP_MEAN:
1711
1811
  case GGML_OP_SOFT_MAX:
1712
1812
  case GGML_OP_GROUP_NORM:
1713
- return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1813
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
1714
1814
  case GGML_OP_RMS_NORM:
1715
1815
  case GGML_OP_L2_NORM:
1716
1816
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -1852,9 +1952,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1852
1952
  }
1853
1953
  }
1854
1954
 
1855
- static bool ggml_metal_encode_node(
1955
+ static int ggml_metal_encode_node(
1856
1956
  ggml_backend_t backend,
1857
1957
  int idx,
1958
+ int idx_end,
1858
1959
  id<MTLComputeCommandEncoder> encoder,
1859
1960
  struct ggml_metal_mem_pool * mem_pool) {
1860
1961
  struct ggml_backend_metal_context * ctx = backend->context;
@@ -1862,7 +1963,10 @@ static bool ggml_metal_encode_node(
1862
1963
 
1863
1964
  struct ggml_cgraph * gf = ctx->gf;
1864
1965
 
1865
- struct ggml_tensor * node = ggml_graph_node(gf, idx);
1966
+ enum ggml_op ops[8];
1967
+
1968
+ struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
1969
+ struct ggml_tensor * node = nodes[0];
1866
1970
 
1867
1971
  //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1868
1972
 
@@ -1872,7 +1976,7 @@ static bool ggml_metal_encode_node(
1872
1976
  struct ggml_tensor * dst = node;
1873
1977
 
1874
1978
  if (ggml_is_empty(dst)) {
1875
- return true;
1979
+ return 1;
1876
1980
  }
1877
1981
 
1878
1982
  switch (dst->op) {
@@ -1883,7 +1987,7 @@ static bool ggml_metal_encode_node(
1883
1987
  case GGML_OP_PERMUTE:
1884
1988
  {
1885
1989
  // noop -> next node
1886
- } return true;
1990
+ } return 1;
1887
1991
  default:
1888
1992
  {
1889
1993
  } break;
@@ -1950,6 +2054,8 @@ static bool ggml_metal_encode_node(
1950
2054
  id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
1951
2055
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
1952
2056
 
2057
+ int n_fuse = 1;
2058
+
1953
2059
  #if 0
1954
2060
  GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1955
2061
  if (src0) {
@@ -2021,37 +2127,15 @@ static bool ggml_metal_encode_node(
2021
2127
  GGML_ASSERT(src0t == GGML_TYPE_F32);
2022
2128
  GGML_ASSERT(src1t == GGML_TYPE_F32);
2023
2129
 
2130
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
2131
+ GGML_ASSERT(ggml_is_contiguous_rows(src1));
2132
+
2024
2133
  const size_t offs = 0;
2025
2134
 
2026
2135
  bool bcast_row = false;
2027
2136
 
2028
2137
  id<MTLComputePipelineState> pipeline = nil;
2029
2138
 
2030
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2031
- GGML_ASSERT(ggml_is_contiguous(src0));
2032
-
2033
- // src1 is a row
2034
- GGML_ASSERT(ne11 == 1);
2035
-
2036
- switch (dst->op) {
2037
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2038
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2039
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2040
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2041
- default: GGML_ABORT("fatal error");
2042
- }
2043
-
2044
- bcast_row = true;
2045
- } else {
2046
- switch (dst->op) {
2047
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2048
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2049
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2050
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2051
- default: GGML_ABORT("fatal error");
2052
- }
2053
- }
2054
-
2055
2139
  ggml_metal_kargs_bin args = {
2056
2140
  /*.ne00 =*/ ne00,
2057
2141
  /*.ne01 =*/ ne01,
@@ -2078,12 +2162,119 @@ static bool ggml_metal_encode_node(
2078
2162
  /*.nb2 =*/ nb2,
2079
2163
  /*.nb3 =*/ nb3,
2080
2164
  /*.offs =*/ offs,
2165
+ /*.o1 =*/ { offs_src1 },
2081
2166
  };
2082
2167
 
2168
+ // c[0] = add(a, b[0])
2169
+ // c[1] = add(c[0], b[1])
2170
+ // c[2] = add(c[1], b[2])
2171
+ // ...
2172
+ if (ctx_dev->use_fusion) {
2173
+ ops[0] = GGML_OP_ADD;
2174
+ ops[1] = GGML_OP_ADD;
2175
+ ops[2] = GGML_OP_ADD;
2176
+ ops[3] = GGML_OP_ADD;
2177
+ ops[4] = GGML_OP_ADD;
2178
+ ops[5] = GGML_OP_ADD;
2179
+ ops[6] = GGML_OP_ADD;
2180
+ ops[7] = GGML_OP_ADD;
2181
+
2182
+ size_t offs_fuse;
2183
+ id<MTLBuffer> id_fuse;
2184
+
2185
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2186
+ // across splits. idx_end indicates the last node in the current split
2187
+ for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
2188
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2189
+ break;
2190
+ }
2191
+
2192
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
2193
+ break;
2194
+ }
2195
+
2196
+ // b[0] === b[1] === ...
2197
+ if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
2198
+ break;
2199
+ }
2200
+
2201
+ // only fuse nodes if src1 is in the same Metal buffer
2202
+ id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
2203
+ if (id_fuse != id_src1) {
2204
+ break;
2205
+ }
2206
+
2207
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
2208
+
2209
+ args.o1[n_fuse + 1] = offs_fuse;
2210
+ }
2211
+
2212
+ ++n_fuse;
2213
+
2214
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
2215
+ GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2216
+ }
2217
+ }
2218
+
2219
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2220
+ GGML_ASSERT(ggml_is_contiguous(src0));
2221
+
2222
+ // src1 is a row
2223
+ GGML_ASSERT(ne11 == 1);
2224
+
2225
+ switch (dst->op) {
2226
+ case GGML_OP_ADD:
2227
+ {
2228
+ switch (n_fuse) {
2229
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
2230
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
2231
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
2232
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
2233
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
2234
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
2235
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
2236
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
2237
+ default: GGML_ABORT("fatal error");
2238
+ }
2239
+ } break;
2240
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
2241
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
2242
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
2243
+ default: GGML_ABORT("fatal error");
2244
+ }
2245
+
2246
+ bcast_row = true;
2247
+ } else {
2248
+ switch (dst->op) {
2249
+ case GGML_OP_ADD:
2250
+ {
2251
+ switch (n_fuse) {
2252
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
2253
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2254
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
2255
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2256
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
2257
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2258
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
2259
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2260
+ default: GGML_ABORT("fatal error");
2261
+ }
2262
+ } break;
2263
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2264
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2265
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2266
+ default: GGML_ABORT("fatal error");
2267
+ }
2268
+ }
2269
+
2270
+ if (n_fuse > 1) {
2271
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
2272
+ }
2273
+
2083
2274
  [encoder setComputePipelineState:pipeline];
2084
2275
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2085
2276
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2086
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2277
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2087
2278
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2088
2279
 
2089
2280
  if (bcast_row) {
@@ -2091,7 +2282,11 @@ static bool ggml_metal_encode_node(
2091
2282
 
2092
2283
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2093
2284
  } else {
2094
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2285
+ int nth = 32;
2286
+
2287
+ while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2288
+ nth *= 2;
2289
+ }
2095
2290
 
2096
2291
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2097
2292
  }
@@ -2216,12 +2411,13 @@ static bool ggml_metal_encode_node(
2216
2411
  /*.nb2 =*/ pnb2,
2217
2412
  /*.nb3 =*/ pnb3,
2218
2413
  /*.offs =*/ offs,
2414
+ /*.o1 =*/ { offs_src1},
2219
2415
  };
2220
2416
 
2221
2417
  [encoder setComputePipelineState:pipeline];
2222
2418
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2223
2419
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2224
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2420
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2225
2421
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2226
2422
 
2227
2423
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@@ -2233,7 +2429,9 @@ static bool ggml_metal_encode_node(
2233
2429
  GGML_ASSERT(ggml_is_contiguous(src0));
2234
2430
 
2235
2431
  float scale;
2236
- memcpy(&scale, dst->op_params, sizeof(scale));
2432
+ float bias;
2433
+ memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
2434
+ memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
2237
2435
 
2238
2436
  int64_t n = ggml_nelements(dst);
2239
2437
 
@@ -2250,6 +2448,7 @@ static bool ggml_metal_encode_node(
2250
2448
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2251
2449
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2252
2450
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
2451
+ [encoder setBytes:&bias length:sizeof(bias) atIndex:3];
2253
2452
 
2254
2453
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2255
2454
  } break;
@@ -2413,12 +2612,146 @@ static bool ggml_metal_encode_node(
2413
2612
 
2414
2613
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2415
2614
  } break;
2615
+ case GGML_UNARY_OP_ABS:
2616
+ {
2617
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
2618
+
2619
+ [encoder setComputePipelineState:pipeline];
2620
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2621
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2622
+
2623
+ const int64_t n = ggml_nelements(dst);
2624
+
2625
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2626
+ } break;
2627
+ case GGML_UNARY_OP_SGN:
2628
+ {
2629
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
2630
+
2631
+ [encoder setComputePipelineState:pipeline];
2632
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2633
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2634
+
2635
+ const int64_t n = ggml_nelements(dst);
2636
+
2637
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2638
+ } break;
2639
+ case GGML_UNARY_OP_STEP:
2640
+ {
2641
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
2642
+
2643
+ [encoder setComputePipelineState:pipeline];
2644
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2645
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2646
+
2647
+ const int64_t n = ggml_nelements(dst);
2648
+
2649
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2650
+ } break;
2651
+ case GGML_UNARY_OP_HARDSWISH:
2652
+ {
2653
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
2654
+
2655
+ [encoder setComputePipelineState:pipeline];
2656
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2657
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2658
+
2659
+ const int64_t n = ggml_nelements(dst);
2660
+
2661
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2662
+ } break;
2663
+ case GGML_UNARY_OP_HARDSIGMOID:
2664
+ {
2665
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
2666
+
2667
+ [encoder setComputePipelineState:pipeline];
2668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2670
+
2671
+ const int64_t n = ggml_nelements(dst);
2672
+
2673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2674
+ } break;
2675
+ case GGML_UNARY_OP_EXP:
2676
+ {
2677
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
2678
+
2679
+ [encoder setComputePipelineState:pipeline];
2680
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2681
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2682
+
2683
+ const int64_t n = ggml_nelements(dst);
2684
+
2685
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2686
+ } break;
2416
2687
  default:
2417
2688
  {
2418
2689
  GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
2419
2690
  GGML_ABORT("fatal error");
2420
2691
  }
2421
2692
  } break;
2693
+ case GGML_OP_GLU:
2694
+ {
2695
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2696
+
2697
+ if (src1) {
2698
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
2699
+ }
2700
+
2701
+ id<MTLComputePipelineState> pipeline = nil;
2702
+
2703
+ switch (ggml_get_glu_op(node)) {
2704
+ case GGML_GLU_OP_REGLU:
2705
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2706
+ break;
2707
+ case GGML_GLU_OP_GEGLU:
2708
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2709
+ break;
2710
+ case GGML_GLU_OP_SWIGLU:
2711
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2712
+ break;
2713
+ case GGML_GLU_OP_GEGLU_ERF:
2714
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2715
+ break;
2716
+ case GGML_GLU_OP_GEGLU_QUICK:
2717
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
2718
+ break;
2719
+ default:
2720
+ GGML_ABORT("fatal error");
2721
+ }
2722
+
2723
+ const int32_t swp = ((const int32_t *) dst->op_params)[1];
2724
+
2725
+ const int32_t i00 = swp ? ne0 : 0;
2726
+ const int32_t i10 = swp ? 0 : ne0;
2727
+
2728
+ ggml_metal_kargs_glu args = {
2729
+ /*.ne00 =*/ ne00,
2730
+ /*.nb01 =*/ nb01,
2731
+ /*.ne10 =*/ src1 ? ne10 : ne00,
2732
+ /*.nb11 =*/ src1 ? nb11 : nb01,
2733
+ /*.ne0 =*/ ne0,
2734
+ /*.nb1 =*/ nb1,
2735
+ /*.i00 =*/ src1 ? 0 : i00,
2736
+ /*.i10 =*/ src1 ? 0 : i10,
2737
+ };
2738
+
2739
+ [encoder setComputePipelineState:pipeline];
2740
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2741
+ if (src1) {
2742
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2743
+ } else {
2744
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2745
+ }
2746
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2747
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2748
+
2749
+ const int64_t nrows = ggml_nrows(src0);
2750
+
2751
+ const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2752
+
2753
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2754
+ } break;
2422
2755
  case GGML_OP_SQR:
2423
2756
  {
2424
2757
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -2573,10 +2906,7 @@ static bool ggml_metal_encode_node(
2573
2906
  memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
2574
2907
  memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
2575
2908
 
2576
- const int64_t nrows_x = ggml_nrows(src0);
2577
- const int64_t nrows_y = src0->ne[1];
2578
-
2579
- const uint32_t n_head = nrows_x/nrows_y;
2909
+ const uint32_t n_head = src0->ne[2];
2580
2910
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2581
2911
 
2582
2912
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2589,7 +2919,7 @@ static bool ggml_metal_encode_node(
2589
2919
  id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2590
2920
  if (!h_src0) {
2591
2921
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2592
- return false;
2922
+ return 0;
2593
2923
  }
2594
2924
 
2595
2925
  offs_src0 = 0;
@@ -2636,6 +2966,18 @@ static bool ggml_metal_encode_node(
2636
2966
  /*.ne00 =*/ ne00,
2637
2967
  /*.ne01 =*/ ne01,
2638
2968
  /*.ne02 =*/ ne02,
2969
+ /*.nb01 =*/ nb01,
2970
+ /*.nb02 =*/ nb02,
2971
+ /*.nb03 =*/ nb03,
2972
+ /*.ne11 =*/ ne11,
2973
+ /*.ne12 =*/ ne12,
2974
+ /*.ne13 =*/ ne13,
2975
+ /*.nb11 =*/ nb11,
2976
+ /*.nb12 =*/ nb12,
2977
+ /*.nb13 =*/ nb13,
2978
+ /*.nb1 =*/ nb1,
2979
+ /*.nb2 =*/ nb2,
2980
+ /*.nb3 =*/ nb3,
2639
2981
  /*.scale =*/ scale,
2640
2982
  /*.max_bias =*/ max_bias,
2641
2983
  /*.m0 =*/ m0,
@@ -2655,7 +2997,7 @@ static bool ggml_metal_encode_node(
2655
2997
 
2656
2998
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2657
2999
 
2658
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3000
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2659
3001
  } break;
2660
3002
  case GGML_OP_DIAG_MASK_INF:
2661
3003
  {
@@ -2729,71 +3071,92 @@ static bool ggml_metal_encode_node(
2729
3071
  struct ggml_tensor * src3 = node->src[3];
2730
3072
  struct ggml_tensor * src4 = node->src[4];
2731
3073
  struct ggml_tensor * src5 = node->src[5];
3074
+ struct ggml_tensor * src6 = node->src[6];
2732
3075
 
2733
3076
  GGML_ASSERT(src3);
2734
3077
  GGML_ASSERT(src4);
2735
3078
  GGML_ASSERT(src5);
3079
+ GGML_ASSERT(src6);
2736
3080
 
2737
3081
  size_t offs_src3 = 0;
2738
3082
  size_t offs_src4 = 0;
2739
3083
  size_t offs_src5 = 0;
3084
+ size_t offs_src6 = 0;
2740
3085
 
2741
3086
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2742
3087
  id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
2743
3088
  id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
3089
+ id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
2744
3090
 
2745
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
3091
+ const int64_t ne30 = src3->ne[0];
2746
3092
  const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
2747
3093
 
2748
- const uint64_t nb30 = src3->nb[0];
3094
+ const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
2749
3095
  const uint64_t nb31 = src3->nb[1];
2750
3096
 
2751
3097
  const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
2752
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
3098
+ const int64_t ne41 = src4->ne[1];
2753
3099
  const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
3100
+ const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
2754
3101
 
2755
- const uint64_t nb40 = src4->nb[0];
3102
+ const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
2756
3103
  const uint64_t nb41 = src4->nb[1];
2757
3104
  const uint64_t nb42 = src4->nb[2];
3105
+ const uint64_t nb43 = src4->nb[3];
2758
3106
 
2759
3107
  const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
2760
3108
  const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
2761
3109
  const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
3110
+ const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
2762
3111
 
2763
- const uint64_t nb50 = src5->nb[0];
3112
+ const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
2764
3113
  const uint64_t nb51 = src5->nb[1];
2765
3114
  const uint64_t nb52 = src5->nb[2];
3115
+ const uint64_t nb53 = src5->nb[3];
3116
+
3117
+ const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
3118
+
3119
+ const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
2766
3120
 
2767
3121
  const int64_t d_state = ne00;
2768
3122
  const int64_t d_inner = ne01;
2769
- const int64_t n_seq_tokens = ne11;
2770
- const int64_t n_seqs = ne02;
3123
+ const int64_t n_head = ne02;
3124
+ const int64_t n_group = ne41;
3125
+ const int64_t n_seq_tokens = ne12;
3126
+ const int64_t n_seqs = ne13;
2771
3127
 
2772
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
3128
+ id<MTLComputePipelineState> pipeline = nil;
3129
+
3130
+ if (ne30 == 1) {
3131
+ // Mamba-2
3132
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
3133
+ } else {
3134
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
3135
+ }
2773
3136
 
2774
3137
  ggml_metal_kargs_ssm_scan args = {
2775
- /*.d_state =*/ d_state,
2776
- /*.d_inner =*/ d_inner,
3138
+ /*.d_state =*/ d_state,
3139
+ /*.d_inner =*/ d_inner,
3140
+ /*.n_head =*/ n_head,
3141
+ /*.n_group =*/ n_group,
2777
3142
  /*.n_seq_tokens =*/ n_seq_tokens,
2778
- /*.n_seqs =*/ n_seqs,
2779
- /*.nb00 =*/ nb00,
2780
- /*.nb01 =*/ nb01,
2781
- /*.nb02 =*/ nb02,
2782
- /*.nb10 =*/ nb10,
2783
- /*.nb11 =*/ nb11,
2784
- /*.nb12 =*/ nb12,
2785
- /*.nb13 =*/ nb13,
2786
- /*.nb20 =*/ nb20,
2787
- /*.nb21 =*/ nb21,
2788
- /*.nb22 =*/ nb22,
2789
- /*.nb30 =*/ nb30,
2790
- /*.nb31 =*/ nb31,
2791
- /*.nb40 =*/ nb40,
2792
- /*.nb41 =*/ nb41,
2793
- /*.nb42 =*/ nb42,
2794
- /*.nb50 =*/ nb50,
2795
- /*.nb51 =*/ nb51,
2796
- /*.nb52 =*/ nb52,
3143
+ /*.n_seqs =*/ n_seqs,
3144
+ /*.s_off =*/ ggml_nelements(src1) * sizeof(float),
3145
+ /*.nb01 =*/ nb01,
3146
+ /*.nb02 =*/ nb02,
3147
+ /*.nb03 =*/ nb03,
3148
+ /*.nb11 =*/ nb11,
3149
+ /*.nb12 =*/ nb12,
3150
+ /*.nb13 =*/ nb13,
3151
+ /*.nb21 =*/ nb21,
3152
+ /*.nb22 =*/ nb22,
3153
+ /*.nb31 =*/ nb31,
3154
+ /*.nb41 =*/ nb41,
3155
+ /*.nb42 =*/ nb42,
3156
+ /*.nb43 =*/ nb43,
3157
+ /*.nb51 =*/ nb51,
3158
+ /*.nb52 =*/ nb52,
3159
+ /*.nb53 =*/ nb53,
2797
3160
  };
2798
3161
 
2799
3162
  [encoder setComputePipelineState:pipeline];
@@ -2803,10 +3166,27 @@ static bool ggml_metal_encode_node(
2803
3166
  [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2804
3167
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2805
3168
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2806
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2807
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
3169
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
3170
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
3171
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
3172
+
3173
+ // One shared memory bucket for each simd group in the threadgroup
3174
+ // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
3175
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3176
+ if (d_state >= 32) {
3177
+ GGML_ASSERT((int64_t)(d_state / 32) <= 32);
3178
+ const int64_t shmem_size = 32;
3179
+ GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
3180
+ [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
3181
+ }
2808
3182
 
2809
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3183
+ if (ne30 == 1) {
3184
+ // Mamba-2
3185
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
3186
+ } else {
3187
+ GGML_ASSERT(d_inner == 1);
3188
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
3189
+ }
2810
3190
  } break;
2811
3191
  case GGML_OP_RWKV_WKV6:
2812
3192
  {
@@ -3426,7 +3806,7 @@ static bool ggml_metal_encode_node(
3426
3806
  id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3427
3807
  if (!h_src1) {
3428
3808
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3429
- return false;
3809
+ return 0;
3430
3810
  }
3431
3811
 
3432
3812
  const int64_t neh0 = ne0;
@@ -3442,7 +3822,7 @@ static bool ggml_metal_encode_node(
3442
3822
  id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3443
3823
  if (!h_dst) {
3444
3824
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3445
- return false;
3825
+ return 0;
3446
3826
  }
3447
3827
 
3448
3828
  // tokens per expert
@@ -3450,7 +3830,7 @@ static bool ggml_metal_encode_node(
3450
3830
  id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3451
3831
  if (!h_tpe) {
3452
3832
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3453
- return false;
3833
+ return 0;
3454
3834
  }
3455
3835
 
3456
3836
  // id map
@@ -3459,7 +3839,7 @@ static bool ggml_metal_encode_node(
3459
3839
  id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3460
3840
  if (!h_ids) {
3461
3841
  GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3462
- return false;
3842
+ return 0;
3463
3843
  }
3464
3844
 
3465
3845
  {
@@ -3891,12 +4271,95 @@ static bool ggml_metal_encode_node(
3891
4271
  case GGML_OP_RMS_NORM:
3892
4272
  {
3893
4273
  GGML_ASSERT(ne00 % 4 == 0);
3894
- GGML_ASSERT(ggml_is_contiguous_1(src0));
4274
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
3895
4275
 
3896
4276
  float eps;
3897
4277
  memcpy(&eps, dst->op_params, sizeof(float));
3898
4278
 
3899
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
4279
+ ggml_metal_kargs_rms_norm args = {
4280
+ /*.ne00 =*/ ne00,
4281
+ /*.ne00_4 =*/ ne00/4,
4282
+ /*.nb1 =*/ nb1,
4283
+ /*.nb2 =*/ nb2,
4284
+ /*.nb3 =*/ nb3,
4285
+ /*.eps =*/ eps,
4286
+ /*.nef1 =*/ { ne01 },
4287
+ /*.nef2 =*/ { ne02 },
4288
+ /*.nef3 =*/ { ne03 },
4289
+ /*.nbf1 =*/ { nb01 },
4290
+ /*.nbf2 =*/ { nb02 },
4291
+ /*.nbf3 =*/ { nb03 },
4292
+ };
4293
+
4294
+ size_t offs_fuse[2] = { 0, 0 };
4295
+ id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
4296
+
4297
+ // d[0] = rms_norm(a)
4298
+ // d[1] = mul(d[0], b)
4299
+ // d[2] = add(d[1], c)
4300
+ if (ctx_dev->use_fusion) {
4301
+ ops[0] = GGML_OP_RMS_NORM;
4302
+ ops[1] = GGML_OP_MUL;
4303
+ ops[2] = GGML_OP_ADD;
4304
+
4305
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
4306
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4307
+ break;
4308
+ }
4309
+
4310
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
4311
+ break;
4312
+ }
4313
+
4314
+ if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
4315
+ break;
4316
+ }
4317
+
4318
+ if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
4319
+ break;
4320
+ }
4321
+
4322
+ if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
4323
+ break;
4324
+ }
4325
+
4326
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
4327
+
4328
+ id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
4329
+
4330
+ args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
4331
+ args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
4332
+ args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
4333
+
4334
+ args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
4335
+ args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
4336
+ args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
4337
+ }
4338
+
4339
+ ++n_fuse;
4340
+
4341
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
4342
+ if (n_fuse == 2) {
4343
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
4344
+ }
4345
+ if (n_fuse == 3) {
4346
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
4347
+ }
4348
+ }
4349
+ }
4350
+
4351
+ if (n_fuse > 1) {
4352
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4353
+ }
4354
+
4355
+ id<MTLComputePipelineState> pipeline;
4356
+
4357
+ switch (n_fuse) {
4358
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4359
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4360
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4361
+ default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4362
+ }
3900
4363
 
3901
4364
  int nth = 32; // SIMD width
3902
4365
 
@@ -3907,23 +4370,16 @@ static bool ggml_metal_encode_node(
3907
4370
  nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3908
4371
  nth = MIN(nth, ne00/4);
3909
4372
 
3910
- ggml_metal_kargs_rms_norm args = {
3911
- /*.ne00 =*/ ne00,
3912
- /*.ne00_4 =*/ ne00/4,
3913
- /*.nb01 =*/ nb01,
3914
- /*.eps =*/ eps,
3915
- };
3916
-
3917
4373
  [encoder setComputePipelineState:pipeline];
3918
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
3919
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3920
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4374
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
4375
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4376
+ [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
4377
+ [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
4378
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
3921
4379
 
3922
4380
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3923
4381
 
3924
- const int64_t nrows = ggml_nrows(src0);
3925
-
3926
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4382
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3927
4383
  } break;
3928
4384
  case GGML_OP_L2_NORM:
3929
4385
  {
@@ -4908,7 +5364,11 @@ static bool ggml_metal_encode_node(
4908
5364
  /*.nb21 =*/ nb21,
4909
5365
  /*.nb22 =*/ nb22,
4910
5366
  /*.nb23 =*/ nb23,
5367
+ /*.ne32 =*/ ne32,
5368
+ /*.ne33 =*/ ne33,
4911
5369
  /*.nb31 =*/ nb31,
5370
+ /*.nb32 =*/ nb32,
5371
+ /*.nb33 =*/ nb33,
4912
5372
  /*.ne1 =*/ ne1,
4913
5373
  /*.ne2 =*/ ne2,
4914
5374
  /*.scale =*/ scale,
@@ -5314,7 +5774,7 @@ static bool ggml_metal_encode_node(
5314
5774
  }
5315
5775
  }
5316
5776
 
5317
- return true;
5777
+ return n_fuse;
5318
5778
  }
5319
5779
 
5320
5780
  static enum ggml_status ggml_metal_graph_compute(
@@ -5820,20 +6280,26 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5820
6280
  struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5821
6281
  ggml_metal_mem_pool_reset(mem_pool);
5822
6282
 
5823
- for (int idx = node_start; idx < node_end; ++idx) {
6283
+ for (int idx = node_start; idx < node_end;) {
5824
6284
  if (should_capture) {
5825
6285
  [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5826
6286
  }
5827
6287
 
5828
- const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6288
+ const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
6289
+ if (idx + res > node_end) {
6290
+ GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
6291
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
6292
+ }
5829
6293
 
5830
6294
  if (should_capture) {
5831
6295
  [encoder popDebugGroup];
5832
6296
  }
5833
6297
 
5834
- if (!res) {
6298
+ if (res == 0) {
5835
6299
  break;
5836
6300
  }
6301
+
6302
+ idx += res;
5837
6303
  }
5838
6304
 
5839
6305
  [encoder endEncoding];