@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
27
27
  const int ne12,
28
28
  const int ne13,
29
29
  const int ne31,
30
+ const int ne32,
30
31
  const int nb31,
32
+ const int nb32,
31
33
  const int nb01,
32
34
  const int nb02,
33
35
  const int nb03,
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
51
53
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
52
54
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
53
55
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
54
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
55
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
56
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
57
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
56
58
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
57
59
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
58
60
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
79
81
  Q += nb02* blockIdx.z + nb01*ic0;
80
82
  K += nb12*(blockIdx.z / gqa_ratio);
81
83
  V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
82
- const half * maskh = (const half *) mask + ne11*ic0;
84
+
85
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
83
86
 
84
87
  const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
85
88
 
@@ -334,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
334
337
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
335
338
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
336
339
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
337
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
338
- GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
339
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
340
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
341
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
342
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
343
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
340
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
341
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
342
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
343
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32);
344
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32);
345
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
346
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
347
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
348
+ GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
344
349
  NO_DEVICE_CODE;
345
350
  #endif // FLASH_ATTN_AVAILABLE
346
351
  }
@@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
46
46
  const int ne12,
47
47
  const int ne13,
48
48
  const int ne31,
49
+ const int ne32,
49
50
  const int nb31,
51
+ const int nb32,
50
52
  const int nb01,
51
53
  const int nb02,
52
54
  const int nb03,
@@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
94
96
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
95
97
 
96
98
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
97
- const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
98
- const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
99
- const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
100
- const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
101
- const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
99
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
100
+ const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
101
+ const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
102
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
103
+ const half2 * mask2 = (const half2 *) maskh;
102
104
 
103
105
  const int stride_Q = nb01 / sizeof(float);
104
106
  const int stride_KV = nb11 / sizeof(half);
@@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
440
442
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
441
443
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
442
444
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
443
- GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
445
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
444
446
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
445
447
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
446
448
  GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
168
168
  get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
169
169
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
170
170
  break;
171
+ case GGML_TYPE_I32:
172
+ get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
173
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
174
+ break;
171
175
  case GGML_TYPE_BF16:
172
176
  get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
173
177
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -210,6 +214,10 @@ void get_rows_cuda(
210
214
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
211
215
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
212
216
  break;
217
+ case GGML_TYPE_I32:
218
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
219
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
220
+ break;
213
221
  case GGML_TYPE_F16:
214
222
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
215
223
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
1749
1749
  }
1750
1750
 
1751
1751
  static __global__ void k_compute_batched_ptrs(
1752
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1752
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1753
1753
  const void ** ptrs_src, void ** ptrs_dst,
1754
1754
  int64_t ne12, int64_t ne13,
1755
1755
  int64_t ne23,
@@ -1772,83 +1772,131 @@ static __global__ void k_compute_batched_ptrs(
1772
1772
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1773
1773
  }
1774
1774
 
1775
- static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1775
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1776
+ template<ggml_type T>
1777
+ struct batched_mul_mat_traits;
1778
+
1779
+ template<>
1780
+ struct batched_mul_mat_traits<GGML_TYPE_F32> {
1781
+ using cuda_type = float;
1782
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1784
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785
+ static inline const float alpha = 1.0f;
1786
+ static inline const float beta = 0.0f;
1787
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1788
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1789
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1790
+ };
1791
+
1792
+ template<>
1793
+ struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1794
+ using cuda_type = nv_bfloat16;
1795
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798
+ static inline const float alpha = 1.0f;
1799
+ static inline const float beta = 0.0f;
1800
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1801
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1802
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1803
+ };
1804
+
1805
+ template<>
1806
+ struct batched_mul_mat_traits<GGML_TYPE_F16> {
1807
+ using cuda_type = half;
1808
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1810
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811
+ static inline const half alpha = 1.0;
1812
+ static inline const half beta = 0.0;
1813
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
1814
+ static inline const void* get_beta() { static const half val = beta; return &val; }
1815
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1816
+ };
1817
+
1818
+ template<ggml_type src0_type>
1819
+ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820
+ using traits = batched_mul_mat_traits<src0_type>;
1821
+ using cuda_t = typename traits::cuda_type;
1822
+
1776
1823
  GGML_ASSERT(!ggml_is_transposed(src0));
1777
1824
  GGML_ASSERT(!ggml_is_transposed(src1));
1778
-
1779
1825
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1780
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
1826
+ GGML_ASSERT(src0->type == src0_type);
1827
+ GGML_ASSERT(ggml_is_contiguous(dst));
1781
1828
 
1782
1829
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1783
1830
  // As long as dst is contiguous this does not matter though.
1784
- GGML_ASSERT(ggml_is_contiguous(dst));
1785
1831
 
1786
1832
  GGML_TENSOR_BINARY_OP_LOCALS
1787
1833
 
1788
1834
  const int64_t ne_dst = ggml_nelements(dst);
1789
-
1790
1835
  cudaStream_t main_stream = ctx.stream();
1791
-
1792
1836
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1793
1837
 
1794
- const half * src0_f16 = (const half *) src0->data;
1795
1838
  float * dst_ddf = (float *) dst->data;
1796
-
1797
- const half * src1_f16 = (const half *) src1->data;
1798
1839
  const size_t ts_src1 = ggml_type_size(src1->type);
1799
1840
  GGML_ASSERT(nb10 == ts_src1);
1800
1841
  int64_t s11 = nb11 / ts_src1;
1801
1842
  int64_t s12 = nb12 / ts_src1;
1802
1843
  int64_t s13 = nb13 / ts_src1;
1803
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1804
1844
 
1805
- // convert src1 to fp16
1806
- if (src1->type != GGML_TYPE_F16) {
1807
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1808
- const int64_t ne_src1 = ggml_nelements(src1);
1809
- src1_f16_alloc.alloc(ne_src1);
1810
- GGML_ASSERT(to_fp16_cuda != nullptr);
1845
+ const cuda_t * src0_ptr = nullptr;
1846
+ const cuda_t * src1_ptr = nullptr;
1847
+
1848
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1811
1850
 
1812
- to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1851
+ // Handle src0
1852
+ src0_ptr = (const cuda_t *) src0->data;
1813
1853
 
1814
- src1_f16 = src1_f16_alloc.get();
1854
+ // Handle src1 - convert if necessary
1855
+ if (src1->type == src0_type) {
1856
+ src1_ptr = (const cuda_t *) src1->data;
1857
+ } else {
1858
+ // Convert src1 to target type using traits conversion functions
1859
+ const int64_t ne_src1 = ggml_nelements(src1);
1860
+ src1_alloc.alloc(ne_src1);
1861
+
1862
+ const auto convert_func = traits::get_nc_converter(src1->type);
1863
+ GGML_ASSERT(convert_func != nullptr);
1864
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865
+ src1_ptr = src1_alloc.get();
1815
1866
  s11 = ne10;
1816
1867
  s12 = ne11*s11;
1817
1868
  s13 = ne12*s12;
1818
1869
  }
1819
1870
 
1820
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1871
+ // Setup destination buffer
1872
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1821
1873
  char * dst_t;
1822
-
1823
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824
- cudaDataType_t cu_data_type = CUDA_R_16F;
1825
-
1826
- // dst strides
1827
1874
  size_t nbd2 = dst->nb[2];
1828
1875
  size_t nbd3 = dst->nb[3];
1829
1876
 
1830
- const half alpha_f16 = 1.0f;
1831
- const half beta_f16 = 0.0f;
1832
-
1877
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1878
+ cudaDataType_t cu_data_type = traits::data_type;
1879
+ cudaDataType_t cu_data_type_a = traits::data_type;
1880
+ cudaDataType_t cu_data_type_b = traits::data_type;
1881
+ const void * alpha = traits::get_alpha();
1882
+ const void * beta = traits::get_beta();
1833
1883
  const float alpha_f32 = 1.0f;
1834
- const float beta_f32 = 0.0f;
1835
-
1836
- const void * alpha = &alpha_f16;
1837
- const void * beta = &beta_f16;
1884
+ const float beta_f32 = 0.0f;
1838
1885
 
1839
1886
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1840
- dst_t = (char *) dst_f16.alloc(ne_dst);
1841
-
1842
- nbd2 /= sizeof(float) / sizeof(half);
1843
- nbd3 /= sizeof(float) / sizeof(half);
1887
+ if constexpr (src0_type == GGML_TYPE_F32) {
1888
+ dst_t = (char *) dst_ddf; // Direct F32 output
1889
+ } else {
1890
+ dst_t = (char *) dst_temp.alloc(ne_dst);
1891
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
1892
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
1893
+ }
1844
1894
  } else {
1845
1895
  dst_t = (char *) dst_ddf;
1846
-
1847
1896
  cu_compute_type = CUBLAS_COMPUTE_32F;
1848
- cu_data_type = CUDA_R_32F;
1849
-
1897
+ cu_data_type = CUDA_R_32F;
1850
1898
  alpha = &alpha_f32;
1851
- beta = &beta_f32;
1899
+ beta = &beta_f32;
1852
1900
  }
1853
1901
 
1854
1902
  int id = ggml_cuda_get_device();
@@ -1856,7 +1904,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1856
1904
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1857
1905
  cu_compute_type = CUBLAS_COMPUTE_32F;
1858
1906
  alpha = &alpha_f32;
1859
- beta = &beta_f32;
1907
+ beta = &beta_f32;
1860
1908
  }
1861
1909
 
1862
1910
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1866
1914
  const int64_t r2 = ne12/ne02;
1867
1915
  const int64_t r3 = ne13/ne03;
1868
1916
 
1869
- #if 0
1870
- // use cublasGemmEx
1871
- {
1872
- for (int i13 = 0; i13 < ne13; ++i13) {
1873
- for (int i12 = 0; i12 < ne12; ++i12) {
1874
- int i03 = i13 / r3;
1875
- int i02 = i12 / r2;
1876
-
1877
- CUBLAS_CHECK(
1878
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1879
- ne01, ne11, ne10,
1880
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1881
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1882
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1883
- cu_compute_type,
1884
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1885
- }
1886
- }
1887
- }
1888
- #else
1889
1917
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1890
1918
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1891
1919
  // use cublasGemmStridedBatchedEx
1892
1920
  CUBLAS_CHECK(
1893
1921
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1894
1922
  ne01, ne11, ne10,
1895
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1897
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1923
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1898
1926
  ne12*ne13,
1899
1927
  cu_compute_type,
1900
1928
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1905
1933
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1906
1934
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1907
1935
 
1936
+ size_t src1_stride_size = sizeof(cuda_t);
1937
+
1908
1938
  dim3 block_dims(ne13, ne12);
1909
1939
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1910
- src0_f16, src1_f16, dst_t,
1940
+ src0_ptr, src1_ptr, dst_t,
1911
1941
  ptrs_src.get(), ptrs_dst.get(),
1912
1942
  ne12, ne13,
1913
1943
  ne23,
1914
1944
  nb02, nb03,
1915
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1916
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1945
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1946
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1917
1947
  nbd2, nbd3,
1918
1948
  r2, r3);
1949
+
1919
1950
  CUDA_CHECK(cudaGetLastError());
1920
1951
 
1921
1952
  CUBLAS_CHECK(
1922
1953
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1923
1954
  ne01, ne11, ne10,
1924
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1925
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1926
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1955
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1956
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1957
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1927
1958
  ne23,
1928
1959
  cu_compute_type,
1929
1960
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1930
1961
  }
1931
- #endif
1932
1962
 
1933
- if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1935
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1963
+ // Convert output back to F32 if needed
1964
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967
+ }
1968
+ }
1969
+
1970
+ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972
+
1973
+ switch (src0->type) {
1974
+ case GGML_TYPE_F32:
1975
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976
+ break;
1977
+ case GGML_TYPE_BF16:
1978
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979
+ break;
1980
+ case GGML_TYPE_F16:
1981
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982
+ break;
1983
+ default:
1984
+ GGML_ABORT("Unsupported type");
1936
1985
  }
1937
1986
  }
1938
1987
 
@@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1984
2033
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1985
2034
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1986
2035
 
2036
+ //TODO update for generic tensor parallelism
2037
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
+
1987
2042
  if (!split && use_mul_mat_vec) {
1988
2043
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
1989
2044
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1992
2047
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1993
2048
  } else if (!split && use_mul_mat_q) {
1994
2049
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1995
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996
- !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2050
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1997
2052
  // general KQ + KQV multi-batch without FlashAttention
1998
2053
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1999
2054
  } else if (use_mul_mat_vec) {
@@ -2248,6 +2303,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2248
2303
  return false;
2249
2304
  }
2250
2305
  break;
2306
+ case GGML_OP_GLU:
2307
+ switch (ggml_get_glu_op(dst)) {
2308
+ case GGML_GLU_OP_REGLU:
2309
+ ggml_cuda_op_reglu(ctx, dst);
2310
+ break;
2311
+ case GGML_GLU_OP_GEGLU:
2312
+ ggml_cuda_op_geglu(ctx, dst);
2313
+ break;
2314
+ case GGML_GLU_OP_SWIGLU:
2315
+ ggml_cuda_op_swiglu(ctx, dst);
2316
+ break;
2317
+ case GGML_GLU_OP_GEGLU_ERF:
2318
+ ggml_cuda_op_geglu_erf(ctx, dst);
2319
+ break;
2320
+ case GGML_GLU_OP_GEGLU_QUICK:
2321
+ ggml_cuda_op_geglu_quick(ctx, dst);
2322
+ break;
2323
+ default:
2324
+ return false;
2325
+ }
2326
+ break;
2251
2327
  case GGML_OP_NORM:
2252
2328
  ggml_cuda_op_norm(ctx, dst);
2253
2329
  break;
@@ -3041,6 +3117,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3041
3117
  return false;
3042
3118
  }
3043
3119
  break;
3120
+ case GGML_OP_GLU:
3121
+ switch (ggml_get_glu_op(op)) {
3122
+ case GGML_GLU_OP_REGLU:
3123
+ case GGML_GLU_OP_GEGLU:
3124
+ case GGML_GLU_OP_SWIGLU:
3125
+ case GGML_GLU_OP_GEGLU_ERF:
3126
+ case GGML_GLU_OP_GEGLU_QUICK:
3127
+ return ggml_is_contiguous_1(op->src[0]);
3128
+ default:
3129
+ return false;
3130
+ }
3131
+ break;
3044
3132
  case GGML_OP_MUL_MAT:
3045
3133
  case GGML_OP_MUL_MAT_ID:
3046
3134
  {
@@ -3112,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3112
3200
  switch (op->src[0]->type) {
3113
3201
  case GGML_TYPE_F16:
3114
3202
  case GGML_TYPE_F32:
3203
+ case GGML_TYPE_BF16:
3204
+ case GGML_TYPE_I32:
3115
3205
  case GGML_TYPE_Q4_0:
3116
3206
  case GGML_TYPE_Q4_1:
3117
3207
  case GGML_TYPE_Q5_0:
@@ -3241,12 +3331,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3241
3331
  case GGML_OP_COS:
3242
3332
  case GGML_OP_CLAMP:
3243
3333
  case GGML_OP_LOG:
3244
- case GGML_OP_SSM_SCAN:
3245
- case GGML_OP_SSM_CONV:
3246
3334
  return true;
3335
+ case GGML_OP_SSM_SCAN: {
3336
+ if (op->src[3]->ne[0] == 1) {
3337
+ // Mamba2
3338
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3339
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
3340
+ } else {
3341
+ // Mamba
3342
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3343
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3344
+ }
3345
+ }
3346
+ case GGML_OP_SSM_CONV: {
3347
+ // assumes d_inner % threads == 0
3348
+ return op->src[0]->ne[1] % 128 == 0;
3349
+ }
3247
3350
  case GGML_OP_CONT:
3248
3351
  return op->src[0]->type != GGML_TYPE_BF16;
3249
3352
  case GGML_OP_DIAG_MASK_INF:
3353
+ return true;
3250
3354
  case GGML_OP_SOFT_MAX:
3251
3355
  return true;
3252
3356
  case GGML_OP_SOFT_MAX_BACK: {
@@ -3271,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3271
3375
  case GGML_OP_GROUP_NORM:
3272
3376
  return ggml_is_contiguous(op->src[0]);
3273
3377
  case GGML_OP_UPSCALE:
3274
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
3275
3378
  case GGML_OP_PAD:
3276
3379
  case GGML_OP_ARANGE:
3277
3380
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -3295,6 +3398,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3295
3398
  if (op->src[0]->ne[0] == 192) {
3296
3399
  return false;
3297
3400
  }
3401
+ // TODO: support broadcast
3402
+ // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
3403
+ // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
3298
3404
  if (op->src[0]->ne[3] != 1) {
3299
3405
  return false;
3300
3406
  }
@@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3016
3016
 
3017
3017
  const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
3018
3018
 
3019
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3020
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3021
- if (!shared_memory_limit_raised[id]) {
3022
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3023
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3024
- shared_memory_limit_raised[id] = true;
3025
- }
3026
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3019
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
3020
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
3027
3021
 
3028
3022
  const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
3023
  const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
50
50
 
51
51
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
52
 
53
- if (i0 >= n_dims) {
54
- const int i = row_dst*ne0 + i0;
55
-
56
- dst[i + 0] = x[i + 0];
57
- dst[i + 1] = x[i + 1];
58
-
59
- return;
60
- }
61
-
62
53
  const int row_x = row_dst % ne1;
63
54
  const int channel_x = row_dst / ne1;
64
55
 
65
56
  const int idst = row_dst*ne0 + i0;
66
57
  const int ix = channel_x*s2 + row_x*s1 + i0;
67
58
 
59
+ if (i0 >= n_dims) {
60
+ dst[idst + 0] = x[ix + 0];
61
+ dst[idst + 1] = x[ix + 1];
62
+
63
+ return;
64
+ }
65
+
68
66
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
69
67
 
70
68
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
94
92
 
95
93
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
96
94
 
97
- if (i0 >= n_dims) {
98
- const int i = row_dst*ne0 + i0;
99
-
100
- dst[i + 0] = x[i + 0];
101
- dst[i + 1] = x[i + 1];
102
-
103
- return;
104
- }
105
-
106
95
  const int row_x = row_dst % ne1;
107
96
  const int channel_x = row_dst / ne1;
108
97
 
109
98
  const int idst = row_dst*ne0 + i0/2;
110
99
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
111
100
 
101
+ if (i0 >= n_dims) {
102
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
104
+
105
+ return;
106
+ }
107
+
112
108
  const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
113
109
 
114
110
  const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
138
134
 
139
135
  const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
140
136
 
141
- if (i0 >= n_dims) {
142
- const int i = row_dst*ne0 + i0;
143
-
144
- dst[i + 0] = x[i + 0];
145
- dst[i + 1] = x[i + 1];
146
-
147
- return;
148
- }
149
-
150
137
  const int row_x = row_dst % ne1;
151
138
  const int channel_x = row_dst / ne1;
152
139
 
153
140
  const int idst = row_dst*ne0 + i0/2;
154
141
  const int ix = channel_x*s2 + row_x*s1 + i0/2;
155
142
 
143
+ if (i0 >= n_dims) {
144
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
146
+
147
+ return;
148
+ }
149
+
156
150
  const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
157
151
  const int sec_w = sections.v[1] + sections.v[0];
158
152
  const int sector = (i0 / 2) % sect_dims;