@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
217
217
  GGML_METAL_KERNEL_TYPE_NORM,
218
218
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
219
219
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
220
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
220
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
221
222
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
222
223
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -526,6 +527,11 @@ enum ggml_metal_kernel_type {
526
527
  GGML_METAL_KERNEL_TYPE_SIN,
527
528
  GGML_METAL_KERNEL_TYPE_COS,
528
529
  GGML_METAL_KERNEL_TYPE_NEG,
530
+ GGML_METAL_KERNEL_TYPE_REGLU,
531
+ GGML_METAL_KERNEL_TYPE_GEGLU,
532
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
533
+ GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
534
+ GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
529
535
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
530
536
  GGML_METAL_KERNEL_TYPE_MEAN,
531
537
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1193,6 +1199,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1193
1199
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1194
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1195
1201
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1202
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1196
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1197
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1198
1205
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1502,6 +1509,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1502
1509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1503
1510
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1504
1511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1513
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1514
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1505
1517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1506
1518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1507
1519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1680,6 +1692,17 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1680
1692
  default:
1681
1693
  return false;
1682
1694
  }
1695
+ case GGML_OP_GLU:
1696
+ switch (ggml_get_glu_op(op)) {
1697
+ case GGML_GLU_OP_REGLU:
1698
+ case GGML_GLU_OP_GEGLU:
1699
+ case GGML_GLU_OP_SWIGLU:
1700
+ case GGML_GLU_OP_GEGLU_ERF:
1701
+ case GGML_GLU_OP_GEGLU_QUICK:
1702
+ return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1703
+ default:
1704
+ return false;
1705
+ }
1683
1706
  case GGML_OP_NONE:
1684
1707
  case GGML_OP_RESHAPE:
1685
1708
  case GGML_OP_VIEW:
@@ -1710,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1710
1733
  case GGML_OP_MEAN:
1711
1734
  case GGML_OP_SOFT_MAX:
1712
1735
  case GGML_OP_GROUP_NORM:
1713
- return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1736
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
1714
1737
  case GGML_OP_RMS_NORM:
1715
1738
  case GGML_OP_L2_NORM:
1716
1739
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -2233,7 +2256,9 @@ static bool ggml_metal_encode_node(
2233
2256
  GGML_ASSERT(ggml_is_contiguous(src0));
2234
2257
 
2235
2258
  float scale;
2236
- memcpy(&scale, dst->op_params, sizeof(scale));
2259
+ float bias;
2260
+ memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
2261
+ memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
2237
2262
 
2238
2263
  int64_t n = ggml_nelements(dst);
2239
2264
 
@@ -2250,6 +2275,7 @@ static bool ggml_metal_encode_node(
2250
2275
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2251
2276
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2252
2277
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
2278
+ [encoder setBytes:&bias length:sizeof(bias) atIndex:3];
2253
2279
 
2254
2280
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2255
2281
  } break;
@@ -2419,6 +2445,68 @@ static bool ggml_metal_encode_node(
2419
2445
  GGML_ABORT("fatal error");
2420
2446
  }
2421
2447
  } break;
2448
+ case GGML_OP_GLU:
2449
+ {
2450
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2451
+
2452
+ if (src1) {
2453
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
2454
+ }
2455
+
2456
+ id<MTLComputePipelineState> pipeline = nil;
2457
+
2458
+ switch (ggml_get_glu_op(node)) {
2459
+ case GGML_GLU_OP_REGLU:
2460
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2461
+ break;
2462
+ case GGML_GLU_OP_GEGLU:
2463
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2464
+ break;
2465
+ case GGML_GLU_OP_SWIGLU:
2466
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2467
+ break;
2468
+ case GGML_GLU_OP_GEGLU_ERF:
2469
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2470
+ break;
2471
+ case GGML_GLU_OP_GEGLU_QUICK:
2472
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
2473
+ break;
2474
+ default:
2475
+ GGML_ABORT("fatal error");
2476
+ }
2477
+
2478
+ const int32_t swp = ((const int32_t *) dst->op_params)[1];
2479
+
2480
+ const int32_t i00 = swp ? ne0 : 0;
2481
+ const int32_t i10 = swp ? 0 : ne0;
2482
+
2483
+ ggml_metal_kargs_glu args = {
2484
+ /*.ne00 =*/ ne00,
2485
+ /*.nb01 =*/ nb01,
2486
+ /*.ne10 =*/ src1 ? ne10 : ne00,
2487
+ /*.nb11 =*/ src1 ? nb11 : nb01,
2488
+ /*.ne0 =*/ ne0,
2489
+ /*.nb1 =*/ nb1,
2490
+ /*.i00 =*/ src1 ? 0 : i00,
2491
+ /*.i10 =*/ src1 ? 0 : i10,
2492
+ };
2493
+
2494
+ [encoder setComputePipelineState:pipeline];
2495
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2496
+ if (src1) {
2497
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2498
+ } else {
2499
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2500
+ }
2501
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2502
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2503
+
2504
+ const int64_t nrows = ggml_nrows(src0);
2505
+
2506
+ const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2507
+
2508
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2509
+ } break;
2422
2510
  case GGML_OP_SQR:
2423
2511
  {
2424
2512
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -2573,10 +2661,7 @@ static bool ggml_metal_encode_node(
2573
2661
  memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
2574
2662
  memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
2575
2663
 
2576
- const int64_t nrows_x = ggml_nrows(src0);
2577
- const int64_t nrows_y = src0->ne[1];
2578
-
2579
- const uint32_t n_head = nrows_x/nrows_y;
2664
+ const uint32_t n_head = src0->ne[2];
2580
2665
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2581
2666
 
2582
2667
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2636,6 +2721,18 @@ static bool ggml_metal_encode_node(
2636
2721
  /*.ne00 =*/ ne00,
2637
2722
  /*.ne01 =*/ ne01,
2638
2723
  /*.ne02 =*/ ne02,
2724
+ /*.nb01 =*/ nb01,
2725
+ /*.nb02 =*/ nb02,
2726
+ /*.nb03 =*/ nb03,
2727
+ /*.ne11 =*/ ne11,
2728
+ /*.ne12 =*/ ne12,
2729
+ /*.ne13 =*/ ne13,
2730
+ /*.nb11 =*/ nb11,
2731
+ /*.nb12 =*/ nb12,
2732
+ /*.nb13 =*/ nb13,
2733
+ /*.nb1 =*/ nb1,
2734
+ /*.nb2 =*/ nb2,
2735
+ /*.nb3 =*/ nb3,
2639
2736
  /*.scale =*/ scale,
2640
2737
  /*.max_bias =*/ max_bias,
2641
2738
  /*.m0 =*/ m0,
@@ -2655,7 +2752,7 @@ static bool ggml_metal_encode_node(
2655
2752
 
2656
2753
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2657
2754
 
2658
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2755
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2659
2756
  } break;
2660
2757
  case GGML_OP_DIAG_MASK_INF:
2661
2758
  {
@@ -2729,71 +2826,91 @@ static bool ggml_metal_encode_node(
2729
2826
  struct ggml_tensor * src3 = node->src[3];
2730
2827
  struct ggml_tensor * src4 = node->src[4];
2731
2828
  struct ggml_tensor * src5 = node->src[5];
2829
+ struct ggml_tensor * src6 = node->src[6];
2732
2830
 
2733
2831
  GGML_ASSERT(src3);
2734
2832
  GGML_ASSERT(src4);
2735
2833
  GGML_ASSERT(src5);
2834
+ GGML_ASSERT(src6);
2736
2835
 
2737
2836
  size_t offs_src3 = 0;
2738
2837
  size_t offs_src4 = 0;
2739
2838
  size_t offs_src5 = 0;
2839
+ size_t offs_src6 = 0;
2740
2840
 
2741
2841
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2742
2842
  id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
2743
2843
  id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
2844
+ id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
2744
2845
 
2745
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
2846
+ const int64_t ne30 = src3->ne[0];
2746
2847
  const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
2747
2848
 
2748
- const uint64_t nb30 = src3->nb[0];
2849
+ const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
2749
2850
  const uint64_t nb31 = src3->nb[1];
2750
2851
 
2751
2852
  const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
2752
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
2853
+ const int64_t ne41 = src4->ne[1];
2753
2854
  const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
2855
+ const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
2754
2856
 
2755
- const uint64_t nb40 = src4->nb[0];
2857
+ const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
2756
2858
  const uint64_t nb41 = src4->nb[1];
2757
2859
  const uint64_t nb42 = src4->nb[2];
2860
+ const uint64_t nb43 = src4->nb[3];
2758
2861
 
2759
2862
  const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
2760
2863
  const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
2761
2864
  const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
2865
+ const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
2762
2866
 
2763
- const uint64_t nb50 = src5->nb[0];
2867
+ const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
2764
2868
  const uint64_t nb51 = src5->nb[1];
2765
2869
  const uint64_t nb52 = src5->nb[2];
2870
+ const uint64_t nb53 = src5->nb[3];
2871
+
2872
+ const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
2873
+
2874
+ const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
2766
2875
 
2767
2876
  const int64_t d_state = ne00;
2768
2877
  const int64_t d_inner = ne01;
2769
- const int64_t n_seq_tokens = ne11;
2770
- const int64_t n_seqs = ne02;
2878
+ const int64_t n_head = ne02;
2879
+ const int64_t n_group = ne41;
2880
+ const int64_t n_seq_tokens = ne12;
2881
+ const int64_t n_seqs = ne13;
2771
2882
 
2772
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2883
+ id<MTLComputePipelineState> pipeline = nil;
2884
+
2885
+ if (ne30 == 1) {
2886
+ // Mamba-2
2887
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
2888
+ } else {
2889
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2890
+ }
2773
2891
 
2774
2892
  ggml_metal_kargs_ssm_scan args = {
2775
- /*.d_state =*/ d_state,
2776
- /*.d_inner =*/ d_inner,
2893
+ /*.d_state =*/ d_state,
2894
+ /*.d_inner =*/ d_inner,
2895
+ /*.n_head =*/ n_head,
2896
+ /*.n_group =*/ n_group,
2777
2897
  /*.n_seq_tokens =*/ n_seq_tokens,
2778
- /*.n_seqs =*/ n_seqs,
2779
- /*.nb00 =*/ nb00,
2780
- /*.nb01 =*/ nb01,
2781
- /*.nb02 =*/ nb02,
2782
- /*.nb10 =*/ nb10,
2783
- /*.nb11 =*/ nb11,
2784
- /*.nb12 =*/ nb12,
2785
- /*.nb13 =*/ nb13,
2786
- /*.nb20 =*/ nb20,
2787
- /*.nb21 =*/ nb21,
2788
- /*.nb22 =*/ nb22,
2789
- /*.nb30 =*/ nb30,
2790
- /*.nb31 =*/ nb31,
2791
- /*.nb40 =*/ nb40,
2792
- /*.nb41 =*/ nb41,
2793
- /*.nb42 =*/ nb42,
2794
- /*.nb50 =*/ nb50,
2795
- /*.nb51 =*/ nb51,
2796
- /*.nb52 =*/ nb52,
2898
+ /*.n_seqs =*/ n_seqs,
2899
+ /*.nb01 =*/ nb01,
2900
+ /*.nb02 =*/ nb02,
2901
+ /*.nb03 =*/ nb03,
2902
+ /*.nb11 =*/ nb11,
2903
+ /*.nb12 =*/ nb12,
2904
+ /*.nb13 =*/ nb13,
2905
+ /*.nb21 =*/ nb21,
2906
+ /*.nb22 =*/ nb22,
2907
+ /*.nb31 =*/ nb31,
2908
+ /*.nb41 =*/ nb41,
2909
+ /*.nb42 =*/ nb42,
2910
+ /*.nb43 =*/ nb43,
2911
+ /*.nb51 =*/ nb51,
2912
+ /*.nb52 =*/ nb52,
2913
+ /*.nb53 =*/ nb53,
2797
2914
  };
2798
2915
 
2799
2916
  [encoder setComputePipelineState:pipeline];
@@ -2803,10 +2920,17 @@ static bool ggml_metal_encode_node(
2803
2920
  [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2804
2921
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2805
2922
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2806
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2807
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
2923
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2924
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2925
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
2808
2926
 
2809
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2927
+ if (ne30 == 1) {
2928
+ // Mamba-2
2929
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2930
+ } else {
2931
+ GGML_ASSERT(d_inner == 1);
2932
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2933
+ }
2810
2934
  } break;
2811
2935
  case GGML_OP_RWKV_WKV6:
2812
2936
  {
@@ -4908,7 +5032,11 @@ static bool ggml_metal_encode_node(
4908
5032
  /*.nb21 =*/ nb21,
4909
5033
  /*.nb22 =*/ nb22,
4910
5034
  /*.nb23 =*/ nb23,
5035
+ /*.ne32 =*/ ne32,
5036
+ /*.ne33 =*/ ne33,
4911
5037
  /*.nb31 =*/ nb31,
5038
+ /*.nb32 =*/ nb32,
5039
+ /*.nb33 =*/ nb33,
4912
5040
  /*.ne1 =*/ ne1,
4913
5041
  /*.ne2 =*/ ne2,
4914
5042
  /*.scale =*/ scale,