@fugood/llama.node 0.3.13 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +60 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  25. package/src/llama.cpp/common/arg.cpp +112 -11
  26. package/src/llama.cpp/common/chat.cpp +960 -266
  27. package/src/llama.cpp/common/chat.h +135 -0
  28. package/src/llama.cpp/common/common.cpp +27 -171
  29. package/src/llama.cpp/common/common.h +27 -67
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  31. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  32. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  33. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  34. package/src/llama.cpp/common/sampling.cpp +45 -7
  35. package/src/llama.cpp/common/speculative.cpp +6 -5
  36. package/src/llama.cpp/common/speculative.h +1 -1
  37. package/src/llama.cpp/docs/build.md +45 -7
  38. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  39. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  40. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
  42. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  43. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  44. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  45. package/src/llama.cpp/examples/llava/clip.h +19 -3
  46. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  47. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  48. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  49. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  50. package/src/llama.cpp/examples/main/main.cpp +73 -28
  51. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  52. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  53. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  54. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  55. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  56. package/src/llama.cpp/examples/run/run.cpp +110 -67
  57. package/src/llama.cpp/examples/server/server.cpp +82 -87
  58. package/src/llama.cpp/examples/server/utils.hpp +94 -107
  59. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  60. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  61. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  62. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  63. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  64. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  65. package/src/llama.cpp/ggml/include/ggml.h +5 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  68. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  69. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  70. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  71. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  72. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  73. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  74. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  75. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  76. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  77. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  78. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
  79. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
  80. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  81. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  82. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  83. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  84. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  85. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  86. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  87. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  89. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  90. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  91. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  92. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
  93. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  94. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  95. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  96. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  97. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  98. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  99. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  100. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  101. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  102. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  103. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  104. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  105. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  106. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  107. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
  108. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  109. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  111. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  112. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
  113. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  114. package/src/llama.cpp/ggml/src/ggml.c +8 -3
  115. package/src/llama.cpp/include/llama.h +19 -5
  116. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  117. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  118. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  119. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  120. package/src/llama.cpp/requirements.txt +1 -0
  121. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  122. package/src/llama.cpp/src/llama-arch.h +1 -0
  123. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  124. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  125. package/src/llama.cpp/src/llama-grammar.h +12 -3
  126. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  127. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  128. package/src/llama.cpp/src/llama-model.cpp +69 -5
  129. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  130. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  131. package/src/llama.cpp/src/llama.cpp +147 -0
  132. package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
  133. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  134. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  135. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  136. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  137. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  138. package/src/llama.cpp/common/chat.hpp +0 -55
  139. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -3,44 +3,42 @@
3
3
  #include <cassert>
4
4
 
5
5
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
6
- static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
7
- const sycl::nd_item<3> &item_ct1) {
8
- const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
9
- item_ct1.get_local_id(1);
6
+ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
7
+ const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
8
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
10
9
 
11
10
  if (row >= nrows) {
12
11
  return;
13
12
  }
14
13
 
15
- const int blocks_per_row = ncols / qk;
16
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
17
- assert(blocks_per_warp>0);
14
+ const int blocks_per_row = ncols / qk;
15
+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; // Ensuring blocks_per_warp > 0
18
16
 
19
- // partial sum for each thread
17
+ assert(blocks_per_warp > 0);
18
+
19
+ // partial sum for each thread
20
20
  float tmp = 0.0f;
21
21
 
22
- const block_q_t * x = (const block_q_t *) vx;
22
+ const block_q_t * x = (const block_q_t *) vx;
23
23
  const block_q8_1 * y = (const block_q8_1 *) vy;
24
24
 
25
- for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
26
- i += blocks_per_warp) {
27
- const int ibx = row*blocks_per_row + i; // x block index
25
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
26
+ const int ibx = row * blocks_per_row + i; // x block index
28
27
 
29
- const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
28
+ const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
30
29
 
31
- const int iqs =
32
- vdr *
33
- (item_ct1.get_local_id(2) %
34
- (qi / vdr)); // x block quant index when casting the quants to int
30
+ for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
31
+ const int iqs = elem + vdr * (item_ct1.get_local_id(2) %
32
+ (qi / vdr)); // x block quant index when casting the quants to int
35
33
 
36
- tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
34
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
35
+ }
37
36
  }
38
37
 
39
38
  // sum up partial sums and write back result
40
39
  #pragma unroll
41
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
42
- tmp +=
43
- dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
40
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
41
+ tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
44
42
  }
45
43
 
46
44
  if (item_ct1.get_local_id(2) == 0) {
@@ -62,7 +60,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
62
60
  }
63
61
 
64
62
  const int blocks_per_row = ncols / qk;
65
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
63
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
66
64
  assert(blocks_per_warp>0);
67
65
 
68
66
  // partial sum for each thread
@@ -87,7 +85,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
87
85
 
88
86
  // sum up partial sums and write back result
89
87
  #pragma unroll
90
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
88
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
91
89
  tmp +=
92
90
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
93
91
  }
@@ -111,7 +109,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
111
109
  }
112
110
 
113
111
  const int blocks_per_row = ncols / qk;
114
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
112
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
115
113
  assert(blocks_per_warp>0);
116
114
  // partial sum for each thread
117
115
  float tmp = 0.0f;
@@ -135,7 +133,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
135
133
 
136
134
  // sum up partial sums and write back result
137
135
  #pragma unroll
138
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
136
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
139
137
  tmp +=
140
138
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
141
139
  }
@@ -159,7 +157,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
159
157
  }
160
158
 
161
159
  const int blocks_per_row = ncols / qk;
162
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
160
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
163
161
  assert(blocks_per_warp>0);
164
162
  // partial sum for each thread
165
163
  float tmp = 0.0f;
@@ -183,7 +181,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
183
181
 
184
182
  // sum up partial sums and write back result
185
183
  #pragma unroll
186
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
184
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
187
185
  tmp +=
188
186
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
189
187
  }
@@ -207,7 +205,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
207
205
  }
208
206
 
209
207
  const int blocks_per_row = ncols / qk;
210
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
208
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
211
209
  assert(blocks_per_warp>0);
212
210
  // partial sum for each thread
213
211
  float tmp = 0.0f;
@@ -231,7 +229,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
231
229
 
232
230
  // sum up partial sums and write back result
233
231
  #pragma unroll
234
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
232
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
235
233
  tmp +=
236
234
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
237
235
  }
@@ -255,7 +253,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
255
253
  }
256
254
 
257
255
  const int blocks_per_row = ncols / qk;
258
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
256
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
259
257
  assert(blocks_per_warp>0);
260
258
  // partial sum for each thread
261
259
  float tmp = 0.0f;
@@ -279,7 +277,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
279
277
 
280
278
  // sum up partial sums and write back result
281
279
  #pragma unroll
282
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
280
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
283
281
  tmp +=
284
282
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
285
283
  }
@@ -303,7 +301,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
303
301
  }
304
302
 
305
303
  const int blocks_per_row = ncols / qk;
306
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
304
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
307
305
  assert(blocks_per_warp>0);
308
306
  // partial sum for each thread
309
307
  float tmp = 0.0f;
@@ -327,7 +325,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
327
325
 
328
326
  // sum up partial sums and write back result
329
327
  #pragma unroll
330
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
328
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
331
329
  tmp +=
332
330
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
333
331
  }
@@ -351,7 +349,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
351
349
  }
352
350
 
353
351
  const int blocks_per_row = ncols / qk;
354
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
352
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
355
353
  assert(blocks_per_warp>0);
356
354
  // partial sum for each thread
357
355
  float tmp = 0.0f;
@@ -375,7 +373,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
375
373
 
376
374
  // sum up partial sums and write back result
377
375
  #pragma unroll
378
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
376
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
379
377
  tmp +=
380
378
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
381
379
  }
@@ -399,7 +397,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
399
397
  }
400
398
 
401
399
  const int blocks_per_row = ncols / qk;
402
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
400
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
403
401
  assert(blocks_per_warp>0);
404
402
  // partial sum for each thread
405
403
  float tmp = 0.0f;
@@ -423,7 +421,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
423
421
 
424
422
  // sum up partial sums and write back result
425
423
  #pragma unroll
426
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
424
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
427
425
  tmp +=
428
426
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
429
427
  }
@@ -448,7 +446,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
448
446
  }
449
447
 
450
448
  const int blocks_per_row = ncols / qk;
451
- const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
449
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
452
450
  assert(blocks_per_warp>0);
453
451
  // partial sum for each thread
454
452
  float tmp = 0.0f;
@@ -472,7 +470,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
472
470
 
473
471
  // sum up partial sums and write back result
474
472
  #pragma unroll
475
- for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
473
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
476
474
  tmp +=
477
475
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
478
476
  }
@@ -489,7 +487,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
489
487
  GGML_ASSERT(ncols % QK4_0 == 0);
490
488
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
491
489
  const sycl::range<3> block_nums(1, 1, block_num_y);
492
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
490
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
493
491
  {
494
492
 
495
493
  stream->submit([&](sycl::handler &cgh) {
@@ -497,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
497
495
  cgh.parallel_for(
498
496
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
499
497
  [=](sycl::nd_item<3> item_ct1)
500
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
498
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
501
499
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
502
500
  VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
503
501
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -513,7 +511,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
513
511
  GGML_ASSERT(ncols % QK4_1 == 0);
514
512
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
515
513
  const sycl::range<3> block_nums(1, 1, block_num_y);
516
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
514
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
517
515
  {
518
516
 
519
517
  stream->submit([&](sycl::handler &cgh) {
@@ -521,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
521
519
  cgh.parallel_for(
522
520
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
523
521
  [=](sycl::nd_item<3> item_ct1)
524
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
522
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
525
523
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
526
524
  VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
527
525
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -537,7 +535,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
537
535
  GGML_ASSERT(ncols % QK5_0 == 0);
538
536
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
539
537
  const sycl::range<3> block_nums(1, 1, block_num_y);
540
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
538
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
541
539
  {
542
540
 
543
541
  stream->submit([&](sycl::handler &cgh) {
@@ -545,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
545
543
  cgh.parallel_for(
546
544
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
547
545
  [=](sycl::nd_item<3> item_ct1)
548
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
546
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
549
547
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
550
548
  VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
551
549
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -561,7 +559,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
561
559
  GGML_ASSERT(ncols % QK5_1 == 0);
562
560
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
563
561
  const sycl::range<3> block_nums(1, 1, block_num_y);
564
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
562
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
565
563
  {
566
564
 
567
565
  stream->submit([&](sycl::handler &cgh) {
@@ -569,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
569
567
  cgh.parallel_for(
570
568
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
571
569
  [=](sycl::nd_item<3> item_ct1)
572
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
570
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
573
571
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
574
572
  VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
575
573
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -585,7 +583,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
585
583
  GGML_ASSERT(ncols % QK8_0 == 0);
586
584
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
587
585
  const sycl::range<3> block_nums(1, 1, block_num_y);
588
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
586
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
589
587
  {
590
588
 
591
589
  stream->submit([&](sycl::handler &cgh) {
@@ -593,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
593
591
  cgh.parallel_for(
594
592
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
595
593
  [=](sycl::nd_item<3> item_ct1)
596
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
594
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
597
595
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
598
596
  VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
599
597
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -609,7 +607,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
609
607
  GGML_ASSERT(ncols % QK_K == 0);
610
608
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
611
609
  const sycl::range<3> block_nums(1, 1, block_num_y);
612
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
610
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
613
611
  {
614
612
 
615
613
  stream->submit([&](sycl::handler &cgh) {
@@ -617,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
617
615
  cgh.parallel_for(
618
616
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
619
617
  [=](sycl::nd_item<3> item_ct1)
620
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
618
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
621
619
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
622
620
  VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
623
621
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -633,7 +631,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
633
631
  GGML_ASSERT(ncols % QK_K == 0);
634
632
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
635
633
  const sycl::range<3> block_nums(1, 1, block_num_y);
636
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
634
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
637
635
  {
638
636
 
639
637
  stream->submit([&](sycl::handler &cgh) {
@@ -641,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
641
639
  cgh.parallel_for(
642
640
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
643
641
  [=](sycl::nd_item<3> item_ct1)
644
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
642
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
645
643
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
646
644
  VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
647
645
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -657,7 +655,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
657
655
  GGML_ASSERT(ncols % QK_K == 0);
658
656
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
659
657
  const sycl::range<3> block_nums(1, 1, block_num_y);
660
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
658
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
661
659
  {
662
660
 
663
661
  stream->submit([&](sycl::handler &cgh) {
@@ -665,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
665
663
  cgh.parallel_for(
666
664
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
667
665
  [=](sycl::nd_item<3> item_ct1)
668
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
666
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
669
667
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
670
668
  VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
671
669
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -681,7 +679,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
681
679
  GGML_ASSERT(ncols % QK_K == 0);
682
680
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
683
681
  const sycl::range<3> block_nums(1, 1, block_num_y);
684
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
682
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
685
683
  {
686
684
 
687
685
  stream->submit([&](sycl::handler &cgh) {
@@ -689,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
689
687
  cgh.parallel_for(
690
688
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
691
689
  [=](sycl::nd_item<3> item_ct1)
692
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
690
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
693
691
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
694
692
  VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
695
693
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -705,7 +703,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
705
703
  GGML_ASSERT(ncols % QK_K == 0);
706
704
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
707
705
  const sycl::range<3> block_nums(1, 1, block_num_y);
708
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
706
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
709
707
  {
710
708
 
711
709
  stream->submit([&](sycl::handler &cgh) {
@@ -713,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
713
711
  cgh.parallel_for(
714
712
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
715
713
  [=](sycl::nd_item<3> item_ct1)
716
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
714
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
717
715
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
718
716
  VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
719
717
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -730,13 +728,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
730
728
  GGML_ASSERT(ncols % QK_K == 0);
731
729
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
732
730
  const sycl::range<3> block_nums(1, 1, block_num_y);
733
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
731
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
734
732
  {
735
733
  stream->submit([&](sycl::handler &cgh) {
736
734
  cgh.parallel_for(
737
735
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
738
736
  [=](sycl::nd_item<3> item_ct1)
739
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
737
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
740
738
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
741
739
  vx, vy, dst, ncols, nrows, item_ct1);
742
740
  });
@@ -751,13 +749,13 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
751
749
  GGML_ASSERT(ncols % QK_K == 0);
752
750
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
753
751
  const sycl::range<3> block_nums(1, 1, block_num_y);
754
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
752
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
755
753
  {
756
754
  stream->submit([&](sycl::handler & cgh) {
757
755
  cgh.parallel_for(
758
756
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
759
757
  [=](sycl::nd_item<3> item_ct1)
760
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
758
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
761
759
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
762
760
  vx, vy, dst, ncols, nrows, item_ct1);
763
761
  });
@@ -772,14 +770,14 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
772
770
  GGML_ASSERT(ncols % QK_K == 0);
773
771
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
774
772
  const sycl::range<3> block_nums(1, 1, block_num_y);
775
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
773
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
776
774
  {
777
775
 
778
776
  stream->submit([&](sycl::handler &cgh) {
779
777
  cgh.parallel_for(
780
778
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
781
779
  [=](sycl::nd_item<3> item_ct1)
782
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
780
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
783
781
  mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
784
782
  vx, vy, dst, ncols, nrows, item_ct1);
785
783
  });
@@ -794,14 +792,14 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
794
792
  GGML_ASSERT(ncols % QK_K == 0);
795
793
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
796
794
  const sycl::range<3> block_nums(1, 1, block_num_y);
797
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
795
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
798
796
  {
799
797
 
800
798
  stream->submit([&](sycl::handler &cgh) {
801
799
  cgh.parallel_for(
802
800
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
803
801
  [=](sycl::nd_item<3> item_ct1)
804
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
802
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
805
803
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
806
804
  vx, vy, dst, ncols, nrows, item_ct1);
807
805
  });
@@ -816,14 +814,14 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
816
814
  GGML_ASSERT(ncols % QK_K == 0);
817
815
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
818
816
  const sycl::range<3> block_nums(1, 1, block_num_y);
819
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
817
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
820
818
  {
821
819
 
822
820
  stream->submit([&](sycl::handler &cgh) {
823
821
  cgh.parallel_for(
824
822
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
825
823
  [=](sycl::nd_item<3> item_ct1)
826
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
824
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
827
825
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
828
826
  vx, vy, dst, ncols, nrows, item_ct1);
829
827
  });
@@ -838,14 +836,14 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
838
836
  GGML_ASSERT(ncols % QK_K == 0);
839
837
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
840
838
  const sycl::range<3> block_nums(1, 1, block_num_y);
841
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
839
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
842
840
  {
843
841
 
844
842
  stream->submit([&](sycl::handler &cgh) {
845
843
  cgh.parallel_for(
846
844
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
847
845
  [=](sycl::nd_item<3> item_ct1)
848
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
846
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
849
847
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
850
848
  vx, vy, dst, ncols, nrows, item_ct1);
851
849
  });
@@ -860,13 +858,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
860
858
  GGML_ASSERT(ncols % QK_K == 0);
861
859
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
862
860
  const sycl::range<3> block_nums(1, 1, block_num_y);
863
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
861
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
864
862
  {
865
863
  stream->submit([&](sycl::handler &cgh) {
866
864
  cgh.parallel_for(
867
865
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
868
866
  [=](sycl::nd_item<3> item_ct1)
869
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
867
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
870
868
  mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
871
869
  vx, vy, dst, ncols, nrows, item_ct1);
872
870
  });
@@ -881,14 +879,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
881
879
  GGML_ASSERT(ncols % QK4_NL == 0);
882
880
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
883
881
  const sycl::range<3> block_nums(1, 1, block_num_y);
884
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
882
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
885
883
  {
886
884
 
887
885
  stream->submit([&](sycl::handler &cgh) {
888
886
  cgh.parallel_for(
889
887
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
890
888
  [=](sycl::nd_item<3> item_ct1)
891
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
889
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
892
890
  mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
893
891
  vx, vy, dst, ncols, nrows, item_ct1);
894
892
  });
@@ -903,14 +901,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
903
901
  GGML_ASSERT(ncols % QK_K == 0);
904
902
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
905
903
  const sycl::range<3> block_nums(1, 1, block_num_y);
906
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
904
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
907
905
  {
908
906
 
909
907
  stream->submit([&](sycl::handler &cgh) {
910
908
  cgh.parallel_for(
911
909
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
910
  [=](sycl::nd_item<3> item_ct1)
913
- [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
911
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
914
912
  mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
915
913
  vx, vy, dst, ncols, nrows, item_ct1);
916
914
  });
@@ -249,13 +249,16 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
249
249
 
250
250
  if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
251
251
  const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
252
+ GGML_SYCL_DEBUG("%s: F16 mask\n", __func__);
252
253
  soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253
254
  main_stream, ctx.device);
254
255
  } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
255
256
  const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
257
+ GGML_SYCL_DEBUG("%s: F32 mask\n", __func__);
256
258
  soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
257
259
  } else {
258
260
  /* mask unavailable */
261
+ GGML_SYCL_DEBUG("%s: No mask\n", __func__);
259
262
  soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
260
263
  }
261
264
  }
@@ -0,0 +1,13 @@
1
+ #include "sycl_hw.hpp"
2
+
3
+
4
+ sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
5
+ sycl_hw_info res;
6
+ int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
7
+ res.device_id = id;
8
+
9
+ syclex::architecture arch = device_ptr->get_info<syclex::info::device::architecture>();
10
+ res.arch = arch;
11
+
12
+ return res;
13
+ }
@@ -0,0 +1,23 @@
1
+ #ifndef SYCL_HW_HPP
2
+ #define SYCL_HW_HPP
3
+
4
+ #include <algorithm>
5
+ #include <stdio.h>
6
+ #include <vector>
7
+ #include <map>
8
+
9
+ #include <sycl/sycl.hpp>
10
+
11
+ namespace syclex = sycl::ext::oneapi::experimental;
12
+
13
+ struct sycl_hw_info {
14
+ syclex::architecture arch;
15
+ int32_t device_id;
16
+ };
17
+
18
+ bool is_in_vector(std::vector<int> &vec, int item);
19
+
20
+ sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
21
+
22
+
23
+ #endif // SYCL_HW_HPP