@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -1,65 +1,76 @@
1
1
  #include "im2col.cuh"
2
2
 
3
+ #define MAX_GRIDDIM_Z 65535
4
+
3
5
  template <typename T>
4
6
  static __global__ void im2col_kernel(
5
- const float * x, T * dst, int64_t batch_offset,
6
- int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
7
+ const float * x, T * dst,
8
+ int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
9
+ int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
7
10
  int s0, int s1, int p0, int p1, int d0, int d1) {
8
11
  const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
9
- if (i >= pelements) {
12
+ if (i >= IC_KH_KW) {
10
13
  return;
11
14
  }
12
15
 
13
- const int64_t ksize = OW * KH;
14
- const int64_t kx = i / ksize;
15
- const int64_t kd = kx * ksize;
16
- const int64_t ky = (i - kd) / OW;
17
- const int64_t ix = i % OW;
16
+ const int64_t iic = i / (KH_KW);
17
+ const int64_t rem = i - iic * KH_KW;
18
+ const int64_t ikh = rem / KW;
19
+ const int64_t ikw = rem - ikh * KW;
18
20
 
19
- const int64_t oh = blockIdx.y;
20
- const int64_t batch = blockIdx.z / IC;
21
- const int64_t ic = blockIdx.z % IC;
21
+ const int64_t iow = blockIdx.y;
22
+ for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
23
+ const int64_t in = iz / OH;
24
+ const int64_t ioh = iz - in * OH;
22
25
 
23
- const int64_t iiw = ix * s0 + kx * d0 - p0;
24
- const int64_t iih = oh * s1 + ky * d1 - p1;
26
+ const int64_t iiw = iow * s0 + ikw * d0 - p0;
27
+ const int64_t iih = ioh * s1 + ikh * d1 - p1;
25
28
 
26
- const int64_t offset_dst =
27
- ((batch * OH + oh) * OW + ix) * CHW +
28
- (ic * (KW * KH) + ky * KW + kx);
29
+ const int64_t offset_dst =
30
+ ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
29
31
 
30
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
31
- dst[offset_dst] = 0.0f;
32
- } else {
33
- const int64_t offset_src = ic * offset_delta + batch * batch_offset;
34
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
32
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
33
+ dst[offset_dst] = 0.0f;
34
+ } else {
35
+ const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
36
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
37
+ }
35
38
  }
39
+
40
+ GGML_UNUSED(IC);
41
+ GGML_UNUSED(KH);
36
42
  }
37
43
 
44
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
38
45
  template <typename T>
39
46
  static void im2col_cuda(const float * x, T* dst,
40
47
  int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
41
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
48
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
42
49
  int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
43
- const int parallel_elements = OW * KW * KH;
44
- const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
45
- dim3 block_nums(num_blocks, OH, batch * IC);
46
- im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
50
+ const int64_t IC_KH_KW = IC * KH * KW;
51
+ const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
52
+ const int64_t N_OH = N * OH;
53
+ const int64_t KH_KW = KW*KH;
54
+ dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
55
+ im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
56
+ IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
57
+ s0, s1, p0, p1, d0, d1);
47
58
  }
48
59
 
49
60
  static void im2col_cuda_f16(const float * x, half * dst,
50
61
  int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
51
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
62
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
52
63
  int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
53
64
 
54
- im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
65
+ im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
55
66
  }
56
67
 
57
68
  static void im2col_cuda_f32(const float * x, float * dst,
58
69
  int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
59
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
70
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
60
71
  int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
61
72
 
62
- im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
73
+ im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
63
74
  }
64
75
 
65
76
  void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -91,13 +102,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
91
102
  const int64_t OH = is_2D ? dst->ne[2] : 1;
92
103
  const int64_t OW = dst->ne[1];
93
104
 
94
- const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
95
- const int64_t batch = src1->ne[is_2D ? 3 : 2];
96
- const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
105
+ const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
106
+ const int64_t N = src1->ne[is_2D ? 3 : 2];
107
+ const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
97
108
 
98
109
  if(dst->type == GGML_TYPE_F16) {
99
- im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
110
+ im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
100
111
  } else {
101
- im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
112
+ im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
102
113
  }
103
114
  }
@@ -1,4 +1,14 @@
1
1
  #include "mean.cuh"
2
+ #include "reduce_rows.cuh"
3
+
4
+ #ifdef GGML_CUDA_USE_CUB
5
+ #include <cub/cub.cuh>
6
+ using namespace cub;
7
+ #endif // GGML_CUDA_USE_CUB
8
+
9
+ template <typename T> __global__ void divide_by_count(T * result, size_t count) {
10
+ *result /= static_cast<T>(count);
11
+ }
2
12
 
3
13
  void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4
14
  const ggml_tensor * src0 = dst->src[0];
@@ -13,7 +23,51 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
23
  const int64_t ncols = src0->ne[0];
14
24
  const int64_t nrows = ggml_nrows(src0);
15
25
 
16
- const dim3 block_dims(WARP_SIZE, 1, 1);
26
+ // Special case for reducing vectors
27
+ #ifdef GGML_CUDA_USE_CUB
28
+ #ifdef USE_CUDA_GRAPH
29
+ cudaStreamCaptureStatus iscapturing;
30
+ CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
31
+ #endif // USE_CUDA_GRAPH
32
+ if ((nrows == 1) &&
33
+ #ifdef USE_CUDA_GRAPH
34
+ // CUDA_GRAPHS_DISABLED
35
+ ((ncols > 65536) &&
36
+ ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
37
+ ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
38
+ ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
39
+ // CUDA_GRAPHS ENABLED
40
+ ((ncols > 32768) &&
41
+ !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
42
+ ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
43
+ ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
44
+ #else
45
+ (ncols > 65536)) {
46
+ #endif // USE_CUDA_GRAPH
47
+ // Single row - use device-wide reduction
48
+ size_t tmp_size = 0;
49
+ ggml_cuda_pool & pool = ctx.pool();
50
+
51
+ DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
52
+
53
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
54
+ DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
55
+
56
+ // Divide by ncols
57
+ divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
58
+ return;
59
+ }
60
+ #endif // GGML_CUDA_USE_CUB
61
+
17
62
  const dim3 block_nums(nrows, 1, 1);
18
- reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
63
+
64
+ const int id = ggml_cuda_get_device();
65
+ const int nsm = ggml_cuda_info().devices[id].nsm;
66
+ if ((nrows / nsm) < 2) {
67
+ const dim3 block_dims(512, 1, 1);
68
+ reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
69
+ } else {
70
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
71
+ reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
72
+ }
19
73
  }
@@ -23,13 +23,13 @@
23
23
  static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
24
24
  int ret = 0;
25
25
 
26
- #ifdef NEW_MMA_AVAILABLE
26
+ #ifdef TURING_MMA_AVAILABLE
27
27
  asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
28
28
  : "=r"(ret) : "r"(x));
29
29
  #else
30
30
  GGML_UNUSED(x);
31
31
  NO_DEVICE_CODE;
32
- #endif // defined(NEW_MMA_AVAILABLE)
32
+ #endif // defined(TURING_MMA_AVAILABLE)
33
33
  return ret;
34
34
  }
35
35
 
@@ -68,7 +68,7 @@ namespace ggml_cuda_mma {
68
68
  static constexpr int I = I_;
69
69
  static constexpr int J = J_;
70
70
 
71
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
71
+ #if defined(GGML_USE_HIP)
72
72
  static constexpr int ne = I * J / 64;
73
73
  T x[ne] = {0};
74
74
 
@@ -132,7 +132,7 @@ namespace ggml_cuda_mma {
132
132
  static_assert(I == -1 && J == -1, "template specialization not implemented");
133
133
  }
134
134
  }
135
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
135
+ #endif // defined(GGML_USE_HIP)
136
136
  };
137
137
 
138
138
  template <int I_, int J_>
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
167
167
  }
168
168
  };
169
169
 
170
+ template <int I_, int J_>
171
+ struct tile<I_, J_, nv_bfloat162> {
172
+ static constexpr int I = I_;
173
+ static constexpr int J = J_;
174
+ static constexpr int ne = I * J / WARP_SIZE;
175
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
176
+
177
+ static __device__ __forceinline__ int get_i(const int l) {
178
+ if constexpr (I == 8 && J == 8) {
179
+ return threadIdx.x / 4;
180
+ } else if constexpr (I == 16 && J == 4) {
181
+ return l * 8 + threadIdx.x / 4;
182
+ } else if constexpr (I == 16 && J == 8) {
183
+ return (l % 2) * 8 + threadIdx.x / 4;
184
+ } else {
185
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
186
+ }
187
+ }
188
+
189
+ static __device__ __forceinline__ int get_j(const int l) {
190
+ if constexpr (I == 8 && J == 8) {
191
+ return l * 4 + threadIdx.x % 4;
192
+ } else if constexpr (I == 16 && J == 4) {
193
+ return threadIdx.x % 4;
194
+ } else if constexpr (I == 16 && J == 8) {
195
+ return (l / 2) * 4 + threadIdx.x % 4;
196
+ } else {
197
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
198
+ }
199
+ }
200
+ };
201
+
170
202
  template <int I, int J>
171
203
  static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
172
204
  tile<I, J/2, half2> ret;
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
209
241
  template <typename T>
210
242
  static __device__ __forceinline__ void load_ldmatrix(
211
243
  tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
212
- #ifdef NEW_MMA_AVAILABLE
244
+ #ifdef TURING_MMA_AVAILABLE
213
245
  int * xi = (int *) t.x;
214
246
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
215
247
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
217
249
  : "l"(xs));
218
250
  #else
219
251
  load_generic(t, xs0, stride);
220
- #endif // NEW_MMA_AVAILABLE
252
+ #endif // TURING_MMA_AVAILABLE
221
253
  }
222
254
 
223
255
  template <typename T>
224
256
  static __device__ __forceinline__ void load_ldmatrix(
225
257
  tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
226
- #ifdef NEW_MMA_AVAILABLE
258
+ #ifdef TURING_MMA_AVAILABLE
227
259
  int * xi = (int *) t.x;
228
260
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
229
261
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
232
264
  #else
233
265
  load_generic(xs0, stride);
234
266
  GGML_UNUSED(t);
235
- #endif // NEW_MMA_AVAILABLE
267
+ #endif // TURING_MMA_AVAILABLE
236
268
  }
237
269
 
238
270
  template <typename T>
239
271
  static __device__ __forceinline__ void load_ldmatrix(
240
272
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
241
- #if defined(NEW_MMA_AVAILABLE)
273
+ #if defined(TURING_MMA_AVAILABLE)
242
274
  int * xi = (int * ) t.x;
243
275
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
244
276
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -246,29 +278,27 @@ namespace ggml_cuda_mma {
246
278
  : "l"(xs));
247
279
  #else
248
280
  load_generic(t, xs0, stride);
249
- #endif // NEW_MMA_AVAILABLE
281
+ #endif // TURING_MMA_AVAILABLE
250
282
  }
251
283
 
252
284
  template <typename T>
253
285
  static __device__ __forceinline__ void load_ldmatrix_trans(
254
286
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
255
- #ifdef NEW_MMA_AVAILABLE
287
+ #ifdef TURING_MMA_AVAILABLE
256
288
  int * xi = (int * ) t.x;
257
289
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
258
290
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
259
291
  : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
260
292
  : "l"(xs));
261
293
  #else
262
- GGML_UNUSED(t);
263
- GGML_UNUSED(xs0);
264
- GGML_UNUSED(stride);
294
+ GGML_UNUSED_VARS(t, xs0, stride);
265
295
  NO_DEVICE_CODE;
266
- #endif // NEW_MMA_AVAILABLE
296
+ #endif // TURING_MMA_AVAILABLE
267
297
  }
268
298
 
269
299
  static __device__ __forceinline__ void mma(
270
300
  tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
271
- #ifdef NEW_MMA_AVAILABLE
301
+ #ifdef TURING_MMA_AVAILABLE
272
302
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
273
303
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
274
304
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -283,16 +313,14 @@ namespace ggml_cuda_mma {
283
313
  : "r"(A.x[1]), "r"(B.x[0]));
284
314
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
285
315
  #else
286
- GGML_UNUSED(D);
287
- GGML_UNUSED(A);
288
- GGML_UNUSED(B);
316
+ GGML_UNUSED_VARS(D, A, B);
289
317
  NO_DEVICE_CODE;
290
- #endif // NEW_MMA_AVAILABLE
318
+ #endif // TURING_MMA_AVAILABLE
291
319
  }
292
320
 
293
321
  static __device__ __forceinline__ void mma(
294
322
  tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
295
- #ifdef NEW_MMA_AVAILABLE
323
+ #ifdef TURING_MMA_AVAILABLE
296
324
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
297
325
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
298
326
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -313,16 +341,14 @@ namespace ggml_cuda_mma {
313
341
  : "r"(A.x[3]), "r"(B.x[1]));
314
342
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
315
343
  #else
316
- GGML_UNUSED(D);
317
- GGML_UNUSED(A);
318
- GGML_UNUSED(B);
344
+ GGML_UNUSED_VARS(D, A, B);
319
345
  NO_DEVICE_CODE;
320
- #endif // NEW_MMA_AVAILABLE
346
+ #endif // TURING_MMA_AVAILABLE
321
347
  }
322
348
 
323
349
  static __device__ __forceinline__ void mma(
324
350
  tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
325
- #ifdef NEW_MMA_AVAILABLE
351
+ #ifdef TURING_MMA_AVAILABLE
326
352
  const int * Axi = (const int *) A.x;
327
353
  const int * Bxi = (const int *) B.x;
328
354
  int * Dxi = (int *) D.x;
@@ -340,16 +366,14 @@ namespace ggml_cuda_mma {
340
366
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
341
367
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
342
368
  #else
343
- GGML_UNUSED(D);
344
- GGML_UNUSED(A);
345
- GGML_UNUSED(B);
369
+ GGML_UNUSED_VARS(D, A, B);
346
370
  NO_DEVICE_CODE;
347
- #endif // NEW_MMA_AVAILABLE
371
+ #endif // TURING_MMA_AVAILABLE
348
372
  }
349
373
 
350
374
  static __device__ __forceinline__ void mma(
351
375
  tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
352
- #ifdef NEW_MMA_AVAILABLE
376
+ #ifdef TURING_MMA_AVAILABLE
353
377
  const int * Axi = (const int *) A.x;
354
378
  const int * Bxi = (const int *) B.x;
355
379
  int * Dxi = (int *) D.x;
@@ -376,16 +400,29 @@ namespace ggml_cuda_mma {
376
400
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
377
401
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
378
402
  #else
379
- GGML_UNUSED(D);
380
- GGML_UNUSED(A);
381
- GGML_UNUSED(B);
403
+ GGML_UNUSED_VARS(D, A, B);
382
404
  NO_DEVICE_CODE;
383
- #endif // NEW_MMA_AVAILABLE
405
+ #endif // TURING_MMA_AVAILABLE
406
+ }
407
+
408
+ static __device__ __forceinline__ void mma(
409
+ tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
410
+ #ifdef AMPERE_MMA_AVAILABLE
411
+ const int * Axi = (const int *) A.x;
412
+ const int * Bxi = (const int *) B.x;
413
+ int * Dxi = (int *) D.x;
414
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
415
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
416
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
417
+ #else
418
+ GGML_UNUSED_VARS(D, A, B);
419
+ NO_DEVICE_CODE;
420
+ #endif // AMPERE_MMA_AVAILABLE
384
421
  }
385
422
 
386
423
  static __device__ __forceinline__ void mma(
387
424
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
388
- #ifdef NEW_MMA_AVAILABLE
425
+ #ifdef TURING_MMA_AVAILABLE
389
426
  const int * Axi = (const int *) A.x;
390
427
  const int * Bxi = (const int *) B.x;
391
428
  int * Dxi = (int *) D.x;
@@ -403,16 +440,29 @@ namespace ggml_cuda_mma {
403
440
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
404
441
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
405
442
  #else
406
- GGML_UNUSED(D);
407
- GGML_UNUSED(A);
408
- GGML_UNUSED(B);
443
+ GGML_UNUSED_VARS(D, A, B);
444
+ NO_DEVICE_CODE;
445
+ #endif // TURING_MMA_AVAILABLE
446
+ }
447
+
448
+ static __device__ __forceinline__ void mma(
449
+ tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
450
+ #ifdef AMPERE_MMA_AVAILABLE
451
+ const int * Axi = (const int *) A.x;
452
+ const int * Bxi = (const int *) B.x;
453
+ int * Dxi = (int *) D.x;
454
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
455
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
456
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
457
+ #else
458
+ GGML_UNUSED_VARS(D, A, B);
409
459
  NO_DEVICE_CODE;
410
- #endif // NEW_MMA_AVAILABLE
460
+ #endif // AMPERE_MMA_AVAILABLE
411
461
  }
412
462
 
413
463
  static __device__ __forceinline__ void mma(
414
464
  tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
415
- #ifdef NEW_MMA_AVAILABLE
465
+ #ifdef TURING_MMA_AVAILABLE
416
466
  const int * Axi = (const int *) A.x;
417
467
  const int * Bxi = (const int *) B.x;
418
468
  int * Dxi = (int *) D.x;
@@ -439,11 +489,9 @@ namespace ggml_cuda_mma {
439
489
  : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
440
490
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
441
491
  #else
442
- GGML_UNUSED(D);
443
- GGML_UNUSED(A);
444
- GGML_UNUSED(B);
492
+ GGML_UNUSED_VARS(D, A, B);
445
493
  NO_DEVICE_CODE;
446
- #endif // NEW_MMA_AVAILABLE
494
+ #endif // TURING_MMA_AVAILABLE
447
495
  }
448
496
 
449
497
  static __device__ __forceinline__ void mma(
@@ -467,9 +515,7 @@ namespace ggml_cuda_mma {
467
515
  0, 0, 0);
468
516
  #endif // defined(CDNA3)
469
517
  #else
470
- GGML_UNUSED(D);
471
- GGML_UNUSED(A);
472
- GGML_UNUSED(B);
518
+ GGML_UNUSED_VARS(D, A, B);
473
519
  NO_DEVICE_CODE;
474
520
  #endif // AMD_MFMA_AVAILABLE
475
521
  }
@@ -495,9 +541,7 @@ namespace ggml_cuda_mma {
495
541
  0, 0, 0);
496
542
  #endif // defined(CDNA3)
497
543
  #else
498
- GGML_UNUSED(D);
499
- GGML_UNUSED(A);
500
- GGML_UNUSED(B);
544
+ GGML_UNUSED_VARS(D, A, B);
501
545
  NO_DEVICE_CODE;
502
546
  #endif // AMD_MFMA_AVAILABLE
503
547
  }