@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -1,6 +1,61 @@
1
1
  #include "mmvq.hpp"
2
+
3
+ #include "ggml.h"
4
+ #include "common.hpp"
5
+ #include "quants.hpp"
2
6
  #include "vecdotq.hpp"
3
- #include <cassert>
7
+
8
+ template <typename reorder_vec_dot_q_sycl>
9
+ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
10
+ const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
11
+ using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
12
+ using block_traits = typename block_type::traits;
13
+
14
+ const auto sg = nd_item.get_sub_group();
15
+ const int sg_range = sg.get_group_linear_range();
16
+ const int workgroup_id = nd_item.get_group_linear_id();
17
+ const int sg_id = sg.get_group_linear_id();
18
+ const int row = workgroup_id * sg_range + sg_id;
19
+
20
+ if (row >= nrows) {
21
+ return;
22
+ }
23
+
24
+ const int blocks_per_row = ncols / block_traits::qk;
25
+ constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
26
+ constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
27
+ const int nblocks = nrows * (ncols / block_traits::qk);
28
+
29
+ static_assert(blocks_per_subgroup > 0);
30
+ static_assert(block_elements_per_subgroup > 0);
31
+
32
+ const block_q8_1 * y = (const block_q8_1 *) vy;
33
+
34
+ float partial_sum = 0.0f;
35
+ for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
36
+ const int ibx = row * blocks_per_row + i; // x block index
37
+ // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
38
+ const int bx_offset = block_type::get_block_offset(ibx);
39
+ const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
40
+
41
+ // Y block index that aligns with ibx
42
+ const int iby = i * block_type::block_to_q8_1_ratio();
43
+
44
+ #pragma unroll
45
+ for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
46
+ // x block quant index when casting the quants to int
47
+ const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
48
+
49
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
50
+ }
51
+ }
52
+
53
+ auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
54
+
55
+ if (sg.leader()) {
56
+ dst[row] = sum;
57
+ }
58
+ }
4
59
 
5
60
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
6
61
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@@ -480,26 +535,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
480
535
  }
481
536
  }
482
537
 
483
- static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
484
- float *dst, const int ncols,
485
- const int nrows,
538
+ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
539
+ const int nrows, dpct::queue_ptr stream) {
540
+ GGML_ASSERT(ncols % QK4_0 == 0);
541
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
542
+ constexpr size_t num_subgroups = 16;
543
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
544
+
545
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
546
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
547
+
548
+ stream->submit([&](sycl::handler & cgh) {
549
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
550
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
551
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
552
+ nd_item);
553
+ });
554
+ });
555
+ }
556
+
557
+ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
486
558
  dpct::queue_ptr stream) {
487
559
  GGML_ASSERT(ncols % QK4_0 == 0);
488
560
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
489
561
  const sycl::range<3> block_nums(1, 1, block_num_y);
490
562
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
491
- {
492
-
493
- stream->submit([&](sycl::handler &cgh) {
494
563
 
495
- cgh.parallel_for(
496
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
497
- [=](sycl::nd_item<3> item_ct1)
498
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
499
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
500
- VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
501
- vx, vy, dst, ncols, nrows, item_ct1);
502
- });
564
+ {
565
+ stream->submit([&](sycl::handler & cgh) {
566
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
567
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
568
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
569
+ vx, vy, dst, ncols, nrows, item_ct1);
570
+ });
503
571
  });
504
572
  }
505
573
  }
@@ -672,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
672
740
  }
673
741
  }
674
742
 
743
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
744
+ const int nrows, dpct::queue_ptr stream) {
745
+ GGML_ASSERT(ncols % QK_K == 0);
746
+
747
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
748
+ constexpr size_t num_subgroups = 16;
749
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
750
+
751
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
752
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
753
+
754
+ stream->submit([&](sycl::handler & cgh) {
755
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
756
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
757
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
758
+ nrows, nd_item);
759
+ });
760
+ });
761
+ }
762
+
763
+
675
764
  static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
676
765
  float *dst, const int ncols,
677
766
  const int nrows,
@@ -916,93 +1005,100 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
916
1005
  }
917
1006
  }
918
1007
 
919
- void ggml_sycl_op_mul_mat_vec_q(
920
- ggml_backend_sycl_context & ctx,
921
- const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
922
- const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
923
- float *dst_dd_i, const int64_t row_low, const int64_t row_high,
924
- const int64_t src1_ncols, const int64_t src1_padded_col_size,
925
- const dpct::queue_ptr &stream) {
926
-
1008
+ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
1009
+ ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1010
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
1011
+ const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
1012
+ const dpct::queue_ptr & stream) {
927
1013
  const int64_t ne10 = src1->ne[0];
928
1014
  GGML_ASSERT(ne10 % QK8_1 == 0);
929
1015
 
930
- const int64_t ne00 = src0->ne[0];
1016
+ const int64_t ne00 = src0->ne[0];
931
1017
  const int64_t row_diff = row_high - row_low;
932
1018
 
933
1019
  int id;
934
- SYCL_CHECK(
935
- CHECK_TRY_ERROR(id = get_current_device_id()));
1020
+ SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
936
1021
  const size_t q8_1_ts = sizeof(block_q8_1);
937
1022
  const size_t q8_1_bs = QK8_1;
938
1023
  // the main device has a larger memory buffer to hold the results from all GPUs
939
1024
  // nrows_dst == nrows of the matrix that the kernel writes into
940
1025
 
941
- for (int i = 0; i < src1_ncols; i++)
942
- {
1026
+ for (int i = 0; i < src1_ncols; i++) {
943
1027
  const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
944
- const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
945
- float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
1028
+ const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
1029
+ float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
946
1030
  switch (src0->type) {
947
- case GGML_TYPE_Q4_0:
948
- mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
949
- break;
950
- case GGML_TYPE_Q4_1:
951
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
952
- break;
953
- case GGML_TYPE_Q5_0:
954
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
955
- break;
956
- case GGML_TYPE_Q5_1:
957
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
958
- break;
959
- case GGML_TYPE_Q8_0:
960
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
961
- break;
962
- case GGML_TYPE_Q2_K:
963
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
964
- break;
965
- case GGML_TYPE_Q3_K:
966
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
967
- break;
968
- case GGML_TYPE_Q4_K:
969
- mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
970
- break;
971
- case GGML_TYPE_Q5_K:
972
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
973
- break;
974
- case GGML_TYPE_Q6_K:
975
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
976
- break;
977
- case GGML_TYPE_IQ1_S:
978
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
979
- break;
980
- case GGML_TYPE_IQ1_M:
981
- mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
982
- break;
983
- case GGML_TYPE_IQ2_XXS:
984
- mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
985
- break;
986
- case GGML_TYPE_IQ2_XS:
987
- mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
988
- break;
989
- case GGML_TYPE_IQ2_S:
990
- mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
991
- break;
992
- case GGML_TYPE_IQ3_XXS:
993
- mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
994
- break;
995
- case GGML_TYPE_IQ3_S:
996
- mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
997
- break;
998
- case GGML_TYPE_IQ4_NL:
999
- mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1000
- break;
1001
- case GGML_TYPE_IQ4_XS:
1002
- mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1003
- break;
1004
- default:
1005
- GGML_ABORT("fatal error");
1031
+ case GGML_TYPE_Q4_0:
1032
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1033
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1034
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
1035
+ reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1036
+ } else {
1037
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
1038
+ mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1039
+ }
1040
+ break;
1041
+ case GGML_TYPE_Q4_1:
1042
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1043
+ break;
1044
+ case GGML_TYPE_Q5_0:
1045
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1046
+ break;
1047
+ case GGML_TYPE_Q5_1:
1048
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1049
+ break;
1050
+ case GGML_TYPE_Q8_0:
1051
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1052
+ break;
1053
+ case GGML_TYPE_Q2_K:
1054
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1055
+ break;
1056
+ case GGML_TYPE_Q3_K:
1057
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1058
+ break;
1059
+ case GGML_TYPE_Q4_K:
1060
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1061
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1062
+ reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1063
+ } else {
1064
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1065
+ }
1066
+ break;
1067
+ case GGML_TYPE_Q5_K:
1068
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1069
+ break;
1070
+ case GGML_TYPE_Q6_K:
1071
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1072
+ break;
1073
+ case GGML_TYPE_IQ1_S:
1074
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1075
+ break;
1076
+ case GGML_TYPE_IQ1_M:
1077
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1078
+ break;
1079
+ case GGML_TYPE_IQ2_XXS:
1080
+ mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1081
+ break;
1082
+ case GGML_TYPE_IQ2_XS:
1083
+ mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1084
+ break;
1085
+ case GGML_TYPE_IQ2_S:
1086
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1087
+ break;
1088
+ case GGML_TYPE_IQ3_XXS:
1089
+ mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1090
+ break;
1091
+ case GGML_TYPE_IQ3_S:
1092
+ mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1093
+ break;
1094
+ case GGML_TYPE_IQ4_NL:
1095
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1096
+ break;
1097
+ case GGML_TYPE_IQ4_XS:
1098
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1099
+ break;
1100
+ default:
1101
+ GGML_ABORT("fatal error");
1006
1102
  }
1007
1103
  }
1008
1104
  GGML_UNUSED(src1);
@@ -0,0 +1,83 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2025 Codeplay Software Ltd.
4
+ // Copyright (C) 2025 Intel Corporation
5
+ // SPDX-License-Identifier: MIT
6
+ //
7
+
8
+ //
9
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10
+ // See https://llvm.org/LICENSE.txt for license information.
11
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12
+ //
13
+
14
+ #ifndef GGML_SYCL_QUANTS_HPP
15
+ #define GGML_SYCL_QUANTS_HPP
16
+
17
+ #include "ggml-common.h"
18
+ #include "ggml.h"
19
+
20
+ namespace ggml_sycl_reordered {
21
+
22
+
23
+ // The reordered block moves quants (qs) and scales(d) to two
24
+ // uniform regions of memory that is contiguous in the same tensor.
25
+ // What this means is that instead of having:
26
+ // [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]
27
+ // We have:
28
+ // [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN]
29
+ //
30
+ // Notes: out-of-bounds qs will run into d values
31
+ // Aligment relies on the allocated size of qs
32
+
33
+ template <ggml_type type> struct block_q_t;
34
+
35
+
36
+ // qk number of weights / quants in a block
37
+ // qr number of weights in a byte (described as 'before dequantization')
38
+ // for quantization types that has low and high bits split, qr is calculated with
39
+ // using the lower bits, e.g for Q6 quants QR6 is 2
40
+ // qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)
41
+ // See ggml-common.h to see how these are calculated
42
+ template <> struct block_q_t<GGML_TYPE_Q4_0> {
43
+ struct traits {
44
+ static constexpr uint32_t qk = QK4_0;
45
+ static constexpr uint32_t qi = QI4_0;
46
+ static constexpr uint32_t qr = QR4_0;
47
+ static constexpr uint32_t vdr_mmvq = 2;
48
+ };
49
+
50
+ static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
51
+
52
+ static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
53
+ return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
54
+ }
55
+
56
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
57
+ };
58
+
59
+ template <> struct block_q_t<GGML_TYPE_Q4_K> {
60
+ struct traits {
61
+ static constexpr uint32_t qk = QK_K;
62
+ static constexpr uint32_t qi = QI4_K;
63
+ static constexpr uint32_t qr = QR4_K;
64
+ static constexpr uint32_t vdr_mmvq = 2;
65
+ };
66
+
67
+ static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
68
+
69
+ static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
70
+ auto nblocks = (nrows * (ncols / traits::qk));
71
+ return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
72
+ }
73
+
74
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
75
+
76
+ constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77
+
78
+ constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
79
+ };
80
+
81
+ } // namespace ggml_sycl_reordered
82
+
83
+ #endif // GGML_SYCL_QUANTS_HPP
@@ -1,6 +1,6 @@
1
1
  //
2
2
  // MIT license
3
- // Copyright (C) 2024 Intel Corporation
3
+ // Copyright (C) 2025 Intel Corporation
4
4
  // SPDX-License-Identifier: MIT
5
5
  //
6
6
 
@@ -14,8 +14,11 @@
14
14
  #define GGML_SYCL_VECDOTQ_HPP
15
15
 
16
16
  #include "dpct/helper.hpp"
17
+ #include "ggml.h"
18
+ #include "quants.hpp"
17
19
 
18
- typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
20
+ typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
21
+ const int & iqs);
19
22
 
20
23
  static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
21
24
  const uint16_t* x16 =
@@ -252,13 +255,121 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
252
255
  // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
253
256
  // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
254
257
 
258
+ template <ggml_type T> struct reorder_vec_dot_q_sycl {
259
+ static_assert(T != T, "ggml_type for reorder vecdot not implemented");
260
+ };
261
+
262
+ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
263
+ static constexpr ggml_type gtype = GGML_TYPE_Q4_0;
264
+
265
+ using q4_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_0>;
266
+ using q4_0_traits = typename q4_0_block::traits;
267
+
268
+ __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) {
269
+ int sumi = 0;
270
+
271
+ #pragma unroll
272
+ for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
273
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
274
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
275
+
276
+ // SIMD dot product of quantized values
277
+ sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
278
+ sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
279
+ }
280
+
281
+ const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
282
+
283
+ // second part effectively subtracts 8 from each quant value
284
+ return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
285
+ }
286
+
287
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288
+ const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
289
+ const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290
+ const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291
+ int v[q4_0_traits::vdr_mmvq];
292
+ int u[2 * q4_0_traits::vdr_mmvq];
293
+
294
+ #pragma unroll
295
+
296
+ for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
297
+ v[i] = get_int_from_uint8(bq4_0, iqs + i);
298
+ u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
299
+ u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
300
+ }
301
+
302
+ return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
303
+ };
304
+ };
305
+
306
+ static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
307
+ const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
308
+ const int & iqs) {
309
+ int v[2];
310
+ int u[2 * QR4_K];
311
+ float d8[QR4_K];
312
+
313
+ v[0] = q4[0];
314
+ v[1] = q4[4];
315
+
316
+ uint16_t aux[2];
317
+ const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
318
+ if (j < 2) {
319
+ aux[0] = scales[j + 0] & 0x3f3f;
320
+ aux[1] = scales[j + 2] & 0x3f3f;
321
+ } else {
322
+ aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
323
+ aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
324
+ }
325
+
326
+ const uint8_t * sc = (const uint8_t *) aux;
327
+ const uint8_t * m = sc + 2;
328
+
329
+ const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
330
+
331
+ for (int i = 0; i < QR4_K; ++i) {
332
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
333
+ d8[i] = bq8i->ds[0];
334
+
335
+ const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);
336
+ u[2 * i + 0] = q8[0];
337
+ u[2 * i + 1] = q8[4];
338
+ }
339
+
340
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);
341
+ }
342
+
343
+ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
344
+ static constexpr ggml_type gtype = GGML_TYPE_Q4_K;
345
+
346
+ using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
347
+ using q4_k_traits = typename q4_k_block::traits;
348
+
349
+ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350
+ const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
351
+ const int ib = ibx_offset / (QK_K / 2);
352
+
353
+ const uint8_t * base = static_cast<const uint8_t *>(vbq);
354
+ const uint8_t * qs = base + ibx_offset;
355
+ const int total_qs_bytes = nblocks * (QK_K / 2);
356
+ const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
357
+ const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
358
+
359
+ const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
360
+ const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361
+ const uint16_t * scales = (const uint16_t *) scs;
362
+
363
+ return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
364
+ }
365
+ };
366
+
255
367
  #define VDR_Q4_0_Q8_1_MMVQ 2
256
368
  #define VDR_Q4_0_Q8_1_MMQ 4
257
369
 
258
370
  template <int vdr>
259
- static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
260
- const float &d4,
261
- const sycl::half2 &ds8) {
371
+ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4,
372
+ const sycl::half2 & ds8) {
262
373
  int sumi = 0;
263
374
  #pragma unroll
264
375
  for (int i = 0; i < vdr; ++i) {
@@ -270,8 +381,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
270
381
  sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
271
382
  }
272
383
 
273
- const sycl::float2 ds8f =
274
- ds8.convert<float, sycl::rounding_mode::automatic>();
384
+ const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
275
385
 
276
386
  // second part effectively subtracts 8 from each quant value
277
387
  return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
@@ -456,13 +566,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
456
566
  const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
457
567
 
458
568
  int v[VDR_Q4_0_Q8_1_MMVQ];
459
- int u[2*VDR_Q4_0_Q8_1_MMVQ];
569
+ int u[2 * VDR_Q4_0_Q8_1_MMVQ];
460
570
 
461
571
  #pragma unroll
462
572
  for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
463
- v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
464
- u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
465
- u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
573
+ v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
574
+ u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
575
+ u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
466
576
  }
467
577
 
468
578
  return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
@@ -600,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
600
710
  return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
601
711
  }
602
712
 
603
- static __dpct_inline__ float
604
- vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
605
- const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
606
-
713
+ static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
714
+ const int & iqs) {
607
715
  #ifndef GGML_QKK_64
608
- const block_q4_K * bq4_K = (const block_q4_K *) vbq;
609
-
610
- int v[2];
611
- int u[2*QR4_K];
612
- float d8[QR4_K];
613
-
614
- // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
615
- const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
616
716
 
617
- // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
618
- // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
619
- // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
620
- // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
621
-
622
- const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
623
- v[0] = q4[0];
624
- v[1] = q4[4];
625
-
626
- const uint16_t * scales = (const uint16_t *)bq4_K->scales;
627
- uint16_t aux[2];
628
- const int j = bq8_offset/2;
629
- if (j < 2) {
630
- aux[0] = scales[j+0] & 0x3f3f;
631
- aux[1] = scales[j+2] & 0x3f3f;
632
- } else {
633
- aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
634
- aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
635
- }
636
- const uint8_t * sc = (const uint8_t *)aux;
637
- const uint8_t * m = sc + 2;
638
-
639
- for (int i = 0; i < QR4_K; ++i) {
640
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
641
- d8[i] = bq8i->ds[0];
717
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
642
718
 
643
- const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
644
- u[2*i+0] = q8[0];
645
- u[2*i+1] = q8[4];
646
- }
719
+ const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
720
+ const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
721
+ const uint16_t * scales = (const uint16_t *) bq4_K->scales;
647
722
 
648
- return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
723
+ return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);
649
724
 
650
725
  #else
651
726