@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -5,14 +5,16 @@
5
5
  #define FATTN_KQ_STRIDE_TILE_F16 64
6
6
 
7
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
8
+ #if !defined(GGML_USE_HIP)
9
9
  __launch_bounds__(nwarps*WARP_SIZE, 2)
10
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
10
+ #endif // !defined(GGML_USE_HIP)
11
11
  static __global__ void flash_attn_tile_ext_f16(
12
12
  const char * __restrict__ Q,
13
13
  const char * __restrict__ K,
14
14
  const char * __restrict__ V,
15
15
  const char * __restrict__ mask,
16
+ const char * __restrict__ sinks,
17
+ const int * __restrict__ KV_max,
16
18
  float * __restrict__ dst,
17
19
  float2 * __restrict__ dst_meta,
18
20
  const float scale,
@@ -47,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
47
49
  const int sequence = blockIdx.z / ne02;
48
50
  const int head = blockIdx.z - sequence*ne02;
49
51
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
50
- const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
51
- const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
52
- const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
53
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
52
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
55
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
56
+ const float * sinksf = (const float *) (sinks);
54
57
 
55
58
  const int stride_KV2 = nb11 / sizeof(half2);
56
59
 
@@ -90,7 +93,8 @@ static __global__ void flash_attn_tile_ext_f16(
90
93
 
91
94
  __syncthreads();
92
95
 
93
- for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
96
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
97
+ for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
94
98
  // Calculate KQ tile and keep track of new maximum KQ values:
95
99
 
96
100
  half kqmax_new[ncols/nwarps];
@@ -239,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
239
243
  __syncthreads();
240
244
  }
241
245
 
246
+ //Attention sink: adjust running max and sum once per head
247
+ if (sinksf && blockIdx.y == 0) {
248
+ const half sink = __float2half(sinksf[head]);
249
+
250
+ #pragma unroll
251
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
252
+ half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
253
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
254
+
255
+ const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
256
+ kqmax[j0/nwarps] = kqmax_new_j;
257
+
258
+ const half val = hexp(sink - kqmax[j0/nwarps]);
259
+ kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
260
+ if (threadIdx.x == 0) {
261
+ kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
262
+ }
263
+
264
+ #pragma unroll
265
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
266
+ VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
267
+ }
268
+ }
269
+ }
270
+
242
271
  float2 * dst2 = (float2 *) dst;
243
272
 
244
273
  #pragma unroll
@@ -270,17 +299,15 @@ static __global__ void flash_attn_tile_ext_f16(
270
299
  }
271
300
  }
272
301
  #else
273
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
274
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
275
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
276
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
277
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
278
- GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
279
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
280
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
281
- GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
282
- GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
283
- GGML_UNUSED(nb23);
302
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
303
+ max_bias, m0, m1, n_head_log2, logit_softcap,
304
+ ne00, ne01, ne02, ne03,
305
+ nb01, nb02, nb03,
306
+ ne10, ne11, ne12, ne13,
307
+ nb11, nb12, nb13,
308
+ nb21, nb22, nb23,
309
+ ne31, ne32, ne33,
310
+ nb31, nb32, nb33);
284
311
  NO_DEVICE_CODE;
285
312
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
286
313
  }
@@ -5,14 +5,16 @@
5
5
  #define FATTN_KQ_STRIDE_TILE_F32 32
6
6
 
7
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
8
+ #if !defined(GGML_USE_HIP)
9
9
  __launch_bounds__(nwarps*WARP_SIZE, 2)
10
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
10
+ #endif // !defined(GGML_USE_HIP)
11
11
  static __global__ void flash_attn_tile_ext_f32(
12
12
  const char * __restrict__ Q,
13
13
  const char * __restrict__ K,
14
14
  const char * __restrict__ V,
15
15
  const char * __restrict__ mask,
16
+ const char * __restrict__ sinks,
17
+ const int * __restrict__ KV_max,
16
18
  float * __restrict__ dst,
17
19
  float2 * __restrict__ dst_meta,
18
20
  const float scale,
@@ -36,17 +38,15 @@ static __global__ void flash_attn_tile_ext_f32(
36
38
  return;
37
39
  #endif // FP16_MMA_AVAILABLE
38
40
  if (use_logit_softcap && !(D == 128 || D == 256)) {
39
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
40
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
41
- GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
42
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
43
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
44
- GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
45
- GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
46
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
47
- GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
48
- GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
49
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
41
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
42
+ max_bias, m0, m1, n_head_log2, logit_softcap,
43
+ ne00, ne01, ne02, ne03,
44
+ nb01, nb02, nb03,
45
+ ne10, ne11, ne12, ne13,
46
+ nb11, nb12, nb13,
47
+ nb21, nb22, nb23,
48
+ ne31, ne32, ne33,
49
+ nb31, nb32, nb33);
50
50
  NO_DEVICE_CODE;
51
51
  return;
52
52
  }
@@ -58,10 +58,11 @@ static __global__ void flash_attn_tile_ext_f32(
58
58
  const int sequence = blockIdx.z / ne02;
59
59
  const int head = blockIdx.z - sequence*ne02;
60
60
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
61
- const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
62
- const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
63
- const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
64
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
61
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
62
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
63
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
64
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
65
+ const float * sinksf = (const float *) (sinks);
65
66
 
66
67
  const int stride_KV2 = nb11 / sizeof(half2);
67
68
 
@@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32(
99
100
 
100
101
  __syncthreads();
101
102
 
102
- for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
103
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
104
+ for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
103
105
  // Calculate KQ tile and keep track of new maximum KQ values:
104
106
 
105
107
  float kqmax_new[ncols/nwarps];
@@ -249,6 +251,33 @@ static __global__ void flash_attn_tile_ext_f32(
249
251
  __syncthreads();
250
252
  }
251
253
 
254
+
255
+ //Attention sink: adjust running max and sum once per head
256
+ if (sinksf && blockIdx.y == 0) {
257
+ const float sink = sinksf[head];
258
+
259
+ #pragma unroll
260
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
261
+ float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
262
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
263
+
264
+ const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
265
+ kqmax[j0/nwarps] = kqmax_new_j;
266
+
267
+ const float val = expf(sink - kqmax[j0/nwarps]);
268
+ kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
269
+ if (threadIdx.x == 0) {
270
+ kqsum[j0/nwarps] += val;
271
+ }
272
+
273
+ #pragma unroll
274
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
275
+ VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
276
+ VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
277
+ }
278
+ }
279
+ }
280
+
252
281
  float2 * dst2 = (float2 *) dst;
253
282
 
254
283
  #pragma unroll
@@ -281,17 +310,15 @@ static __global__ void flash_attn_tile_ext_f32(
281
310
  }
282
311
  }
283
312
  #else
284
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
285
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
286
- GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
287
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
288
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
289
- GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
290
- GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
291
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
292
- GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
293
- GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
294
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
313
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
314
+ max_bias, m0, m1, n_head_log2, logit_softcap,
315
+ ne00, ne01, ne02, ne03,
316
+ nb01, nb02, nb03,
317
+ ne10, ne11, ne12, ne13,
318
+ nb11, nb12, nb13,
319
+ nb21, nb22, nb23,
320
+ ne31, ne32, ne33,
321
+ nb31, nb32, nb33);
295
322
  NO_DEVICE_CODE;
296
323
  #endif // FLASH_ATTN_AVAILABLE
297
324
  }
@@ -1,6 +1,12 @@
1
1
  #include "common.cuh"
2
2
  #include "fattn-common.cuh"
3
3
 
4
+ // Currenlty llvm with the amdgcn target dose not support unrolling loops
5
+ // that contain a break that can not be resolved at compile time.
6
+ #ifdef __clang__
7
+ #pragma clang diagnostic push
8
+ #pragma clang diagnostic ignored "-Wpass-failed"
9
+ #endif // __clang__
4
10
  template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
11
  #ifndef GGML_USE_HIP
6
12
  __launch_bounds__(D, 1)
@@ -10,6 +16,8 @@ static __global__ void flash_attn_vec_ext_f16(
10
16
  const char * __restrict__ K,
11
17
  const char * __restrict__ V,
12
18
  const char * __restrict__ mask,
19
+ const char * __restrict__ sinks,
20
+ const int * __restrict__ KV_max,
13
21
  float * __restrict__ dst,
14
22
  float2 * __restrict__ dst_meta,
15
23
  const float scale,
@@ -54,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
54
62
  K += nb13*sequence + nb12*(head / gqa_ratio);
55
63
  V += nb23*sequence + nb22*(head / gqa_ratio);
56
64
 
57
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
65
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
66
+ const float * sinksf = (const float *) (sinks);
58
67
 
59
68
  const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
60
69
  const half slopeh = __float2half(slopef);
@@ -68,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
68
77
  half2 * KQ2 = (half2 *) KQ;
69
78
 
70
79
  half kqmax[ncols];
80
+ half kqsum[ncols];
71
81
  #pragma unroll
72
82
  for (int j = 0; j < ncols; ++j) {
73
83
  kqmax[j] = -HALF_MAX_HALF;
84
+ kqsum[j] = 0.0f;
74
85
  }
75
- half kqsum[ncols] = {0.0f};
76
86
 
77
87
  __shared__ half kqmax_shared[ncols][WARP_SIZE];
78
88
  __shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -171,10 +181,14 @@ static __global__ void flash_attn_vec_ext_f16(
171
181
 
172
182
  half2 VKQ[ncols] = {{0.0f, 0.0f}};
173
183
 
184
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
174
185
  K += blockIdx.y*D * nb11;
175
186
  V += blockIdx.y*D * nb21;
176
187
  maskh += blockIdx.y*D;
177
- for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
188
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
189
+ // Increment pointers after each loop:
190
+ K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
191
+
178
192
  // Calculate KQ tile and keep track of new maximum KQ values:
179
193
 
180
194
  if (mask) {
@@ -182,29 +196,7 @@ static __global__ void flash_attn_vec_ext_f16(
182
196
  for (int j = 0; j < ncols; ++j) {
183
197
  maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
184
198
  }
185
-
186
199
  __syncthreads();
187
-
188
- // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
189
- // In such cases, skip the KV slice.
190
- // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
191
- #ifndef GGML_USE_HIP
192
- bool skip = true;
193
- #pragma unroll
194
- for (int j = 0; j < ncols; ++j) {
195
- #pragma unroll
196
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
197
- const int i = i0 + threadIdx.x;
198
-
199
- const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
200
- skip = skip && isinf(tmp.x) && isinf(tmp.y);
201
- }
202
- }
203
- if (__all_sync(0xFFFFFFFF, skip)) {
204
- __syncthreads();
205
- continue;
206
- }
207
- #endif // GGML_USE_HIP
208
200
  }
209
201
 
210
202
  // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -291,9 +283,38 @@ static __global__ void flash_attn_vec_ext_f16(
291
283
  }
292
284
  }
293
285
 
294
- K += gridDim.y*D * nb11;
295
- V += gridDim.y*D * nb21;
296
- maskh += gridDim.y*D;
286
+ __syncthreads();
287
+ }
288
+
289
+ if (sinksf && blockIdx.y == 0) {
290
+ const half sink = __float2half(sinksf[head]);
291
+
292
+ #pragma unroll
293
+ for (int j = 0; j < ncols; ++j) {
294
+ if (threadIdx.x == 0) {
295
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
296
+ }
297
+ }
298
+
299
+ __syncthreads();
300
+
301
+ #pragma unroll
302
+ for (int j = 0; j < ncols; ++j) {
303
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
304
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
305
+
306
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
307
+ kqmax[j] = kqmax_new_j;
308
+
309
+ const half val = hexp(sink - kqmax[j]);
310
+ kqsum[j] = kqsum[j]*KQ_max_scale;
311
+
312
+ if (tid == 0) {
313
+ kqsum[j] += val;
314
+ }
315
+
316
+ VKQ[j] *= __half2half2(KQ_max_scale);
317
+ }
297
318
 
298
319
  __syncthreads();
299
320
  }
@@ -328,20 +349,21 @@ static __global__ void flash_attn_vec_ext_f16(
328
349
  dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
329
350
  }
330
351
  #else
331
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
332
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
333
- GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
334
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
335
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
336
- GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
337
- GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
338
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
339
- GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
340
- GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
341
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
352
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
353
+ max_bias, m0, m1, n_head_log2, logit_softcap,
354
+ ne00, ne01, ne02, ne03,
355
+ nb01, nb02, nb03,
356
+ ne10, ne11, ne12, ne13,
357
+ nb11, nb12, nb13,
358
+ nb21, nb22, nb23,
359
+ ne31, ne32, ne33,
360
+ nb31, nb32, nb33);
342
361
  NO_DEVICE_CODE;
343
362
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
344
363
  }
364
+ #ifdef __clang__
365
+ #pragma clang diagnostic pop
366
+ #endif // __clang__
345
367
 
346
368
  template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
347
369
  void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -1,6 +1,12 @@
1
1
  #include "common.cuh"
2
2
  #include "fattn-common.cuh"
3
3
 
4
+ // Currenlty llvm with the amdgcn target dose not support unrolling loops
5
+ // that contain a break that can not be resolved at compile time.
6
+ #ifdef __clang__
7
+ #pragma clang diagnostic push
8
+ #pragma clang diagnostic ignored "-Wpass-failed"
9
+ #endif // __clang__
4
10
  template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
11
  #ifndef GGML_USE_HIP
6
12
  __launch_bounds__(D, 1)
@@ -10,6 +16,8 @@ static __global__ void flash_attn_vec_ext_f32(
10
16
  const char * __restrict__ K,
11
17
  const char * __restrict__ V,
12
18
  const char * __restrict__ mask,
19
+ const char * __restrict__ sinks,
20
+ const int * __restrict__ KV_max,
13
21
  float * __restrict__ dst,
14
22
  float2 * __restrict__ dst_meta,
15
23
  const float scale,
@@ -29,17 +37,15 @@ static __global__ void flash_attn_vec_ext_f32(
29
37
 
30
38
  // Skip unused kernel variants for faster compilation:
31
39
  if (use_logit_softcap && !(D == 128 || D == 256)) {
32
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
33
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
34
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
35
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
36
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
37
- GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
38
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
39
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
40
- GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
41
- GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
42
- GGML_UNUSED(nb23);
40
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
41
+ max_bias, m0, m1, n_head_log2, logit_softcap,
42
+ ne00, ne01, ne02, ne03,
43
+ nb01, nb02, nb03,
44
+ ne10, ne11, ne12, ne13,
45
+ nb11, nb12, nb13,
46
+ nb21, nb22, nb23,
47
+ ne31, ne32, ne33,
48
+ nb31, nb32, nb33);
43
49
  NO_DEVICE_CODE;
44
50
  return;
45
51
  }
@@ -65,7 +71,8 @@ static __global__ void flash_attn_vec_ext_f32(
65
71
  K += nb13*sequence + nb12*(head / gqa_ratio);
66
72
  V += nb23*sequence + nb22*(head / gqa_ratio);
67
73
 
68
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
74
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
75
+ const float * sinksf = (const float *) (sinks);
69
76
 
70
77
  const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
71
78
 
@@ -81,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32(
81
88
  }
82
89
 
83
90
  float kqmax[ncols];
91
+ float kqsum[ncols];
84
92
  #pragma unroll
85
93
  for (int j = 0; j < ncols; ++j) {
86
94
  kqmax[j] = -FLT_MAX/2.0f;
95
+ kqsum[j] = 0.0f;
87
96
  }
88
- float kqsum[ncols] = {0.0f};
89
97
 
90
98
  __shared__ float kqmax_shared[ncols][WARP_SIZE];
91
99
  __shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -177,10 +185,14 @@ static __global__ void flash_attn_vec_ext_f32(
177
185
 
178
186
  float VKQ[ncols] = {0.0f};
179
187
 
188
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
180
189
  K += blockIdx.y*D * nb11;
181
190
  V += blockIdx.y*D * nb21;
182
191
  maskh += blockIdx.y*D;
183
- for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
192
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
193
+ // Increment pointers after each loop:
194
+ K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
195
+
184
196
  // Calculate KQ tile and keep track of new maximum KQ values:
185
197
 
186
198
  if (mask) {
@@ -188,28 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(
188
200
  for (int j = 0; j < ncols; ++j) {
189
201
  maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
190
202
  }
191
-
192
203
  __syncthreads();
193
-
194
- // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
195
- // In such cases, skip the KV slice.
196
- // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
197
- #ifndef GGML_USE_HIP
198
- bool skip = true;
199
- #pragma unroll
200
- for (int j = 0; j < ncols; ++j) {
201
- #pragma unroll
202
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
203
- const int i = i0 + threadIdx.x;
204
-
205
- skip = skip && isinf(maskf_shared[j*D + i]);
206
- }
207
- }
208
- if (__all_sync(0xFFFFFFFF, skip)) {
209
- __syncthreads();
210
- continue;
211
- }
212
- #endif // GGML_USE_HIP
213
204
  }
214
205
 
215
206
  float kqmax_new_arr[ncols];
@@ -286,9 +277,38 @@ static __global__ void flash_attn_vec_ext_f32(
286
277
  }
287
278
  }
288
279
 
289
- K += gridDim.y*D * nb11;
290
- V += gridDim.y*D * nb21;
291
- maskh += gridDim.y*D;
280
+ __syncthreads();
281
+ }
282
+
283
+ if (sinksf && blockIdx.y == 0) {
284
+ const float sink = sinksf[head];
285
+
286
+ #pragma unroll
287
+ for (int j = 0; j < ncols; ++j) {
288
+ if (threadIdx.x == 0) {
289
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
290
+ }
291
+ }
292
+
293
+ __syncthreads();
294
+
295
+ #pragma unroll
296
+ for (int j = 0; j < ncols; ++j) {
297
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
298
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
299
+
300
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
301
+ kqmax[j] = kqmax_new_j;
302
+
303
+ const float val = expf(sink - kqmax[j]);
304
+ kqsum[j] = kqsum[j]*KQ_max_scale;
305
+
306
+ if (tid == 0) {
307
+ kqsum[j] += val;
308
+ }
309
+
310
+ VKQ[j] *= KQ_max_scale;
311
+ }
292
312
 
293
313
  __syncthreads();
294
314
  }
@@ -323,20 +343,21 @@ static __global__ void flash_attn_vec_ext_f32(
323
343
  dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
324
344
  }
325
345
  #else
326
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
327
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
328
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
329
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
330
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
331
- GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
332
- GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
333
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
334
- GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
335
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
336
- GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
346
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
347
+ max_bias, m0, m1, n_head_log2, logit_softcap,
348
+ ne00, ne01, ne02, ne03,
349
+ nb01, nb02, nb03,
350
+ ne10, ne11, ne12, ne13,
351
+ nb11, nb12, nb13,
352
+ nb21, nb22, nb23,
353
+ ne31, ne32, ne33,
354
+ nb31, nb32, nb33);
337
355
  NO_DEVICE_CODE;
338
356
  #endif // FLASH_ATTN_AVAILABLE
339
357
  }
358
+ #ifdef __clang__
359
+ #pragma clang diagnostic pop
360
+ #endif // __clang__
340
361
 
341
362
  template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
342
363
  void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {