@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
109
109
  }
110
110
 
111
111
  void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ #pragma METAL fp math_mode(safe)
112
113
  float amax = 0.0f; // absolute max
113
114
  float max = 0.0f;
114
115
 
@@ -138,6 +139,7 @@ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
138
139
  }
139
140
 
140
141
  void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
142
+ #pragma METAL fp math_mode(safe)
141
143
  float min = FLT_MAX;
142
144
  float max = -FLT_MAX;
143
145
 
@@ -166,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
166
168
  }
167
169
 
168
170
  void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
171
+ #pragma METAL fp math_mode(safe)
169
172
  float amax = 0.0f; // absolute max
170
173
  float max = 0.0f;
171
174
 
@@ -203,6 +206,7 @@ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
203
206
  }
204
207
 
205
208
  void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
209
+ #pragma METAL fp math_mode(safe)
206
210
  float max = src[0];
207
211
  float min = src[0];
208
212
 
@@ -239,6 +243,7 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
239
243
  }
240
244
 
241
245
  void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
246
+ #pragma METAL fp math_mode(safe)
242
247
  float amax = 0.0f; // absolute max
243
248
  float max = 0.0f;
244
249
 
@@ -458,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
458
463
  }
459
464
 
460
465
  void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
466
+ #pragma METAL fp math_mode(safe)
461
467
  float amax = 0.0f; // absolute max
462
468
 
463
469
  for (int j = 0; j < QK8_0; j++) {
@@ -826,7 +832,8 @@ enum ggml_sort_order {
826
832
  // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
827
833
  // pros: works for non-contiguous tensors, supports broadcast across all dims
828
834
  // cons: not very efficient
829
- kernel void kernel_add(
835
+ template <int F>
836
+ kernel void kernel_add_fuse_impl(
830
837
  constant ggml_metal_kargs_bin & args,
831
838
  device const char * src0,
832
839
  device const char * src1,
@@ -842,16 +849,39 @@ kernel void kernel_add(
842
849
  const int i12 = i02%args.ne12;
843
850
  const int i11 = i01%args.ne11;
844
851
 
845
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
846
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
847
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
852
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
853
+ device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
854
+
855
+ device const float * src1_ptr[F];
856
+ for (short j = 0; j < F; ++j) {
857
+ src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
858
+ }
848
859
 
849
860
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
850
861
  const int i10 = i0%args.ne10;
851
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
862
+
863
+ float res = src0_ptr[i0];
864
+
865
+ #pragma unroll
866
+ for (short j = 0; j < F; ++j) {
867
+ res += src1_ptr[j][i10];
868
+ }
869
+
870
+ dst_ptr[i0] = res;
852
871
  }
853
872
  }
854
873
 
874
+ typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
875
+
876
+ template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
877
+ template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
878
+ template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
879
+ template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
880
+ template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
881
+ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
882
+ template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
883
+ template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
884
+
855
885
  kernel void kernel_sub(
856
886
  constant ggml_metal_kargs_bin & args,
857
887
  device const char * src0,
@@ -869,7 +899,7 @@ kernel void kernel_sub(
869
899
  const int i11 = i01%args.ne11;
870
900
 
871
901
  device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
872
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
902
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
873
903
  device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
874
904
 
875
905
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
@@ -894,9 +924,9 @@ kernel void kernel_mul(
894
924
  const int i12 = i02%args.ne12;
895
925
  const int i11 = i01%args.ne11;
896
926
 
897
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
898
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
899
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
927
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
928
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
929
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
900
930
 
901
931
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
902
932
  const int i10 = i0%args.ne10;
@@ -920,9 +950,9 @@ kernel void kernel_div(
920
950
  const int i12 = i02%args.ne12;
921
951
  const int i11 = i01%args.ne11;
922
952
 
923
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
924
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
925
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
953
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
954
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
955
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
926
956
 
927
957
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
928
958
  const int i10 = i0%args.ne10;
@@ -964,60 +994,161 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
964
994
 
965
995
  // assumption: src1 is a row
966
996
  // broadcast src1 into src0
967
- kernel void kernel_add_row(
997
+ template <short F>
998
+ kernel void kernel_add_row_c4_fuse_impl(
968
999
  constant ggml_metal_kargs_bin & args,
969
- device const float4 * src0,
970
- device const float4 * src1,
971
- device float4 * dst,
1000
+ device const char * src0,
1001
+ device const char * src1,
1002
+ device char * dst,
972
1003
  uint tpig[[thread_position_in_grid]]) {
1004
+
973
1005
  const uint nb = args.ne00/4;
974
- dst[tpig] = src0[tpig] + src1[tpig % nb];
1006
+ const uint i = tpig % nb;
1007
+
1008
+ device const float4 * src0_row = (device const float4 *) (src0);
1009
+ device float4 * dst_row = (device float4 *) (dst);
1010
+
1011
+ device const float4 * src1_row[F];
1012
+ for (short j = 0; j < F; ++j) {
1013
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1014
+ }
1015
+
1016
+ float4 res = src0_row[tpig];
1017
+
1018
+ #pragma unroll(F)
1019
+ for (short j = 0; j < F; ++j) {
1020
+ res += src1_row[j][i];
1021
+ }
1022
+
1023
+ dst_row[tpig] = res;
975
1024
  }
976
1025
 
977
- kernel void kernel_sub_row(
1026
+ typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
1027
+
1028
+ template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
1029
+ template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
1030
+ template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
1031
+ template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
1032
+ template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
1033
+ template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
1034
+ template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
1035
+ template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
1036
+
1037
+ template <short F>
1038
+ kernel void kernel_sub_row_c4_fuse_impl(
978
1039
  constant ggml_metal_kargs_bin & args,
979
- device const float4 * src0,
980
- device const float4 * src1,
981
- device float4 * dst,
1040
+ device const char * src0,
1041
+ device const char * src1,
1042
+ device char * dst,
982
1043
  uint tpig[[thread_position_in_grid]]) {
1044
+
983
1045
  const uint nb = args.ne00/4;
984
- dst[tpig] = src0[tpig] - src1[tpig % nb];
1046
+ const uint i = tpig % nb;
1047
+
1048
+ device const float4 * src0_row = (device const float4 *) (src0);
1049
+ device float4 * dst_row = (device float4 *) (dst);
1050
+
1051
+ device const float4 * src1_row[F];
1052
+ for (short j = 0; j < F; ++j) {
1053
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1054
+ }
1055
+
1056
+ float4 res = src0_row[tpig];
1057
+
1058
+ #pragma unroll(F)
1059
+ for (short j = 0; j < F; ++j) {
1060
+ res -= src1_row[j][i];
1061
+ }
1062
+
1063
+ dst_row[tpig] = res;
985
1064
  }
986
1065
 
987
- kernel void kernel_mul_row(
1066
+ typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
1067
+
1068
+ template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
1069
+
1070
+ template <short F>
1071
+ kernel void kernel_mul_row_c4_fuse_impl(
988
1072
  constant ggml_metal_kargs_bin & args,
989
- device const float4 * src0,
990
- device const float4 * src1,
991
- device float4 * dst,
1073
+ device const char * src0,
1074
+ device const char * src1,
1075
+ device char * dst,
992
1076
  uint tpig[[thread_position_in_grid]]) {
1077
+
993
1078
  const uint nb = args.ne00/4;
994
- dst[tpig] = src0[tpig] * src1[tpig % nb];
1079
+ const uint i = tpig % nb;
1080
+
1081
+ device const float4 * src0_row = (device const float4 *) (src0);
1082
+ device float4 * dst_row = (device float4 *) (dst);
1083
+
1084
+ device const float4 * src1_row[F];
1085
+ for (short j = 0; j < F; ++j) {
1086
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1087
+ }
1088
+
1089
+ float4 res = src0_row[tpig];
1090
+
1091
+ #pragma unroll(F)
1092
+ for (short j = 0; j < F; ++j) {
1093
+ res *= src1_row[j][i];
1094
+ }
1095
+
1096
+ dst_row[tpig] = res;
995
1097
  }
996
1098
 
997
- kernel void kernel_div_row(
1099
+ typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
1100
+
1101
+ template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
1102
+
1103
+ template <short F>
1104
+ kernel void kernel_div_row_c4_fuse_impl(
998
1105
  constant ggml_metal_kargs_bin & args,
999
- device const float4 * src0,
1000
- device const float4 * src1,
1001
- device float4 * dst,
1106
+ device const char * src0,
1107
+ device const char * src1,
1108
+ device char * dst,
1002
1109
  uint tpig[[thread_position_in_grid]]) {
1110
+
1003
1111
  const uint nb = args.ne00/4;
1004
- dst[tpig] = src0[tpig] / src1[tpig % nb];
1112
+ const uint i = tpig % nb;
1113
+
1114
+ device const float4 * src0_row = (device const float4 *) (src0);
1115
+ device float4 * dst_row = (device float4 *) (dst);
1116
+
1117
+ device const float4 * src1_row[F];
1118
+ for (short j = 0; j < F; ++j) {
1119
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1120
+ }
1121
+
1122
+ float4 res = src0_row[tpig];
1123
+
1124
+ #pragma unroll(F)
1125
+ for (short j = 0; j < F; ++j) {
1126
+ res /= src1_row[j][i];
1127
+ }
1128
+
1129
+ dst_row[tpig] = res;
1005
1130
  }
1006
1131
 
1132
+ typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
1133
+
1134
+ template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
1135
+
1007
1136
  kernel void kernel_scale(
1008
1137
  device const float * src0,
1009
1138
  device float * dst,
1010
1139
  constant float & scale,
1140
+ constant float & bias,
1011
1141
  uint tpig[[thread_position_in_grid]]) {
1012
- dst[tpig] = src0[tpig] * scale;
1142
+ dst[tpig] = src0[tpig] * scale + bias;
1013
1143
  }
1014
1144
 
1015
1145
  kernel void kernel_scale_4(
1016
1146
  device const float4 * src0,
1017
1147
  device float4 * dst,
1018
1148
  constant float & scale,
1149
+ constant float & bias,
1019
1150
  uint tpig[[thread_position_in_grid]]) {
1020
- dst[tpig] = src0[tpig] * scale;
1151
+ dst[tpig] = src0[tpig] * scale + bias;
1021
1152
  }
1022
1153
 
1023
1154
  kernel void kernel_clamp(
@@ -1191,6 +1322,159 @@ kernel void kernel_neg(
1191
1322
  dst[tpig] = -src0[tpig];
1192
1323
  }
1193
1324
 
1325
+ kernel void kernel_abs(
1326
+ device const float * src0,
1327
+ device float * dst,
1328
+ uint tpig[[thread_position_in_grid]]) {
1329
+ dst[tpig] = fabs(src0[tpig]);
1330
+ }
1331
+
1332
+ kernel void kernel_sgn(
1333
+ device const float * src0,
1334
+ device float * dst,
1335
+ uint tpig[[thread_position_in_grid]]) {
1336
+ device const float & x = src0[tpig];
1337
+ dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
1338
+ }
1339
+
1340
+ kernel void kernel_step(
1341
+ device const float * src0,
1342
+ device float * dst,
1343
+ uint tpig[[thread_position_in_grid]]) {
1344
+ dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
1345
+ }
1346
+
1347
+ kernel void kernel_hardswish(
1348
+ device const float * src0,
1349
+ device float * dst,
1350
+ uint tpig[[thread_position_in_grid]]) {
1351
+ device const float & x = src0[tpig];
1352
+ dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1353
+ }
1354
+
1355
+ kernel void kernel_hardsigmoid(
1356
+ device const float * src0,
1357
+ device float * dst,
1358
+ uint tpig[[thread_position_in_grid]]) {
1359
+ device const float & x = src0[tpig];
1360
+ dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1361
+ }
1362
+
1363
+ kernel void kernel_exp(
1364
+ device const float * src0,
1365
+ device float * dst,
1366
+ uint tpig[[thread_position_in_grid]]) {
1367
+ dst[tpig] = exp(src0[tpig]);
1368
+ }
1369
+
1370
+ kernel void kernel_reglu(
1371
+ device const char * src0,
1372
+ device const char * src1,
1373
+ device char * dst,
1374
+ constant ggml_metal_kargs_glu & args,
1375
+ uint tgpig[[threadgroup_position_in_grid]],
1376
+ uint tpitg[[thread_position_in_threadgroup]],
1377
+ uint ntg[[threads_per_threadgroup]]) {
1378
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1379
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1380
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1381
+
1382
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1383
+ const float x0 = src0_row[i0];
1384
+ const float x1 = src1_row[i0];
1385
+
1386
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
1387
+ }
1388
+ }
1389
+
1390
+ kernel void kernel_geglu(
1391
+ device const char * src0,
1392
+ device const char * src1,
1393
+ device char * dst,
1394
+ constant ggml_metal_kargs_glu & args,
1395
+ uint tgpig[[threadgroup_position_in_grid]],
1396
+ uint tpitg[[thread_position_in_threadgroup]],
1397
+ uint ntg[[threads_per_threadgroup]]) {
1398
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1399
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1400
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1401
+
1402
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1403
+ const float x0 = src0_row[i0];
1404
+ const float x1 = src1_row[i0];
1405
+
1406
+ const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1407
+
1408
+ dst_row[i0] = gelu*x1;
1409
+ }
1410
+ }
1411
+
1412
+ kernel void kernel_swiglu(
1413
+ device const char * src0,
1414
+ device const char * src1,
1415
+ device char * dst,
1416
+ constant ggml_metal_kargs_glu & args,
1417
+ uint tgpig[[threadgroup_position_in_grid]],
1418
+ uint tpitg[[thread_position_in_threadgroup]],
1419
+ uint ntg[[threads_per_threadgroup]]) {
1420
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1421
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1422
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1423
+
1424
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1425
+ const float x0 = src0_row[i0];
1426
+ const float x1 = src1_row[i0];
1427
+
1428
+ const float silu = x0 / (1.0f + exp(-x0));
1429
+
1430
+ dst_row[i0] = silu*x1;
1431
+ }
1432
+ }
1433
+
1434
+ kernel void kernel_geglu_erf(
1435
+ device const char * src0,
1436
+ device const char * src1,
1437
+ device char * dst,
1438
+ constant ggml_metal_kargs_glu & args,
1439
+ uint tgpig[[threadgroup_position_in_grid]],
1440
+ uint tpitg[[thread_position_in_threadgroup]],
1441
+ uint ntg[[threads_per_threadgroup]]) {
1442
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1443
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1444
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1445
+
1446
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1447
+ const float x0 = src0_row[i0];
1448
+ const float x1 = src1_row[i0];
1449
+
1450
+ const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1451
+
1452
+ dst_row[i0] = gelu_erf*x1;
1453
+ }
1454
+ }
1455
+
1456
+ kernel void kernel_geglu_quick(
1457
+ device const char * src0,
1458
+ device const char * src1,
1459
+ device char * dst,
1460
+ constant ggml_metal_kargs_glu & args,
1461
+ uint tgpig[[threadgroup_position_in_grid]],
1462
+ uint tpitg[[thread_position_in_threadgroup]],
1463
+ uint ntg[[threads_per_threadgroup]]) {
1464
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1465
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1466
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1467
+
1468
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1469
+ const float x0 = src0_row[i0];
1470
+ const float x1 = src1_row[i0];
1471
+
1472
+ const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1473
+
1474
+ dst_row[i0] = gelu_quick*x1;
1475
+ }
1476
+ }
1477
+
1194
1478
  template <bool norm>
1195
1479
  kernel void kernel_sum_rows(
1196
1480
  constant ggml_metal_kargs_sum_rows & args,
@@ -1253,24 +1537,28 @@ kernel void kernel_soft_max(
1253
1537
  device char * dst,
1254
1538
  constant ggml_metal_kargs_soft_max & args,
1255
1539
  threadgroup float * buf [[threadgroup(0)]],
1256
- uint tgpig[[threadgroup_position_in_grid]],
1257
- uint tpitg[[thread_position_in_threadgroup]],
1540
+ uint3 tgpig[[threadgroup_position_in_grid]],
1541
+ uint3 tpitg[[thread_position_in_threadgroup]],
1258
1542
  uint sgitg[[simdgroup_index_in_threadgroup]],
1259
1543
  uint tiisg[[thread_index_in_simdgroup]],
1260
- uint ntg[[threads_per_threadgroup]]) {
1261
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1262
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1263
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1544
+ uint3 tptg[[threads_per_threadgroup]]) {
1545
+ const int32_t i03 = tgpig.z;
1546
+ const int32_t i02 = tgpig.y;
1547
+ const int32_t i01 = tgpig.x;
1548
+
1549
+ const int32_t i13 = i03%args.ne13;
1550
+ const int32_t i12 = i02%args.ne12;
1551
+ const int32_t i11 = i01;
1264
1552
 
1265
- device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1266
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
1267
- device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1553
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1554
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1555
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1268
1556
 
1269
1557
  float slope = 1.0f;
1270
1558
 
1271
1559
  // ALiBi
1272
1560
  if (args.max_bias > 0.0f) {
1273
- const int64_t h = i02;
1561
+ const int32_t h = i02;
1274
1562
 
1275
1563
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1276
1564
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1281,13 +1569,13 @@ kernel void kernel_soft_max(
1281
1569
  // parallel max
1282
1570
  float lmax = -INFINITY;
1283
1571
 
1284
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1572
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1285
1573
  lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1286
1574
  }
1287
1575
 
1288
1576
  // find the max value in the block
1289
1577
  float max_val = simd_max(lmax);
1290
- if (ntg > N_SIMDWIDTH) {
1578
+ if (tptg.x > N_SIMDWIDTH) {
1291
1579
  if (sgitg == 0) {
1292
1580
  buf[tiisg] = -INFINITY;
1293
1581
  }
@@ -1306,7 +1594,7 @@ kernel void kernel_soft_max(
1306
1594
 
1307
1595
  // parallel sum
1308
1596
  float lsum = 0.0f;
1309
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1597
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1310
1598
  const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1311
1599
  lsum += exp_psrc0;
1312
1600
  pdst[i00] = exp_psrc0;
@@ -1318,7 +1606,7 @@ kernel void kernel_soft_max(
1318
1606
 
1319
1607
  float sum = simd_sum(lsum);
1320
1608
 
1321
- if (ntg > N_SIMDWIDTH) {
1609
+ if (tptg.x > N_SIMDWIDTH) {
1322
1610
  if (sgitg == 0) {
1323
1611
  buf[tiisg] = 0.0f;
1324
1612
  }
@@ -1337,7 +1625,7 @@ kernel void kernel_soft_max(
1337
1625
 
1338
1626
  const float inv_sum = 1.0f/sum;
1339
1627
 
1340
- for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1628
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1341
1629
  pdst[i00] *= inv_sum;
1342
1630
  }
1343
1631
  }
@@ -1349,23 +1637,27 @@ kernel void kernel_soft_max_4(
1349
1637
  device char * dst,
1350
1638
  constant ggml_metal_kargs_soft_max & args,
1351
1639
  threadgroup float * buf [[threadgroup(0)]],
1352
- uint tgpig[[threadgroup_position_in_grid]],
1353
- uint tpitg[[thread_position_in_threadgroup]],
1640
+ uint3 tgpig[[threadgroup_position_in_grid]],
1641
+ uint3 tpitg[[thread_position_in_threadgroup]],
1354
1642
  uint sgitg[[simdgroup_index_in_threadgroup]],
1355
1643
  uint tiisg[[thread_index_in_simdgroup]],
1356
- uint ntg[[threads_per_threadgroup]]) {
1357
- const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1358
- const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1359
- const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1644
+ uint3 tptg[[threads_per_threadgroup]]) {
1645
+ const int32_t i03 = tgpig.z;
1646
+ const int32_t i02 = tgpig.y;
1647
+ const int32_t i01 = tgpig.x;
1648
+
1649
+ const int32_t i13 = i03%args.ne13;
1650
+ const int32_t i12 = i02%args.ne12;
1651
+ const int32_t i11 = i01;
1360
1652
 
1361
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1362
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1363
- device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1653
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1654
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1655
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1364
1656
 
1365
1657
  float slope = 1.0f;
1366
1658
 
1367
1659
  if (args.max_bias > 0.0f) {
1368
- const int64_t h = i02;
1660
+ const int32_t h = i02;
1369
1661
 
1370
1662
  const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1371
1663
  const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@@ -1376,14 +1668,14 @@ kernel void kernel_soft_max_4(
1376
1668
  // parallel max
1377
1669
  float4 lmax4 = -INFINITY;
1378
1670
 
1379
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1671
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1380
1672
  lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1381
1673
  }
1382
1674
 
1383
1675
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1384
1676
 
1385
1677
  float max_val = simd_max(lmax);
1386
- if (ntg > N_SIMDWIDTH) {
1678
+ if (tptg.x > N_SIMDWIDTH) {
1387
1679
  if (sgitg == 0) {
1388
1680
  buf[tiisg] = -INFINITY;
1389
1681
  }
@@ -1402,7 +1694,7 @@ kernel void kernel_soft_max_4(
1402
1694
 
1403
1695
  // parallel sum
1404
1696
  float4 lsum4 = 0.0f;
1405
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1697
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1406
1698
  const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1407
1699
  lsum4 += exp_psrc4;
1408
1700
  pdst4[i00] = exp_psrc4;
@@ -1416,7 +1708,7 @@ kernel void kernel_soft_max_4(
1416
1708
 
1417
1709
  float sum = simd_sum(lsum);
1418
1710
 
1419
- if (ntg > N_SIMDWIDTH) {
1711
+ if (tptg.x > N_SIMDWIDTH) {
1420
1712
  if (sgitg == 0) {
1421
1713
  buf[tiisg] = 0.0f;
1422
1714
  }
@@ -1435,7 +1727,7 @@ kernel void kernel_soft_max_4(
1435
1727
 
1436
1728
  const float inv_sum = 1.0f/sum;
1437
1729
 
1438
- for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1730
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1439
1731
  pdst4[i00] *= inv_sum;
1440
1732
  }
1441
1733
  }
@@ -1521,7 +1813,7 @@ kernel void kernel_ssm_conv_f32(
1521
1813
  x[0] = sumf;
1522
1814
  }
1523
1815
 
1524
- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1816
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
1525
1817
  kernel void kernel_ssm_scan_f32(
1526
1818
  device const void * src0,
1527
1819
  device const void * src1,
@@ -1529,47 +1821,222 @@ kernel void kernel_ssm_scan_f32(
1529
1821
  device const void * src3,
1530
1822
  device const void * src4,
1531
1823
  device const void * src5,
1824
+ device const void * src6,
1532
1825
  device float * dst,
1826
+ threadgroup float * shared [[threadgroup(0)]],
1533
1827
  constant ggml_metal_kargs_ssm_scan & args,
1534
- uint3 tgpig[[threadgroup_position_in_grid]],
1535
- uint3 tpitg[[thread_position_in_threadgroup]],
1536
- uint3 ntg[[threads_per_threadgroup]]) {
1537
- const int64_t ir = tgpig.x;
1538
- const int64_t i3 = tgpig.y;
1828
+ uint3 tgpig[[threadgroup_position_in_grid]],
1829
+ uint3 tpitg[[thread_position_in_threadgroup]],
1830
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1831
+ ushort tiisg[[thread_index_in_simdgroup]],
1832
+ ushort sgptg[[simdgroups_per_threadgroup]],
1833
+ uint3 tgpg[[threadgroups_per_grid]]) {
1834
+
1835
+ const int64_t i0 = tpitg.x;
1836
+ const int64_t i1 = 0;
1837
+ const int64_t ir = tgpig.x; // current head
1838
+ const int64_t i3 = tgpig.y; // current seq
1839
+
1840
+ const uint64_t nb00 = sizeof(float);
1841
+ const uint64_t nb10 = sizeof(float);
1842
+ const uint64_t nb20 = sizeof(float);
1539
1843
 
1540
1844
  const int64_t nc = args.d_state;
1541
- // const int64_t nr = args.d_inner;
1845
+ const int64_t nr = args.d_inner;
1846
+ const int64_t nh = args.n_head;
1847
+ const int64_t ng = args.n_group;
1542
1848
  const int64_t n_t = args.n_seq_tokens;
1543
- // const int64_t n_s = args.n_seqs;
1849
+
1850
+ const int64_t s_off = args.s_off;
1851
+
1852
+ device const int32_t * ids = (device const int32_t *) src6;
1853
+
1854
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1855
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1856
+ const int64_t i = i0 + i1*nc;
1857
+ float s0 = s0_buff[i];
1858
+ float s = s_buff[i];
1859
+
1860
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1861
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1862
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1863
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1864
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1865
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
1544
1866
 
1545
1867
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1546
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1547
- device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1548
- device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1549
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1550
- device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1551
- device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1552
- device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1553
- device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
1868
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1869
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1870
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1871
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1872
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
1873
+
1874
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1875
+ const float x_dt = x[0] * dt_soft_plus;
1876
+
1877
+ const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1878
+ s = state;
1879
+
1880
+ // Parallel sum: This relies on the fact that this kernel will be
1881
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1882
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1883
+ // compute y = sum({state * C[i] for i in range(d_state)}).
1884
+ // To parallelize this effectively, we first use simd_sum over each SIMD
1885
+ // group to compute the sum of each SIMD group, then place the result in
1886
+ // the SIMD group's indexed bucket in the shared memory. We then sum
1887
+ // over the individual group sums to compute the final sum.
1554
1888
 
1555
- if (i2 > 0) {
1556
- s0 = s;
1889
+ // Computed for each thread
1890
+ float sumf = state * C[i0];
1891
+
1892
+ // Sum the threads in the simd group => simd sum
1893
+ sumf = simd_sum(sumf);
1894
+
1895
+ if (sgptg > 1) {
1896
+
1897
+ // Once per simd group, place the group sum into the shared buffer
1898
+ if (tiisg == 0) {
1899
+ shared[sgitg] = sumf;
1900
+ }
1901
+
1902
+ // Wait for all threads in the threadgroup to reach this point. This
1903
+ // ensures that all elements of the shared buffer are populated with the
1904
+ // sum of the individual simd groups.
1905
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1906
+
1907
+ // For simd group 0 at indices < num simd groups, extract the shared
1908
+ // simd sum
1909
+ sumf = 0.0f;
1910
+ if (sgitg == 0) {
1911
+ if (tiisg < sgptg) {
1912
+ sumf = shared[tiisg];
1913
+ }
1914
+ sumf = simd_sum(sumf);
1915
+ if (tiisg == 0) {
1916
+ y[0] = sumf;
1917
+ }
1918
+ }
1919
+ } else if (tiisg == 0) {
1920
+ y[0] = sumf;
1557
1921
  }
1558
1922
 
1559
- // i1 == 0
1560
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1561
- float x_dt = x[0] * dt_soft_plus;
1562
- float sumf = 0.0f;
1923
+ // recurse
1924
+ s0 = s;
1925
+ }
1926
+
1927
+ // Assign the final state to the output buffer
1928
+ s_buff[i] = s;
1929
+ }
1563
1930
 
1564
- for (int64_t i0 = 0; i0 < nc; ++i0) {
1565
- int64_t i = i0;
1566
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1567
- sumf += state * C[i0];
1568
- s[i] = state;
1931
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1932
+ kernel void kernel_ssm_scan_f32_group(
1933
+ device const void * src0,
1934
+ device const void * src1,
1935
+ device const void * src2,
1936
+ device const void * src3,
1937
+ device const void * src4,
1938
+ device const void * src5,
1939
+ device const void * src6,
1940
+ device float * dst,
1941
+ threadgroup float * shared [[threadgroup(0)]],
1942
+ constant ggml_metal_kargs_ssm_scan & args,
1943
+ uint3 tgpig[[threadgroup_position_in_grid]],
1944
+ uint3 tpitg[[thread_position_in_threadgroup]],
1945
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1946
+ ushort tiisg[[thread_index_in_simdgroup]],
1947
+ ushort sgptg[[simdgroups_per_threadgroup]],
1948
+ uint3 tgpg[[threadgroups_per_grid]]) {
1949
+
1950
+ const int64_t i0 = tpitg.x;
1951
+ const int64_t i1 = tgpig.x;
1952
+ const int64_t ir = tgpig.y; // current head
1953
+ const int64_t i3 = tgpig.z; // current seq
1954
+
1955
+ const uint64_t nb00 = sizeof(float);
1956
+ const uint64_t nb10 = sizeof(float);
1957
+ const uint64_t nb20 = sizeof(float);
1958
+
1959
+ const int64_t nc = args.d_state;
1960
+ const int64_t nr = args.d_inner;
1961
+ const int64_t nh = args.n_head;
1962
+ const int64_t ng = args.n_group;
1963
+ const int64_t n_t = args.n_seq_tokens;
1964
+
1965
+ const int64_t s_off = args.s_off;
1966
+
1967
+ device const int32_t * ids = (device const int32_t *) src6;
1968
+
1969
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1970
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1971
+ const int64_t i = i0 + i1*nc;
1972
+ float s0 = s0_buff[i];
1973
+ float s = s_buff[i];
1974
+
1975
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1976
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1977
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1978
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1979
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1980
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
1981
+
1982
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
1983
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1984
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1985
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1986
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1987
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
1988
+
1989
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1990
+ const float x_dt = x[0] * dt_soft_plus;
1991
+ const float dA = exp(dt_soft_plus * A[0]);
1992
+
1993
+ const float state = (s0 * dA) + (B[i0] * x_dt);
1994
+ s = state;
1995
+
1996
+ // Parallel sum: This relies on the fact that this kernel will be
1997
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1998
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1999
+ // compute y = sum({state * C[i] for i in range(d_state)}).
2000
+ // To parallelize this effectively, we first use simd_sum over each SIMD
2001
+ // group to compute the sum of each SIMD group, then place the result in
2002
+ // the SIMD group's indexed bucket in the shared memory. We then sum
2003
+ // over the individual group sums to compute the final sum.
2004
+
2005
+ // Computed for each thread
2006
+ float sumf = state * C[i0];
2007
+
2008
+ // Sum the threads in the simd group => simd sum
2009
+ sumf = simd_sum(sumf);
2010
+
2011
+ // Once per simd group, place the group sum into the shared buffer
2012
+ if (tiisg == 0) {
2013
+ shared[sgitg] = sumf;
1569
2014
  }
1570
2015
 
1571
- y[0] = sumf;
2016
+ // Wait for all threads in the threadgroup to reach this point. This
2017
+ // ensures that all elements of the shared buffer are populated with the
2018
+ // sum of the individual simd groups.
2019
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2020
+
2021
+ // For simd group 0 at indices < num simd groups, extract the shared
2022
+ // simd sum
2023
+ sumf = 0.0f;
2024
+ if (sgitg == 0) {
2025
+ if (tiisg < sgptg) {
2026
+ sumf = shared[tiisg];
2027
+ }
2028
+ sumf = simd_sum(sumf);
2029
+ if (tiisg == 0) {
2030
+ y[0] = sumf;
2031
+ }
2032
+ }
2033
+
2034
+ // recurse
2035
+ s0 = s;
1572
2036
  }
2037
+
2038
+ // Assign the final state to the output buffer
2039
+ s_buff[i] = s;
1573
2040
  }
1574
2041
 
1575
2042
  kernel void kernel_rwkv_wkv6_f32(
@@ -1874,26 +2341,39 @@ kernel void kernel_norm(
1874
2341
  }
1875
2342
  }
1876
2343
 
1877
- kernel void kernel_rms_norm(
2344
+ // F == 1 : rms_norm (no fuse)
2345
+ // F == 2 : rms_norm + mul
2346
+ // F == 3 : rms_norm + mul + add
2347
+ template <short F>
2348
+ kernel void kernel_rms_norm_fuse_impl(
1878
2349
  constant ggml_metal_kargs_rms_norm & args,
1879
2350
  device const char * src0,
2351
+ device const char * src1_0,
2352
+ device const char * src1_1,
1880
2353
  device char * dst,
1881
2354
  threadgroup float * shmem_f32 [[threadgroup(0)]],
1882
- uint tgpig[[threadgroup_position_in_grid]],
1883
- ushort tpitg[[thread_position_in_threadgroup]],
1884
- ushort sgitg[[simdgroup_index_in_threadgroup]],
1885
- ushort tiisg[[thread_index_in_simdgroup]],
1886
- ushort ntg[[threads_per_threadgroup]]) {
2355
+ uint3 tgpig[[threadgroup_position_in_grid]],
2356
+ ushort3 tpitg[[thread_position_in_threadgroup]],
2357
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2358
+ ushort tiisg[[thread_index_in_simdgroup]],
2359
+ ushort3 ntg[[threads_per_threadgroup]]) {
1887
2360
  if (sgitg == 0) {
1888
2361
  shmem_f32[tiisg] = 0.0f;
1889
2362
  }
1890
2363
 
1891
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
2364
+ const int i01 = tgpig.x;
2365
+ const int i02 = tgpig.y;
2366
+ const int i03 = tgpig.z;
2367
+
2368
+ device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2369
+
2370
+ device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2371
+ device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
1892
2372
 
1893
2373
  float sumf = 0.0f;
1894
2374
 
1895
2375
  // parallel sum
1896
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2376
+ for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
1897
2377
  sumf += dot(x[i00], x[i00]);
1898
2378
  }
1899
2379
  sumf = simd_sum(sumf);
@@ -1912,12 +2392,26 @@ kernel void kernel_rms_norm(
1912
2392
  const float mean = sumf/args.ne00;
1913
2393
  const float scale = 1.0f/sqrt(mean + args.eps);
1914
2394
 
1915
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1916
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1917
- y[i00] = x[i00] * scale;
2395
+ device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2396
+ for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2397
+ if (F == 1) {
2398
+ y[i00] = (x[i00]*scale);
2399
+ }
2400
+ if (F == 2) {
2401
+ y[i00] = (x[i00]*scale)*f0[i00];
2402
+ }
2403
+ if (F == 3) {
2404
+ y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
2405
+ }
1918
2406
  }
1919
2407
  }
1920
2408
 
2409
+ typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
2410
+
2411
+ template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
2412
+ template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
2413
+ template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
2414
+
1921
2415
  kernel void kernel_l2_norm(
1922
2416
  constant ggml_metal_kargs_l2_norm & args,
1923
2417
  device const char * src0,
@@ -3709,7 +4203,7 @@ kernel void kernel_flash_attn_ext(
3709
4203
  // load the mask in shared memory
3710
4204
  #pragma unroll(Q)
3711
4205
  for (short j = 0; j < Q; ++j) {
3712
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
4206
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
3713
4207
 
3714
4208
  const float m = pm[ic + tiisg];
3715
4209
 
@@ -4195,7 +4689,7 @@ kernel void kernel_flash_attn_ext_vec(
4195
4689
  const bool has_mask = mask != q;
4196
4690
 
4197
4691
  // pointer to the mask
4198
- device const half * pm = (device const half *) (mask + iq1*args.nb31);
4692
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
4199
4693
 
4200
4694
  float slope = 1.0f;
4201
4695