@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -1,8 +1,20 @@
1
1
  #pragma once
2
2
 
3
3
  #include "common.cuh"
4
+
4
5
  #include <cstdint>
5
6
 
7
+ static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
8
+ const uint8_t * x8 = (const uint8_t *) x;
9
+
10
+ int x32 = x8[4*i32 + 0] << 0;
11
+ x32 |= x8[4*i32 + 1] << 8;
12
+ x32 |= x8[4*i32 + 2] << 16;
13
+ x32 |= x8[4*i32 + 3] << 24;
14
+
15
+ return x32;
16
+ }
17
+
6
18
  static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
7
19
  const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
8
20
 
@@ -16,6 +28,72 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
16
28
  return ((const int *) x)[i32]; // assume at least 4 byte alignment
17
29
  }
18
30
 
31
+ // q4 contains 8 indices with 4 bit each.
32
+ // This function selects those bytes from table that are at those indices and returns them as int2.
33
+ // The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
34
+ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
35
+ #if defined(GGML_USE_HIP)
36
+ // Load the 16-byte table into four 32-bit unsigned integers.
37
+ const uint32_t *values = (const uint32_t *)table;
38
+
39
+ const uint32_t q_even = q4;
40
+ const uint32_t q_odd = (q4 >> 4);
41
+
42
+ // Perform lookups in the lower half of the table (indices 0-7).
43
+ uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
44
+ uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
45
+
46
+ // Perform lookups in the upper half of the table (indices 8-15).
47
+ uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
48
+ uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
49
+
50
+ // Select between the low and high results based on the MSB of each index nibble.
51
+ uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
52
+ uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
53
+ uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
54
+ uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
55
+
56
+ return make_int2(res_x, res_y);
57
+ #elif !defined(GGML_USE_MUSA)
58
+ // CUDA does not have an instruction for selecting bytes with 4 bit indices.
59
+ // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
60
+ const uint32_t * table32 = (const uint32_t *) table;
61
+
62
+ // __byte_perm selects bytes based on the lower 16 bits in its third argument.
63
+ // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
64
+ // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
65
+ // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
66
+ uint32_t tmp[2];
67
+ const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
68
+ #pragma unroll
69
+ for (uint32_t i = 0; i < 2; ++i) {
70
+ const uint32_t shift = 16 * i;
71
+
72
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
73
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
74
+ tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
75
+ }
76
+
77
+ // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
78
+ // However, for the result we need ints with all even/odd 4 bit indices in q4.
79
+ // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
80
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
81
+ #else
82
+ // Generic implementation.
83
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
84
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
85
+ const char4 val0_8 = make_char4(
86
+ table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
87
+
88
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
89
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
90
+ const char4 val1_8 = make_char4(
91
+ table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
92
+
93
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
94
+ #endif
95
+ }
96
+
19
97
  // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
20
98
  // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
21
99
 
@@ -61,7 +139,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
61
139
  sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
62
140
  }
63
141
 
64
- #ifdef GGML_CUDA_F16
142
+ #ifdef FAST_FP16_AVAILABLE
65
143
  const float2 tmp = __half22float2(__hmul2(dm4, ds8));
66
144
  const float d4d8 = tmp.x;
67
145
  const float m4s8 = tmp.y;
@@ -70,7 +148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
70
148
  const float2 ds8f = __half22float2(ds8);
71
149
  const float d4d8 = dm4f.x * ds8f.x;
72
150
  const float m4s8 = dm4f.y * ds8f.y;
73
- #endif // GGML_CUDA_F16
151
+ #endif // FAST_FP16_AVAILABLE
74
152
 
75
153
  // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
76
154
  return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
@@ -132,7 +210,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
132
210
  sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
133
211
  }
134
212
 
135
- #ifdef GGML_CUDA_F16
213
+ #ifdef FAST_FP16_AVAILABLE
136
214
  const float2 tmp = __half22float2(__hmul2(dm5, ds8));
137
215
  const float d5d8 = tmp.x;
138
216
  const float m5s8 = tmp.y;
@@ -141,7 +219,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
141
219
  const float2 ds8f = __half22float2(ds8);
142
220
  const float d5d8 = dm5f.x * ds8f.x;
143
221
  const float m5s8 = dm5f.y * ds8f.y;
144
- #endif // GGML_CUDA_F16
222
+ #endif // FAST_FP16_AVAILABLE
145
223
 
146
224
  // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
147
225
  return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
@@ -175,7 +253,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
175
253
  sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
176
254
  }
177
255
 
178
- #ifdef GGML_CUDA_F16
256
+ #ifdef FAST_FP16_AVAILABLE
179
257
  const float2 tmp = __half22float2(__hmul2(dm8, ds8));
180
258
  const float d8d8 = tmp.x;
181
259
  const float m8s8 = tmp.y;
@@ -184,7 +262,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
184
262
  const float2 ds8f = __half22float2(ds8);
185
263
  const float d8d8 = dm8f.x * ds8f.x;
186
264
  const float m8s8 = dm8f.y * ds8f.y;
187
- #endif // GGML_CUDA_F16
265
+ #endif // FAST_FP16_AVAILABLE
188
266
 
189
267
  // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
190
268
  return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
@@ -211,6 +289,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
211
289
  return d8_1*sumf;
212
290
  }
213
291
 
292
+ #define VDR_MXFP4_Q8_1_MMVQ 2
293
+ #define VDR_MXFP4_Q8_1_MMQ 4
294
+
295
+ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
296
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
297
+
298
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
299
+
300
+ const int * q8 = (const int *) bq8_1->qs + iqs;
301
+
302
+ int sumi = 0;
303
+ #pragma unroll
304
+ for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
305
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
306
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
307
+
308
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
309
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
310
+ }
311
+
312
+ const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
313
+ return d * sumi;
314
+ }
315
+
214
316
  #define VDR_Q2_K_Q8_1_MMVQ 1
215
317
  #define VDR_Q2_K_Q8_1_MMQ 4
216
318
 
@@ -1068,20 +1170,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
1068
1170
  return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
1069
1171
  }
1070
1172
 
1071
- static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
1072
- const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
1073
- const int8_t * q0_8 = (const int8_t *) &q0_32;
1074
- const char4 val0_8 = make_char4(
1075
- kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
1076
-
1077
- const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
1078
- const int8_t * q1_8 = (const int8_t *) &q1_32;
1079
- const char4 val1_8 = make_char4(
1080
- kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
1081
-
1082
- return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
1083
- }
1084
-
1085
1173
  #define VDR_IQ4_NL_Q8_1_MMVQ 2
1086
1174
  #define VDR_IQ4_NL_Q8_1_MMQ 4
1087
1175
 
@@ -1096,7 +1184,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
1096
1184
  #pragma unroll
1097
1185
  for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
1098
1186
  const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
1099
- const int2 v = get_int_from_table_16(aux_q4);
1187
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
1100
1188
 
1101
1189
  sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
1102
1190
  sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1206,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
1118
1206
  #pragma unroll
1119
1207
  for (int j = 0; j < 4; ++j) {
1120
1208
  const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
1121
- const int2 v = get_int_from_table_16(aux_q4);
1209
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
1122
1210
 
1123
1211
  const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
1124
1212
  const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
@@ -6,6 +6,10 @@
6
6
  #include <cuda_bf16.h>
7
7
  #include <cuda_fp16.h>
8
8
 
9
+ #if CUDART_VERSION >= 12050
10
+ #include <cuda_fp8.h>
11
+ #endif // CUDART_VERSION >= 12050
12
+
9
13
  #if CUDART_VERSION < 11020
10
14
  #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
11
15
  #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
@@ -1,14 +1,10 @@
1
1
  #pragma once
2
2
 
3
- #define HIP_ENABLE_WARP_SYNC_BUILTINS 1
3
+ #define HIP_DISABLE_WARP_SYNC_BUILTINS 1
4
4
  #include <hip/hip_runtime.h>
5
5
  #include <hipblas/hipblas.h>
6
6
  #include <hip/hip_fp16.h>
7
- #include <hip/hip_bfloat16.h>
8
- #ifdef __HIP_PLATFORM_AMD__
9
- // for rocblas_initialize()
10
- #include "rocblas/rocblas.h"
11
- #endif // __HIP_PLATFORM_AMD__
7
+ #include <hip/hip_bf16.h>
12
8
 
13
9
  #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
14
10
  #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
@@ -26,7 +22,10 @@
26
22
  #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
27
23
  #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
28
24
  #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
25
+ #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
29
26
  #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
27
+ #define __all_sync(mask, var) __all(var)
28
+ #define __any_sync(mask, var) __any(var)
30
29
  #define cublasCreate hipblasCreate
31
30
  #define cublasDestroy hipblasDestroy
32
31
  #define cublasGemmEx hipblasGemmEx
@@ -139,7 +138,7 @@
139
138
  #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
140
139
  #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
141
140
 
142
- #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
141
+ #if HIP_VERSION >= 60500000
143
142
  #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
144
143
  #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
145
144
  #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
@@ -151,7 +150,11 @@
151
150
  #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
152
151
  #define cublasComputeType_t hipblasDatatype_t
153
152
  #define cudaDataType_t hipblasDatatype_t
154
- #endif
153
+ #endif // HIP_VERSION >= 6050000
154
+
155
+ #if !defined(__HIP_PLATFORM_AMD__)
156
+ #error "The HIP backend supports only AMD targets"
157
+ #endif // !defined(__HIP_PLATFORM_AMD__)
155
158
 
156
159
  #define __CUDA_ARCH__ 1300
157
160
 
@@ -179,8 +182,7 @@
179
182
  #define RDNA4
180
183
  #endif
181
184
 
182
- #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
183
- defined(__gfx1150__) || defined(__gfx1151__)
185
+ #if defined(__GFX11__)
184
186
  #define RDNA3
185
187
  #endif
186
188
 
@@ -197,7 +199,8 @@
197
199
  #define __has_builtin(x) 0
198
200
  #endif
199
201
 
200
- typedef hip_bfloat16 nv_bfloat16;
202
+ typedef __hip_bfloat16 nv_bfloat16;
203
+ typedef __hip_bfloat162 nv_bfloat162;
201
204
 
202
205
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
203
206
  typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
@@ -248,17 +251,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
248
251
  }
249
252
  return c;
250
253
  }
251
-
252
- #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
253
- // __shfl_xor() for half2 was added in ROCm 5.6
254
- static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
255
- typedef union half2_b32 {
256
- half2 val;
257
- int b32;
258
- } half2_b32_t;
259
- half2_b32_t tmp;
260
- tmp.val = var;
261
- tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
262
- return tmp.val;
263
- }
264
- #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
@@ -137,4 +137,5 @@
137
137
  #define cudaStreamEndCapture musaStreamEndCapture
138
138
  #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
139
139
 
140
- typedef mt_bfloat16 nv_bfloat16;
140
+ typedef __mt_bfloat16 nv_bfloat16;
141
+ typedef __mt_bfloat162 nv_bfloat162;
@@ -46,8 +46,8 @@ if (GGML_HIP_ROCWMMA_FATTN)
46
46
  endif()
47
47
  endif()
48
48
 
49
- if (${hip_VERSION} VERSION_LESS 5.5)
50
- message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
49
+ if (${hip_VERSION} VERSION_LESS 6.1)
50
+ message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
51
51
  endif()
52
52
 
53
53
  message(STATUS "HIP and hipBLAS found")
@@ -113,10 +113,18 @@ if (GGML_HIP_ROCWMMA_FATTN)
113
113
  add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
114
114
  endif()
115
115
 
116
+ if (NOT GGML_HIP_MMQ_MFMA)
117
+ add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
118
+ endif()
119
+
116
120
  if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
117
121
  add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
118
122
  endif()
119
123
 
124
+ if (GGML_HIP_EXPORT_METRICS)
125
+ set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
126
+ endif()
127
+
120
128
  if (NOT GGML_CUDA_FA)
121
129
  add_compile_definitions(GGML_CUDA_NO_FA)
122
130
  endif()
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
410
410
  #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
411
411
  #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
412
412
 
413
+ static inline float ggml_e8m0_to_fp32(uint8_t x) {
414
+ uint32_t bits; // Stores the raw bit representation of the float
415
+
416
+ // Handle special case for minimum exponent (denormalized float)
417
+ if (x == 0) {
418
+ // Bit pattern for 2^(-127):
419
+ // - Sign bit: 0 (positive)
420
+ // - Exponent: 0 (denormalized number)
421
+ // - Mantissa: 0x400000 (0.5 in fractional form)
422
+ // Value = 0.5 * 2^(-126) = 2^(-127)
423
+ bits = 0x00400000;
424
+ }
425
+ // note: disabled as we don't need to handle NaNs
426
+ //// Handle special case for NaN (all bits set)
427
+ //else if (x == 0xFF) {
428
+ // // Standard quiet NaN pattern:
429
+ // // - Sign bit: 0
430
+ // // - Exponent: all 1s (0xFF)
431
+ // // - Mantissa: 0x400000 (quiet NaN flag)
432
+ // bits = 0x7FC00000;
433
+ //}
434
+ // Normalized values (most common case)
435
+ else {
436
+ // Construct normalized float by shifting exponent into position:
437
+ // - Exponent field: 8 bits (positions 30-23)
438
+ // - Mantissa: 0 (implicit leading 1)
439
+ // Value = 2^(x - 127)
440
+ bits = (uint32_t) x << 23;
441
+ }
442
+
443
+ float result; // Final float value
444
+ // Safely reinterpret bit pattern as float without type-punning issues
445
+ memcpy(&result, &bits, sizeof(float));
446
+ return result;
447
+ }
448
+
449
+ // Equal to ggml_e8m0_to_fp32/2
450
+ // Useful with MXFP4 quantization since the E0M2 values are doubled
451
+ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
452
+ uint32_t bits;
453
+
454
+ // For x < 2: use precomputed denormal patterns
455
+ if (x < 2) {
456
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
457
+ bits = 0x00200000 << x;
458
+ }
459
+ // For x >= 2: normalized exponent adjustment
460
+ else {
461
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
462
+ bits = (uint32_t)(x - 1) << 23;
463
+ }
464
+ // Note: NaNs are not handled here
465
+
466
+ float result;
467
+ memcpy(&result, &bits, sizeof(float));
468
+ return result;
469
+ }
470
+
471
+ #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
472
+ #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
473
+
413
474
  /**
414
475
  * Converts brain16 to float32.
415
476
  *
@@ -23,6 +23,9 @@
23
23
  #define N_R0_Q8_0 4
24
24
  #define N_SG_Q8_0 2
25
25
 
26
+ #define N_R0_MXFP4 2
27
+ #define N_SG_MXFP4 2
28
+
26
29
  #define N_R0_Q2_K 4
27
30
  #define N_SG_Q2_K 2
28
31
 
@@ -129,6 +132,15 @@ typedef struct {
129
132
  uint64_t o1[8];
130
133
  } ggml_metal_kargs_bin;
131
134
 
135
+ typedef struct {
136
+ int64_t ne0;
137
+ int64_t ne1;
138
+ size_t nb01;
139
+ size_t nb02;
140
+ size_t nb11;
141
+ size_t nb21;
142
+ } ggml_metal_kargs_add_id;
143
+
132
144
  typedef struct {
133
145
  int32_t ne00;
134
146
  int32_t ne01;
@@ -237,6 +249,7 @@ typedef struct {
237
249
  uint64_t nb33;
238
250
  int32_t ne1;
239
251
  int32_t ne2;
252
+ int32_t ne3;
240
253
  float scale;
241
254
  float max_bias;
242
255
  float m0;
@@ -245,6 +258,11 @@ typedef struct {
245
258
  float logit_softcap;
246
259
  } ggml_metal_kargs_flash_attn_ext;
247
260
 
261
+ typedef struct {
262
+ int32_t nrows;
263
+ int32_t ne20;
264
+ } ggml_metal_kargs_flash_attn_ext_reduce;
265
+
248
266
  typedef struct {
249
267
  int32_t ne00;
250
268
  int32_t ne02;
@@ -308,40 +326,31 @@ typedef struct {
308
326
  } ggml_metal_kargs_mul_mv_ext;
309
327
 
310
328
  typedef struct {
329
+ int32_t ne02;
311
330
  int32_t ne10;
312
331
  int32_t ne11; // n_expert_used (bcast)
313
332
  uint64_t nb11;
314
333
  uint64_t nb12;
315
- int32_t neh11; // n_tokens
316
- uint64_t nbh11;
334
+ int32_t ne21; // n_tokens
317
335
  int32_t ne20; // n_expert_used
318
336
  uint64_t nb21;
319
337
  } ggml_metal_kargs_mul_mm_id_map0;
320
338
 
321
- typedef struct {
322
- int32_t ne20; // n_expert_used
323
- int32_t neh0;
324
- int32_t neh1;
325
- uint64_t nbh1;
326
- uint64_t nbh2;
327
- int32_t ne0;
328
- uint64_t nb1;
329
- uint64_t nb2;
330
- } ggml_metal_kargs_mul_mm_id_map1;
331
-
332
339
  typedef struct {
333
340
  int32_t ne00;
334
341
  int32_t ne02;
335
342
  uint64_t nb01;
336
343
  uint64_t nb02;
337
344
  uint64_t nb03;
338
- int32_t neh12;
339
- uint64_t nbh10;
340
- uint64_t nbh11;
341
- uint64_t nbh12;
342
- uint64_t nbh13;
343
- int32_t neh0;
344
- int32_t neh1;
345
+ int32_t ne11;
346
+ uint64_t nb10;
347
+ uint64_t nb11;
348
+ uint64_t nb12;
349
+ uint64_t nb13;
350
+ int32_t ne20;
351
+ int32_t ne21;
352
+ int32_t ne0;
353
+ int32_t ne1;
345
354
  int16_t r2;
346
355
  int16_t r3;
347
356
  } ggml_metal_kargs_mul_mm_id;
@@ -444,6 +453,8 @@ typedef struct{
444
453
  uint64_t nb1;
445
454
  int32_t i00;
446
455
  int32_t i10;
456
+ float alpha;
457
+ float limit;
447
458
  } ggml_metal_kargs_glu;
448
459
 
449
460
  typedef struct {