@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -43,6 +43,7 @@
43
43
  #include "ggml-cuda/upscale.cuh"
44
44
  #include "ggml-cuda/wkv.cuh"
45
45
  #include "ggml-cuda/gla.cuh"
46
+ #include "ggml-cuda/set-rows.cuh"
46
47
  #include "ggml.h"
47
48
 
48
49
  #include <algorithm>
@@ -54,6 +55,7 @@
54
55
  #include <cstddef>
55
56
  #include <cstdint>
56
57
  #include <float.h>
58
+ #include <initializer_list>
57
59
  #include <limits>
58
60
  #include <map>
59
61
  #include <memory>
@@ -1749,7 +1751,7 @@ static void ggml_cuda_op_mul_mat(
1749
1751
  }
1750
1752
 
1751
1753
  static __global__ void k_compute_batched_ptrs(
1752
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1754
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1753
1755
  const void ** ptrs_src, void ** ptrs_dst,
1754
1756
  int64_t ne12, int64_t ne13,
1755
1757
  int64_t ne23,
@@ -1772,83 +1774,131 @@ static __global__ void k_compute_batched_ptrs(
1772
1774
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1773
1775
  }
1774
1776
 
1775
- static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1777
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1778
+ template<ggml_type T>
1779
+ struct batched_mul_mat_traits;
1780
+
1781
+ template<>
1782
+ struct batched_mul_mat_traits<GGML_TYPE_F32> {
1783
+ using cuda_type = float;
1784
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1785
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1786
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1787
+ static inline const float alpha = 1.0f;
1788
+ static inline const float beta = 0.0f;
1789
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1790
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1791
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1792
+ };
1793
+
1794
+ template<>
1795
+ struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1796
+ using cuda_type = nv_bfloat16;
1797
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1798
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1799
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1800
+ static inline const float alpha = 1.0f;
1801
+ static inline const float beta = 0.0f;
1802
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1803
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1804
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1805
+ };
1806
+
1807
+ template<>
1808
+ struct batched_mul_mat_traits<GGML_TYPE_F16> {
1809
+ using cuda_type = half;
1810
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1811
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1812
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1813
+ static inline const half alpha = 1.0;
1814
+ static inline const half beta = 0.0;
1815
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
1816
+ static inline const void* get_beta() { static const half val = beta; return &val; }
1817
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1818
+ };
1819
+
1820
+ template<ggml_type src0_type>
1821
+ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1822
+ using traits = batched_mul_mat_traits<src0_type>;
1823
+ using cuda_t = typename traits::cuda_type;
1824
+
1776
1825
  GGML_ASSERT(!ggml_is_transposed(src0));
1777
1826
  GGML_ASSERT(!ggml_is_transposed(src1));
1778
-
1779
1827
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1780
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
1828
+ GGML_ASSERT(src0->type == src0_type);
1829
+ GGML_ASSERT(ggml_is_contiguous(dst));
1781
1830
 
1782
1831
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1783
1832
  // As long as dst is contiguous this does not matter though.
1784
- GGML_ASSERT(ggml_is_contiguous(dst));
1785
1833
 
1786
1834
  GGML_TENSOR_BINARY_OP_LOCALS
1787
1835
 
1788
1836
  const int64_t ne_dst = ggml_nelements(dst);
1789
-
1790
1837
  cudaStream_t main_stream = ctx.stream();
1791
-
1792
1838
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1793
1839
 
1794
- const half * src0_f16 = (const half *) src0->data;
1795
1840
  float * dst_ddf = (float *) dst->data;
1796
-
1797
- const half * src1_f16 = (const half *) src1->data;
1798
1841
  const size_t ts_src1 = ggml_type_size(src1->type);
1799
1842
  GGML_ASSERT(nb10 == ts_src1);
1800
1843
  int64_t s11 = nb11 / ts_src1;
1801
1844
  int64_t s12 = nb12 / ts_src1;
1802
1845
  int64_t s13 = nb13 / ts_src1;
1803
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1804
1846
 
1805
- // convert src1 to fp16
1806
- if (src1->type != GGML_TYPE_F16) {
1807
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1808
- const int64_t ne_src1 = ggml_nelements(src1);
1809
- src1_f16_alloc.alloc(ne_src1);
1810
- GGML_ASSERT(to_fp16_cuda != nullptr);
1847
+ const cuda_t * src0_ptr = nullptr;
1848
+ const cuda_t * src1_ptr = nullptr;
1849
+
1850
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1851
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1811
1852
 
1812
- to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1853
+ // Handle src0
1854
+ src0_ptr = (const cuda_t *) src0->data;
1813
1855
 
1814
- src1_f16 = src1_f16_alloc.get();
1856
+ // Handle src1 - convert if necessary
1857
+ if (src1->type == src0_type) {
1858
+ src1_ptr = (const cuda_t *) src1->data;
1859
+ } else {
1860
+ // Convert src1 to target type using traits conversion functions
1861
+ const int64_t ne_src1 = ggml_nelements(src1);
1862
+ src1_alloc.alloc(ne_src1);
1863
+
1864
+ const auto convert_func = traits::get_nc_converter(src1->type);
1865
+ GGML_ASSERT(convert_func != nullptr);
1866
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1867
+ src1_ptr = src1_alloc.get();
1815
1868
  s11 = ne10;
1816
1869
  s12 = ne11*s11;
1817
1870
  s13 = ne12*s12;
1818
1871
  }
1819
1872
 
1820
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1873
+ // Setup destination buffer
1874
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1821
1875
  char * dst_t;
1822
-
1823
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824
- cudaDataType_t cu_data_type = CUDA_R_16F;
1825
-
1826
- // dst strides
1827
1876
  size_t nbd2 = dst->nb[2];
1828
1877
  size_t nbd3 = dst->nb[3];
1829
1878
 
1830
- const half alpha_f16 = 1.0f;
1831
- const half beta_f16 = 0.0f;
1832
-
1879
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1880
+ cudaDataType_t cu_data_type = traits::data_type;
1881
+ cudaDataType_t cu_data_type_a = traits::data_type;
1882
+ cudaDataType_t cu_data_type_b = traits::data_type;
1883
+ const void * alpha = traits::get_alpha();
1884
+ const void * beta = traits::get_beta();
1833
1885
  const float alpha_f32 = 1.0f;
1834
- const float beta_f32 = 0.0f;
1835
-
1836
- const void * alpha = &alpha_f16;
1837
- const void * beta = &beta_f16;
1886
+ const float beta_f32 = 0.0f;
1838
1887
 
1839
1888
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1840
- dst_t = (char *) dst_f16.alloc(ne_dst);
1841
-
1842
- nbd2 /= sizeof(float) / sizeof(half);
1843
- nbd3 /= sizeof(float) / sizeof(half);
1889
+ if constexpr (src0_type == GGML_TYPE_F32) {
1890
+ dst_t = (char *) dst_ddf; // Direct F32 output
1891
+ } else {
1892
+ dst_t = (char *) dst_temp.alloc(ne_dst);
1893
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
1894
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
1895
+ }
1844
1896
  } else {
1845
1897
  dst_t = (char *) dst_ddf;
1846
-
1847
1898
  cu_compute_type = CUBLAS_COMPUTE_32F;
1848
- cu_data_type = CUDA_R_32F;
1849
-
1899
+ cu_data_type = CUDA_R_32F;
1850
1900
  alpha = &alpha_f32;
1851
- beta = &beta_f32;
1901
+ beta = &beta_f32;
1852
1902
  }
1853
1903
 
1854
1904
  int id = ggml_cuda_get_device();
@@ -1856,7 +1906,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1856
1906
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1857
1907
  cu_compute_type = CUBLAS_COMPUTE_32F;
1858
1908
  alpha = &alpha_f32;
1859
- beta = &beta_f32;
1909
+ beta = &beta_f32;
1860
1910
  }
1861
1911
 
1862
1912
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1866,35 +1916,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1866
1916
  const int64_t r2 = ne12/ne02;
1867
1917
  const int64_t r3 = ne13/ne03;
1868
1918
 
1869
- #if 0
1870
- // use cublasGemmEx
1871
- {
1872
- for (int i13 = 0; i13 < ne13; ++i13) {
1873
- for (int i12 = 0; i12 < ne12; ++i12) {
1874
- int i03 = i13 / r3;
1875
- int i02 = i12 / r2;
1876
-
1877
- CUBLAS_CHECK(
1878
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1879
- ne01, ne11, ne10,
1880
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1881
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1882
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1883
- cu_compute_type,
1884
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1885
- }
1886
- }
1887
- }
1888
- #else
1889
1919
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1890
1920
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1891
1921
  // use cublasGemmStridedBatchedEx
1892
1922
  CUBLAS_CHECK(
1893
1923
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1894
1924
  ne01, ne11, ne10,
1895
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1897
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1925
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1926
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1927
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1898
1928
  ne12*ne13,
1899
1929
  cu_compute_type,
1900
1930
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1905,34 +1935,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1905
1935
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1906
1936
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1907
1937
 
1938
+ size_t src1_stride_size = sizeof(cuda_t);
1939
+
1908
1940
  dim3 block_dims(ne13, ne12);
1909
1941
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1910
- src0_f16, src1_f16, dst_t,
1942
+ src0_ptr, src1_ptr, dst_t,
1911
1943
  ptrs_src.get(), ptrs_dst.get(),
1912
1944
  ne12, ne13,
1913
1945
  ne23,
1914
1946
  nb02, nb03,
1915
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1916
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1947
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1948
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1917
1949
  nbd2, nbd3,
1918
1950
  r2, r3);
1951
+
1919
1952
  CUDA_CHECK(cudaGetLastError());
1920
1953
 
1921
1954
  CUBLAS_CHECK(
1922
1955
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1923
1956
  ne01, ne11, ne10,
1924
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1925
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1926
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1957
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1958
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1959
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1927
1960
  ne23,
1928
1961
  cu_compute_type,
1929
1962
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1930
1963
  }
1931
- #endif
1932
1964
 
1933
- if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1935
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1965
+ // Convert output back to F32 if needed
1966
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1967
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1968
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1969
+ }
1970
+ }
1971
+
1972
+ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1973
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1974
+
1975
+ switch (src0->type) {
1976
+ case GGML_TYPE_F32:
1977
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1978
+ break;
1979
+ case GGML_TYPE_BF16:
1980
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1981
+ break;
1982
+ case GGML_TYPE_F16:
1983
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1984
+ break;
1985
+ default:
1986
+ GGML_ABORT("Unsupported type");
1936
1987
  }
1937
1988
  }
1938
1989
 
@@ -1984,6 +2035,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1984
2035
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1985
2036
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1986
2037
 
2038
+ //TODO update for generic tensor parallelism
2039
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2040
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2041
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2042
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2043
+
1987
2044
  if (!split && use_mul_mat_vec) {
1988
2045
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
1989
2046
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -1992,8 +2049,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1992
2049
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1993
2050
  } else if (!split && use_mul_mat_q) {
1994
2051
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1995
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996
- !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2052
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2053
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1997
2054
  // general KQ + KQV multi-batch without FlashAttention
1998
2055
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1999
2056
  } else if (use_mul_mat_vec) {
@@ -2175,6 +2232,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2175
2232
  case GGML_OP_GET_ROWS_BACK:
2176
2233
  ggml_cuda_op_get_rows_back(ctx, dst);
2177
2234
  break;
2235
+ case GGML_OP_SET_ROWS:
2236
+ ggml_cuda_op_set_rows(ctx, dst);
2237
+ break;
2178
2238
  case GGML_OP_DUP:
2179
2239
  ggml_cuda_dup(ctx, dst);
2180
2240
  break;
@@ -2244,6 +2304,30 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2244
2304
  case GGML_UNARY_OP_EXP:
2245
2305
  ggml_cuda_op_exp(ctx, dst);
2246
2306
  break;
2307
+ case GGML_UNARY_OP_ELU:
2308
+ ggml_cuda_op_elu(ctx, dst);
2309
+ break;
2310
+ default:
2311
+ return false;
2312
+ }
2313
+ break;
2314
+ case GGML_OP_GLU:
2315
+ switch (ggml_get_glu_op(dst)) {
2316
+ case GGML_GLU_OP_REGLU:
2317
+ ggml_cuda_op_reglu(ctx, dst);
2318
+ break;
2319
+ case GGML_GLU_OP_GEGLU:
2320
+ ggml_cuda_op_geglu(ctx, dst);
2321
+ break;
2322
+ case GGML_GLU_OP_SWIGLU:
2323
+ ggml_cuda_op_swiglu(ctx, dst);
2324
+ break;
2325
+ case GGML_GLU_OP_GEGLU_ERF:
2326
+ ggml_cuda_op_geglu_erf(ctx, dst);
2327
+ break;
2328
+ case GGML_GLU_OP_GEGLU_QUICK:
2329
+ ggml_cuda_op_geglu_quick(ctx, dst);
2330
+ break;
2247
2331
  default:
2248
2332
  return false;
2249
2333
  }
@@ -2507,6 +2591,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2507
2591
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2508
2592
  cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2509
2593
 
2594
+ const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2595
+ const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2596
+
2510
2597
  for (int i = 0; i < cgraph->n_nodes; i++) {
2511
2598
  ggml_tensor * node = cgraph->nodes[i];
2512
2599
 
@@ -2528,9 +2615,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2528
2615
  #endif
2529
2616
  }
2530
2617
 
2531
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2532
- // disable CUDA graphs for batch size > 1 for now.
2533
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2618
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
2619
+ // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2620
+ // by means of matching node names. See
2621
+ // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2622
+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2623
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2534
2624
  use_cuda_graph = false;
2535
2625
  #ifndef NDEBUG
2536
2626
  GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
@@ -2676,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2676
2766
  }
2677
2767
  #endif
2678
2768
 
2769
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2770
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2771
+ return false;
2772
+ }
2773
+
2774
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2775
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2776
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2777
+
2778
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2779
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2780
+
2781
+ //rms norm only supports F32
2782
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
2783
+ mul->src[1]->type != GGML_TYPE_F32 ||
2784
+ mul->type != GGML_TYPE_F32) {
2785
+ return false;
2786
+ }
2787
+
2788
+ //if rms norm is the B operand, then we don't handle broadcast
2789
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2790
+ return false;
2791
+ }
2792
+
2793
+ //rms_norm kernel assumes contigous rows
2794
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2795
+ return false;
2796
+ }
2797
+ }
2798
+
2799
+ return true;
2800
+ }
2801
+
2679
2802
  static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2680
2803
  bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2681
2804
  // flag used to determine whether it is an integrated_gpu
@@ -2685,6 +2808,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2685
2808
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2686
2809
  // With the use of CUDA graphs, the execution will be performed by the graph launch.
2687
2810
  if (!use_cuda_graph || cuda_graph_update_required) {
2811
+
2688
2812
  for (int i = 0; i < cgraph->n_nodes; i++) {
2689
2813
  ggml_tensor * node = cgraph->nodes[i];
2690
2814
 
@@ -2692,6 +2816,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2692
2816
  continue;
2693
2817
  }
2694
2818
 
2819
+ static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2820
+ if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2821
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
2822
+ i++;
2823
+ continue;
2824
+ }
2695
2825
  #ifndef NDEBUG
2696
2826
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2697
2827
  for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -3036,11 +3166,24 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3036
3166
  case GGML_UNARY_OP_GELU_QUICK:
3037
3167
  case GGML_UNARY_OP_TANH:
3038
3168
  case GGML_UNARY_OP_EXP:
3169
+ case GGML_UNARY_OP_ELU:
3039
3170
  return ggml_is_contiguous(op->src[0]);
3040
3171
  default:
3041
3172
  return false;
3042
3173
  }
3043
3174
  break;
3175
+ case GGML_OP_GLU:
3176
+ switch (ggml_get_glu_op(op)) {
3177
+ case GGML_GLU_OP_REGLU:
3178
+ case GGML_GLU_OP_GEGLU:
3179
+ case GGML_GLU_OP_SWIGLU:
3180
+ case GGML_GLU_OP_GEGLU_ERF:
3181
+ case GGML_GLU_OP_GEGLU_QUICK:
3182
+ return ggml_is_contiguous_1(op->src[0]);
3183
+ default:
3184
+ return false;
3185
+ }
3186
+ break;
3044
3187
  case GGML_OP_MUL_MAT:
3045
3188
  case GGML_OP_MUL_MAT_ID:
3046
3189
  {
@@ -3112,6 +3255,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3112
3255
  switch (op->src[0]->type) {
3113
3256
  case GGML_TYPE_F16:
3114
3257
  case GGML_TYPE_F32:
3258
+ case GGML_TYPE_BF16:
3259
+ case GGML_TYPE_I32:
3115
3260
  case GGML_TYPE_Q4_0:
3116
3261
  case GGML_TYPE_Q4_1:
3117
3262
  case GGML_TYPE_Q5_0:
@@ -3126,17 +3271,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3126
3271
  {
3127
3272
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
3128
3273
  } break;
3274
+ case GGML_OP_SET_ROWS:
3275
+ {
3276
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
3277
+ op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
3278
+ op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
3279
+ op->src[0]->type == GGML_TYPE_F32 &&
3280
+ op->src[1]->type == GGML_TYPE_I64;
3281
+ } break;
3129
3282
  case GGML_OP_CPY:
3130
3283
  {
3131
3284
  ggml_type src0_type = op->src[0]->type;
3132
3285
  ggml_type src1_type = op->src[1]->type;
3133
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3134
- return true;
3135
- }
3136
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3137
- return true;
3138
- }
3139
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3286
+ if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
3287
+ (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
3288
+ ) {
3140
3289
  return true;
3141
3290
  }
3142
3291
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3172,12 +3321,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3172
3321
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3173
3322
  return true;
3174
3323
  }
3175
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
3176
- return true;
3177
- }
3178
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3179
- return true;
3180
- }
3181
3324
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
3182
3325
  return true;
3183
3326
  }
@@ -3241,12 +3384,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3241
3384
  case GGML_OP_COS:
3242
3385
  case GGML_OP_CLAMP:
3243
3386
  case GGML_OP_LOG:
3244
- case GGML_OP_SSM_SCAN:
3245
- case GGML_OP_SSM_CONV:
3246
3387
  return true;
3388
+ case GGML_OP_SSM_SCAN: {
3389
+ if (op->src[3]->ne[0] == 1) {
3390
+ // Mamba2
3391
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3392
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
3393
+ } else {
3394
+ // Mamba
3395
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3396
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3397
+ }
3398
+ }
3399
+ case GGML_OP_SSM_CONV: {
3400
+ // assumes d_inner % threads == 0
3401
+ return op->src[0]->ne[1] % 128 == 0;
3402
+ }
3247
3403
  case GGML_OP_CONT:
3248
- return op->src[0]->type != GGML_TYPE_BF16;
3404
+ return true;
3249
3405
  case GGML_OP_DIAG_MASK_INF:
3406
+ return true;
3250
3407
  case GGML_OP_SOFT_MAX:
3251
3408
  return true;
3252
3409
  case GGML_OP_SOFT_MAX_BACK: {
@@ -3271,7 +3428,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3271
3428
  case GGML_OP_GROUP_NORM:
3272
3429
  return ggml_is_contiguous(op->src[0]);
3273
3430
  case GGML_OP_UPSCALE:
3274
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
3275
3431
  case GGML_OP_PAD:
3276
3432
  case GGML_OP_ARANGE:
3277
3433
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -3295,9 +3451,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3295
3451
  if (op->src[0]->ne[0] == 192) {
3296
3452
  return false;
3297
3453
  }
3298
- if (op->src[0]->ne[3] != 1) {
3299
- return false;
3300
- }
3301
3454
  if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3302
3455
  return false;
3303
3456
  }
@@ -3310,6 +3463,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3310
3463
  if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3311
3464
  return true;
3312
3465
  }
3466
+ if (op->src[3] && op->src[3]->ne[2] != 1) {
3467
+ return false;
3468
+ }
3313
3469
  return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3314
3470
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3315
3471
  }
@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
10
10
  return;
11
11
  }
12
12
 
13
- const int64_t ksize = OW * (KH > 1 ? KW : 1);
13
+ const int64_t ksize = OW * KH;
14
14
  const int64_t kx = i / ksize;
15
15
  const int64_t kd = kx * ksize;
16
16
  const int64_t ky = (i - kd) / OW;