@fugood/llama.node 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (320) hide show
  1. package/CMakeLists.txt +5 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +8 -1
  17. package/package.json +1 -1
  18. package/patches/llama.patch +12 -12
  19. package/src/DetokenizeWorker.cpp +1 -1
  20. package/src/LlamaContext.cpp +33 -1
  21. package/src/LlamaContext.h +1 -0
  22. package/src/LoadSessionWorker.cpp +1 -0
  23. package/src/llama.cpp/.github/workflows/bench.yml +310 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +1315 -0
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +23 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +116 -0
  27. package/src/llama.cpp/.github/workflows/editorconfig.yml +27 -0
  28. package/src/llama.cpp/.github/workflows/gguf-publish.yml +44 -0
  29. package/src/llama.cpp/.github/workflows/labeler.yml +17 -0
  30. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +65 -0
  31. package/src/llama.cpp/.github/workflows/nix-ci.yml +72 -0
  32. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +22 -0
  33. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +36 -0
  34. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +35 -0
  35. package/src/llama.cpp/.github/workflows/python-lint.yml +23 -0
  36. package/src/llama.cpp/.github/workflows/python-type-check.yml +38 -0
  37. package/src/llama.cpp/.github/workflows/server.yml +183 -0
  38. package/src/llama.cpp/CMakeLists.txt +91 -1245
  39. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +1 -1
  40. package/src/llama.cpp/cmake/build-info.cmake +58 -0
  41. package/src/llama.cpp/cmake/git-vars.cmake +22 -0
  42. package/src/llama.cpp/common/CMakeLists.txt +4 -3
  43. package/src/llama.cpp/common/build-info.cpp.in +4 -0
  44. package/src/llama.cpp/common/common.cpp +1116 -877
  45. package/src/llama.cpp/common/common.h +191 -77
  46. package/src/llama.cpp/common/grammar-parser.cpp +118 -31
  47. package/src/llama.cpp/common/json-schema-to-grammar.cpp +346 -65
  48. package/src/llama.cpp/common/log.h +1 -1
  49. package/src/llama.cpp/common/ngram-cache.h +10 -3
  50. package/src/llama.cpp/common/sampling.cpp +19 -10
  51. package/src/llama.cpp/docs/build.md +353 -0
  52. package/src/llama.cpp/examples/CMakeLists.txt +22 -22
  53. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +6 -6
  55. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/batched/batched.cpp +52 -55
  57. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +20 -72
  59. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/chat-13B.bat +57 -0
  61. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/{finetune → cvector-generator}/CMakeLists.txt +2 -2
  63. package/src/llama.cpp/examples/cvector-generator/completions.txt +582 -0
  64. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +503 -0
  65. package/src/llama.cpp/examples/cvector-generator/mean.hpp +48 -0
  66. package/src/llama.cpp/examples/cvector-generator/negative.txt +4 -0
  67. package/src/llama.cpp/examples/cvector-generator/pca.hpp +325 -0
  68. package/src/llama.cpp/examples/cvector-generator/positive.txt +4 -0
  69. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +35 -0
  70. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  71. package/src/llama.cpp/examples/embedding/embedding.cpp +94 -46
  72. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +2 -2
  73. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +4 -6
  74. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/export-lora/export-lora.cpp +344 -386
  76. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +2 -2
  77. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +30 -25
  78. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/gguf/gguf.cpp +5 -0
  80. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +15 -0
  81. package/src/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h +46 -0
  82. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c +295 -0
  83. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h +52 -0
  84. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c +221 -0
  85. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h +24 -0
  86. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c +42 -0
  87. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h +7093 -0
  88. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +693 -0
  89. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  90. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +3 -3
  91. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  92. package/src/llama.cpp/examples/gritlm/gritlm.cpp +6 -2
  93. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  94. package/src/llama.cpp/examples/imatrix/imatrix.cpp +137 -176
  95. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/infill/infill.cpp +38 -153
  97. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +175 -94
  98. package/src/llama.cpp/examples/llama.android/app/build.gradle.kts +65 -0
  99. package/src/llama.cpp/examples/llama.android/build.gradle.kts +6 -0
  100. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +68 -0
  101. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +11 -7
  102. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +2 -2
  103. package/src/llama.cpp/examples/llama.android/settings.gradle.kts +18 -0
  104. package/src/llama.cpp/examples/llava/CMakeLists.txt +6 -5
  105. package/src/llama.cpp/examples/llava/android/build_64.sh +8 -0
  106. package/src/llama.cpp/examples/llava/clip.cpp +23 -14
  107. package/src/llama.cpp/examples/llava/llava-cli.cpp +8 -6
  108. package/src/llama.cpp/examples/llava/requirements.txt +3 -2
  109. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  110. package/src/llama.cpp/examples/lookahead/lookahead.cpp +2 -1
  111. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  112. package/src/llama.cpp/examples/lookup/lookup-create.cpp +2 -0
  113. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  114. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -2
  115. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  116. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  117. package/src/llama.cpp/examples/main/main.cpp +98 -75
  118. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +4 -5
  119. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  120. package/src/llama.cpp/examples/parallel/parallel.cpp +2 -1
  121. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  122. package/src/llama.cpp/examples/passkey/passkey.cpp +23 -43
  123. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  124. package/src/llama.cpp/examples/perplexity/perplexity.cpp +13 -10
  125. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  126. package/src/llama.cpp/examples/quantize/quantize.cpp +37 -34
  127. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  129. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  130. package/src/llama.cpp/examples/retrieval/retrieval.cpp +26 -77
  131. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +14 -7
  133. package/src/llama.cpp/examples/server/CMakeLists.txt +26 -2
  134. package/src/llama.cpp/examples/server/server.cpp +274 -671
  135. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  136. package/src/llama.cpp/examples/server/utils.hpp +28 -29
  137. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  138. package/src/llama.cpp/examples/simple/simple.cpp +21 -29
  139. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  140. package/src/llama.cpp/examples/speculative/speculative.cpp +2 -1
  141. package/src/llama.cpp/examples/sycl/CMakeLists.txt +1 -1
  142. package/src/llama.cpp/examples/sycl/build.sh +23 -0
  143. package/src/llama.cpp/examples/sycl/run-llama2.sh +36 -0
  144. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +33 -0
  145. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +9 -0
  146. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  147. package/src/llama.cpp/examples/tokenize/tokenize.cpp +16 -2
  148. package/src/llama.cpp/ggml/CMakeLists.txt +253 -0
  149. package/src/llama.cpp/{cmake → ggml/cmake}/FindSIMD.cmake +6 -6
  150. package/src/llama.cpp/{ggml-backend.h → ggml/include/ggml-backend.h} +22 -17
  151. package/src/llama.cpp/ggml/include/ggml-blas.h +23 -0
  152. package/src/llama.cpp/ggml/include/ggml-cann.h +125 -0
  153. package/src/llama.cpp/{ggml-cuda.h → ggml/include/ggml-cuda.h} +3 -0
  154. package/src/llama.cpp/{ggml-metal.h → ggml/include/ggml-metal.h} +1 -2
  155. package/src/llama.cpp/{ggml-sycl.h → ggml/include/ggml-sycl.h} +3 -10
  156. package/src/llama.cpp/{ggml.h → ggml/include/ggml.h} +80 -85
  157. package/src/llama.cpp/ggml/src/CMakeLists.txt +1329 -0
  158. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2193 -0
  159. package/src/llama.cpp/ggml/src/ggml-aarch64.h +39 -0
  160. package/src/llama.cpp/{ggml-alloc.c → ggml/src/ggml-alloc.c} +100 -49
  161. package/src/llama.cpp/{ggml-backend-impl.h → ggml/src/ggml-backend-impl.h} +20 -8
  162. package/src/llama.cpp/{ggml-backend.c → ggml/src/ggml-backend.c} +307 -167
  163. package/src/llama.cpp/ggml/src/ggml-blas.cpp +367 -0
  164. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +198 -0
  165. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +230 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +2944 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/common.h +282 -0
  169. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +32 -0
  170. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +17 -0
  171. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +223 -0
  172. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +186 -0
  173. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +180 -0
  174. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +193 -0
  175. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  176. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +208 -0
  177. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +206 -0
  178. package/src/llama.cpp/ggml/src/ggml-cann.cpp +2023 -0
  179. package/src/llama.cpp/{ggml-common.h → ggml/src/ggml-common.h} +41 -7
  180. package/src/llama.cpp/{ggml-impl.h → ggml/src/ggml-impl.h} +113 -9
  181. package/src/llama.cpp/{ggml-kompute.cpp → ggml/src/ggml-kompute.cpp} +33 -18
  182. package/src/llama.cpp/{ggml-quants.c → ggml/src/ggml-quants.c} +1460 -940
  183. package/src/llama.cpp/{ggml-quants.h → ggml/src/ggml-quants.h} +19 -20
  184. package/src/llama.cpp/{ggml-rpc.cpp → ggml/src/ggml-rpc.cpp} +95 -72
  185. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +27 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +53 -0
  187. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +355 -0
  188. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +195 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +21 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +547 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +27 -0
  192. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +698 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  195. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +3011 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  197. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.hpp +33 -0
  198. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1027 -0
  199. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  200. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +374 -0
  201. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +35 -0
  202. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +66 -0
  203. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +275 -0
  204. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +22 -0
  205. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +251 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +24 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +1140 -0
  208. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +5314 -0
  209. package/src/llama.cpp/{ggml-vulkan.cpp → ggml/src/ggml-vulkan.cpp} +1781 -1868
  210. package/src/llama.cpp/{ggml.c → ggml/src/ggml.c} +1245 -2087
  211. package/src/llama.cpp/{sgemm.cpp → ggml/src/llamafile/sgemm.cpp} +21 -24
  212. package/src/llama.cpp/{sgemm.h → ggml/src/llamafile/sgemm.h} +1 -1
  213. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +5 -0
  214. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +552 -0
  215. package/src/llama.cpp/{llama.h → include/llama.h} +175 -100
  216. package/src/llama.cpp/models/.editorconfig +1 -0
  217. package/src/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  218. package/src/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  219. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  220. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +112 -0
  221. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +46 -0
  222. package/src/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  223. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.inp +112 -0
  224. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.out +46 -0
  225. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  226. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +112 -0
  227. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +46 -0
  228. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  229. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +112 -0
  230. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +46 -0
  231. package/src/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  232. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.inp +112 -0
  233. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.out +46 -0
  234. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  235. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +112 -0
  236. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +46 -0
  237. package/src/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  238. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  239. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  242. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +112 -0
  243. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +46 -0
  244. package/src/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  245. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.inp +112 -0
  246. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.out +46 -0
  247. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  248. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +112 -0
  249. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.out +46 -0
  250. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  251. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +112 -0
  252. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.out +46 -0
  253. package/src/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  254. package/src/llama.cpp/models/ggml-vocab-refact.gguf.inp +112 -0
  255. package/src/llama.cpp/models/ggml-vocab-refact.gguf.out +46 -0
  256. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  257. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +112 -0
  258. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.out +46 -0
  259. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  260. package/src/llama.cpp/requirements/requirements-all.txt +12 -0
  261. package/src/llama.cpp/requirements/requirements-compare-llama-bench.txt +2 -0
  262. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -0
  263. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +3 -0
  264. package/src/llama.cpp/requirements/{requirements-convert.txt → requirements-convert_legacy_llama.txt} +1 -1
  265. package/src/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt +1 -0
  266. package/src/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  267. package/src/llama.cpp/requirements/requirements-pydantic.txt +3 -0
  268. package/src/llama.cpp/requirements/requirements-test-tokenizer-random.txt +1 -0
  269. package/src/llama.cpp/requirements.txt +5 -4
  270. package/src/llama.cpp/scripts/build-info.sh +30 -0
  271. package/src/llama.cpp/scripts/install-oneapi.bat +19 -0
  272. package/src/llama.cpp/src/CMakeLists.txt +33 -0
  273. package/src/llama.cpp/src/llama-grammar.cpp +539 -0
  274. package/src/llama.cpp/src/llama-grammar.h +39 -0
  275. package/src/llama.cpp/src/llama-impl.h +26 -0
  276. package/src/llama.cpp/src/llama-sampling.cpp +635 -0
  277. package/src/llama.cpp/src/llama-sampling.h +56 -0
  278. package/src/llama.cpp/src/llama-vocab.cpp +1721 -0
  279. package/src/llama.cpp/src/llama-vocab.h +130 -0
  280. package/src/llama.cpp/{llama.cpp → src/llama.cpp} +5979 -5260
  281. package/src/llama.cpp/{unicode-data.cpp → src/unicode-data.cpp} +851 -802
  282. package/src/llama.cpp/{unicode.cpp → src/unicode.cpp} +52 -30
  283. package/src/llama.cpp/{unicode.h → src/unicode.h} +5 -1
  284. package/src/llama.cpp/tests/CMakeLists.txt +19 -20
  285. package/src/llama.cpp/tests/test-backend-ops.cpp +245 -67
  286. package/src/llama.cpp/tests/test-chat-template.cpp +57 -3
  287. package/src/llama.cpp/tests/test-double-float.cpp +2 -2
  288. package/src/llama.cpp/tests/test-grad0.cpp +2 -2
  289. package/src/llama.cpp/tests/test-grammar-integration.cpp +978 -31
  290. package/src/llama.cpp/tests/test-grammar-parser.cpp +423 -158
  291. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +508 -135
  292. package/src/llama.cpp/tests/test-llama-grammar.cpp +15 -9
  293. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -1
  294. package/src/llama.cpp/tests/test-quantize-perf.cpp +1 -1
  295. package/src/llama.cpp/tests/test-rope.cpp +3 -4
  296. package/src/llama.cpp/tests/test-sampling.cpp +5 -5
  297. package/src/llama.cpp/tests/test-tokenizer-0.cpp +6 -6
  298. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +20 -15
  299. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +22 -11
  300. package/bin/darwin/arm64/default.metallib +0 -0
  301. package/bin/darwin/x64/default.metallib +0 -0
  302. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +0 -5
  303. package/src/llama.cpp/examples/beam-search/beam-search.cpp +0 -188
  304. package/src/llama.cpp/examples/finetune/finetune.cpp +0 -1862
  305. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +0 -55
  306. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +0 -5
  307. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +0 -1253
  308. package/src/llama.cpp/ggml-opencl.cpp +0 -2305
  309. package/src/llama.cpp/ggml-opencl.h +0 -36
  310. package/src/llama.cpp/ggml-sycl.cpp +0 -17340
  311. package/src/llama.cpp/ggml-vulkan-shaders.hpp +0 -81211
  312. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +0 -2
  313. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -2
  314. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +0 -1
  315. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +0 -24
  316. /package/src/llama.cpp/{ggml-alloc.h → ggml/include/ggml-alloc.h} +0 -0
  317. /package/src/llama.cpp/{ggml-kompute.h → ggml/include/ggml-kompute.h} +0 -0
  318. /package/src/llama.cpp/{ggml-rpc.h → ggml/include/ggml-rpc.h} +0 -0
  319. /package/src/llama.cpp/{ggml-vulkan.h → ggml/include/ggml-vulkan.h} +0 -0
  320. /package/src/llama.cpp/{unicode-data.h → src/unicode-data.h} +0 -0
@@ -0,0 +1,1140 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_VECDOTQ_HPP
14
+ #define GGML_SYCL_VECDOTQ_HPP
15
+
16
+ #include "dpct/helper.hpp"
17
+
18
+ typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
19
+
20
+ static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
21
+ const uint16_t* x16 =
22
+ (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
23
+ // alignment
24
+
25
+ int x32 = 0;
26
+ x32 |= x16[0] << 0;
27
+ x32 |= x16[1] << 16;
28
+
29
+ return x32;
30
+ }
31
+
32
+ static __dpct_inline__ int get_int_from_uint8(
33
+ const uint8_t* x8,
34
+ const int& i32) {
35
+ const uint16_t* x16 =
36
+ (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
37
+ // alignment
38
+
39
+ int x32 = 0;
40
+ x32 |= x16[0] << 0;
41
+ x32 |= x16[1] << 16;
42
+
43
+ return x32;
44
+ }
45
+
46
+ static __dpct_inline__ int get_int_from_int8_aligned(
47
+ const int8_t* x8,
48
+ const int& i32) {
49
+ return *(
50
+ (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
51
+ }
52
+
53
+ static __dpct_inline__ int get_int_from_uint8_aligned(
54
+ const uint8_t* x8,
55
+ const int& i32) {
56
+ return *(
57
+ (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
58
+ }
59
+
60
+ static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
61
+ const uint8_t *values,
62
+ int &val1, int &val2) {
63
+
64
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
65
+ aux32 = q4 & 0x0f0f0f0f;
66
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
67
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
68
+ val1 = v1 | (v2 << 16);
69
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
70
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
71
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
72
+ val2 = v1 | (v2 << 16);
73
+ }
74
+
75
+ #define VDR_Q2_K_Q8_1_MMVQ 1
76
+
77
+ // contiguous v/x values
78
+ static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq(
79
+ const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales,
80
+ const sycl::half2 &dm2, const float *__restrict__ d8) {
81
+
82
+ float sumf_d = 0.0f;
83
+ float sumf_m = 0.0f;
84
+
85
+ #pragma unroll
86
+ for (int i = 0; i < QR2_K; ++i) {
87
+ const int sc = scales[2*i];
88
+
89
+ const int vi = (v >> (2*i)) & 0x03030303;
90
+
91
+ sumf_d +=
92
+ d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
93
+
94
+ // fill int with 4x m
95
+ int m = sc >> 4;
96
+ m |= m << 8;
97
+ m |= m << 16;
98
+ sumf_m += d8[i] *
99
+ dpct::dp4a(
100
+ m, u[i],
101
+ 0); // multiply constant q2_K part with sum of q8_1 values
102
+ }
103
+
104
+ const sycl::float2 dm2f =
105
+ dm2.convert<float, sycl::rounding_mode::automatic>();
106
+
107
+ return dm2f.x() * sumf_d - dm2f.y() * sumf_m;
108
+ }
109
+
110
+
111
+ #define VDR_Q3_K_Q8_1_MMVQ 1
112
+
113
+ // contiguous v/x values
114
+ static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq(
115
+ const int &vl, const int &vh, const int *__restrict__ u,
116
+ const uint8_t *__restrict__ scales, const int &scale_offset,
117
+ const float &d3, const float *__restrict__ d8) {
118
+
119
+ float sumf = 0.0f;
120
+
121
+ #pragma unroll
122
+ for (int i = 0; i < QR3_K; ++i) {
123
+ const int isc = scale_offset + 2*i;
124
+
125
+ const int isc_low = isc % (QK_K/32);
126
+ const int sc_shift_low = 4 * (isc / (QK_K/32));
127
+ const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
128
+
129
+ const int isc_high = isc % (QK_K/64);
130
+ const int sc_shift_high = 2 * (isc / (QK_K/64));
131
+ const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
132
+
133
+ const int sc = (sc_low | sc_high) - 32;
134
+
135
+ const int vil = (vl >> (2*i)) & 0x03030303;
136
+
137
+ const int vih = ((vh >> i) << 2) & 0x04040404;
138
+
139
+ const int vi =
140
+ dpct::vectorized_binary<sycl::char4>(vil, vih, dpct::sub_sat());
141
+
142
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
143
+ }
144
+
145
+ return d3 * sumf;
146
+ }
147
+
148
+ #define VDR_Q4_K_Q8_1_MMVQ 2
149
+
150
+ // contiguous v/x values
151
+ static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq(
152
+ const int *__restrict__ v, const int *__restrict__ u,
153
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
154
+ const sycl::half2 &dm4, const float *__restrict__ d8) {
155
+
156
+ float sumf_d = 0.0f;
157
+ float sumf_m = 0.0f;
158
+
159
+ #pragma unroll
160
+ for (int i = 0; i < QR4_K; ++i) {
161
+ const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
162
+ const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
163
+
164
+ const int dot1 =
165
+ dpct::dp4a(v1i, u[2 * i + 1],
166
+ dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product
167
+ const int dot2 =
168
+ dpct::dp4a(0x01010101, u[2 * i + 1],
169
+ dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u
170
+
171
+ sumf_d += d8[i] * (dot1 * sc[i]);
172
+ sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
173
+ }
174
+
175
+ const sycl::float2 dm4f =
176
+ dm4.convert<float, sycl::rounding_mode::automatic>();
177
+
178
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
179
+ }
180
+
181
+
182
+ #define VDR_Q5_K_Q8_1_MMVQ 2
183
+
184
+ // contiguous v/x values
185
+ static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq(
186
+ const int *__restrict__ vl, const int *__restrict__ vh,
187
+ const int *__restrict__ u, const uint8_t *__restrict__ sc,
188
+ const uint8_t *__restrict__ m, const sycl::half2 &dm5,
189
+ const float *__restrict__ d8) {
190
+
191
+ float sumf_d = 0.0f;
192
+ float sumf_m = 0.0f;
193
+
194
+ #pragma unroll
195
+ for (int i = 0; i < QR5_K; ++i) {
196
+ const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
197
+ const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
198
+
199
+ const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
200
+ const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
201
+
202
+ const int v0i = vl0i | vh0i;
203
+ const int v1i = vl1i | vh1i;
204
+
205
+ const int dot1 =
206
+ dpct::dp4a(v0i, u[2 * i + 0],
207
+ dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product
208
+ const int dot2 =
209
+ dpct::dp4a(0x01010101, u[2 * i + 0],
210
+ dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u
211
+
212
+ sumf_d += d8[i] * (dot1 * sc[i]);
213
+ sumf_m += d8[i] * (dot2 * m[i]);
214
+
215
+ }
216
+
217
+ const sycl::float2 dm5f =
218
+ dm5.convert<float, sycl::rounding_mode::automatic>();
219
+
220
+ return dm5f.x() * sumf_d - dm5f.y() * sumf_m;
221
+ }
222
+
223
+
224
+ #define VDR_Q6_K_Q8_1_MMVQ 1
225
+
226
+ // contiguous v/x values
227
+ static __dpct_inline__ float
228
+ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
229
+ const int *__restrict__ u,
230
+ const int8_t *__restrict__ scales, const float &d,
231
+ const float *__restrict__ d8) {
232
+
233
+ float sumf = 0.0f;
234
+
235
+ #pragma unroll
236
+ for (int i = 0; i < QR6_K; ++i) {
237
+ const int sc = scales[4*i];
238
+
239
+ const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
240
+
241
+ const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
242
+
243
+ const int vi = dpct::vectorized_binary<sycl::char4>(
244
+ (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32
245
+
246
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
247
+ }
248
+
249
+ return d*sumf;
250
+ }
251
+
252
+ // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
253
+ // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
254
+
255
+ #define VDR_Q4_0_Q8_1_MMVQ 2
256
+ #define VDR_Q4_0_Q8_1_MMQ 4
257
+
258
+ template <int vdr>
259
+ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
260
+ const float &d4,
261
+ const sycl::half2 &ds8) {
262
+ int sumi = 0;
263
+ #pragma unroll
264
+ for (int i = 0; i < vdr; ++i) {
265
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
266
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
267
+
268
+ // SIMD dot product of quantized values
269
+ sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
270
+ sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
271
+ }
272
+
273
+ const sycl::float2 ds8f =
274
+ ds8.convert<float, sycl::rounding_mode::automatic>();
275
+
276
+ // second part effectively subtracts 8 from each quant value
277
+ return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
278
+ }
279
+
280
+ #define VDR_Q4_1_Q8_1_MMVQ 2
281
+ #define VDR_Q4_1_Q8_1_MMQ 4
282
+
283
+ template <int vdr>
284
+ static __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u,
285
+ const sycl::half2 &dm4,
286
+ const sycl::half2 &ds8) {
287
+
288
+ int sumi = 0;
289
+
290
+ #pragma unroll
291
+ for (int i = 0; i < vdr; ++i) {
292
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
293
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
294
+
295
+ // SIMD dot product of quantized values
296
+ sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
297
+ sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
298
+ }
299
+
300
+ #ifdef GGML_SYCL_F16
301
+ const sycl::float2 tmp =
302
+ (dm4 * ds8).convert<float, sycl::rounding_mode::automatic>();
303
+ const float d4d8 = tmp.x();
304
+ const float m4s8 = tmp.y();
305
+ #else
306
+ const sycl::float2 dm4f =
307
+ dm4.convert<float, sycl::rounding_mode::automatic>();
308
+ const sycl::float2 ds8f =
309
+ ds8.convert<float, sycl::rounding_mode::automatic>();
310
+ const float d4d8 = dm4f.x() * ds8f.x();
311
+ const float m4s8 = dm4f.y() * ds8f.y();
312
+ #endif // GGML_SYCL_F16
313
+
314
+ // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
315
+ return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
316
+ }
317
+
318
+ #define VDR_Q5_0_Q8_1_MMVQ 2
319
+ #define VDR_Q5_0_Q8_1_MMQ 4
320
+
321
+ template <int vdr>
322
+ static __dpct_inline__ float
323
+ vec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u,
324
+ const float &d5, const sycl::half2 &ds8) {
325
+ int sumi = 0;
326
+
327
+ #pragma unroll
328
+ for (int i = 0; i < vdr; ++i) {
329
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
330
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
331
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
332
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
333
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
334
+ sumi = dpct::dp4a(vi0, u[2 * i + 0],
335
+ sumi); // SIMD dot product of quantized values
336
+
337
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
338
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
339
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
340
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
341
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
342
+ sumi = dpct::dp4a(vi1, u[2 * i + 1],
343
+ sumi); // SIMD dot product of quantized values
344
+ }
345
+
346
+ const sycl::float2 ds8f =
347
+ ds8.convert<float, sycl::rounding_mode::automatic>();
348
+
349
+ // second part effectively subtracts 16 from each quant value
350
+ return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y());
351
+ }
352
+
353
+ #define VDR_Q5_1_Q8_1_MMVQ 2
354
+ #define VDR_Q5_1_Q8_1_MMQ 4
355
+
356
+ template <int vdr>
357
+ static __dpct_inline__ float
358
+ vec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u,
359
+ const sycl::half2 &dm5, const sycl::half2 &ds8) {
360
+
361
+ int sumi = 0;
362
+
363
+ #pragma unroll
364
+ for (int i = 0; i < vdr; ++i) {
365
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
366
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
367
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
368
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
369
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
370
+ sumi = dpct::dp4a(vi0, u[2 * i + 0],
371
+ sumi); // SIMD dot product of quantized values
372
+
373
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
374
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
375
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
376
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
377
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
378
+ sumi = dpct::dp4a(vi1, u[2 * i + 1],
379
+ sumi); // SIMD dot product of quantized values
380
+ }
381
+
382
+ #ifdef GGML_SYCL_F16
383
+ const sycl::float2 tmp =
384
+ (dm5 * ds8).convert<float, sycl::rounding_mode::automatic>();
385
+ const float d5d8 = tmp.x();
386
+ const float m5s8 = tmp.y();
387
+
388
+
389
+ #else
390
+ const sycl::float2 dm5f =
391
+ dm5.convert<float, sycl::rounding_mode::automatic>();
392
+ const sycl::float2 ds8f =
393
+ ds8.convert<float, sycl::rounding_mode::automatic>();
394
+ const float d5d8 = dm5f.x() * ds8f.x();
395
+ const float m5s8 = dm5f.y() * ds8f.y();
396
+ #endif // GGML_SYCL_F16
397
+
398
+ // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
399
+ return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
400
+ }
401
+
402
+ #define VDR_Q8_0_Q8_1_MMVQ 2
403
+ #define VDR_Q8_0_Q8_1_MMQ 8
404
+
405
+ template <int vdr>
406
+ static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
407
+ const float &d8_0,
408
+ const float &d8_1) {
409
+
410
+ int sumi = 0;
411
+
412
+ #pragma unroll
413
+ for (int i = 0; i < vdr; ++i) {
414
+ // SIMD dot product of quantized values
415
+ sumi = dpct::dp4a(v[i], u[i], sumi);
416
+ }
417
+
418
+ return d8_0*d8_1 * sumi;
419
+ }
420
+
421
+ template <int vdr>
422
+ static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
423
+ const sycl::half2 &dm8,
424
+ const sycl::half2 &ds8) {
425
+
426
+ int sumi = 0;
427
+
428
+ #pragma unroll
429
+ for (int i = 0; i < vdr; ++i) {
430
+ // SIMD dot product of quantized values
431
+ sumi = dpct::dp4a(v[i], u[i], sumi);
432
+ }
433
+
434
+ #ifdef GGML_SYCL_F16
435
+ const sycl::float2 tmp =
436
+ (dm8 * ds8).convert<float, sycl::rounding_mode::automatic>();
437
+ const float d8d8 = tmp.x();
438
+ const float m8s8 = tmp.y();
439
+ #else
440
+ const sycl::float2 dm8f =
441
+ dm8.convert<float, sycl::rounding_mode::automatic>();
442
+ const sycl::float2 ds8f =
443
+ ds8.convert<float, sycl::rounding_mode::automatic>();
444
+ const float d8d8 = dm8f.x() * ds8f.x();
445
+ const float m8s8 = dm8f.y() * ds8f.y();
446
+ #endif // GGML_SYCL_F16
447
+
448
+ // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
449
+ return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
450
+ }
451
+
452
+ static __dpct_inline__ float
453
+ vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
454
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
455
+
456
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
457
+
458
+ int v[VDR_Q4_0_Q8_1_MMVQ];
459
+ int u[2*VDR_Q4_0_Q8_1_MMVQ];
460
+
461
+ #pragma unroll
462
+ for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
463
+ v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
464
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
465
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
466
+ }
467
+
468
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
469
+ }
470
+
471
+ static __dpct_inline__ float
472
+ vec_dot_q4_1_q8_1(const void *__restrict__ vbq,
473
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
474
+
475
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
476
+
477
+ int v[VDR_Q4_1_Q8_1_MMVQ];
478
+ int u[2*VDR_Q4_1_Q8_1_MMVQ];
479
+
480
+ #pragma unroll
481
+ for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
482
+ v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
483
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
484
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
485
+ }
486
+
487
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
488
+ }
489
+
490
+ static __dpct_inline__ float
491
+ vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
492
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
493
+
494
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
495
+
496
+ int vl[VDR_Q5_0_Q8_1_MMVQ];
497
+ int vh[VDR_Q5_0_Q8_1_MMVQ];
498
+ int u[2*VDR_Q5_0_Q8_1_MMVQ];
499
+
500
+ #pragma unroll
501
+ for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
502
+ vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i);
503
+ vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
504
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
505
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
506
+ }
507
+
508
+ return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
509
+ }
510
+
511
+ static __dpct_inline__ float
512
+ vec_dot_q5_1_q8_1(const void *__restrict__ vbq,
513
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
514
+
515
+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
516
+
517
+ int vl[VDR_Q5_1_Q8_1_MMVQ];
518
+ int vh[VDR_Q5_1_Q8_1_MMVQ];
519
+ int u[2*VDR_Q5_1_Q8_1_MMVQ];
520
+
521
+ #pragma unroll
522
+ for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
523
+ vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
524
+ vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
525
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
526
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
527
+ }
528
+
529
+ return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
530
+ }
531
+
532
+ static __dpct_inline__ float
533
+ vec_dot_q8_0_q8_1(const void *__restrict__ vbq,
534
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
535
+
536
+ const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
537
+
538
+ int v[VDR_Q8_0_Q8_1_MMVQ];
539
+ int u[VDR_Q8_0_Q8_1_MMVQ];
540
+
541
+ #pragma unroll
542
+ for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
543
+ v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
544
+ u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
545
+ }
546
+
547
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d,
548
+ bq8_1->ds[0]);
549
+ }
550
+
551
+ static __dpct_inline__ float
552
+ vec_dot_q2_K_q8_1(const void *__restrict__ vbq,
553
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
554
+
555
+ const block_q2_K * bq2_K = (const block_q2_K *) vbq;
556
+
557
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
558
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
559
+
560
+ const uint8_t * scales = bq2_K->scales + scale_offset;
561
+
562
+ const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
563
+ int u[QR2_K];
564
+ float d8[QR2_K];
565
+
566
+ #pragma unroll
567
+ for (int i = 0; i < QR2_K; ++ i) {
568
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
569
+ d8[i] = bq8_1[bq8_offset + i].ds[0];
570
+ }
571
+
572
+ return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
573
+ }
574
+
575
+ static __dpct_inline__ float
576
+ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
577
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
578
+
579
+ const block_q3_K * bq3_K = (const block_q3_K *) vbq;
580
+
581
+ const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
582
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
583
+
584
+ const float d = bq3_K->d;
585
+
586
+ const int vl = get_int_from_uint8(bq3_K->qs, iqs);
587
+
588
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
589
+ const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
590
+
591
+ int u[QR3_K];
592
+ float d8[QR3_K];
593
+
594
+ #pragma unroll
595
+ for (int i = 0; i < QR3_K; ++i) {
596
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
597
+ d8[i] = bq8_1[bq8_offset + i].ds[0];
598
+ }
599
+
600
+ return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
601
+ }
602
+
603
+ static __dpct_inline__ float
604
+ vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
605
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
606
+
607
+ #ifndef GGML_QKK_64
608
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
609
+
610
+ int v[2];
611
+ int u[2*QR4_K];
612
+ float d8[QR4_K];
613
+
614
+ // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
615
+ const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
616
+
617
+ // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
618
+ // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
619
+ // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
620
+ // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
621
+
622
+ const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
623
+ v[0] = q4[0];
624
+ v[1] = q4[4];
625
+
626
+ const uint16_t * scales = (const uint16_t *)bq4_K->scales;
627
+ uint16_t aux[2];
628
+ const int j = bq8_offset/2;
629
+ if (j < 2) {
630
+ aux[0] = scales[j+0] & 0x3f3f;
631
+ aux[1] = scales[j+2] & 0x3f3f;
632
+ } else {
633
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
634
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
635
+ }
636
+ const uint8_t * sc = (const uint8_t *)aux;
637
+ const uint8_t * m = sc + 2;
638
+
639
+ for (int i = 0; i < QR4_K; ++i) {
640
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
641
+ d8[i] = bq8i->ds[0];
642
+
643
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
644
+ u[2*i+0] = q8[0];
645
+ u[2*i+1] = q8[4];
646
+ }
647
+
648
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
649
+
650
+ #else
651
+
652
+ #if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
653
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
654
+
655
+ float sumf_d = 0.0f;
656
+ float sumf_m = 0.0f;
657
+
658
+ uint16_t aux16[2];
659
+ const uint8_t * s = (const uint8_t *)aux16;
660
+
661
+ const uint16_t * a = (const uint16_t *)bq4_K->scales;
662
+ aux16[0] = a[0] & 0x0f0f;
663
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
664
+
665
+ const float dall = bq4_K->dm[0];
666
+ const float dmin = bq4_K->dm[1];
667
+
668
+ const float d8_1 = bq8_1[0].ds[0];
669
+ const float d8_2 = bq8_1[1].ds[1];
670
+
671
+ const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
672
+ const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
673
+ const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
674
+ const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
675
+
676
+ const int * q4 = (const int *)bq4_K->qs + (iqs/2);
677
+ const int v1 = q4[0];
678
+ const int v2 = q4[4];
679
+
680
+ const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0));
681
+ const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
682
+ const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0));
683
+ const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0));
684
+
685
+ sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
686
+ sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
687
+
688
+ return dall * sumf_d - dmin * sumf_m;
689
+
690
+ #else
691
+ bad_arch();
692
+ #endif // __SYCL_ARCH__ >= VER_4VEC
693
+
694
+ #endif
695
+ }
696
+
697
+ static __dpct_inline__ float
698
+ vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
699
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
700
+
701
+ #ifndef GGML_QKK_64
702
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
703
+
704
+ int vl[2];
705
+ int vh[2];
706
+ int u[2*QR5_K];
707
+ float d8[QR5_K];
708
+
709
+ const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
710
+ const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
711
+ const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
712
+
713
+ vl[0] = ql[0];
714
+ vl[1] = ql[4];
715
+
716
+ vh[0] = qh[0] >> bq8_offset;
717
+ vh[1] = qh[4] >> bq8_offset;
718
+
719
+ const uint16_t * scales = (const uint16_t *)bq5_K->scales;
720
+ uint16_t aux[2];
721
+ const int j = bq8_offset/2;
722
+ if (j < 2) {
723
+ aux[0] = scales[j+0] & 0x3f3f;
724
+ aux[1] = scales[j+2] & 0x3f3f;
725
+ } else {
726
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
727
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
728
+ }
729
+ const uint8_t * sc = (const uint8_t *)aux;
730
+ const uint8_t * m = sc + 2;
731
+
732
+ #pragma unroll
733
+ for (int i = 0; i < QR5_K; ++i) {
734
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
735
+ d8[i] = bq8i->ds[0];
736
+
737
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
738
+ u[2*i+0] = q8[0];
739
+ u[2*i+1] = q8[4];
740
+ }
741
+
742
+ return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
743
+
744
+ #else
745
+
746
+ #if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
747
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
748
+
749
+ const int8_t * s = bq5_K->scales;
750
+
751
+ const float d = bq5_K->d;
752
+
753
+ const float d8_1 = bq8_1[0].ds[0];
754
+ const float d8_2 = bq8_1[1].ds[1];
755
+
756
+ const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
757
+ const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
758
+ const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
759
+ const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
760
+
761
+ const int * ql = (const int *)bq5_K->qs + (iqs/2);
762
+ const int vl1 = ql[0];
763
+ const int vl2 = ql[4];
764
+
765
+ const int step = 4 * (iqs/2); // 0, 4, 8, 12
766
+ const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
767
+ const int in = step%8; // 0, 4, 0, 4
768
+ const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
769
+
770
+ const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
771
+ const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
772
+ const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
773
+ const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
774
+
775
+ const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1])
776
+ + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]);
777
+
778
+ return d * sumf_d;
779
+
780
+ #else
781
+ bad_arch();
782
+ #endif // __SYCL_ARCH__ >= VER_4VEC
783
+
784
+ #endif
785
+ }
786
+
787
+ static __dpct_inline__ float
788
+ vec_dot_q6_K_q8_1(const void *__restrict__ vbq,
789
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
790
+
791
+ const block_q6_K * bq6_K = (const block_q6_K *) vbq;
792
+
793
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
794
+ const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
795
+ const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
796
+
797
+ const int vl = get_int_from_uint8(bq6_K->ql, iqs);
798
+ const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
799
+
800
+ const int8_t * scales = bq6_K->scales + scale_offset;
801
+
802
+ int u[QR6_K];
803
+ float d8[QR6_K];
804
+
805
+ #pragma unroll
806
+ for (int i = 0; i < QR6_K; ++i) {
807
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
808
+ d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];
809
+ }
810
+
811
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
812
+ }
813
+
814
+
815
+ static __dpct_inline__ float
816
+ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
817
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
818
+ const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,
819
+ const uint8_t *kmask_iq2xs) {
820
+ #if QK_K == 256
821
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
822
+
823
+ const int ib32 = iqs;
824
+ const uint16_t * q2 = bq2->qs + 4*ib32;
825
+ const uint8_t * aux8 = (const uint8_t *)q2;
826
+ const int8_t * q8 = bq8_1[ib32].qs;
827
+ uint32_t aux32 = q2[2] | (q2[3] << 16);
828
+ int sumi = 0;
829
+ for (int l = 0; l < 4; ++l) {
830
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
831
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127];
832
+ for (int j = 0; j < 8; ++j) {
833
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
834
+ }
835
+ q8 += 8;
836
+ aux32 >>= 7;
837
+ }
838
+ const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;
839
+ return d * sumi;
840
+ #else
841
+ assert(false);
842
+ return 0.f;
843
+ #endif
844
+ }
845
+
846
+ static __dpct_inline__ float
847
+ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
848
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
849
+ const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {
850
+ #if DPCT_COMPATIBILITY_TEMP >= \
851
+ MIN_CC_DP4A // lowest compute capability for integer intrinsics
852
+ #if QK_K == 256
853
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
854
+
855
+ const int ib32 = iqs;
856
+ const uint16_t * q2 = bq2->qs + 4*ib32;
857
+ const int8_t * q8 = bq8_1[ib32].qs;
858
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
859
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
860
+ int sumi1 = 0;
861
+ for (int l = 0; l < 2; ++l) {
862
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
863
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
864
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
865
+ grid[0] ^ signs[0], signs[0], std::minus<>());
866
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
867
+ grid[1] ^ signs[1], signs[1], std::minus<>());
868
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
869
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
870
+ q8 += 8;
871
+ }
872
+ int sumi2 = 0;
873
+ for (int l = 2; l < 4; ++l) {
874
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
875
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
876
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
877
+ grid[0] ^ signs[0], signs[0], std::minus<>());
878
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
879
+ grid[1] ^ signs[1], signs[1], std::minus<>());
880
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
881
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
882
+ q8 += 8;
883
+ }
884
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
885
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
886
+ #else
887
+ assert(false);
888
+ return 0.f;
889
+ #endif
890
+ #else
891
+ assert(false);
892
+ return 0.f;
893
+ #endif
894
+ }
895
+
896
+ static __dpct_inline__ float
897
+ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
898
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
899
+ #if QK_K == 256
900
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
901
+
902
+ const int ib32 = iqs;
903
+ const int8_t * q8 = bq8_1[ib32].qs;
904
+ const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
905
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
906
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
907
+ int sumi1 = 0;
908
+ for (int l = 0; l < 2; ++l) {
909
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
910
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
911
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
912
+ std::equal_to<>());
913
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
914
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
915
+ std::equal_to<>());
916
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
917
+ grid[0] ^ signs0, signs0, std::minus<>());
918
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
919
+ grid[1] ^ signs1, signs1, std::minus<>());
920
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
921
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
922
+ q8 += 8;
923
+ }
924
+ int sumi2 = 0;
925
+ for (int l = 2; l < 4; ++l) {
926
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
927
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
928
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
929
+ std::equal_to<>());
930
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
931
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
932
+ std::equal_to<>());
933
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
934
+ grid[0] ^ signs0, signs0, std::minus<>());
935
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
936
+ grid[1] ^ signs1, signs1, std::minus<>());
937
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
938
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
939
+ q8 += 8;
940
+ }
941
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
942
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
943
+ #else
944
+ assert(false);
945
+ #endif
946
+ }
947
+
948
+ static __dpct_inline__ float
949
+ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
950
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
951
+ const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {
952
+ #if DPCT_COMPATIBILITY_TEMP >= \
953
+ MIN_CC_DP4A // lowest compute capability for integer intrinsics
954
+ #if QK_K == 256
955
+ const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
956
+
957
+ const int ib32 = iqs;
958
+ const uint8_t * q3 = bq2->qs + 8*ib32;
959
+ const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
960
+ const int8_t * q8 = bq8_1[ib32].qs;
961
+ uint32_t aux32 = gas[0] | (gas[1] << 16);
962
+ int sumi = 0;
963
+ for (int l = 0; l < 4; ++l) {
964
+ const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
965
+ const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
966
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
967
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
968
+ grid1[0] ^ signs[0], signs[0], std::minus<>());
969
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
970
+ grid2[0] ^ signs[1], signs[1], std::minus<>());
971
+ sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
972
+ sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
973
+ q8 += 8;
974
+ aux32 >>= 7;
975
+ }
976
+ const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f;
977
+ return d * sumi;
978
+ #else
979
+ assert(false);
980
+ return 0.f;
981
+ #endif
982
+ #else
983
+ assert(false);
984
+ return 0.f;
985
+ #endif
986
+ }
987
+
988
+ static __dpct_inline__ float
989
+ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
990
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
991
+ const uint32_t *iq3s_grid) {
992
+ #if QK_K == 256
993
+ const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
994
+
995
+ const int ib32 = iqs;
996
+ const uint8_t * qs = bq2->qs + 8*ib32;
997
+ const int8_t * q8 = bq8_1[ib32].qs;
998
+ int sumi = 0;
999
+ for (int l = 0; l < 4; ++l) {
1000
+ const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
1001
+ const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
1002
+ uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
1003
+ ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
1004
+ 0x08040201, std::equal_to<>());
1005
+ uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
1006
+ ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
1007
+ 0x08040201, std::equal_to<>());
1008
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
1009
+ grid1[0] ^ signs0, signs0, std::minus<>());
1010
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
1011
+ grid2[0] ^ signs1, signs1, std::minus<>());
1012
+ sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
1013
+ sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
1014
+ q8 += 8;
1015
+ }
1016
+ const float d =
1017
+ (float)bq2->d *
1018
+ (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
1019
+ bq8_1[ib32].ds[0];
1020
+ return d * sumi;
1021
+ #else
1022
+ assert(false);
1023
+ #endif
1024
+ }
1025
+
1026
+ static __dpct_inline__ float
1027
+ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
1028
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
1029
+ const uint32_t *iq1s_grid_gpu) {
1030
+ #if QK_K == 256
1031
+ const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
1032
+
1033
+ const int ib32 = iqs;
1034
+ int sumi = 0;
1035
+ const int * q8 = (const int *)bq8_1[ib32].qs;
1036
+ for (int l = 0; l < 4; ++l) {
1037
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
1038
+ int grid0 = grid[0] & 0x0f0f0f0f;
1039
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
1040
+ sumi = dpct::dp4a(q8[2 * l + 1], grid1,
1041
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi));
1042
+ }
1043
+
1044
+ const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
1045
+ const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
1046
+ const float d = d1q * bq8_1[ib32].ds[0];
1047
+ const float m = d1q * bq8_1[ib32].ds[1];
1048
+ return d * sumi + m * delta;
1049
+ #else
1050
+ assert(false);
1051
+ #endif
1052
+ }
1053
+
1054
+ static __dpct_inline__ float
1055
+ vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
1056
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
1057
+ #if QK_K == 256
1058
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
1059
+
1060
+ const int ib32 = iqs;
1061
+ int sumi[2] = {0, 0};
1062
+ float sumf[2] = {0.f, 0.f};
1063
+
1064
+ const int * q8 = (const int *)bq8_1[ib32].qs;
1065
+ for (int l = 0; l < 4; ++l) {
1066
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
1067
+ int grid0 = grid[0] & 0x0f0f0f0f;
1068
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
1069
+ sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
1070
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
1071
+ const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
1072
+ const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
1073
+ dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
1074
+ sumf[l/2] += delta*sumy;
1075
+ }
1076
+
1077
+ iq1m_scale_t scale;
1078
+ const uint16_t * sc = (const uint16_t *)bq1->scales;
1079
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
1080
+ const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
1081
+ return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
1082
+ #else
1083
+ assert(false);
1084
+ #endif
1085
+ }
1086
+
1087
+
1088
+ static __dpct_inline__ float
1089
+ vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
1090
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
1091
+
1092
+ const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
1093
+
1094
+ const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
1095
+ const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
1096
+
1097
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
1098
+
1099
+ int v1, v2;
1100
+ int sumi1 = 0, sumi2 = 0;
1101
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
1102
+ const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
1103
+ get_int_from_table_16(aux, values, v1, v2);
1104
+ sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
1105
+ sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
1106
+ }
1107
+
1108
+ const float d = (float)bq->d * bq8_1->ds[0];
1109
+ return d * (sumi1 + sumi2);
1110
+ }
1111
+
1112
+
1113
+ static __dpct_inline__ float
1114
+ vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
1115
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
1116
+
1117
+ #if QK_K == 256
1118
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
1119
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
1120
+
1121
+ // iqs is 0...7
1122
+ const int ib32 = iqs;
1123
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
1124
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
1125
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
1126
+ const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
1127
+ int v1, v2;
1128
+ int sumi1 = 0, sumi2 = 0;
1129
+ for (int j = 0; j < 4; ++j) {
1130
+ get_int_from_table_16(q4[j], values, v1, v2);
1131
+ sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
1132
+ sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
1133
+ }
1134
+ return d * (sumi1 + sumi2);
1135
+ #else
1136
+ assert(false);
1137
+ #endif
1138
+ }
1139
+
1140
+ #endif // GGML_SYCL_VECDOTQ_HPP