@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -305,6 +305,27 @@ void main() {
305
305
  return;
306
306
  }
307
307
 
308
+ if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
309
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
310
+ float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
311
+
312
+ float ms = 1.0f;
313
+ float vs = 1.0f;
314
+
315
+ if (sink > Mf[r]) {
316
+ ms = exp(Mf[r] - sink);
317
+
318
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
319
+ Of[r][d] *= ms;
320
+ }
321
+ } else {
322
+ vs = exp(sink - Mf[r]);
323
+ }
324
+
325
+ Lf[r] = Lf[r]*ms + vs;
326
+ }
327
+ }
328
+
308
329
  float Lfrcp[Br];
309
330
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
310
331
  Lfrcp[r] = 1.0 / Lf[r];
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
9
9
  layout (constant_id = 5) const uint32_t Clamp = 0;
10
10
  layout (constant_id = 6) const uint32_t D_split = 16;
11
11
 
12
+ // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
13
+ const uint32_t HSK_pad = (HSK + 15) & ~15;
14
+ const uint32_t HSV_pad = (HSV + 15) & ~15;
15
+
12
16
  layout (push_constant) uniform parameter {
13
17
  uint32_t N;
14
18
  uint32_t KV;
@@ -50,10 +54,13 @@ layout (push_constant) uniform parameter {
50
54
  uint32_t k_num;
51
55
  } p;
52
56
 
57
+ #define SINK_ENABLE_BIT (1<<24)
53
58
  #define MASK_ENABLE_BIT (1<<16)
54
59
  #define N_LOG2_MASK 0xFFFF
55
60
 
56
- layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
61
+ layout (binding = 4) readonly buffer S {float data_s[];};
62
+
63
+ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
57
64
 
58
65
  #if defined(A_TYPE_PACKED16)
59
66
  #define BINDING_IDX_K 0
@@ -111,6 +118,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
111
118
  return ACC_TYPE(pow(base, ACC_TYPE(exph)));
112
119
  }
113
120
 
121
+ // Load the sink value, indexed by Q's dimension 2.
122
+ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
123
+ {
124
+ const uint32_t h = iq2 + (r % p.gqa_ratio);
125
+
126
+ return ACC_TYPE(data_s[h]);
127
+ }
128
+
114
129
  uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
115
130
  iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
116
131
  q_stride, k_stride, v_stride, m_stride;
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
46
46
  shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
47
47
  shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
48
48
 
49
- const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
49
+ const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
50
50
  shared f16vec4 Qf[Br * qstride];
51
51
 
52
52
  // Avoid padding for hsk==256 to make it fit in 48KB shmem.
53
53
  const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
54
54
  shared ACC_TYPE sfsh[Bc * sfshstride];
55
55
 
56
- const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
56
+ const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
57
57
  shared f16vec4 ksh[Bc * kshstride];
58
58
 
59
59
  shared float slope[Br];
@@ -74,6 +74,21 @@ void main() {
74
74
 
75
75
  #define tile_row(r) (row_tid * rows_per_thread + (r))
76
76
 
77
+ // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
78
+ if ((HSK % 16) != 0) {
79
+ [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
80
+ if (i + tid < Br * qstride) {
81
+ Qf[i + tid] = f16vec4(0);
82
+ }
83
+ }
84
+ [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
85
+ if (i + tid < Bc * kshstride) {
86
+ ksh[i + tid] = f16vec4(0);
87
+ }
88
+ }
89
+ barrier();
90
+ }
91
+
77
92
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
78
93
 
79
94
  [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
@@ -151,14 +166,14 @@ void main() {
151
166
  }
152
167
  barrier();
153
168
 
154
- // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
169
+ // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
155
170
  // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
156
171
  // This is written transposed in order to allow for N being 8 if implementations need it
157
172
  coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
158
173
  coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
159
174
  coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
160
175
 
161
- for (uint32_t d = 0; d < HSK / 16; ++d) {
176
+ for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
162
177
  coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
163
178
 
164
179
  uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@@ -210,7 +225,7 @@ void main() {
210
225
 
211
226
  [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
212
227
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
213
- Of[r][d] = float16_t(eMf[r]) * Of[r][d];
228
+ Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
214
229
  }
215
230
  }
216
231
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@@ -233,7 +248,7 @@ void main() {
233
248
  vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
234
249
  #endif
235
250
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
236
- Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
251
+ Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
237
252
  }
238
253
  }
239
254
  }
@@ -288,7 +303,7 @@ void main() {
288
303
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
289
304
  [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
290
305
 
291
- Of[r][d] = float16_t(eMf[r]) * Of[r][d];
306
+ Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
292
307
  tmpshv4[tid] = Of[r][d];
293
308
 
294
309
  barrier();
@@ -329,6 +344,27 @@ void main() {
329
344
  return;
330
345
  }
331
346
 
347
+ if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
348
+ [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
349
+ float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
350
+
351
+ float ms = 1.0f;
352
+ float vs = 1.0f;
353
+
354
+ if (sink > Mf[r]) {
355
+ ms = exp(Mf[r] - sink);
356
+
357
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
358
+ Of[r][d] *= ACC_TYPE(ms);
359
+ }
360
+ } else {
361
+ vs = exp(sink - Mf[r]);
362
+ }
363
+
364
+ Lf[r] = Lf[r]*ms + vs;
365
+ }
366
+ }
367
+
332
368
  float Lfrcp[rows_per_thread];
333
369
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
334
370
  Lfrcp[r] = 1.0 / Lf[r];
@@ -336,7 +372,7 @@ void main() {
336
372
 
337
373
  [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
338
374
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
339
- Of[r][d] *= float16_t(Lfrcp[r]);
375
+ Of[r][d] *= ACC_TYPE(Lfrcp[r]);
340
376
  }
341
377
  }
342
378
 
@@ -104,16 +104,16 @@ void main() {
104
104
  tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
105
105
  tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
106
106
 
107
- coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
108
- coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
107
+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
108
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
109
109
 
110
110
  uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
111
- coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
111
+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
112
112
 
113
- Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
113
+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
114
114
  Qf16 *= float16_t(p.scale);
115
115
 
116
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
116
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
117
117
 
118
118
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
119
119
 
@@ -140,10 +140,10 @@ void main() {
140
140
 
141
141
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
142
142
 
143
- coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
143
+ coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
144
144
 
145
145
  uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
146
- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
146
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
147
147
  S = coopMatMulAdd(Qf16, K_T, S);
148
148
 
149
149
  if (p.logit_softcap != 0.0f) {
@@ -208,31 +208,31 @@ void main() {
208
208
  rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
209
209
  rowsum = coopMatMulAdd(P_A, One, rowsum);
210
210
 
211
- coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
211
+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
212
212
  uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
213
- coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
213
+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
214
214
 
215
215
  L = eM*L + rowsum;
216
216
 
217
217
  // This is the "diagonal" matrix in the paper, but since we do componentwise
218
218
  // multiply rather than matrix multiply it has the diagonal element smeared
219
219
  // across the row
220
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
220
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
221
221
 
222
222
  // resize eM by using smear/reduce
223
223
  coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
224
224
 
225
225
  // multiply with fp16 accumulation, then add to O.
226
- coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
226
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
227
227
  PV = coopMatMulAdd(P_A, V, PV);
228
228
 
229
- O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
229
+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
230
230
  }
231
231
 
232
232
  // If there is split_k, then the split_k resolve shader does the final
233
233
  // division by L. Store the intermediate O value and per-row m and L values.
234
234
  if (p.k_num > 1) {
235
- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
235
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
236
236
 
237
237
  uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
238
238
  coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
@@ -243,11 +243,39 @@ void main() {
243
243
  return;
244
244
  }
245
245
 
246
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
246
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
247
247
 
248
248
  // resize L by using smear/reduce
249
249
  coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
250
250
 
251
+ if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
252
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
253
+ coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
254
+
255
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
256
+
257
+ // resize M by using smear/reduce
258
+ coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
259
+
260
+ // O, Ldiag, Mr all have the same type so all element locations match
261
+ [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
262
+ ACC_TYPE sink = S[i];
263
+
264
+ ACC_TYPE ms = ACC_TYPE(1.0f);
265
+ ACC_TYPE vs = ACC_TYPE(1.0f);
266
+
267
+ if (sink > Mr[i]) {
268
+ ms = exp(Mr[i] - sink);
269
+
270
+ O[i] *= ms;
271
+ } else {
272
+ vs = exp(sink - Mr[i]);
273
+ }
274
+
275
+ Ldiag[i] = Ldiag[i]*ms + vs;
276
+ }
277
+ }
278
+
251
279
  [[unroll]]
252
280
  for (int k = 0; k < Ldiag.length(); ++k) {
253
281
  Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
@@ -257,7 +285,7 @@ void main() {
257
285
 
258
286
  uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
259
287
 
260
- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
288
+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
261
289
  if (p.gqa_ratio > 1) {
262
290
  coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
263
291
  } else {
@@ -267,6 +295,6 @@ void main() {
267
295
  // permute dimensions
268
296
  tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
269
297
 
270
- coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
298
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
271
299
  }
272
300
  }
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
7
7
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8
8
 
9
9
  layout (binding = 0) readonly buffer A {float data_a[];};
10
- layout (binding = 1) writeonly buffer D {float data_d[];};
10
+ layout (binding = 1) readonly buffer B {float data_s[];};
11
+ layout (binding = 2) writeonly buffer D {float data_d[];};
11
12
 
12
13
  layout (push_constant) uniform parameter {
13
14
  uint D;
14
15
  uint N;
15
16
  uint ne3;
16
17
  uint k_num;
18
+ uint sinks;
17
19
  } p;
18
20
 
19
21
  shared float tmpsh[BLOCK_SIZE];
@@ -73,6 +75,22 @@ void main() {
73
75
  }
74
76
  L = tmpsh[0];
75
77
 
78
+ float sink;
79
+ if (p.sinks != 0) {
80
+ sink = data_s[n];
81
+
82
+ float ms = 1.0f;
83
+ float vs = 1.0f;
84
+
85
+ if (sink > m_max) {
86
+ ms = exp(m_max - sink);
87
+ } else {
88
+ vs = exp(sink - m_max);
89
+ }
90
+
91
+ L = L*ms + vs;
92
+ }
93
+
76
94
  L = 1.0 / L;
77
95
 
78
96
  // D dimension is split across workgroups in the y dimension
@@ -85,6 +103,13 @@ void main() {
85
103
  float m = data_a[m_offset + k * lm_stride];
86
104
  O += exp(m - m_max) * data_a[o_offset];
87
105
  }
106
+ if (p.sinks != 0) {
107
+ if (sink > m_max) {
108
+ float ms = 1.0f;
109
+ ms = exp(m_max - sink);
110
+ O *= ms;
111
+ }
112
+ }
88
113
  O *= L;
89
114
  data_d[iq3 * D * N + D * n + d] = O;
90
115
  }
@@ -2,6 +2,7 @@
2
2
  #extension GL_EXT_control_flow_attributes : require
3
3
 
4
4
  #include "rte.comp"
5
+ #include "utils.comp"
5
6
 
6
7
  layout (push_constant) uniform parameter
7
8
  {
@@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
28
29
  uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
29
30
  uint get_doffset() { return p.misalign_offsets & 0xFF; }
30
31
 
31
- // mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
32
- uint fastmod(uint a, uint b) {
33
- if ((b & (b-1)) == 0) {
34
- return a & (b-1);
35
- }
36
- return a % b;
37
- }
38
-
39
- uint fastdiv(uint a, uint b) {
40
- return (a < b) ? 0 : (a / b);
41
- }
42
32
 
43
33
  void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
44
- i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
45
- const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
46
- i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
47
- const uint i02_offset = i02*p.ne01*p.ne00;
48
- i01 = (idx - i03_offset - i02_offset) / p.ne00;
49
- i00 = idx - i03_offset - i02_offset - i01*p.ne00;
34
+ get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
50
35
  }
51
36
 
52
37
  uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
@@ -14,4 +14,6 @@ layout (push_constant) uniform parameter
14
14
  uint ne00;
15
15
  uint ne20;
16
16
  uint mode;
17
+ float alpha;
18
+ float limit;
17
19
  } p;
@@ -1,6 +1,10 @@
1
1
  #extension GL_EXT_control_flow_attributes : enable
2
2
  #extension GL_EXT_shader_16bit_storage : require
3
3
  #extension GL_EXT_shader_8bit_storage : require
4
+ #if USE_SUBGROUP_ADD
5
+ #extension GL_KHR_shader_subgroup_basic : require
6
+ #extension GL_KHR_shader_subgroup_arithmetic : require
7
+ #endif
4
8
 
5
9
  #ifdef MUL_MAT_ID
6
10
  #define EXPERT_COUNT 8
@@ -90,7 +94,38 @@ layout (constant_id = 2) const uint NUM_COLS = 1;
90
94
 
91
95
  shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
92
96
 
93
- void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
97
+ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
98
+ // subgroupAdd is probably faster on devices that support it,
99
+ // particularly when the workgroup has more than one subgroup
100
+ #if USE_SUBGROUP_ADD
101
+ // sum up partial sums within a subgroup
102
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
103
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
104
+ temp[j][n] = subgroupAdd(temp[j][n]);
105
+ }
106
+ }
107
+
108
+ // Go through shared memory to sum partials across subgroups
109
+ if (gl_SubgroupInvocationID == 0) {
110
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
111
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
112
+ tmpsh[j][n][gl_SubgroupID] = temp[j][n];
113
+ }
114
+ }
115
+ }
116
+ barrier();
117
+ if (tid == 0) {
118
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
119
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
120
+ temp[j][n] = FLOAT_TYPE(0);
121
+ [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
122
+ temp[j][n] += tmpsh[j][n][s];
123
+ }
124
+ data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
125
+ }
126
+ }
127
+ }
128
+ #else
94
129
  // sum up partial sums and write back result
95
130
  [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
96
131
  [[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -115,4 +150,5 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32
115
150
  }
116
151
  }
117
152
  }
153
+ #endif
118
154
  }
@@ -26,6 +26,9 @@ layout (push_constant) uniform parameter
26
26
  uint ne12;
27
27
  uint b_offset;
28
28
  uint d_offset;
29
+ uint nb03;
30
+ uint nb13;
31
+ uint nb23;
29
32
  } p;
30
33
 
31
34
  shared FLOAT_TYPE tmp[BLOCK_SIZE];
@@ -34,6 +37,7 @@ void main() {
34
37
  const uint tid = gl_LocalInvocationID.x;
35
38
  const uint row_x = gl_GlobalInvocationID.y;
36
39
  const uint channel = gl_GlobalInvocationID.z;
40
+ const uint i3 = gl_WorkGroupID.x;
37
41
  const uint channel_x = channel / p.channel_x_divisor;
38
42
  const uint channel_y = channel % p.ne12;
39
43
 
@@ -41,7 +45,7 @@ void main() {
41
45
  const uint nrows_dst = p.nrows_x;
42
46
  const uint row_dst = row_x;
43
47
 
44
- const uint idst = channel*nrows_dst + row_dst;
48
+ const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
45
49
 
46
50
  FLOAT_TYPE temp = 0.0f;
47
51
 
@@ -58,8 +62,8 @@ void main() {
58
62
 
59
63
  const uint row_y = col_x;
60
64
 
61
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
62
- const uint iy = channel_y*p.channel_stride_y + row_y;
65
+ const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
66
+ const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
63
67
 
64
68
  const vec4 av4 = vec4(data_a_v4[ix / 4]);
65
69
  const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -74,8 +78,8 @@ void main() {
74
78
 
75
79
  const uint row_y = col_x;
76
80
 
77
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
78
- const uint iy = channel_y*p.channel_stride_y + row_y;
81
+ const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
82
+ const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
79
83
 
80
84
  const vec4 av4 = vec4(data_a_v4[ix / 4]);
81
85
  const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -91,8 +95,8 @@ void main() {
91
95
 
92
96
  const uint row_y = col_x;
93
97
 
94
- const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
95
- const uint iy = channel_y*p.channel_stride_y + row_y;
98
+ const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
99
+ const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
96
100
 
97
101
  const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
98
102