@novastera-oss/llamarn 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/PureCppImpl.cpp +9 -27
  14. package/cpp/SystemUtils.h +2 -2
  15. package/cpp/build-info.cpp +2 -2
  16. package/cpp/llama.cpp/README.md +11 -3
  17. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  18. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  19. package/cpp/llama.cpp/common/arg.cpp +153 -113
  20. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  21. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  22. package/cpp/llama.cpp/common/chat.cpp +847 -699
  23. package/cpp/llama.cpp/common/chat.h +73 -6
  24. package/cpp/llama.cpp/common/common.cpp +50 -82
  25. package/cpp/llama.cpp/common/common.h +21 -17
  26. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  27. package/cpp/llama.cpp/common/json-partial.h +37 -0
  28. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  29. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  30. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  31. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  32. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  33. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  34. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  37. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  38. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  75. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  120. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  121. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  122. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  123. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  124. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  125. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  126. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  127. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  128. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  129. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  130. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  131. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  132. package/cpp/llama.cpp/include/llama.h +62 -125
  133. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  150. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  152. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  159. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  160. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  161. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  162. package/cpp/llama.cpp/models/templates/README.md +2 -0
  163. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  164. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  165. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  166. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  167. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  168. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  169. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  170. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  171. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  172. package/cpp/llama.cpp/src/llama-context.h +30 -0
  173. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  174. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  175. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  176. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  177. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  178. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  179. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  180. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  181. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  182. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  183. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  184. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  185. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  186. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  187. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  188. package/cpp/llama.cpp/src/llama-model.h +6 -1
  189. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  190. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  191. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  192. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  193. package/cpp/llama.cpp/src/llama.cpp +14 -0
  194. package/cpp/rn-completion.cpp +60 -5
  195. package/ios/include/chat.h +73 -6
  196. package/ios/include/common/minja/chat-template.hpp +9 -5
  197. package/ios/include/common/minja/minja.hpp +69 -36
  198. package/ios/include/common.h +21 -17
  199. package/ios/include/llama.h +62 -125
  200. package/ios/libs/llama.xcframework/Info.plist +19 -19
  201. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  205. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  206. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  212. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  213. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  227. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  228. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  233. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  234. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  240. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  241. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  242. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  246. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  247. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  253. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  254. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  255. package/package.json +1 -1
  256. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  257. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  267. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  268. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -883,7 +883,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
883
883
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
884
884
  #endif
885
885
  }
886
- #elif defined(__riscv_v_intrinsic)
886
+ #elif defined(__riscv_v)
887
887
 
888
888
  size_t vl = QK8_0;
889
889
 
@@ -1221,7 +1221,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
1221
1221
  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1222
1222
  #endif
1223
1223
  }
1224
- #elif defined(__riscv_v_intrinsic)
1224
+ #elif defined(__riscv_v)
1225
1225
 
1226
1226
  size_t vl = QK8_1;
1227
1227
 
@@ -2384,7 +2384,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
2384
2384
  }
2385
2385
 
2386
2386
  sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2387
- #elif defined(__riscv_v_intrinsic)
2387
+ #elif defined(__riscv_v)
2388
2388
  size_t vl = qk / 2;
2389
2389
 
2390
2390
  for (; ib < nb; ++ib) {
@@ -2774,7 +2774,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
2774
2774
  }
2775
2775
 
2776
2776
  sumf = hsum_float_8(acc) + summs;
2777
- #elif defined(__riscv_v_intrinsic)
2777
+ #elif defined(__riscv_v)
2778
2778
  size_t vl = qk / 2;
2779
2779
 
2780
2780
  for (; ib < nb; ++ib) {
@@ -3121,7 +3121,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
3121
3121
  }
3122
3122
 
3123
3123
  sumf = hsum_float_8(acc);
3124
- #elif defined(__riscv_v_intrinsic)
3124
+ #elif defined(__riscv_v)
3125
3125
  size_t vl;
3126
3126
  size_t vlenb = __riscv_vlenb();
3127
3127
 
@@ -3460,7 +3460,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
3460
3460
  }
3461
3461
 
3462
3462
  sumf = hsum_float_8(acc) + summs;
3463
- #elif defined(__riscv_v_intrinsic)
3463
+ #elif defined(__riscv_v)
3464
3464
  size_t vl;
3465
3465
  size_t vlenb = __riscv_vlenb();
3466
3466
 
@@ -3897,7 +3897,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
3897
3897
  }
3898
3898
 
3899
3899
  sumf = hsum_float_8(accum);
3900
- #elif defined(__riscv_v_intrinsic)
3900
+ #elif defined(__riscv_v)
3901
3901
  size_t vl = qk;
3902
3902
 
3903
3903
  for (; ib < nb; ++ib) {
@@ -5100,14 +5100,111 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
5100
5100
 
5101
5101
  *s = sumf;
5102
5102
 
5103
- #elif defined __riscv_v_intrinsic
5103
+ #elif defined __riscv_xtheadvector
5104
+
5105
+ float sumf = 0;
5106
+ uint8_t atmp[16];
5107
+
5108
+ for (int i = 0; i < nb; ++i) {
5109
+ const uint8_t * q2 = x[i].qs;
5110
+ const int8_t * q8 = y[i].qs;
5111
+ const uint8_t * sc = x[i].scales;
5112
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5113
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5114
+ uint8_t *patmp = atmp;
5115
+ int vsums;
5116
+ int tmp;
5117
+ __asm__ __volatile__(
5118
+ "th.vsetvli zero, %[vl16], e8, m1\n\t"
5119
+ "th.vmv.v.x v8, zero\n\t"
5120
+ "th.vlb.v v1, (%[sc])\n\t"
5121
+ "th.vand.vi v0, v1, 0xF\n\t"
5122
+ "th.vsrl.vi v1, v1, 4\n\t"
5123
+ "th.vsb.v v0, (%[scale])\n\t"
5124
+ "th.vwaddu.vx v16, v1, zero\n\t"
5125
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
5126
+ "th.vlh.v v2, (%[bsums])\n\t"
5127
+ "th.vwmul.vv v4, v16, v2\n\t"
5128
+ "th.vsetvli zero, %[vl16], e32, m4\n\t"
5129
+ "th.vredsum.vs v8, v4, v8\n\t"
5130
+ "th.vmv.x.s %[vsums], v8"
5131
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
5132
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
5133
+ , [vl16] "r" (16)
5134
+ : "memory"
5135
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5136
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5137
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5138
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5139
+ );
5140
+ sumf += dmin * vsums;
5141
+ int isum = 0;
5142
+
5143
+ for (int j = 0; j < QK_K/128; ++j) {
5144
+ __asm__ __volatile__(
5145
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
5146
+ "th.vlb.v v0, (%[q2])\n\t"
5147
+ "th.vsrl.vi v2, v0, 2\n\t"
5148
+ "th.vsrl.vi v4, v0, 4\n\t"
5149
+ "th.vsrl.vi v6, v0, 6\n\t"
5150
+ "th.vand.vi v0, v0, 0x3\n\t"
5151
+ "th.vand.vi v2, v2, 0x3\n\t"
5152
+ "th.vand.vi v4, v4, 0x3\n\t"
5153
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
5154
+ "th.vlb.v v8, (%[q8])\n\t"
5155
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
5156
+ "th.vwmul.vv v16, v0, v8\n\t"
5157
+ "th.vwmul.vv v24, v4, v12\n\t"
5158
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
5159
+ "th.vmv.v.x v0, zero\n\t"
5160
+ "th.vwredsum.vs v10, v16, v0\n\t"
5161
+ "th.vwredsum.vs v9, v18, v0\n\t"
5162
+ "th.vwredsum.vs v8, v20, v0\n\t"
5163
+ "th.vwredsum.vs v7, v22, v0\n\t"
5164
+ "th.vwredsum.vs v11, v24, v0\n\t"
5165
+ "th.vwredsum.vs v12, v26, v0\n\t"
5166
+ "th.vwredsum.vs v13, v28, v0\n\t"
5167
+ "th.vwredsum.vs v14, v30, v0\n\t"
5168
+ "li %[tmp], 4\n\t"
5169
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
5170
+ "th.vslideup.vi v10, v9, 1\n\t"
5171
+ "th.vslideup.vi v8, v7, 1\n\t"
5172
+ "th.vslideup.vi v11, v12, 1\n\t"
5173
+ "th.vslideup.vi v13, v14, 1\n\t"
5174
+ "th.vslideup.vi v10, v8, 2\n\t"
5175
+ "th.vslideup.vi v11, v13, 2\n\t"
5176
+ "li %[tmp], 8\n\t"
5177
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
5178
+ "th.vlbu.v v12, (%[scale])\n\t"
5179
+ "th.vmul.vv v10, v10, v12\n\t"
5180
+ "th.vredsum.vs v0, v10, v0\n\t"
5181
+ "th.vmv.x.s %[tmp], v0\n\t"
5182
+ "add %[isum], %[isum], %[tmp]"
5183
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
5184
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
5185
+ , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
5186
+ : "memory"
5187
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5188
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5189
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5190
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5191
+ );
5192
+ q2 += 32; q8 += 128; patmp += 8;
5193
+ }
5194
+
5195
+ sumf += dall * isum;
5196
+ }
5197
+
5198
+ *s = sumf;
5199
+
5200
+ #elif defined __riscv_v
5104
5201
 
5105
- const int vector_length = __riscv_vlenb() * 8;
5106
5202
  float sumf = 0;
5203
+ uint8_t atmp[16];
5107
5204
 
5205
+ const int vector_length = __riscv_vlenb() * 8;
5108
5206
  uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5109
5207
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
5110
- uint8_t atmp[16];
5111
5208
 
5112
5209
  switch (vector_length) {
5113
5210
  case 256:
@@ -6137,13 +6234,140 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6137
6234
 
6138
6235
  *s = sumf;
6139
6236
 
6140
- #elif defined __riscv_v_intrinsic
6237
+ #elif defined __riscv_xtheadvector
6141
6238
 
6142
- uint32_t aux[3];
6143
6239
  uint32_t utmp[4];
6240
+ float sumf = 0;
6144
6241
 
6145
- const int vector_length = __riscv_vlenb() * 8;
6242
+ for (int i = 0; i < nb; ++i) {
6243
+ const uint8_t * restrict q3 = x[i].qs;
6244
+ const uint8_t * restrict qh = x[i].hmask;
6245
+ const int8_t * restrict q8 = y[i].qs;
6246
+
6247
+ int8_t * scale = (int8_t *)utmp;
6248
+ int tmp;
6249
+ __asm__ __volatile__(
6250
+ "li %[tmp], 12\n\t"
6251
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
6252
+ "th.vlb.v v0, (%[s6b])\n\t"
6253
+ "th.vmv.v.v v2, v0\n\t"
6254
+ "li %[tmp], 2\n\t"
6255
+ "th.vsetvli zero, %[tmp], e64, m1\n\t"
6256
+ "th.vmv.v.x v9, %[sh]\n\t"\
6257
+ "th.vslidedown.vi v1, v0, 1\n\t"
6258
+ "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
6259
+ "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
6260
+ "li %[tmp], 4\n\t"
6261
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
6262
+ "th.vid.v v9\n\t"
6263
+ "th.vmv.x.s %[tmp], v1\n\t"
6264
+ "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
6265
+ "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
6266
+ "th.vsrl.vv v4, v1, v9\n\t"
6267
+ "th.vsrl.vv v2, v0, v8\n\t"
6268
+ "th.vand.vx v5, v4, %[kmask1]\n\t"
6269
+ "th.vand.vx v3, v2, %[kmask2]\n\t"
6270
+ "th.vsll.vi v6, v5, 4\n\t"
6271
+ "th.vor.vv v7, v6, v3\n\t"
6272
+ "li %[tmp], 16\n\t"
6273
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
6274
+ "th.vsub.vx v0, v7, %[c]\n\t"
6275
+ "th.vsb.v v0, (%[scale])"
6276
+ : [tmp] "=&r" (tmp)
6277
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
6278
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
6279
+ : "memory"
6280
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6281
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6282
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6283
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6284
+ );
6285
+
6286
+ uint8_t m = 1;
6287
+ int isum = 0;
6288
+ for (int j = 0; j < QK_K; j += 128) {
6289
+ __asm__ __volatile__(
6290
+ // fixme: use v0p7 mask layout directly
6291
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
6292
+ "th.vlb.v v8, (%[q3])\n\t"
6293
+ "th.vsrl.vi v10, v8, 2\n\t"
6294
+ "th.vsrl.vi v12, v8, 4\n\t"
6295
+ "th.vsrl.vi v14, v8, 6\n\t"
6296
+ "th.vand.vi v8, v8, 3\n\t"
6297
+ "th.vand.vi v10, v10, 3\n\t"
6298
+ "th.vand.vi v12, v12, 3\n\t"
6299
+ "th.vlb.v v2, (%[qh])\n\t"
6300
+ "th.vand.vx v4, v2, %[m]\n\t"
6301
+ "slli %[m], %[m], 1\n\t"
6302
+ "th.vmseq.vx v0, v4, zero\n\t"
6303
+ "th.vadd.vi v8, v8, -4, v0.t\n\t"
6304
+ "th.vand.vx v4, v2, %[m]\n\t"
6305
+ "slli %[m], %[m], 1\n\t"
6306
+ "th.vmseq.vx v0, v4, zero\n\t"
6307
+ "th.vadd.vi v10, v10, -4, v0.t\n\t"
6308
+ "th.vand.vx v4, v2, %[m]\n\t"
6309
+ "slli %[m], %[m], 1\n\t"
6310
+ "th.vmseq.vx v0, v4, zero\n\t"
6311
+ "th.vadd.vi v12, v12, -4, v0.t\n\t"
6312
+ "th.vand.vx v4, v2, %[m]\n\t"
6313
+ "slli %[m], %[m], 1\n\t"
6314
+ "th.vmseq.vx v0, v4, zero\n\t"
6315
+ "th.vadd.vi v14, v14, -4, v0.t\n\t"
6316
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
6317
+ "th.vlb.v v0, (%[q8])\n\t"
6318
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
6319
+ "th.vwmul.vv v16, v0, v8\n\t"
6320
+ "th.vwmul.vv v24, v4, v12\n\t"
6321
+ "li %[tmp], 16\n\t"
6322
+ "th.vsetvli zero, %[tmp], e16, m2\n\t"
6323
+ "th.vmv.v.x v0, zero\n\t"
6324
+ "th.vwredsum.vs v10, v16, v0\n\t"
6325
+ "th.vwredsum.vs v9, v18, v0\n\t"
6326
+ "th.vwredsum.vs v8, v20, v0\n\t"
6327
+ "th.vwredsum.vs v7, v22, v0\n\t"
6328
+ "th.vwredsum.vs v11, v24, v0\n\t"
6329
+ "th.vwredsum.vs v12, v26, v0\n\t"
6330
+ "th.vwredsum.vs v13, v28, v0\n\t"
6331
+ "th.vwredsum.vs v14, v30, v0\n\t"
6332
+ "li %[tmp], 4\n\t"
6333
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
6334
+ "th.vslideup.vi v10, v9, 1\n\t"
6335
+ "th.vslideup.vi v8, v7, 1\n\t"
6336
+ "th.vslideup.vi v11, v12, 1\n\t"
6337
+ "th.vslideup.vi v13, v14, 1\n\t"
6338
+ "th.vslideup.vi v10, v8, 2\n\t"
6339
+ "th.vslideup.vi v11, v13, 2\n\t"
6340
+ "li %[tmp], 8\n\t"
6341
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
6342
+ "th.vlb.v v12, (%[scale])\n\t"
6343
+ "th.vmul.vv v10, v10, v12\n\t"
6344
+ "th.vredsum.vs v0, v10, v0\n\t"
6345
+ "th.vmv.x.s %[tmp], v0\n\t"
6346
+ "add %[isum], %[isum], %[tmp]"
6347
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
6348
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
6349
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
6350
+ : "memory"
6351
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6352
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6353
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6354
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6355
+ );
6356
+ q3 += 32; q8 += 128; scale += 8;
6357
+ }
6358
+
6359
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6360
+ sumf += d * isum;
6361
+ }
6362
+
6363
+ *s = sumf;
6364
+
6365
+ #elif defined __riscv_v
6366
+
6367
+ uint32_t utmp[4];
6146
6368
  float sumf = 0;
6369
+ uint32_t aux[3];
6370
+ const int vector_length = __riscv_vlenb() * 8;
6147
6371
 
6148
6372
  switch (vector_length) {
6149
6373
  case 256:
@@ -6331,7 +6555,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6331
6555
  "vslideup.vi v13, v14, 1\n\t"
6332
6556
  "vslideup.vi v10, v8, 2\n\t"
6333
6557
  "vslideup.vi v11, v13, 2\n\t"
6334
- "vsetivli zero, 8, e32, m2\n\t"\
6558
+ "vsetivli zero, 8, e32, m2\n\t"
6335
6559
  "vle8.v v15, (%[scale])\n\t"
6336
6560
  "vsext.vf4 v12, v15\n\t"
6337
6561
  "vmul.vv v10, v10, v12\n\t"
@@ -6771,7 +6995,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6771
6995
 
6772
6996
  void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
6773
6997
  assert(n % QK_K == 0);
6998
+ #ifdef __ARM_FEATURE_MATMUL_INT8
6999
+ assert((nrc == 2) || (nrc == 1));
7000
+ #else
6774
7001
  assert(nrc == 1);
7002
+ #endif
6775
7003
  UNUSED(nrc);
6776
7004
  UNUSED(bx);
6777
7005
  UNUSED(by);
@@ -6788,6 +7016,146 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6788
7016
 
6789
7017
  uint32_t utmp[4];
6790
7018
 
7019
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
7020
+ if (nrc == 2) {
7021
+ const block_q4_K * GGML_RESTRICT x0 = x;
7022
+ const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
7023
+ const block_q8_K * GGML_RESTRICT y0 = y;
7024
+ const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
7025
+
7026
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
7027
+
7028
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
7029
+
7030
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
7031
+ const uint8_t * GGML_RESTRICT qx0 = x0->qs;
7032
+ const uint8_t * GGML_RESTRICT qx1 = x1->qs;
7033
+ const int8_t * GGML_RESTRICT qy0 = y0->qs;
7034
+ const int8_t * GGML_RESTRICT qy1 = y1->qs;
7035
+
7036
+ // decode scales and mins
7037
+ int8_t x0_scales[8], x1_scales[8];
7038
+ int16x8_t x0_mins, x1_mins;
7039
+ {
7040
+ uint32_t scales_mins[3];
7041
+ memcpy(scales_mins, x0->scales, 12);
7042
+ const uint32_t mins_0_3 = scales_mins[1] & kmask1;
7043
+ const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
7044
+ const uint32x2_t mins = {mins_0_3, mins_4_7};
7045
+ x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
7046
+ uint32_t scales[2];
7047
+ scales[0] = scales_mins[0] & kmask1; // scales 0~3
7048
+ scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
7049
+ memcpy(x0_scales, scales, 8);
7050
+ }
7051
+ {
7052
+ uint32_t scales_mins[3];
7053
+ memcpy(scales_mins, x1->scales, 12);
7054
+ const uint32_t mins_0_3 = scales_mins[1] & kmask1;
7055
+ const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
7056
+ const uint32x2_t mins = {mins_0_3, mins_4_7};
7057
+ x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
7058
+ uint32_t scales[2];
7059
+ scales[0] = scales_mins[0] & kmask1; // scales 0~3
7060
+ scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
7061
+ memcpy(x1_scales, scales, 8);
7062
+ }
7063
+
7064
+ int32x4_t visum = {0};
7065
+
7066
+ // process 64 data points per iteration, totally 256 data points
7067
+ for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
7068
+ const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
7069
+ const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
7070
+
7071
+ int8x16_t vx0[4], vx1[4];
7072
+ {
7073
+ const uint8x16x2_t vv = vld1q_u8_x2(qx0);
7074
+ vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
7075
+ vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
7076
+ vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
7077
+ vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
7078
+ }
7079
+ {
7080
+ const uint8x16x2_t vv = vld1q_u8_x2(qx1);
7081
+ vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
7082
+ vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
7083
+ vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
7084
+ vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
7085
+ }
7086
+
7087
+ // process 32 data points (share same block scale) per iteration
7088
+ for (int k = 0; k < 2; ++k) {
7089
+ const int blk = j * 2 + k;
7090
+ const int32x4_t block_scale = {
7091
+ x0_scales[blk],
7092
+ x0_scales[blk],
7093
+ x1_scales[blk],
7094
+ x1_scales[blk],
7095
+ };
7096
+
7097
+ int32x4_t vr = {0};
7098
+ for (int l = 0; l < 2; ++l) {
7099
+ const int idx = k * 2 + l;
7100
+ const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
7101
+ const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
7102
+ const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
7103
+ const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
7104
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
7105
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
7106
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
7107
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
7108
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
7109
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
7110
+ }
7111
+ // apply block scale, will NOT overflow
7112
+ // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
7113
+ visum = vmlaq_s32(visum, vr, block_scale);
7114
+ }
7115
+ }
7116
+
7117
+ // adjust bias, apply superblock scale
7118
+ {
7119
+ int32_t bias[4];
7120
+ // no obvious uplift from sve sdot-16, just use neon mul add
7121
+ const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
7122
+ const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
7123
+ bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
7124
+ vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
7125
+ bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
7126
+ vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
7127
+ bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
7128
+ vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
7129
+ bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
7130
+ vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
7131
+ const float32x4_t dmins = {
7132
+ GGML_FP16_TO_FP32(x0->dmin) * y0->d,
7133
+ GGML_FP16_TO_FP32(x0->dmin) * y1->d,
7134
+ GGML_FP16_TO_FP32(x1->dmin) * y0->d,
7135
+ GGML_FP16_TO_FP32(x1->dmin) * y1->d,
7136
+ };
7137
+ vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
7138
+
7139
+ const float32x4_t superblock_scale = {
7140
+ GGML_FP16_TO_FP32(x0->d) * y0->d,
7141
+ GGML_FP16_TO_FP32(x0->d) * y1->d,
7142
+ GGML_FP16_TO_FP32(x1->d) * y0->d,
7143
+ GGML_FP16_TO_FP32(x1->d) * y1->d,
7144
+ };
7145
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
7146
+ }
7147
+ }
7148
+
7149
+ // vfsum = ABCD -> ACBD
7150
+ // AC -> s, BD -> (s+bs)
7151
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
7152
+ vst1_f32(s, vget_low_f32 (vfsum));
7153
+ vst1_f32(s + bs, vget_high_f32(vfsum));
7154
+
7155
+ return;
7156
+ }
7157
+ #endif
7158
+
6791
7159
  #ifdef __ARM_FEATURE_SVE
6792
7160
  float sumf = 0;
6793
7161
  for (int i = 0; i < nb; ++i) {
@@ -7180,14 +7548,130 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
7180
7548
 
7181
7549
  *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
7182
7550
 
7183
- #elif defined __riscv_v_intrinsic
7551
+ #elif defined __riscv_xtheadvector
7184
7552
 
7185
7553
  const uint8_t * scales = (const uint8_t*)&utmp[0];
7186
7554
  const uint8_t * mins = (const uint8_t*)&utmp[2];
7187
7555
 
7188
- const int vector_length = __riscv_vlenb() * 8;
7189
7556
  float sumf = 0;
7190
7557
 
7558
+ for (int i = 0; i < nb; ++i) {
7559
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7560
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7561
+
7562
+ int tmp, tmp2, sumi;
7563
+ __asm__ __volatile__(
7564
+ "li %[t1], 12\n\t"
7565
+ "th.vsetvli zero, %[t1], e8, m1\n\t"
7566
+ "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
7567
+ "li %[t1], 4\n\t"
7568
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
7569
+ "th.vslidedown.vi v2, v1, 2\n\t"
7570
+ "th.vmv.v.v v3, v2\n\t"
7571
+ "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
7572
+ "li %[t1], 2\n\t"
7573
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
7574
+ "th.vmv.v.i v4, 4\n\t"
7575
+ "th.vand.vx v8, v1, %[kmask1]\n\t"
7576
+ "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
7577
+ "th.vsrl.vi v6, v1, 6\n\t"
7578
+ "th.vsrl.vv v7, v2, v5\n\t"
7579
+ "th.vand.vx v0, v6, %[kmask3]\n\t"
7580
+ "th.vand.vx v2, v7, %[kmask2]\n\t"
7581
+ "th.vsll.vi v6, v0, 4\n\t"
7582
+ "li %[t2], 8\n\t"
7583
+ "addi %[t1], %[utmp], 4\n\t"
7584
+ "th.vor.vv v1, v6, v2\n\t"
7585
+ "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
7586
+ "th.vssw.v v1, (%[t1]), %[t2]\n\t"
7587
+ "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
7588
+ "th.vlw.v v2, (%[bsums])\n\t"
7589
+ "th.vsetvli zero, %[t2], e16, m1\n\t"
7590
+ "th.vnsrl.vi v0, v2, 0\n\t"
7591
+ "th.vnsrl.vi v1, v2, 16\n\t"
7592
+ "th.vadd.vv v2, v0, v1\n\t"
7593
+ "th.vlbu.v v4, (%[mins])\n\t"
7594
+ "th.vwmul.vv v6, v4, v2\n\t"
7595
+ "th.vmv.v.x v0, zero\n\t"
7596
+ "th.vsetvli zero, %[t2], e32, m2\n\t"
7597
+ "th.vredsum.vs v0, v6, v0\n\t"
7598
+ "th.vmv.x.s %[sumi], v0"
7599
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
7600
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
7601
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
7602
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
7603
+ : "memory"
7604
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7605
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7606
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7607
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7608
+ );
7609
+ sumf -= dmin * sumi;
7610
+
7611
+ const uint8_t * restrict q4 = x[i].qs;
7612
+ const int8_t * restrict q8 = y[i].qs;
7613
+
7614
+ sumi = 0;
7615
+ const uint8_t * scale = scales;
7616
+
7617
+ for (int j = 0; j < QK_K/128; ++j) {
7618
+ int vl128 = 128, vl64 = 64, vl32 = 32;
7619
+ __asm__ __volatile__(
7620
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
7621
+ "th.vlb.v v8, (%[q8])\n\t"
7622
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
7623
+ "th.vlb.v v0, (%[q4])\n\t"
7624
+ "th.vsrl.vi v4, v0, 4\n\t"
7625
+ "th.vand.vi v0, v0, 0xF\n\t"
7626
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
7627
+ "th.vwmul.vv v28, v6, v14\n\t"
7628
+ "th.vwmul.vv v20, v4, v10\n\t"
7629
+ "th.vwmul.vv v24, v2, v12\n\t"
7630
+ "th.vwmul.vv v16, v0, v8\n\t"
7631
+ "li %[tmp], 4\n\t"
7632
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
7633
+ "th.vlbu.v v1, (%[scale])\n\t"
7634
+ "th.vmv.v.x v0, zero\n\t"
7635
+ "th.vsetvli zero, %[vl32], e16, m4\n\t"
7636
+ "th.vwredsum.vs v6, v24, v0\n\t"
7637
+ "th.vwredsum.vs v7, v28, v0\n\t"
7638
+ "th.vwredsum.vs v4, v16, v0\n\t"
7639
+ "th.vwredsum.vs v5, v20, v0\n\t"
7640
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
7641
+ "th.vslideup.vi v6, v7, 1\n\t"
7642
+ "th.vslideup.vi v4, v5, 1\n\t"
7643
+ "th.vslideup.vi v4, v6, 2\n\t"
7644
+ "th.vmul.vv v8, v4, v1\n\t"
7645
+ "th.vredsum.vs v0, v8, v0\n\t"
7646
+ "th.vmv.x.s %[tmp], v0\n\t"
7647
+ "add %[sumi], %[sumi], %[tmp]"
7648
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
7649
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
7650
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
7651
+ : "memory"
7652
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7653
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7654
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7655
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7656
+ );
7657
+
7658
+ q4 += 64; q8 += 128; scale += 4;
7659
+ }
7660
+
7661
+ sumf += d * sumi;
7662
+
7663
+ }
7664
+
7665
+ *s = sumf;
7666
+
7667
+ #elif defined __riscv_v
7668
+
7669
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
7670
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
7671
+
7672
+ float sumf = 0;
7673
+ const int vector_length = __riscv_vlenb() * 8;
7674
+
7191
7675
  switch (vector_length) {
7192
7676
  case 256:
7193
7677
  for (int i = 0; i < nb; ++i) {
@@ -8074,7 +8558,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8074
8558
 
8075
8559
  *s = sumf;
8076
8560
 
8077
- #elif defined __riscv_v_intrinsic
8561
+ #elif defined __riscv_v
8078
8562
 
8079
8563
  const uint8_t * scales = (const uint8_t*)&utmp[0];
8080
8564
  const uint8_t * mins = (const uint8_t*)&utmp[2];
@@ -8519,7 +9003,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8519
9003
 
8520
9004
  void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
8521
9005
  assert(n % QK_K == 0);
9006
+ #ifdef __ARM_FEATURE_MATMUL_INT8
9007
+ assert((nrc == 2) || (nrc == 1));
9008
+ #else
8522
9009
  assert(nrc == 1);
9010
+ #endif
8523
9011
  UNUSED(nrc);
8524
9012
  UNUSED(bx);
8525
9013
  UNUSED(by);
@@ -8530,6 +9018,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8530
9018
 
8531
9019
  const int nb = n / QK_K;
8532
9020
 
9021
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
9022
+ if (nrc == 2) {
9023
+ const block_q6_K * GGML_RESTRICT x0 = x;
9024
+ const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
9025
+ const block_q8_K * GGML_RESTRICT y0 = y;
9026
+ const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
9027
+
9028
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
9029
+
9030
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
9031
+ const uint8_t * GGML_RESTRICT ql0 = x0->ql;
9032
+ const uint8_t * GGML_RESTRICT ql1 = x1->ql;
9033
+ const uint8_t * GGML_RESTRICT qh0 = x0->qh;
9034
+ const uint8_t * GGML_RESTRICT qh1 = x1->qh;
9035
+ const int8_t * GGML_RESTRICT qy0 = y0->qs;
9036
+ const int8_t * GGML_RESTRICT qy1 = y1->qs;
9037
+
9038
+ const uint8x16_t mone = vdupq_n_u8(0x30);
9039
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
9040
+
9041
+ int32x4_t visum = vdupq_n_s32(0);
9042
+
9043
+ // process 8 blocks per iteration, totally 16 blocks
9044
+ for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
9045
+ int8x16_t vx0[8], vx1[8];
9046
+
9047
+ // de-quantize vx0[8]
9048
+ {
9049
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
9050
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
9051
+
9052
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
9053
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
9054
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
9055
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
9056
+
9057
+ vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
9058
+ vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
9059
+ vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
9060
+ vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
9061
+
9062
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
9063
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
9064
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
9065
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
9066
+
9067
+ vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
9068
+ vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
9069
+ vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
9070
+ vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
9071
+ }
9072
+
9073
+ // de-quantize vx1[8]
9074
+ {
9075
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
9076
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
9077
+
9078
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
9079
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
9080
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
9081
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
9082
+
9083
+ vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
9084
+ vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
9085
+ vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
9086
+ vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
9087
+
9088
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
9089
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
9090
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
9091
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
9092
+
9093
+ vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
9094
+ vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
9095
+ vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
9096
+ vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
9097
+ }
9098
+
9099
+ // process 16 elements (one block with same scale) per iteration
9100
+ // - vx = concat(ql, qh) - 32
9101
+ // - r1,r2,r3,r4 = smmla(vx, vy)
9102
+ for (int k = 0; k < 8; ++k) {
9103
+ const int blk = j * 8 + k;
9104
+
9105
+ const int8x16_t vy0 = vld1q_s8(qy0);
9106
+ const int8x16_t vy1 = vld1q_s8(qy1);
9107
+ qy0 += 16;
9108
+ qy1 += 16;
9109
+
9110
+ const int32x4_t block_scale = {
9111
+ x0->scales[blk],
9112
+ x0->scales[blk],
9113
+ x1->scales[blk],
9114
+ x1->scales[blk],
9115
+ };
9116
+
9117
+ // calculate four results at once with outer product
9118
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
9119
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
9120
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
9121
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
9122
+ int32x4_t vr = vdupq_n_s32(0);
9123
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
9124
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
9125
+
9126
+ // apply block scale, will NOT overflow
9127
+ // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
9128
+ visum = vmlaq_s32(visum, vr, block_scale);
9129
+ }
9130
+ }
9131
+
9132
+ // adjust bias, apply superblock scale
9133
+ {
9134
+ int32_t bias[4];
9135
+ #ifdef __ARM_FEATURE_SVE
9136
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
9137
+ const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
9138
+ const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
9139
+ const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
9140
+ const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
9141
+ const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
9142
+ const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
9143
+ const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
9144
+ const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
9145
+ const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
9146
+ const svint64_t zero = svdup_n_s64(0);
9147
+ bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
9148
+ svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
9149
+ bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
9150
+ svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
9151
+ bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
9152
+ svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
9153
+ bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
9154
+ svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
9155
+ #else
9156
+ // NEON doesn't support int16 dot product, fallback to separated mul and add
9157
+ const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
9158
+ const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
9159
+
9160
+ int8x16_t scales_s8 = vld1q_s8(x0->scales);
9161
+ const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
9162
+ scales_s8 = vld1q_s8(x1->scales);
9163
+ const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
9164
+
9165
+ int32x4_t prod;
9166
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
9167
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
9168
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
9169
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
9170
+ bias[0] = vaddvq_s32(prod);
9171
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
9172
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
9173
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
9174
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
9175
+ bias[1] = vaddvq_s32(prod);
9176
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
9177
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
9178
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
9179
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
9180
+ bias[2] = vaddvq_s32(prod);
9181
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
9182
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
9183
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
9184
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
9185
+ bias[3] = vaddvq_s32(prod);
9186
+
9187
+ #endif
9188
+ const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
9189
+
9190
+ const float32x4_t superblock_scale = {
9191
+ GGML_FP16_TO_FP32(x0->d) * y0->d,
9192
+ GGML_FP16_TO_FP32(x0->d) * y1->d,
9193
+ GGML_FP16_TO_FP32(x1->d) * y0->d,
9194
+ GGML_FP16_TO_FP32(x1->d) * y1->d,
9195
+ };
9196
+
9197
+ visum = vsubq_s32(visum, vibias);
9198
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
9199
+ }
9200
+ }
9201
+
9202
+ // vfsum = ABCD -> ACBD
9203
+ // AC -> s, BD -> (s+bs)
9204
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
9205
+ vst1_f32(s, vget_low_f32 (vfsum));
9206
+ vst1_f32(s + bs, vget_high_f32(vfsum));
9207
+
9208
+ return;
9209
+ }
9210
+ #endif
9211
+
8533
9212
  #ifdef __ARM_FEATURE_SVE
8534
9213
  const int vector_length = ggml_cpu_get_sve_cnt()*8;
8535
9214
  float sum = 0;
@@ -9037,11 +9716,92 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
9037
9716
  }
9038
9717
  *s = sumf;
9039
9718
 
9040
- #elif defined __riscv_v_intrinsic
9719
+ #elif defined __riscv_xtheadvector
9041
9720
 
9042
- const int vector_length = __riscv_vlenb() * 8;
9043
9721
  float sumf = 0;
9044
9722
 
9723
+ for (int i = 0; i < nb; ++i) {
9724
+
9725
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9726
+
9727
+ const uint8_t * restrict q6 = x[i].ql;
9728
+ const uint8_t * restrict qh = x[i].qh;
9729
+ const int8_t * restrict q8 = y[i].qs;
9730
+
9731
+ const int8_t * restrict scale = x[i].scales;
9732
+
9733
+ int sum_t = 0;
9734
+ int t0;
9735
+
9736
+ for (int j = 0; j < QK_K/128; ++j) {
9737
+ __asm__ __volatile__(
9738
+ "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
9739
+ "th.vlb.v v4, (%[qh])\n\t"
9740
+ "th.vsll.vi v0, v4, 4\n\t"
9741
+ "th.vsll.vi v2, v4, 2\n\t"
9742
+ "th.vsrl.vi v6, v4, 2\n\t"
9743
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
9744
+ "th.vlb.v v8, (%[q6])\n\t"
9745
+ "th.vsrl.vi v12, v8, 4\n\t"
9746
+ "th.vand.vi v8, v8, 0xF\n\t"
9747
+ "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
9748
+ "th.vand.vx v0, v0, %[mask]\n\t"
9749
+ "th.vor.vv v8, v8, v0\n\t"
9750
+ "th.vlb.v v0, (%[q8])\n\t"
9751
+ "th.vsub.vx v8, v8, %[vl32]\n\t"
9752
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
9753
+ "th.vwmul.vv v16, v0, v8\n\t"
9754
+ "th.vwmul.vv v24, v4, v12\n\t"
9755
+ "li %[t0], 16\n\t"
9756
+ "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
9757
+ "th.vmv.v.x v0, zero\n\t"
9758
+ "th.vwredsum.vs v10, v16, v0\n\t"
9759
+ "th.vwredsum.vs v9, v18, v0\n\t"
9760
+ "th.vwredsum.vs v8, v20, v0\n\t"
9761
+ "th.vwredsum.vs v7, v22, v0\n\t"
9762
+ "th.vwredsum.vs v11, v24, v0\n\t"
9763
+ "th.vwredsum.vs v12, v26, v0\n\t"
9764
+ "th.vwredsum.vs v13, v28, v0\n\t"
9765
+ "th.vwredsum.vs v14, v30, v0\n\t"
9766
+ "li %[t0], 4\n\t"
9767
+ "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
9768
+ "th.vslideup.vi v10, v9, 1\n\t"
9769
+ "th.vslideup.vi v8, v7, 1\n\t"
9770
+ "th.vslideup.vi v11, v12, 1\n\t"
9771
+ "th.vslideup.vi v13, v14, 1\n\t"
9772
+ "th.vslideup.vi v10, v8, 2\n\t"
9773
+ "th.vslideup.vi v11, v13, 2\n\t"
9774
+ "li %[t0], 8\n\t"
9775
+ "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
9776
+ "th.vlb.v v4, (%[scale])\n\t"
9777
+ "th.vmul.vv v2, v4, v10\n\t"
9778
+ "th.vredsum.vs v0, v2, v0\n\t"
9779
+ "th.vmv.x.s %[t0], v0\n\t"
9780
+ "add %[sumi], %[sumi], %[t0]"
9781
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
9782
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
9783
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
9784
+ , [mask] "r" (0x30)
9785
+ : "memory"
9786
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
9787
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
9788
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
9789
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
9790
+ );
9791
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
9792
+ }
9793
+
9794
+ sumf += d * sum_t;
9795
+
9796
+ }
9797
+
9798
+ *s = sumf;
9799
+
9800
+ #elif defined __riscv_v
9801
+
9802
+ float sumf = 0;
9803
+ const int vector_length = __riscv_vlenb() * 8;
9804
+
9045
9805
  switch (vector_length) {
9046
9806
  case 256:
9047
9807
  for (int i = 0; i < nb; ++i) {