@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -0,0 +1,125 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "im2col.hpp"
14
+
15
+ template <typename T>
16
+ static void im2col_kernel(
17
+ const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
18
+ int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
19
+ int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
20
+ const sycl::nd_item<3> &item_ct1) {
21
+ const int64_t work_group_size = item_ct1.get_local_range(2);
22
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
23
+
24
+ // make each work-item deal with more elements since sycl global range can not exceed max int
25
+ for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
26
+
27
+ const int64_t ksize = OW * (KH > 1 ? KW : 1);
28
+ const int64_t kx = i / ksize;
29
+ const int64_t kd = kx * ksize;
30
+ const int64_t ky = (i - kd) / OW;
31
+ const int64_t ix = i % OW;
32
+
33
+ const int64_t oh = item_ct1.get_group(1);
34
+ const int64_t batch = item_ct1.get_group(0) / IC;
35
+ const int64_t ic = item_ct1.get_group(0) % IC;
36
+
37
+ const int64_t iiw = ix * s0 + kx * d0 - p0;
38
+ const int64_t iih = oh * s1 + ky * d1 - p1;
39
+
40
+ const int64_t offset_dst =
41
+ ((batch * OH + oh) * OW + ix) * CHW +
42
+ (ic * (KW * KH) + ky * KW + kx);
43
+
44
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
45
+ dst[offset_dst] =
46
+ sycl::vec<float, 1>(0.0f)
47
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
48
+ } else {
49
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
50
+ dst[offset_dst] =
51
+ sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
52
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
53
+ }
54
+ }
55
+ }
56
+
57
+ template <typename T>
58
+ static void im2col_sycl(
59
+ const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
60
+ int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
61
+ int s0, int s1, int p0, int p1, int d0, int d1,
62
+ queue_ptr stream) {
63
+ const int64_t parallel_elements = OW * KW * KH;
64
+ const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
65
+
66
+ // decrease global range when it exceeds the max int
67
+ int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
68
+ sycl::range<3> block_nums(batch * IC, OH, num_blocks);
69
+ sycl::range<3> local_range(1, 1, local_size);
70
+
71
+ {
72
+ dpct::has_capability_or_fail(stream->get_device(),
73
+ {sycl::aspect::fp16});
74
+
75
+ stream->parallel_for(
76
+ sycl::nd_range<3>(block_nums * local_range, local_range),
77
+ [=](sycl::nd_item<3> item_ct1) {
78
+ im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
79
+ parallel_elements, (IC * KH * KW), s0, s1, p0,
80
+ p1, d0, d1, item_ct1);
81
+ });
82
+ }
83
+ }
84
+
85
+ void ggml_sycl_op_im2col(
86
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
87
+ ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
88
+ const queue_ptr &main_stream) {
89
+
90
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
91
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
92
+ GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
93
+
94
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
95
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
96
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
97
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
98
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
99
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
100
+
101
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
102
+
103
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
104
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
105
+ const int64_t IW = src1->ne[0];
106
+
107
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
108
+ const int64_t KW = src0->ne[0];
109
+
110
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
111
+ const int64_t OW = dst->ne[1];
112
+
113
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
114
+ const int64_t batch = src1->ne[3];
115
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
116
+
117
+ if (dst->type == GGML_TYPE_F16) {
118
+ im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
119
+ } else {
120
+ im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
121
+ }
122
+
123
+ (void) src0;
124
+ (void) src0_dd;
125
+ }
@@ -0,0 +1,23 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_IM2COL_HPP
14
+ #define GGML_SYCL_IM2COL_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_im2col(
19
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
20
+ ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
21
+ const queue_ptr &main_stream);
22
+
23
+ #endif // GGML_SYCL_IM2COL_HPP
@@ -1,6 +1,6 @@
1
1
  #include "mmvq.hpp"
2
2
  #include "vecdotq.hpp"
3
-
3
+ #include <cassert>
4
4
 
5
5
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
6
6
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
@@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
13
13
  }
14
14
 
15
15
  const int blocks_per_row = ncols / qk;
16
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
16
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
17
+ assert(blocks_per_warp>0);
17
18
 
18
19
  // partial sum for each thread
19
20
  float tmp = 0.0f;
@@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
37
38
 
38
39
  // sum up partial sums and write back result
39
40
  #pragma unroll
40
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
41
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
41
42
  tmp +=
42
43
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
43
44
  }
@@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
61
62
  }
62
63
 
63
64
  const int blocks_per_row = ncols / qk;
64
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
65
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
66
+ assert(blocks_per_warp>0);
65
67
 
66
68
  // partial sum for each thread
67
69
  float tmp = 0.0f;
@@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
85
87
 
86
88
  // sum up partial sums and write back result
87
89
  #pragma unroll
88
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
90
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
89
91
  tmp +=
90
92
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
91
93
  }
@@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
109
111
  }
110
112
 
111
113
  const int blocks_per_row = ncols / qk;
112
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
113
-
114
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
115
+ assert(blocks_per_warp>0);
114
116
  // partial sum for each thread
115
117
  float tmp = 0.0f;
116
118
 
@@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
133
135
 
134
136
  // sum up partial sums and write back result
135
137
  #pragma unroll
136
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
138
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
137
139
  tmp +=
138
140
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
139
141
  }
@@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
157
159
  }
158
160
 
159
161
  const int blocks_per_row = ncols / qk;
160
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
161
-
162
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
163
+ assert(blocks_per_warp>0);
162
164
  // partial sum for each thread
163
165
  float tmp = 0.0f;
164
166
 
@@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
181
183
 
182
184
  // sum up partial sums and write back result
183
185
  #pragma unroll
184
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
186
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
185
187
  tmp +=
186
188
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
187
189
  }
@@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
205
207
  }
206
208
 
207
209
  const int blocks_per_row = ncols / qk;
208
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
209
-
210
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
211
+ assert(blocks_per_warp>0);
210
212
  // partial sum for each thread
211
213
  float tmp = 0.0f;
212
214
 
@@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
229
231
 
230
232
  // sum up partial sums and write back result
231
233
  #pragma unroll
232
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
234
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
233
235
  tmp +=
234
236
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
235
237
  }
@@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
253
255
  }
254
256
 
255
257
  const int blocks_per_row = ncols / qk;
256
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
257
-
258
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
259
+ assert(blocks_per_warp>0);
258
260
  // partial sum for each thread
259
261
  float tmp = 0.0f;
260
262
 
@@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
277
279
 
278
280
  // sum up partial sums and write back result
279
281
  #pragma unroll
280
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
282
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
281
283
  tmp +=
282
284
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
283
285
  }
@@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
301
303
  }
302
304
 
303
305
  const int blocks_per_row = ncols / qk;
304
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
305
-
306
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
307
+ assert(blocks_per_warp>0);
306
308
  // partial sum for each thread
307
309
  float tmp = 0.0f;
308
310
 
@@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
325
327
 
326
328
  // sum up partial sums and write back result
327
329
  #pragma unroll
328
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
330
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
329
331
  tmp +=
330
332
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
331
333
  }
@@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
349
351
  }
350
352
 
351
353
  const int blocks_per_row = ncols / qk;
352
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
353
-
354
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
355
+ assert(blocks_per_warp>0);
354
356
  // partial sum for each thread
355
357
  float tmp = 0.0f;
356
358
 
@@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
373
375
 
374
376
  // sum up partial sums and write back result
375
377
  #pragma unroll
376
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
378
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
377
379
  tmp +=
378
380
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
379
381
  }
@@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
397
399
  }
398
400
 
399
401
  const int blocks_per_row = ncols / qk;
400
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
401
-
402
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
403
+ assert(blocks_per_warp>0);
402
404
  // partial sum for each thread
403
405
  float tmp = 0.0f;
404
406
 
@@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
421
423
 
422
424
  // sum up partial sums and write back result
423
425
  #pragma unroll
424
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
426
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
425
427
  tmp +=
426
428
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
427
429
  }
@@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
446
448
  }
447
449
 
448
450
  const int blocks_per_row = ncols / qk;
449
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
450
-
451
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
452
+ assert(blocks_per_warp>0);
451
453
  // partial sum for each thread
452
454
  float tmp = 0.0f;
453
455
 
@@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
470
472
 
471
473
  // sum up partial sums and write back result
472
474
  #pragma unroll
473
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
475
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
474
476
  tmp +=
475
477
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
476
478
  }
@@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
487
489
  GGML_ASSERT(ncols % QK4_0 == 0);
488
490
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
489
491
  const sycl::range<3> block_nums(1, 1, block_num_y);
490
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
492
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
491
493
  {
492
494
 
493
495
  stream->submit([&](sycl::handler &cgh) {
@@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
495
497
  cgh.parallel_for(
496
498
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
497
499
  [=](sycl::nd_item<3> item_ct1)
498
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
500
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
499
501
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
500
502
  VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
501
503
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
511
513
  GGML_ASSERT(ncols % QK4_1 == 0);
512
514
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
513
515
  const sycl::range<3> block_nums(1, 1, block_num_y);
514
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
516
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
515
517
  {
516
518
 
517
519
  stream->submit([&](sycl::handler &cgh) {
@@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
519
521
  cgh.parallel_for(
520
522
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
521
523
  [=](sycl::nd_item<3> item_ct1)
522
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
524
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
523
525
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
524
526
  VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
525
527
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
535
537
  GGML_ASSERT(ncols % QK5_0 == 0);
536
538
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
537
539
  const sycl::range<3> block_nums(1, 1, block_num_y);
538
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
540
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
539
541
  {
540
542
 
541
543
  stream->submit([&](sycl::handler &cgh) {
@@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
543
545
  cgh.parallel_for(
544
546
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
545
547
  [=](sycl::nd_item<3> item_ct1)
546
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
548
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
547
549
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
548
550
  VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
549
551
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
559
561
  GGML_ASSERT(ncols % QK5_1 == 0);
560
562
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
561
563
  const sycl::range<3> block_nums(1, 1, block_num_y);
562
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
564
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
563
565
  {
564
566
 
565
567
  stream->submit([&](sycl::handler &cgh) {
@@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
567
569
  cgh.parallel_for(
568
570
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
569
571
  [=](sycl::nd_item<3> item_ct1)
570
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
572
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
571
573
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
572
574
  VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
573
575
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
583
585
  GGML_ASSERT(ncols % QK8_0 == 0);
584
586
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
585
587
  const sycl::range<3> block_nums(1, 1, block_num_y);
586
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
588
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
587
589
  {
588
590
 
589
591
  stream->submit([&](sycl::handler &cgh) {
@@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
591
593
  cgh.parallel_for(
592
594
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
593
595
  [=](sycl::nd_item<3> item_ct1)
594
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
596
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
595
597
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
596
598
  VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
597
599
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
607
609
  GGML_ASSERT(ncols % QK_K == 0);
608
610
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
609
611
  const sycl::range<3> block_nums(1, 1, block_num_y);
610
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
612
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
611
613
  {
612
614
 
613
615
  stream->submit([&](sycl::handler &cgh) {
@@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
615
617
  cgh.parallel_for(
616
618
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
617
619
  [=](sycl::nd_item<3> item_ct1)
618
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
620
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
619
621
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
620
622
  VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
621
623
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
631
633
  GGML_ASSERT(ncols % QK_K == 0);
632
634
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
633
635
  const sycl::range<3> block_nums(1, 1, block_num_y);
634
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
636
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
635
637
  {
636
638
 
637
639
  stream->submit([&](sycl::handler &cgh) {
@@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
639
641
  cgh.parallel_for(
640
642
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
641
643
  [=](sycl::nd_item<3> item_ct1)
642
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
644
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
643
645
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
644
646
  VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
645
647
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
655
657
  GGML_ASSERT(ncols % QK_K == 0);
656
658
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
657
659
  const sycl::range<3> block_nums(1, 1, block_num_y);
658
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
660
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
659
661
  {
660
662
 
661
663
  stream->submit([&](sycl::handler &cgh) {
@@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
663
665
  cgh.parallel_for(
664
666
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
665
667
  [=](sycl::nd_item<3> item_ct1)
666
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
668
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
667
669
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
668
670
  VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
669
671
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
679
681
  GGML_ASSERT(ncols % QK_K == 0);
680
682
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
681
683
  const sycl::range<3> block_nums(1, 1, block_num_y);
682
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
684
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
683
685
  {
684
686
 
685
687
  stream->submit([&](sycl::handler &cgh) {
@@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
687
689
  cgh.parallel_for(
688
690
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
689
691
  [=](sycl::nd_item<3> item_ct1)
690
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
692
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
691
693
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
692
694
  VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
693
695
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
703
705
  GGML_ASSERT(ncols % QK_K == 0);
704
706
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
705
707
  const sycl::range<3> block_nums(1, 1, block_num_y);
706
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
708
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
707
709
  {
708
710
 
709
711
  stream->submit([&](sycl::handler &cgh) {
@@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
711
713
  cgh.parallel_for(
712
714
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
713
715
  [=](sycl::nd_item<3> item_ct1)
714
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
716
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
715
717
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
716
718
  VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
717
719
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
728
730
  GGML_ASSERT(ncols % QK_K == 0);
729
731
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
730
732
  const sycl::range<3> block_nums(1, 1, block_num_y);
731
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
733
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
732
734
  {
733
735
  stream->submit([&](sycl::handler &cgh) {
734
736
  cgh.parallel_for(
735
737
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
736
738
  [=](sycl::nd_item<3> item_ct1)
737
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
739
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
738
740
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
739
741
  vx, vy, dst, ncols, nrows, item_ct1);
740
742
  });
@@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
749
751
  GGML_ASSERT(ncols % QK_K == 0);
750
752
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
751
753
  const sycl::range<3> block_nums(1, 1, block_num_y);
752
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
754
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
753
755
  {
754
756
 
755
757
  stream->submit([&](sycl::handler &cgh) {
@@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
759
761
  cgh.parallel_for(
760
762
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
761
763
  [=](sycl::nd_item<3> item_ct1)
762
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
764
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
763
765
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
764
766
  vx, vy, dst, ncols, nrows, item_ct1);
765
767
  });
@@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
774
776
  GGML_ASSERT(ncols % QK_K == 0);
775
777
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
776
778
  const sycl::range<3> block_nums(1, 1, block_num_y);
777
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
779
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
778
780
  {
779
781
 
780
782
  stream->submit([&](sycl::handler &cgh) {
@@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
784
786
  cgh.parallel_for(
785
787
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
786
788
  [=](sycl::nd_item<3> item_ct1)
787
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
789
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
788
790
  mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
789
791
  vx, vy, dst, ncols, nrows, item_ct1);
790
792
  });
@@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
799
801
  GGML_ASSERT(ncols % QK_K == 0);
800
802
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
801
803
  const sycl::range<3> block_nums(1, 1, block_num_y);
802
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
804
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
803
805
  {
804
806
 
805
807
  stream->submit([&](sycl::handler &cgh) {
@@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
809
811
  cgh.parallel_for(
810
812
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
811
813
  [=](sycl::nd_item<3> item_ct1)
812
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
814
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
813
815
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
814
816
  vx, vy, dst, ncols, nrows, item_ct1);
815
817
  });
@@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
824
826
  GGML_ASSERT(ncols % QK_K == 0);
825
827
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
826
828
  const sycl::range<3> block_nums(1, 1, block_num_y);
827
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
829
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
828
830
  {
829
831
 
830
832
  stream->submit([&](sycl::handler &cgh) {
@@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
833
835
  cgh.parallel_for(
834
836
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
835
837
  [=](sycl::nd_item<3> item_ct1)
836
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
838
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
837
839
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
838
840
  vx, vy, dst, ncols, nrows, item_ct1);
839
841
  });
@@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
848
850
  GGML_ASSERT(ncols % QK_K == 0);
849
851
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
850
852
  const sycl::range<3> block_nums(1, 1, block_num_y);
851
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
853
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
852
854
  {
853
855
 
854
856
  stream->submit([&](sycl::handler &cgh) {
@@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
858
860
  cgh.parallel_for(
859
861
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
860
862
  [=](sycl::nd_item<3> item_ct1)
861
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
863
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
862
864
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
863
865
  vx, vy, dst, ncols, nrows, item_ct1);
864
866
  });
@@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
873
875
  GGML_ASSERT(ncols % QK_K == 0);
874
876
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
875
877
  const sycl::range<3> block_nums(1, 1, block_num_y);
876
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
878
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
877
879
  {
878
880
  stream->submit([&](sycl::handler &cgh) {
879
881
  cgh.parallel_for(
880
882
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
881
883
  [=](sycl::nd_item<3> item_ct1)
882
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
884
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
883
885
  mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
884
886
  vx, vy, dst, ncols, nrows, item_ct1);
885
887
  });
@@ -894,15 +896,15 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
894
896
  GGML_ASSERT(ncols % QK4_NL == 0);
895
897
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
896
898
  const sycl::range<3> block_nums(1, 1, block_num_y);
897
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
899
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
898
900
  {
899
901
 
900
902
  stream->submit([&](sycl::handler &cgh) {
901
903
  cgh.parallel_for(
902
904
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
903
905
  [=](sycl::nd_item<3> item_ct1)
904
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
905
- mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
906
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
907
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
906
908
  vx, vy, dst, ncols, nrows, item_ct1);
907
909
  });
908
910
  });
@@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
916
918
  GGML_ASSERT(ncols % QK_K == 0);
917
919
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
918
920
  const sycl::range<3> block_nums(1, 1, block_num_y);
919
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
921
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
920
922
  {
921
923
 
922
924
  stream->submit([&](sycl::handler &cgh) {
923
925
  cgh.parallel_for(
924
926
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
925
927
  [=](sycl::nd_item<3> item_ct1)
926
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
928
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
927
929
  mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
928
930
  vx, vy, dst, ncols, nrows, item_ct1);
929
931
  });