@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -11,6 +11,8 @@
11
11
  #include "ggml-cuda/clamp.cuh"
12
12
  #include "ggml-cuda/concat.cuh"
13
13
  #include "ggml-cuda/conv-transpose-1d.cuh"
14
+ #include "ggml-cuda/conv2d-dw.cuh"
15
+ #include "ggml-cuda/conv2d-transpose.cuh"
14
16
  #include "ggml-cuda/convert.cuh"
15
17
  #include "ggml-cuda/count-equal.cuh"
16
18
  #include "ggml-cuda/cpy.cuh"
@@ -35,6 +37,7 @@
35
37
  #include "ggml-cuda/ssm-scan.cuh"
36
38
  #include "ggml-cuda/sum.cuh"
37
39
  #include "ggml-cuda/sumrows.cuh"
40
+ #include "ggml-cuda/mean.cuh"
38
41
  #include "ggml-cuda/tsembd.cuh"
39
42
  #include "ggml-cuda/unary.cuh"
40
43
  #include "ggml-cuda/upscale.cuh"
@@ -47,6 +50,7 @@
47
50
  #include <atomic>
48
51
  #include <charconv>
49
52
  #include <cinttypes>
53
+ #include <condition_variable>
50
54
  #include <cstddef>
51
55
  #include <cstdint>
52
56
  #include <float.h>
@@ -54,9 +58,8 @@
54
58
  #include <map>
55
59
  #include <memory>
56
60
  #include <mutex>
57
- #include <stdint.h>
58
- #include <stdio.h>
59
61
  #include <stdarg.h>
62
+ #include <stdio.h>
60
63
  #include <stdlib.h>
61
64
  #include <string>
62
65
  #include <vector>
@@ -97,8 +100,7 @@ int ggml_cuda_get_device() {
97
100
  static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
98
101
  ggml_cuda_set_device(device);
99
102
  cudaError_t err;
100
- if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
101
- {
103
+ if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
102
104
  err = cudaMallocManaged(ptr, size);
103
105
  #if defined(GGML_USE_HIP)
104
106
  if (err == hipSuccess) {
@@ -116,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
116
118
  err = cudaMalloc(ptr, size);
117
119
  }
118
120
  #endif // defined(GGML_USE_HIP)
119
- }
120
- else
121
- {
121
+ } else {
122
122
  err = cudaMalloc(ptr, size);
123
123
  }
124
124
  return err;
@@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514
514
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
515
515
  }
516
516
 
517
+ // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
518
+ // this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
519
+
520
+ static std::mutex ggml_cuda_lock;
521
+ static std::condition_variable ggml_cuda_lock_cv;
522
+ static std::atomic<int> ggml_cuda_lock_counter;
523
+
524
+ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
525
+ std::unique_lock<std::mutex> lock(ggml_cuda_lock);
526
+ ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
527
+
528
+ if (copy_event != nullptr) {
529
+ CUDA_CHECK(cudaEventDestroy(copy_event));
530
+ }
531
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
532
+ for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
533
+ if (streams[i][j] != nullptr) {
534
+ CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
535
+ }
536
+ }
537
+ if (cublas_handles[i] != nullptr) {
538
+ CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
539
+ }
540
+ }
541
+ }
542
+
543
+
517
544
  // cuda buffer
518
545
 
519
546
  struct ggml_backend_cuda_buffer_context {
@@ -1200,9 +1227,12 @@ static void ggml_cuda_op_mul_mat_cublas(
1200
1227
 
1201
1228
  const int cc = ggml_cuda_info().devices[id].cc;
1202
1229
 
1230
+ const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1231
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1232
+
1203
1233
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1204
1234
 
1205
- if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1235
+ if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1206
1236
  ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
1207
1237
  if (src1->type != GGML_TYPE_BF16) {
1208
1238
  const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1230,7 +1260,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1230
1260
 
1231
1261
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1232
1262
  to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1233
- } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1263
+ } else if (fast_fp16_hardware_available(cc) && use_fp16) {
1234
1264
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1235
1265
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1236
1266
  if (src0->type != GGML_TYPE_F16) {
@@ -1719,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
1719
1749
  }
1720
1750
 
1721
1751
  static __global__ void k_compute_batched_ptrs(
1722
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1752
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1723
1753
  const void ** ptrs_src, void ** ptrs_dst,
1724
1754
  int64_t ne12, int64_t ne13,
1725
1755
  int64_t ne23,
@@ -1742,83 +1772,131 @@ static __global__ void k_compute_batched_ptrs(
1742
1772
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1743
1773
  }
1744
1774
 
1745
- static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1775
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1776
+ template<ggml_type T>
1777
+ struct batched_mul_mat_traits;
1778
+
1779
+ template<>
1780
+ struct batched_mul_mat_traits<GGML_TYPE_F32> {
1781
+ using cuda_type = float;
1782
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1784
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785
+ static inline const float alpha = 1.0f;
1786
+ static inline const float beta = 0.0f;
1787
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1788
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1789
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1790
+ };
1791
+
1792
+ template<>
1793
+ struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1794
+ using cuda_type = nv_bfloat16;
1795
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798
+ static inline const float alpha = 1.0f;
1799
+ static inline const float beta = 0.0f;
1800
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1801
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1802
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1803
+ };
1804
+
1805
+ template<>
1806
+ struct batched_mul_mat_traits<GGML_TYPE_F16> {
1807
+ using cuda_type = half;
1808
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1810
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811
+ static inline const half alpha = 1.0;
1812
+ static inline const half beta = 0.0;
1813
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
1814
+ static inline const void* get_beta() { static const half val = beta; return &val; }
1815
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1816
+ };
1817
+
1818
+ template<ggml_type src0_type>
1819
+ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820
+ using traits = batched_mul_mat_traits<src0_type>;
1821
+ using cuda_t = typename traits::cuda_type;
1822
+
1746
1823
  GGML_ASSERT(!ggml_is_transposed(src0));
1747
1824
  GGML_ASSERT(!ggml_is_transposed(src1));
1748
-
1749
1825
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1750
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
1826
+ GGML_ASSERT(src0->type == src0_type);
1827
+ GGML_ASSERT(ggml_is_contiguous(dst));
1751
1828
 
1752
1829
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1753
1830
  // As long as dst is contiguous this does not matter though.
1754
- GGML_ASSERT(ggml_is_contiguous(dst));
1755
1831
 
1756
1832
  GGML_TENSOR_BINARY_OP_LOCALS
1757
1833
 
1758
1834
  const int64_t ne_dst = ggml_nelements(dst);
1759
-
1760
1835
  cudaStream_t main_stream = ctx.stream();
1761
-
1762
1836
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1763
1837
 
1764
- const half * src0_f16 = (const half *) src0->data;
1765
1838
  float * dst_ddf = (float *) dst->data;
1766
-
1767
- const half * src1_f16 = (const half *) src1->data;
1768
1839
  const size_t ts_src1 = ggml_type_size(src1->type);
1769
1840
  GGML_ASSERT(nb10 == ts_src1);
1770
1841
  int64_t s11 = nb11 / ts_src1;
1771
1842
  int64_t s12 = nb12 / ts_src1;
1772
1843
  int64_t s13 = nb13 / ts_src1;
1773
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1774
1844
 
1775
- // convert src1 to fp16
1776
- if (src1->type != GGML_TYPE_F16) {
1777
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1778
- const int64_t ne_src1 = ggml_nelements(src1);
1779
- src1_f16_alloc.alloc(ne_src1);
1780
- GGML_ASSERT(to_fp16_cuda != nullptr);
1845
+ const cuda_t * src0_ptr = nullptr;
1846
+ const cuda_t * src1_ptr = nullptr;
1847
+
1848
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1781
1850
 
1782
- to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851
+ // Handle src0
1852
+ src0_ptr = (const cuda_t *) src0->data;
1783
1853
 
1784
- src1_f16 = src1_f16_alloc.get();
1854
+ // Handle src1 - convert if necessary
1855
+ if (src1->type == src0_type) {
1856
+ src1_ptr = (const cuda_t *) src1->data;
1857
+ } else {
1858
+ // Convert src1 to target type using traits conversion functions
1859
+ const int64_t ne_src1 = ggml_nelements(src1);
1860
+ src1_alloc.alloc(ne_src1);
1861
+
1862
+ const auto convert_func = traits::get_nc_converter(src1->type);
1863
+ GGML_ASSERT(convert_func != nullptr);
1864
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865
+ src1_ptr = src1_alloc.get();
1785
1866
  s11 = ne10;
1786
1867
  s12 = ne11*s11;
1787
1868
  s13 = ne12*s12;
1788
1869
  }
1789
1870
 
1790
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1871
+ // Setup destination buffer
1872
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1791
1873
  char * dst_t;
1792
-
1793
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1794
- cudaDataType_t cu_data_type = CUDA_R_16F;
1795
-
1796
- // dst strides
1797
1874
  size_t nbd2 = dst->nb[2];
1798
1875
  size_t nbd3 = dst->nb[3];
1799
1876
 
1800
- const half alpha_f16 = 1.0f;
1801
- const half beta_f16 = 0.0f;
1802
-
1877
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1878
+ cudaDataType_t cu_data_type = traits::data_type;
1879
+ cudaDataType_t cu_data_type_a = traits::data_type;
1880
+ cudaDataType_t cu_data_type_b = traits::data_type;
1881
+ const void * alpha = traits::get_alpha();
1882
+ const void * beta = traits::get_beta();
1803
1883
  const float alpha_f32 = 1.0f;
1804
- const float beta_f32 = 0.0f;
1805
-
1806
- const void * alpha = &alpha_f16;
1807
- const void * beta = &beta_f16;
1884
+ const float beta_f32 = 0.0f;
1808
1885
 
1809
1886
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1810
- dst_t = (char *) dst_f16.alloc(ne_dst);
1811
-
1812
- nbd2 /= sizeof(float) / sizeof(half);
1813
- nbd3 /= sizeof(float) / sizeof(half);
1887
+ if constexpr (src0_type == GGML_TYPE_F32) {
1888
+ dst_t = (char *) dst_ddf; // Direct F32 output
1889
+ } else {
1890
+ dst_t = (char *) dst_temp.alloc(ne_dst);
1891
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
1892
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
1893
+ }
1814
1894
  } else {
1815
1895
  dst_t = (char *) dst_ddf;
1816
-
1817
1896
  cu_compute_type = CUBLAS_COMPUTE_32F;
1818
- cu_data_type = CUDA_R_32F;
1819
-
1897
+ cu_data_type = CUDA_R_32F;
1820
1898
  alpha = &alpha_f32;
1821
- beta = &beta_f32;
1899
+ beta = &beta_f32;
1822
1900
  }
1823
1901
 
1824
1902
  int id = ggml_cuda_get_device();
@@ -1826,7 +1904,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1826
1904
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1827
1905
  cu_compute_type = CUBLAS_COMPUTE_32F;
1828
1906
  alpha = &alpha_f32;
1829
- beta = &beta_f32;
1907
+ beta = &beta_f32;
1830
1908
  }
1831
1909
 
1832
1910
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1836,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1836
1914
  const int64_t r2 = ne12/ne02;
1837
1915
  const int64_t r3 = ne13/ne03;
1838
1916
 
1839
- #if 0
1840
- // use cublasGemmEx
1841
- {
1842
- for (int i13 = 0; i13 < ne13; ++i13) {
1843
- for (int i12 = 0; i12 < ne12; ++i12) {
1844
- int i03 = i13 / r3;
1845
- int i02 = i12 / r2;
1846
-
1847
- CUBLAS_CHECK(
1848
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1849
- ne01, ne11, ne10,
1850
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1851
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1852
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1853
- cu_compute_type,
1854
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1855
- }
1856
- }
1857
- }
1858
- #else
1859
1917
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1860
1918
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1861
1919
  // use cublasGemmStridedBatchedEx
1862
1920
  CUBLAS_CHECK(
1863
1921
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1864
1922
  ne01, ne11, ne10,
1865
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1866
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1867
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1923
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1868
1926
  ne12*ne13,
1869
1927
  cu_compute_type,
1870
1928
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1875,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1875
1933
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1876
1934
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1877
1935
 
1936
+ size_t src1_stride_size = sizeof(cuda_t);
1937
+
1878
1938
  dim3 block_dims(ne13, ne12);
1879
1939
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1880
- src0_f16, src1_f16, dst_t,
1940
+ src0_ptr, src1_ptr, dst_t,
1881
1941
  ptrs_src.get(), ptrs_dst.get(),
1882
1942
  ne12, ne13,
1883
1943
  ne23,
1884
1944
  nb02, nb03,
1885
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1886
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1945
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1946
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1887
1947
  nbd2, nbd3,
1888
1948
  r2, r3);
1949
+
1889
1950
  CUDA_CHECK(cudaGetLastError());
1890
1951
 
1891
1952
  CUBLAS_CHECK(
1892
1953
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1893
1954
  ne01, ne11, ne10,
1894
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1895
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1896
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1955
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1956
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1957
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1897
1958
  ne23,
1898
1959
  cu_compute_type,
1899
1960
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1900
1961
  }
1901
- #endif
1902
1962
 
1903
- if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1904
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1905
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1963
+ // Convert output back to F32 if needed
1964
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967
+ }
1968
+ }
1969
+
1970
+ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972
+
1973
+ switch (src0->type) {
1974
+ case GGML_TYPE_F32:
1975
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976
+ break;
1977
+ case GGML_TYPE_BF16:
1978
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979
+ break;
1980
+ case GGML_TYPE_F16:
1981
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982
+ break;
1983
+ default:
1984
+ GGML_ABORT("Unsupported type");
1906
1985
  }
1907
1986
  }
1908
1987
 
@@ -1916,16 +1995,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1916
1995
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
1917
1996
 
1918
1997
  bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
1919
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1920
- && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
1998
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1921
1999
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
1922
2000
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1923
2001
  && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
1924
2002
  bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
1925
2003
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1926
2004
 
1927
- bool any_gpus_with_slow_fp16 = false;
1928
- bool any_gpus_without_fp16_mma = false;
2005
+ bool any_gpus_with_slow_fp16 = false;
1929
2006
 
1930
2007
  if (split) {
1931
2008
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1936,16 +2013,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1936
2013
  continue;
1937
2014
  }
1938
2015
 
1939
- const int cc = ggml_cuda_info().devices[id].cc;
1940
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1941
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1942
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
2016
+ const int cc = ggml_cuda_info().devices[id].cc;
2017
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2018
+ use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2019
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1943
2020
  }
1944
2021
  } else {
1945
- const int cc = ggml_cuda_info().devices[ctx.device].cc;
1946
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1947
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1948
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
2022
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
2023
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2024
+ use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2025
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1949
2026
  }
1950
2027
 
1951
2028
  // debug helpers
@@ -1956,7 +2033,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1956
2033
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1957
2034
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1958
2035
 
1959
- if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
2036
+ //TODO update for generic tensor parallelism
2037
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
+
2042
+ if (!split && use_mul_mat_vec) {
1960
2043
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
1961
2044
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
1962
2045
  ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
@@ -1964,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1964
2047
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1965
2048
  } else if (!split && use_mul_mat_q) {
1966
2049
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1967
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1968
- !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2050
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1969
2052
  // general KQ + KQV multi-batch without FlashAttention
1970
2053
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1971
2054
  } else if (use_mul_mat_vec) {
@@ -2220,6 +2303,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2220
2303
  return false;
2221
2304
  }
2222
2305
  break;
2306
+ case GGML_OP_GLU:
2307
+ switch (ggml_get_glu_op(dst)) {
2308
+ case GGML_GLU_OP_REGLU:
2309
+ ggml_cuda_op_reglu(ctx, dst);
2310
+ break;
2311
+ case GGML_GLU_OP_GEGLU:
2312
+ ggml_cuda_op_geglu(ctx, dst);
2313
+ break;
2314
+ case GGML_GLU_OP_SWIGLU:
2315
+ ggml_cuda_op_swiglu(ctx, dst);
2316
+ break;
2317
+ case GGML_GLU_OP_GEGLU_ERF:
2318
+ ggml_cuda_op_geglu_erf(ctx, dst);
2319
+ break;
2320
+ case GGML_GLU_OP_GEGLU_QUICK:
2321
+ ggml_cuda_op_geglu_quick(ctx, dst);
2322
+ break;
2323
+ default:
2324
+ return false;
2325
+ }
2326
+ break;
2223
2327
  case GGML_OP_NORM:
2224
2328
  ggml_cuda_op_norm(ctx, dst);
2225
2329
  break;
@@ -2310,6 +2414,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2310
2414
  case GGML_OP_IM2COL:
2311
2415
  ggml_cuda_op_im2col(ctx, dst);
2312
2416
  break;
2417
+ case GGML_OP_CONV_2D_DW:
2418
+ ggml_cuda_op_conv2d_dw(ctx, dst);
2419
+ break;
2420
+ case GGML_OP_CONV_TRANSPOSE_2D:
2421
+ ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2422
+ break;
2313
2423
  case GGML_OP_CONV_TRANSPOSE_1D:
2314
2424
  ggml_cuda_op_conv_transpose_1d(ctx,dst);
2315
2425
  break;
@@ -2322,6 +2432,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2322
2432
  case GGML_OP_SUM_ROWS:
2323
2433
  ggml_cuda_op_sum_rows(ctx, dst);
2324
2434
  break;
2435
+ case GGML_OP_MEAN:
2436
+ ggml_cuda_op_mean(ctx, dst);
2437
+ break;
2325
2438
  case GGML_OP_SSM_CONV:
2326
2439
  ggml_cuda_op_ssm_conv(ctx, dst);
2327
2440
  break;
@@ -2685,6 +2798,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2685
2798
 
2686
2799
  CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2687
2800
  graph_evaluated_or_captured = true; // CUDA graph has been captured
2801
+
2802
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2803
+ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2804
+ ggml_cuda_lock_cv.notify_all();
2805
+ }
2688
2806
  } else {
2689
2807
  graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2690
2808
  }
@@ -2760,7 +2878,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2760
2878
  }
2761
2879
  }
2762
2880
 
2763
- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2881
+ if (use_cuda_graph && cuda_graph_update_required) {
2882
+ // Start CUDA graph capture
2883
+ {
2884
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2885
+ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
2886
+ }
2887
+
2764
2888
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
2765
2889
  }
2766
2890
 
@@ -2993,6 +3117,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2993
3117
  return false;
2994
3118
  }
2995
3119
  break;
3120
+ case GGML_OP_GLU:
3121
+ switch (ggml_get_glu_op(op)) {
3122
+ case GGML_GLU_OP_REGLU:
3123
+ case GGML_GLU_OP_GEGLU:
3124
+ case GGML_GLU_OP_SWIGLU:
3125
+ case GGML_GLU_OP_GEGLU_ERF:
3126
+ case GGML_GLU_OP_GEGLU_QUICK:
3127
+ return ggml_is_contiguous_1(op->src[0]);
3128
+ default:
3129
+ return false;
3130
+ }
3131
+ break;
2996
3132
  case GGML_OP_MUL_MAT:
2997
3133
  case GGML_OP_MUL_MAT_ID:
2998
3134
  {
@@ -3016,9 +3152,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3016
3152
  return false;
3017
3153
  }
3018
3154
  #ifdef GGML_USE_MUSA
3019
- if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3020
- !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3021
- return false;
3155
+ const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3156
+ if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3157
+ if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3158
+ a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3159
+ return false;
3160
+ }
3161
+ if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3162
+ a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3163
+ return false;
3164
+ }
3022
3165
  }
3023
3166
  #endif // GGML_USE_MUSA
3024
3167
  switch (a->type) {
@@ -3045,11 +3188,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3045
3188
  case GGML_TYPE_IQ4_NL:
3046
3189
  case GGML_TYPE_IQ4_XS:
3047
3190
  case GGML_TYPE_BF16:
3048
- #ifdef GGML_USE_MUSA
3049
- if (a->type == GGML_TYPE_Q3_K) {
3050
- return false;
3051
- }
3052
- #endif // GGML_USE_MUSA
3053
3191
  return true;
3054
3192
  default:
3055
3193
  return false;
@@ -3062,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3062
3200
  switch (op->src[0]->type) {
3063
3201
  case GGML_TYPE_F16:
3064
3202
  case GGML_TYPE_F32:
3203
+ case GGML_TYPE_BF16:
3204
+ case GGML_TYPE_I32:
3065
3205
  case GGML_TYPE_Q4_0:
3066
3206
  case GGML_TYPE_Q4_1:
3067
3207
  case GGML_TYPE_Q5_0:
@@ -3191,12 +3331,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3191
3331
  case GGML_OP_COS:
3192
3332
  case GGML_OP_CLAMP:
3193
3333
  case GGML_OP_LOG:
3194
- case GGML_OP_SSM_SCAN:
3195
- case GGML_OP_SSM_CONV:
3196
3334
  return true;
3335
+ case GGML_OP_SSM_SCAN: {
3336
+ if (op->src[3]->ne[0] == 1) {
3337
+ // Mamba2
3338
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3339
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
3340
+ } else {
3341
+ // Mamba
3342
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3343
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3344
+ }
3345
+ }
3346
+ case GGML_OP_SSM_CONV: {
3347
+ // assumes d_inner % threads == 0
3348
+ return op->src[0]->ne[1] % 128 == 0;
3349
+ }
3197
3350
  case GGML_OP_CONT:
3198
3351
  return op->src[0]->type != GGML_TYPE_BF16;
3199
3352
  case GGML_OP_DIAG_MASK_INF:
3353
+ return true;
3200
3354
  case GGML_OP_SOFT_MAX:
3201
3355
  return true;
3202
3356
  case GGML_OP_SOFT_MAX_BACK: {
@@ -3209,16 +3363,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3209
3363
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3210
3364
  }
3211
3365
  case GGML_OP_IM2COL:
3366
+ case GGML_OP_CONV_2D_DW:
3367
+ case GGML_OP_CONV_TRANSPOSE_2D:
3212
3368
  case GGML_OP_POOL_2D:
3213
3369
  case GGML_OP_SUM:
3214
3370
  case GGML_OP_SUM_ROWS:
3371
+ case GGML_OP_MEAN:
3215
3372
  case GGML_OP_ARGSORT:
3216
3373
  case GGML_OP_ACC:
3217
3374
  return true;
3218
3375
  case GGML_OP_GROUP_NORM:
3219
3376
  return ggml_is_contiguous(op->src[0]);
3220
3377
  case GGML_OP_UPSCALE:
3221
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
3222
3378
  case GGML_OP_PAD:
3223
3379
  case GGML_OP_ARANGE:
3224
3380
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -3242,6 +3398,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3242
3398
  if (op->src[0]->ne[0] == 192) {
3243
3399
  return false;
3244
3400
  }
3401
+ // TODO: support broadcast
3402
+ // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
3403
+ // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
3245
3404
  if (op->src[0]->ne[3] != 1) {
3246
3405
  return false;
3247
3406
  }