@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
48
48
  int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
+ NSLock * mtl_lock;
52
+
51
53
  bool has_simdgroup_reduction;
52
54
  bool has_simdgroup_mm;
53
55
  bool has_residency_sets;
54
56
  bool has_bfloat;
55
57
  bool use_bfloat;
56
58
 
59
+ size_t max_size;
60
+
57
61
  char name[128];
58
62
  } g_ggml_ctx_dev_main = {
59
63
  /*.mtl_device =*/ nil,
60
64
  /*.mtl_device_ref_count =*/ 0,
61
65
  /*.mtl_library =*/ nil,
66
+ /*.mtl_lock =*/ nil,
62
67
  /*.has_simdgroup_reduction =*/ false,
63
68
  /*.has_simdgroup_mm =*/ false,
64
69
  /*.has_residency_sets =*/ false,
65
70
  /*.has_bfloat =*/ false,
66
71
  /*.use_bfloat =*/ false,
72
+ /*.max_size =*/ 0,
67
73
  /*.name =*/ "",
68
74
  };
69
75
 
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
71
77
  static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
72
78
  assert(ctx != NULL);
73
79
 
80
+ if (ctx->mtl_lock == nil) {
81
+ ctx->mtl_lock = [[NSLock alloc] init];
82
+ }
83
+
74
84
  if (ctx->mtl_device == nil) {
75
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
86
  }
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
94
104
  ctx->use_bfloat = false;
95
105
  #endif
96
106
 
107
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
108
+
97
109
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98
110
  }
99
111
 
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
110
122
  ctx->mtl_device_ref_count--;
111
123
 
112
124
  if (ctx->mtl_device_ref_count == 0) {
125
+ if (ctx->mtl_lock) {
126
+ [ctx->mtl_lock release];
127
+ ctx->mtl_lock = nil;
128
+ }
129
+
113
130
  if (ctx->mtl_library) {
114
131
  [ctx->mtl_library release];
115
132
  ctx->mtl_library = nil;
@@ -185,20 +202,33 @@ enum ggml_metal_kernel_type {
185
202
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
186
203
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
187
204
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
188
214
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
189
215
  GGML_METAL_KERNEL_TYPE_L2_NORM,
190
216
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
191
217
  GGML_METAL_KERNEL_TYPE_NORM,
192
218
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
193
219
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
220
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
194
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
222
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
196
223
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
224
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
197
225
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
226
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
198
227
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
199
228
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
200
229
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
201
230
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
231
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
202
232
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
203
233
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
204
234
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -497,6 +527,11 @@ enum ggml_metal_kernel_type {
497
527
  GGML_METAL_KERNEL_TYPE_SIN,
498
528
  GGML_METAL_KERNEL_TYPE_COS,
499
529
  GGML_METAL_KERNEL_TYPE_NEG,
530
+ GGML_METAL_KERNEL_TYPE_REGLU,
531
+ GGML_METAL_KERNEL_TYPE_GEGLU,
532
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
533
+ GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
534
+ GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
500
535
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
536
  GGML_METAL_KERNEL_TYPE_MEAN,
502
537
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -977,7 +1012,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
977
1012
  struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
978
1013
  struct ggml_backend_metal_device_context * ctx_dev = dev->context;
979
1014
 
980
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
1015
+ id<MTLDevice> device = ctx_dev->mtl_device;
981
1016
 
982
1017
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
983
1018
 
@@ -991,9 +1026,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
991
1026
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
992
1027
 
993
1028
  // load library
994
- if (ctx_dev->mtl_library == nil) {
995
- ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1029
+ {
1030
+ [ctx_dev->mtl_lock lock];
1031
+
1032
+ if (ctx_dev->mtl_library == nil) {
1033
+ ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1034
+ }
1035
+
1036
+ [ctx_dev->mtl_lock unlock];
996
1037
  }
1038
+
997
1039
  id<MTLLibrary> metal_library = ctx_dev->mtl_library;
998
1040
  if (metal_library == nil) {
999
1041
  GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1142,20 +1184,33 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1142
1184
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1143
1185
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1144
1186
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1190
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1191
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1192
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1193
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1194
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1195
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1145
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1146
1197
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1147
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1148
1199
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1149
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1150
1201
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1202
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1151
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1152
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1153
1205
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1206
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1154
1207
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1208
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1155
1209
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1156
1210
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1157
1211
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1158
1212
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1213
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1159
1214
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1160
1215
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1161
1216
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -1454,6 +1509,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1454
1509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1455
1510
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1456
1511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1513
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1514
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1457
1517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1458
1518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1459
1519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1605,6 +1665,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1605
1665
  const bool use_bfloat = ctx_dev->use_bfloat;
1606
1666
 
1607
1667
  if (!use_bfloat) {
1668
+ if (op->type == GGML_TYPE_BF16) {
1669
+ return false;
1670
+ }
1671
+
1608
1672
  for (size_t i = 0, n = 3; i < n; ++i) {
1609
1673
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1610
1674
  return false;
@@ -1628,6 +1692,17 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1628
1692
  default:
1629
1693
  return false;
1630
1694
  }
1695
+ case GGML_OP_GLU:
1696
+ switch (ggml_get_glu_op(op)) {
1697
+ case GGML_GLU_OP_REGLU:
1698
+ case GGML_GLU_OP_GEGLU:
1699
+ case GGML_GLU_OP_SWIGLU:
1700
+ case GGML_GLU_OP_GEGLU_ERF:
1701
+ case GGML_GLU_OP_GEGLU_QUICK:
1702
+ return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1703
+ default:
1704
+ return false;
1705
+ }
1631
1706
  case GGML_OP_NONE:
1632
1707
  case GGML_OP_RESHAPE:
1633
1708
  case GGML_OP_VIEW:
@@ -1658,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1658
1733
  case GGML_OP_MEAN:
1659
1734
  case GGML_OP_SOFT_MAX:
1660
1735
  case GGML_OP_GROUP_NORM:
1661
- return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1736
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
1662
1737
  case GGML_OP_RMS_NORM:
1663
1738
  case GGML_OP_L2_NORM:
1664
1739
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -1774,6 +1849,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1774
1849
  {
1775
1850
  return op->ne[3] == 1;
1776
1851
  }
1852
+ case GGML_OP_SET_ROWS:
1853
+ {
1854
+ if (op->src[0]->type != GGML_TYPE_F32) {
1855
+ return false;
1856
+ }
1857
+
1858
+ switch (op->type) {
1859
+ case GGML_TYPE_F32:
1860
+ case GGML_TYPE_F16:
1861
+ case GGML_TYPE_BF16:
1862
+ case GGML_TYPE_Q8_0:
1863
+ case GGML_TYPE_Q4_0:
1864
+ case GGML_TYPE_Q4_1:
1865
+ case GGML_TYPE_Q5_0:
1866
+ case GGML_TYPE_Q5_1:
1867
+ case GGML_TYPE_IQ4_NL:
1868
+ return true;
1869
+ default:
1870
+ return false;
1871
+ };
1872
+ }
1777
1873
  default:
1778
1874
  return false;
1779
1875
  }
@@ -2160,7 +2256,9 @@ static bool ggml_metal_encode_node(
2160
2256
  GGML_ASSERT(ggml_is_contiguous(src0));
2161
2257
 
2162
2258
  float scale;
2163
- memcpy(&scale, dst->op_params, sizeof(scale));
2259
+ float bias;
2260
+ memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
2261
+ memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
2164
2262
 
2165
2263
  int64_t n = ggml_nelements(dst);
2166
2264
 
@@ -2177,6 +2275,7 @@ static bool ggml_metal_encode_node(
2177
2275
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2178
2276
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2179
2277
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
2278
+ [encoder setBytes:&bias length:sizeof(bias) atIndex:3];
2180
2279
 
2181
2280
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2182
2281
  } break;
@@ -2346,6 +2445,68 @@ static bool ggml_metal_encode_node(
2346
2445
  GGML_ABORT("fatal error");
2347
2446
  }
2348
2447
  } break;
2448
+ case GGML_OP_GLU:
2449
+ {
2450
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2451
+
2452
+ if (src1) {
2453
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
2454
+ }
2455
+
2456
+ id<MTLComputePipelineState> pipeline = nil;
2457
+
2458
+ switch (ggml_get_glu_op(node)) {
2459
+ case GGML_GLU_OP_REGLU:
2460
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2461
+ break;
2462
+ case GGML_GLU_OP_GEGLU:
2463
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2464
+ break;
2465
+ case GGML_GLU_OP_SWIGLU:
2466
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2467
+ break;
2468
+ case GGML_GLU_OP_GEGLU_ERF:
2469
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2470
+ break;
2471
+ case GGML_GLU_OP_GEGLU_QUICK:
2472
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
2473
+ break;
2474
+ default:
2475
+ GGML_ABORT("fatal error");
2476
+ }
2477
+
2478
+ const int32_t swp = ((const int32_t *) dst->op_params)[1];
2479
+
2480
+ const int32_t i00 = swp ? ne0 : 0;
2481
+ const int32_t i10 = swp ? 0 : ne0;
2482
+
2483
+ ggml_metal_kargs_glu args = {
2484
+ /*.ne00 =*/ ne00,
2485
+ /*.nb01 =*/ nb01,
2486
+ /*.ne10 =*/ src1 ? ne10 : ne00,
2487
+ /*.nb11 =*/ src1 ? nb11 : nb01,
2488
+ /*.ne0 =*/ ne0,
2489
+ /*.nb1 =*/ nb1,
2490
+ /*.i00 =*/ src1 ? 0 : i00,
2491
+ /*.i10 =*/ src1 ? 0 : i10,
2492
+ };
2493
+
2494
+ [encoder setComputePipelineState:pipeline];
2495
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2496
+ if (src1) {
2497
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2498
+ } else {
2499
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2500
+ }
2501
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2502
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2503
+
2504
+ const int64_t nrows = ggml_nrows(src0);
2505
+
2506
+ const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2507
+
2508
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2509
+ } break;
2349
2510
  case GGML_OP_SQR:
2350
2511
  {
2351
2512
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -2426,6 +2587,7 @@ static bool ggml_metal_encode_node(
2426
2587
  nth *= 2;
2427
2588
  }
2428
2589
 
2590
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
2429
2591
  nth = MIN(nth, ne00);
2430
2592
 
2431
2593
  ggml_metal_kargs_sum_rows args = {
@@ -2499,10 +2661,7 @@ static bool ggml_metal_encode_node(
2499
2661
  memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
2500
2662
  memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
2501
2663
 
2502
- const int64_t nrows_x = ggml_nrows(src0);
2503
- const int64_t nrows_y = src0->ne[1];
2504
-
2505
- const uint32_t n_head = nrows_x/nrows_y;
2664
+ const uint32_t n_head = src0->ne[2];
2506
2665
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2507
2666
 
2508
2667
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2562,6 +2721,18 @@ static bool ggml_metal_encode_node(
2562
2721
  /*.ne00 =*/ ne00,
2563
2722
  /*.ne01 =*/ ne01,
2564
2723
  /*.ne02 =*/ ne02,
2724
+ /*.nb01 =*/ nb01,
2725
+ /*.nb02 =*/ nb02,
2726
+ /*.nb03 =*/ nb03,
2727
+ /*.ne11 =*/ ne11,
2728
+ /*.ne12 =*/ ne12,
2729
+ /*.ne13 =*/ ne13,
2730
+ /*.nb11 =*/ nb11,
2731
+ /*.nb12 =*/ nb12,
2732
+ /*.nb13 =*/ nb13,
2733
+ /*.nb1 =*/ nb1,
2734
+ /*.nb2 =*/ nb2,
2735
+ /*.nb3 =*/ nb3,
2565
2736
  /*.scale =*/ scale,
2566
2737
  /*.max_bias =*/ max_bias,
2567
2738
  /*.m0 =*/ m0,
@@ -2581,7 +2752,7 @@ static bool ggml_metal_encode_node(
2581
2752
 
2582
2753
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2583
2754
 
2584
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2755
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2585
2756
  } break;
2586
2757
  case GGML_OP_DIAG_MASK_INF:
2587
2758
  {
@@ -2655,71 +2826,91 @@ static bool ggml_metal_encode_node(
2655
2826
  struct ggml_tensor * src3 = node->src[3];
2656
2827
  struct ggml_tensor * src4 = node->src[4];
2657
2828
  struct ggml_tensor * src5 = node->src[5];
2829
+ struct ggml_tensor * src6 = node->src[6];
2658
2830
 
2659
2831
  GGML_ASSERT(src3);
2660
2832
  GGML_ASSERT(src4);
2661
2833
  GGML_ASSERT(src5);
2834
+ GGML_ASSERT(src6);
2662
2835
 
2663
2836
  size_t offs_src3 = 0;
2664
2837
  size_t offs_src4 = 0;
2665
2838
  size_t offs_src5 = 0;
2839
+ size_t offs_src6 = 0;
2666
2840
 
2667
2841
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2668
2842
  id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
2669
2843
  id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
2844
+ id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
2670
2845
 
2671
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
2846
+ const int64_t ne30 = src3->ne[0];
2672
2847
  const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
2673
2848
 
2674
- const uint64_t nb30 = src3->nb[0];
2849
+ const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
2675
2850
  const uint64_t nb31 = src3->nb[1];
2676
2851
 
2677
2852
  const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
2678
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
2853
+ const int64_t ne41 = src4->ne[1];
2679
2854
  const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
2855
+ const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
2680
2856
 
2681
- const uint64_t nb40 = src4->nb[0];
2857
+ const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
2682
2858
  const uint64_t nb41 = src4->nb[1];
2683
2859
  const uint64_t nb42 = src4->nb[2];
2860
+ const uint64_t nb43 = src4->nb[3];
2684
2861
 
2685
2862
  const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
2686
2863
  const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
2687
2864
  const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
2865
+ const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
2688
2866
 
2689
- const uint64_t nb50 = src5->nb[0];
2867
+ const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
2690
2868
  const uint64_t nb51 = src5->nb[1];
2691
2869
  const uint64_t nb52 = src5->nb[2];
2870
+ const uint64_t nb53 = src5->nb[3];
2871
+
2872
+ const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
2873
+
2874
+ const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
2692
2875
 
2693
2876
  const int64_t d_state = ne00;
2694
2877
  const int64_t d_inner = ne01;
2695
- const int64_t n_seq_tokens = ne11;
2696
- const int64_t n_seqs = ne02;
2878
+ const int64_t n_head = ne02;
2879
+ const int64_t n_group = ne41;
2880
+ const int64_t n_seq_tokens = ne12;
2881
+ const int64_t n_seqs = ne13;
2882
+
2883
+ id<MTLComputePipelineState> pipeline = nil;
2697
2884
 
2698
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2885
+ if (ne30 == 1) {
2886
+ // Mamba-2
2887
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
2888
+ } else {
2889
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2890
+ }
2699
2891
 
2700
2892
  ggml_metal_kargs_ssm_scan args = {
2701
- /*.d_state =*/ d_state,
2702
- /*.d_inner =*/ d_inner,
2893
+ /*.d_state =*/ d_state,
2894
+ /*.d_inner =*/ d_inner,
2895
+ /*.n_head =*/ n_head,
2896
+ /*.n_group =*/ n_group,
2703
2897
  /*.n_seq_tokens =*/ n_seq_tokens,
2704
- /*.n_seqs =*/ n_seqs,
2705
- /*.nb00 =*/ nb00,
2706
- /*.nb01 =*/ nb01,
2707
- /*.nb02 =*/ nb02,
2708
- /*.nb10 =*/ nb10,
2709
- /*.nb11 =*/ nb11,
2710
- /*.nb12 =*/ nb12,
2711
- /*.nb13 =*/ nb13,
2712
- /*.nb20 =*/ nb20,
2713
- /*.nb21 =*/ nb21,
2714
- /*.nb22 =*/ nb22,
2715
- /*.nb30 =*/ nb30,
2716
- /*.nb31 =*/ nb31,
2717
- /*.nb40 =*/ nb40,
2718
- /*.nb41 =*/ nb41,
2719
- /*.nb42 =*/ nb42,
2720
- /*.nb50 =*/ nb50,
2721
- /*.nb51 =*/ nb51,
2722
- /*.nb52 =*/ nb52,
2898
+ /*.n_seqs =*/ n_seqs,
2899
+ /*.nb01 =*/ nb01,
2900
+ /*.nb02 =*/ nb02,
2901
+ /*.nb03 =*/ nb03,
2902
+ /*.nb11 =*/ nb11,
2903
+ /*.nb12 =*/ nb12,
2904
+ /*.nb13 =*/ nb13,
2905
+ /*.nb21 =*/ nb21,
2906
+ /*.nb22 =*/ nb22,
2907
+ /*.nb31 =*/ nb31,
2908
+ /*.nb41 =*/ nb41,
2909
+ /*.nb42 =*/ nb42,
2910
+ /*.nb43 =*/ nb43,
2911
+ /*.nb51 =*/ nb51,
2912
+ /*.nb52 =*/ nb52,
2913
+ /*.nb53 =*/ nb53,
2723
2914
  };
2724
2915
 
2725
2916
  [encoder setComputePipelineState:pipeline];
@@ -2729,10 +2920,17 @@ static bool ggml_metal_encode_node(
2729
2920
  [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2730
2921
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2731
2922
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2732
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2733
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
2923
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2924
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2925
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
2734
2926
 
2735
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2927
+ if (ne30 == 1) {
2928
+ // Mamba-2
2929
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2930
+ } else {
2931
+ GGML_ASSERT(d_inner == 1);
2932
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2933
+ }
2736
2934
  } break;
2737
2935
  case GGML_OP_RWKV_WKV6:
2738
2936
  {
@@ -3086,14 +3284,23 @@ static bool ggml_metal_encode_node(
3086
3284
  nsg = 1;
3087
3285
  nr0 = 1;
3088
3286
  nr1 = 4;
3089
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3287
+ if (ne00 == 4) {
3288
+ nr0 = 32;
3289
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3290
+ } else {
3291
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3292
+ }
3090
3293
  } break;
3091
3294
  case GGML_TYPE_F16:
3092
3295
  {
3093
3296
  nsg = 1;
3094
3297
  nr0 = 1;
3095
3298
  if (src1t == GGML_TYPE_F32) {
3096
- if (ne11 * ne12 < 4) {
3299
+ if (ne00 == 4) {
3300
+ nr0 = 32;
3301
+ nr1 = 4;
3302
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3303
+ } else if (ne11 * ne12 < 4) {
3097
3304
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3098
3305
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3099
3306
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3112,7 +3319,11 @@ static bool ggml_metal_encode_node(
3112
3319
  nsg = 1;
3113
3320
  nr0 = 1;
3114
3321
  if (src1t == GGML_TYPE_F32) {
3115
- if (ne11 * ne12 < 4) {
3322
+ if (ne00 == 4) {
3323
+ nr0 = 32;
3324
+ nr1 = 4;
3325
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3326
+ } else if (ne11 * ne12 < 4) {
3116
3327
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3117
3328
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3118
3329
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
@@ -3733,13 +3944,74 @@ static bool ggml_metal_encode_node(
3733
3944
  };
3734
3945
 
3735
3946
  [encoder setComputePipelineState:pipeline];
3736
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3737
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3738
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3739
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3947
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3948
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3949
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3950
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3740
3951
 
3741
3952
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3742
3953
  } break;
3954
+ case GGML_OP_SET_ROWS:
3955
+ {
3956
+ id<MTLComputePipelineState> pipeline = nil;
3957
+
3958
+ switch (dst->type) {
3959
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3960
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3961
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3962
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3963
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3964
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3965
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3966
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3967
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3968
+ default: GGML_ABORT("not implemented");
3969
+ }
3970
+
3971
+ const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3972
+
3973
+ int nth = 32; // SIMD width
3974
+
3975
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3976
+ nth *= 2;
3977
+ }
3978
+
3979
+ int nrptg = 1;
3980
+ if (nth > nk0) {
3981
+ nrptg = (nth + nk0 - 1)/nk0;
3982
+ nth = nk0;
3983
+
3984
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3985
+ nrptg--;
3986
+ }
3987
+ }
3988
+
3989
+ nth = MIN(nth, nk0);
3990
+
3991
+ ggml_metal_kargs_set_rows args = {
3992
+ /*.nk0 =*/ nk0,
3993
+ /*.ne01 =*/ ne01,
3994
+ /*.nb01 =*/ nb01,
3995
+ /*.nb02 =*/ nb02,
3996
+ /*.nb03 =*/ nb03,
3997
+ /*.ne11 =*/ ne11,
3998
+ /*.ne12 =*/ ne12,
3999
+ /*.nb10 =*/ nb10,
4000
+ /*.nb11 =*/ nb11,
4001
+ /*.nb12 =*/ nb12,
4002
+ /*.nb1 =*/ nb1,
4003
+ /*.nb2 =*/ nb2,
4004
+ /*.nb3 =*/ nb3,
4005
+ };
4006
+
4007
+ [encoder setComputePipelineState:pipeline];
4008
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
4009
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4010
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
4011
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
4012
+
4013
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4014
+ } break;
3743
4015
  case GGML_OP_RMS_NORM:
3744
4016
  {
3745
4017
  GGML_ASSERT(ne00 % 4 == 0);
@@ -3756,6 +4028,7 @@ static bool ggml_metal_encode_node(
3756
4028
  nth *= 2;
3757
4029
  }
3758
4030
 
4031
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3759
4032
  nth = MIN(nth, ne00/4);
3760
4033
 
3761
4034
  ggml_metal_kargs_rms_norm args = {
@@ -3792,6 +4065,7 @@ static bool ggml_metal_encode_node(
3792
4065
  nth *= 2;
3793
4066
  }
3794
4067
 
4068
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3795
4069
  nth = MIN(nth, ne00/4);
3796
4070
 
3797
4071
  ggml_metal_kargs_l2_norm args = {
@@ -3864,6 +4138,7 @@ static bool ggml_metal_encode_node(
3864
4138
  nth *= 2;
3865
4139
  }
3866
4140
 
4141
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3867
4142
  nth = MIN(nth, ne00/4);
3868
4143
 
3869
4144
  ggml_metal_kargs_norm args = {
@@ -4757,7 +5032,11 @@ static bool ggml_metal_encode_node(
4757
5032
  /*.nb21 =*/ nb21,
4758
5033
  /*.nb22 =*/ nb22,
4759
5034
  /*.nb23 =*/ nb23,
5035
+ /*.ne32 =*/ ne32,
5036
+ /*.ne33 =*/ ne33,
4760
5037
  /*.nb31 =*/ nb31,
5038
+ /*.nb32 =*/ nb32,
5039
+ /*.nb33 =*/ nb33,
4761
5040
  /*.ne1 =*/ ne1,
4762
5041
  /*.ne2 =*/ ne2,
4763
5042
  /*.scale =*/ scale,
@@ -4950,8 +5229,39 @@ static bool ggml_metal_encode_node(
4950
5229
  default: GGML_ABORT("not implemented");
4951
5230
  }
4952
5231
 
5232
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
5233
+
5234
+ // TODO: support
5235
+ //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
5236
+ const int32_t nk00 = ne00;
5237
+
5238
+ int nth = 32; // SIMD width
5239
+
5240
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
5241
+ nth *= 2;
5242
+ }
5243
+
5244
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
5245
+
5246
+ // when rows are small, we can batch them together in a single threadgroup
5247
+ int nrptg = 1;
5248
+
5249
+ // TODO: relax this constraint in the future
5250
+ if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
5251
+ if (nth > nk00) {
5252
+ nrptg = (nth + nk00 - 1)/nk00;
5253
+ nth = nk00;
5254
+
5255
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5256
+ nrptg--;
5257
+ }
5258
+ }
5259
+ }
5260
+
5261
+ nth = MIN(nth, nk00);
5262
+
4953
5263
  ggml_metal_kargs_cpy args = {
4954
- /*.ne00 =*/ ne00,
5264
+ /*.ne00 =*/ nk00,
4955
5265
  /*.ne01 =*/ ne01,
4956
5266
  /*.ne02 =*/ ne02,
4957
5267
  /*.ne03 =*/ ne03,
@@ -4974,11 +5284,7 @@ static bool ggml_metal_encode_node(
4974
5284
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4975
5285
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4976
5286
 
4977
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4978
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4979
-
4980
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4981
-
5287
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4982
5288
  } break;
4983
5289
  case GGML_OP_SET:
4984
5290
  {
@@ -5284,7 +5590,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
5284
5590
  }
5285
5591
 
5286
5592
  ggml_backend_metal_buffer_rset_free(ctx);
5287
- ggml_backend_metal_device_rel(buffer->buft->device->context);
5288
5593
 
5289
5594
  if (ctx->owned) {
5290
5595
  #if TARGET_OS_OSX
@@ -5393,7 +5698,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5393
5698
  }
5394
5699
 
5395
5700
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
5396
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5701
+
5702
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5703
+
5704
+ id<MTLDevice> device = ctx_dev->mtl_device;
5397
5705
 
5398
5706
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
5399
5707
  ctx->all_size = size_aligned;
@@ -5416,14 +5724,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5416
5724
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
5417
5725
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
5418
5726
  free(ctx);
5419
- ggml_backend_metal_device_rel(ctx_dev);
5420
5727
  return NULL;
5421
5728
  }
5422
5729
 
5423
5730
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5424
5731
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5425
5732
  free(ctx);
5426
- ggml_backend_metal_device_rel(ctx_dev);
5427
5733
  return NULL;
5428
5734
  }
5429
5735
 
@@ -5434,17 +5740,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5434
5740
 
5435
5741
  static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
5436
5742
  return 32;
5743
+
5437
5744
  GGML_UNUSED(buft);
5438
5745
  }
5439
5746
 
5440
5747
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
5441
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
5442
- const size_t max_size = device.maxBufferLength;
5443
- ggml_backend_metal_device_rel(buft->device->context);
5748
+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
5444
5749
 
5445
5750
  return max_size;
5446
-
5447
- GGML_UNUSED(buft);
5448
5751
  }
5449
5752
 
5450
5753
  static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5517,7 +5820,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5517
5820
  }
5518
5821
 
5519
5822
  struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5520
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5823
+
5824
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5825
+
5826
+ id<MTLDevice> device = ctx_dev->mtl_device;
5521
5827
 
5522
5828
  // the buffer fits into the max buffer size allowed by the device
5523
5829
  if (size_aligned <= device.maxBufferLength) {
@@ -5573,7 +5879,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5573
5879
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5574
5880
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5575
5881
  free(ctx);
5576
- ggml_backend_metal_device_rel(ctx_dev);
5577
5882
  return NULL;
5578
5883
  }
5579
5884
 
@@ -5589,10 +5894,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
5589
5894
  }
5590
5895
 
5591
5896
  static void ggml_backend_metal_free(ggml_backend_t backend) {
5592
- struct ggml_backend_metal_context * ctx = backend->context;
5593
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5897
+ struct ggml_backend_metal_context * ctx = backend->context;
5594
5898
 
5595
- ggml_backend_metal_device_rel(ctx_dev);
5596
5899
  ggml_metal_free(ctx);
5597
5900
 
5598
5901
  free(backend);
@@ -5732,6 +6035,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
5732
6035
 
5733
6036
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5734
6037
 
6038
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
6039
+
5735
6040
  return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
5736
6041
  }
5737
6042
 
@@ -5751,10 +6056,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
5751
6056
  }
5752
6057
 
5753
6058
  static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
5754
- // acq/rel just to populate ctx->name in case it hasn't been done yet
5755
6059
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5756
- ggml_backend_metal_device_acq(ctx_dev);
5757
- ggml_backend_metal_device_rel(ctx_dev);
5758
6060
 
5759
6061
  return ctx_dev->name;
5760
6062
  }
@@ -5762,12 +6064,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
5762
6064
  static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
5763
6065
  if (@available(macOS 10.12, iOS 16.0, *)) {
5764
6066
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5765
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
6067
+ id<MTLDevice> device = ctx_dev->mtl_device;
5766
6068
 
5767
6069
  *total = device.recommendedMaxWorkingSetSize;
5768
6070
  *free = *total - device.currentAllocatedSize;
5769
-
5770
- ggml_backend_metal_device_rel(ctx_dev);
5771
6071
  } else {
5772
6072
  *free = 1;
5773
6073
  *total = 1;
@@ -5845,7 +6145,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5845
6145
  }
5846
6146
 
5847
6147
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5848
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
6148
+
6149
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
6150
+
6151
+ id<MTLDevice> device = ctx_dev->mtl_device;
5849
6152
 
5850
6153
  // the buffer fits into the max buffer size allowed by the device
5851
6154
  if (size_aligned <= device.maxBufferLength) {
@@ -5901,7 +6204,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5901
6204
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5902
6205
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5903
6206
  free(ctx);
5904
- ggml_backend_metal_device_rel(ctx_dev);
5905
6207
  return NULL;
5906
6208
  }
5907
6209
 
@@ -5915,8 +6217,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
5915
6217
  }
5916
6218
 
5917
6219
  static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
5918
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
5919
- buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
6220
+ return
6221
+ buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
6222
+ buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
5920
6223
 
5921
6224
  GGML_UNUSED(dev);
5922
6225
  }
@@ -6001,8 +6304,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
6001
6304
  /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
6002
6305
  };
6003
6306
 
6307
+ // called upon program exit
6308
+ static void ggml_metal_cleanup(void) {
6309
+ ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
6310
+ }
6311
+
6312
+ // TODO: make thread-safe
6004
6313
  ggml_backend_reg_t ggml_backend_metal_reg(void) {
6005
- // TODO: make this thread-safe somehow?
6314
+ ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
6315
+
6316
+ // register cleanup callback
6317
+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6318
+ atexit(ggml_metal_cleanup);
6319
+
6006
6320
  {
6007
6321
  g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
6008
6322
  /* .api_version = */ GGML_BACKEND_API_VERSION,