@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -0,0 +1,427 @@
1
+ #include "ggml.h"
2
+ #include "common.cuh"
3
+ #include "mma.cuh"
4
+ #include "mmf.cuh"
5
+
6
+ using namespace ggml_cuda_mma;
7
+
8
+ #define MMF_ROWS_PER_BLOCK 32
9
+
10
+ template <typename T, int rows_per_block, int cols_per_block, int nwarps>
11
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
12
+ static __global__ void mul_mat_f(
13
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
14
+ const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
15
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
16
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
17
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
18
+ typedef tile<16, 8, T> tile_A;
19
+ typedef tile< 8, 8, T> tile_B;
20
+ typedef tile<16, 8, float> tile_C;
21
+
22
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
23
+ constexpr int tile_k_padded = warp_size + 4;
24
+ constexpr int ntA = rows_per_block / tile_A::I;
25
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
26
+
27
+ const int row0 = blockIdx.x * rows_per_block;
28
+ const int channel_dst = blockIdx.y;
29
+ const int channel_x = channel_dst / channel_ratio;
30
+ const int channel_y = channel_dst;
31
+ const int sample_dst = blockIdx.z;
32
+ const int sample_x = sample_dst / sample_ratio;
33
+ const int sample_y = sample_dst;
34
+
35
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
36
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
37
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
38
+
39
+ const float2 * y2 = (const float2 *) y;
40
+
41
+ extern __shared__ char data_mmv[];
42
+
43
+ tile_C C[ntA][ntB];
44
+
45
+ T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
46
+
47
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
48
+ tile_A A[ntA][warp_size / tile_A::J];
49
+ #pragma unroll
50
+ for (int itA = 0; itA < ntA; ++itA) {
51
+ #pragma unroll
52
+ for (int i = 0; i < tile_A::I; ++i) {
53
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
54
+ }
55
+ #pragma unroll
56
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
57
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
58
+ }
59
+ }
60
+
61
+ #pragma unroll
62
+ for (int itB = 0; itB < ntB; ++itB) {
63
+ if constexpr (std::is_same_v<T, float>) {
64
+ #pragma unroll
65
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
66
+ const int j = j0 + itB*tile_B::I;
67
+
68
+ tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
69
+ }
70
+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
71
+ #pragma unroll
72
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
73
+ const int j = j0 + itB*tile_B::I;
74
+
75
+ const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
76
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
77
+ }
78
+ } else {
79
+ static_assert(std::is_same_v<T, void>, "unsupported type");
80
+ }
81
+ #pragma unroll
82
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
83
+ tile_B B;
84
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
85
+ #pragma unroll
86
+ for (int itA = 0; itA < ntA; ++itA) {
87
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
88
+ }
89
+ }
90
+ }
91
+ }
92
+
93
+ float * buf_iw = (float *) data_mmv;
94
+ constexpr int kiw = nwarps*rows_per_block + 4;
95
+
96
+ if (nwarps > 1) {
97
+ __syncthreads();
98
+ }
99
+ #pragma unroll
100
+ for (int itB = 0; itB < ntB; ++itB) {
101
+ #pragma unroll
102
+ for (int itA = 0; itA < ntA; ++itA) {
103
+ #pragma unroll
104
+ for (int l = 0; l < tile_C::ne; ++l) {
105
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
106
+ const int j = itB*tile_C::J + tile_C::get_j(l);
107
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
108
+ }
109
+ }
110
+ }
111
+
112
+ if (nwarps > 1) {
113
+ __syncthreads();
114
+ }
115
+
116
+ #pragma unroll
117
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
118
+ const int j = j0 + threadIdx.y;
119
+
120
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
121
+ return;
122
+ }
123
+
124
+ float sum = 0.0f;
125
+ static_assert(rows_per_block == warp_size, "need loop/check");
126
+ #pragma unroll
127
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
128
+ const int i = i0 + threadIdx.x;
129
+
130
+ sum += buf_iw[j*kiw + i];
131
+ }
132
+ dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
133
+ }
134
+ #else
135
+ GGML_UNUSED_VARS(x, y, ids, dst,
136
+ ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
137
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
138
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
139
+ NO_DEVICE_CODE;
140
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
141
+ }
142
+
143
+ template <typename T, int cols_per_block>
144
+ static void mul_mat_f_cuda(
145
+ const T * x, const float * y, const int32_t * ids, float * dst,
146
+ const int64_t ncols_x, const int64_t nrows_x,
147
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
148
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
149
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
150
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
151
+ cudaStream_t stream) {
152
+ typedef tile<16, 8, T> tile_A;
153
+ typedef tile< 8, 8, T> tile_B;
154
+
155
+ GGML_ASSERT(!ids && "mul_mat_id not implemented");
156
+
157
+ GGML_ASSERT(ncols_x % 2 == 0);
158
+ GGML_ASSERT(stride_row % 2 == 0);
159
+ GGML_ASSERT(stride_col_y % 2 == 0);
160
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
161
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
162
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
163
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
164
+
165
+ const int device = ggml_cuda_get_device();
166
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
167
+
168
+ int64_t nwarps_best = 1;
169
+ int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
170
+ int64_t max_block_size = 256;
171
+ for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
172
+ const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
173
+ if (niter < niter_best) {
174
+ niter_best = niter;
175
+ nwarps_best = nwarps;
176
+ }
177
+ }
178
+
179
+ constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
180
+ const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
181
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
182
+ const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
183
+ const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
184
+ const dim3 block_dims(warp_size, nwarps_best, 1);
185
+ switch (nwarps_best) {
186
+ case 1: {
187
+ mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
188
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
189
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
190
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
191
+ } break;
192
+ case 2: {
193
+ mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
194
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
195
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
196
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
197
+ } break;
198
+ case 3: {
199
+ mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
200
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
201
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
202
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
203
+ } break;
204
+ case 4: {
205
+ mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
206
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
207
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
208
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
209
+ } break;
210
+ case 5: {
211
+ mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
212
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
213
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
214
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
215
+ } break;
216
+ case 6: {
217
+ mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
218
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
219
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
220
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
221
+ } break;
222
+ case 7: {
223
+ mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
224
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
225
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
226
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
227
+ } break;
228
+ case 8: {
229
+ mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
230
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
231
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
232
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
233
+ } break;
234
+ default: {
235
+ GGML_ABORT("fatal error");
236
+ } break;
237
+ }
238
+ }
239
+
240
+ template <typename T>
241
+ static void mul_mat_f_switch_cols_per_block(
242
+ const T * x, const float * y, const int32_t * ids, float * dst,
243
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
244
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
245
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
246
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
247
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
248
+ cudaStream_t stream) {
249
+ switch (ncols_dst) {
250
+ case 1: {
251
+ mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
252
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
253
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
254
+ } break;
255
+ case 2: {
256
+ mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
257
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
258
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
259
+ } break;
260
+ case 3: {
261
+ mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
262
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
263
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
264
+ } break;
265
+ case 4: {
266
+ mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
267
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
268
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
269
+ } break;
270
+ case 5: {
271
+ mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
272
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
273
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
274
+ } break;
275
+ case 6: {
276
+ mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
277
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
278
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
279
+ } break;
280
+ case 7: {
281
+ mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
282
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
283
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
284
+ } break;
285
+ case 8: {
286
+ mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
287
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
288
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
289
+ } break;
290
+ case 9: {
291
+ mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
292
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
293
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
294
+ } break;
295
+ case 10: {
296
+ mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
297
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
298
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
299
+ } break;
300
+ case 11: {
301
+ mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
302
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
303
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
304
+ } break;
305
+ case 12: {
306
+ mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
307
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
308
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
309
+ } break;
310
+ case 13: {
311
+ mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
312
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
313
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
314
+ } break;
315
+ case 14: {
316
+ mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
317
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
318
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
319
+ } break;
320
+ case 15: {
321
+ mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
322
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
323
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
324
+ } break;
325
+ case 16: {
326
+ mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
327
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
328
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
329
+ } break;
330
+ default: {
331
+ GGML_ABORT("fatal error");
332
+ } break;
333
+ }
334
+ }
335
+
336
+ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
337
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
338
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
339
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
340
+
341
+ GGML_TENSOR_BINARY_OP_LOCALS;
342
+
343
+ const size_t ts_src0 = ggml_type_size(src0->type);
344
+ const size_t ts_src1 = ggml_type_size(src1->type);
345
+ const size_t ts_dst = ggml_type_size(dst->type);
346
+
347
+ GGML_ASSERT(ne13 == ne3);
348
+
349
+ GGML_ASSERT( nb00 == ts_src0);
350
+ GGML_ASSERT( nb10 == ts_src1);
351
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
352
+ GGML_ASSERT( nb0 == ts_dst);
353
+
354
+ const float * src1_d = (const float *) src1->data;
355
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
356
+ float * dst_d = (float *) dst->data;
357
+
358
+ const int64_t s01 = src0->nb[1] / ts_src0;
359
+ const int64_t s11 = src1->nb[1] / ts_src1;
360
+ const int64_t s1 = dst->nb[1] / ts_dst;
361
+ const int64_t s02 = src0->nb[2] / ts_src0;
362
+ const int64_t s12 = src1->nb[2] / ts_src1;
363
+ const int64_t s2 = dst->nb[2] / ts_dst;
364
+ const int64_t s03 = src0->nb[3] / ts_src0;
365
+ const int64_t s13 = src1->nb[3] / ts_src1;
366
+ const int64_t s3 = dst->nb[3] / ts_dst;
367
+
368
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
369
+ const int64_t ncols_dst = ids ? ne2 : ne1;
370
+ const int64_t nchannels_y = ids ? ne11 : ne12;
371
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
372
+ const int64_t stride_channel_dst = ids ? s1 : s2;
373
+ const int64_t stride_channel_y = ids ? s11 : s12;
374
+
375
+ GGML_ASSERT(!ids || ncols_dst == 1);
376
+
377
+ switch (src0->type) {
378
+ case GGML_TYPE_F32: {
379
+ const float * src0_d = (const float *) src0->data;
380
+ constexpr int vals_per_T = 1;
381
+ mul_mat_f_switch_cols_per_block(
382
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
383
+ ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
384
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
385
+ } break;
386
+ case GGML_TYPE_F16: {
387
+ const half2 * src0_d = (const half2 *) src0->data;
388
+ constexpr int vals_per_T = 2;
389
+ mul_mat_f_switch_cols_per_block(
390
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
391
+ ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
392
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
393
+ } break;
394
+ case GGML_TYPE_BF16: {
395
+ const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
396
+ constexpr int vals_per_T = 2;
397
+ mul_mat_f_switch_cols_per_block(
398
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
399
+ ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
400
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
401
+ } break;
402
+ default:
403
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
404
+ }
405
+ }
406
+
407
+ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
408
+ if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
409
+ return false;
410
+ }
411
+ if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
412
+ return false;
413
+ }
414
+ if (ne11 > 16) {
415
+ return false;
416
+ }
417
+ switch (type) {
418
+ case GGML_TYPE_F32:
419
+ return ampere_mma_available(cc);
420
+ case GGML_TYPE_F16:
421
+ return turing_mma_available(cc);
422
+ case GGML_TYPE_BF16:
423
+ return ampere_mma_available(cc);
424
+ default:
425
+ return false;
426
+ }
427
+ }
@@ -0,0 +1,5 @@
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
4
+
5
+ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);