@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  #include "mmvq.hpp"
2
2
  #include "vecdotq.hpp"
3
-
3
+ #include <cassert>
4
4
 
5
5
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
6
6
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
@@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
13
13
  }
14
14
 
15
15
  const int blocks_per_row = ncols / qk;
16
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
16
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
17
+ assert(blocks_per_warp>0);
17
18
 
18
19
  // partial sum for each thread
19
20
  float tmp = 0.0f;
@@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
37
38
 
38
39
  // sum up partial sums and write back result
39
40
  #pragma unroll
40
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
41
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
41
42
  tmp +=
42
43
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
43
44
  }
@@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
61
62
  }
62
63
 
63
64
  const int blocks_per_row = ncols / qk;
64
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
65
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
66
+ assert(blocks_per_warp>0);
65
67
 
66
68
  // partial sum for each thread
67
69
  float tmp = 0.0f;
@@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
85
87
 
86
88
  // sum up partial sums and write back result
87
89
  #pragma unroll
88
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
90
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
89
91
  tmp +=
90
92
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
91
93
  }
@@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
109
111
  }
110
112
 
111
113
  const int blocks_per_row = ncols / qk;
112
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
113
-
114
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
115
+ assert(blocks_per_warp>0);
114
116
  // partial sum for each thread
115
117
  float tmp = 0.0f;
116
118
 
@@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
133
135
 
134
136
  // sum up partial sums and write back result
135
137
  #pragma unroll
136
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
138
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
137
139
  tmp +=
138
140
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
139
141
  }
@@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
157
159
  }
158
160
 
159
161
  const int blocks_per_row = ncols / qk;
160
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
161
-
162
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
163
+ assert(blocks_per_warp>0);
162
164
  // partial sum for each thread
163
165
  float tmp = 0.0f;
164
166
 
@@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
181
183
 
182
184
  // sum up partial sums and write back result
183
185
  #pragma unroll
184
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
186
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
185
187
  tmp +=
186
188
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
187
189
  }
@@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
205
207
  }
206
208
 
207
209
  const int blocks_per_row = ncols / qk;
208
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
209
-
210
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
211
+ assert(blocks_per_warp>0);
210
212
  // partial sum for each thread
211
213
  float tmp = 0.0f;
212
214
 
@@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
229
231
 
230
232
  // sum up partial sums and write back result
231
233
  #pragma unroll
232
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
234
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
233
235
  tmp +=
234
236
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
235
237
  }
@@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
253
255
  }
254
256
 
255
257
  const int blocks_per_row = ncols / qk;
256
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
257
-
258
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
259
+ assert(blocks_per_warp>0);
258
260
  // partial sum for each thread
259
261
  float tmp = 0.0f;
260
262
 
@@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
277
279
 
278
280
  // sum up partial sums and write back result
279
281
  #pragma unroll
280
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
282
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
281
283
  tmp +=
282
284
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
283
285
  }
@@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
301
303
  }
302
304
 
303
305
  const int blocks_per_row = ncols / qk;
304
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
305
-
306
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
307
+ assert(blocks_per_warp>0);
306
308
  // partial sum for each thread
307
309
  float tmp = 0.0f;
308
310
 
@@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
325
327
 
326
328
  // sum up partial sums and write back result
327
329
  #pragma unroll
328
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
330
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
329
331
  tmp +=
330
332
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
331
333
  }
@@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
349
351
  }
350
352
 
351
353
  const int blocks_per_row = ncols / qk;
352
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
353
-
354
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
355
+ assert(blocks_per_warp>0);
354
356
  // partial sum for each thread
355
357
  float tmp = 0.0f;
356
358
 
@@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
373
375
 
374
376
  // sum up partial sums and write back result
375
377
  #pragma unroll
376
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
378
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
377
379
  tmp +=
378
380
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
379
381
  }
@@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
397
399
  }
398
400
 
399
401
  const int blocks_per_row = ncols / qk;
400
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
401
-
402
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
403
+ assert(blocks_per_warp>0);
402
404
  // partial sum for each thread
403
405
  float tmp = 0.0f;
404
406
 
@@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
421
423
 
422
424
  // sum up partial sums and write back result
423
425
  #pragma unroll
424
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
426
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
425
427
  tmp +=
426
428
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
427
429
  }
@@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
446
448
  }
447
449
 
448
450
  const int blocks_per_row = ncols / qk;
449
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
450
-
451
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
452
+ assert(blocks_per_warp>0);
451
453
  // partial sum for each thread
452
454
  float tmp = 0.0f;
453
455
 
@@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
470
472
 
471
473
  // sum up partial sums and write back result
472
474
  #pragma unroll
473
- for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
475
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
474
476
  tmp +=
475
477
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
476
478
  }
@@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
487
489
  GGML_ASSERT(ncols % QK4_0 == 0);
488
490
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
489
491
  const sycl::range<3> block_nums(1, 1, block_num_y);
490
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
492
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
491
493
  {
492
494
 
493
495
  stream->submit([&](sycl::handler &cgh) {
@@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
495
497
  cgh.parallel_for(
496
498
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
497
499
  [=](sycl::nd_item<3> item_ct1)
498
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
500
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
499
501
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
500
502
  VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
501
503
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
511
513
  GGML_ASSERT(ncols % QK4_1 == 0);
512
514
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
513
515
  const sycl::range<3> block_nums(1, 1, block_num_y);
514
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
516
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
515
517
  {
516
518
 
517
519
  stream->submit([&](sycl::handler &cgh) {
@@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
519
521
  cgh.parallel_for(
520
522
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
521
523
  [=](sycl::nd_item<3> item_ct1)
522
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
524
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
523
525
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
524
526
  VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
525
527
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
535
537
  GGML_ASSERT(ncols % QK5_0 == 0);
536
538
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
537
539
  const sycl::range<3> block_nums(1, 1, block_num_y);
538
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
540
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
539
541
  {
540
542
 
541
543
  stream->submit([&](sycl::handler &cgh) {
@@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
543
545
  cgh.parallel_for(
544
546
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
545
547
  [=](sycl::nd_item<3> item_ct1)
546
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
548
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
547
549
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
548
550
  VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
549
551
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
559
561
  GGML_ASSERT(ncols % QK5_1 == 0);
560
562
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
561
563
  const sycl::range<3> block_nums(1, 1, block_num_y);
562
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
564
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
563
565
  {
564
566
 
565
567
  stream->submit([&](sycl::handler &cgh) {
@@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
567
569
  cgh.parallel_for(
568
570
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
569
571
  [=](sycl::nd_item<3> item_ct1)
570
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
572
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
571
573
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
572
574
  VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
573
575
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
583
585
  GGML_ASSERT(ncols % QK8_0 == 0);
584
586
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
585
587
  const sycl::range<3> block_nums(1, 1, block_num_y);
586
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
588
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
587
589
  {
588
590
 
589
591
  stream->submit([&](sycl::handler &cgh) {
@@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
591
593
  cgh.parallel_for(
592
594
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
593
595
  [=](sycl::nd_item<3> item_ct1)
594
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
596
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
595
597
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
596
598
  VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
597
599
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
607
609
  GGML_ASSERT(ncols % QK_K == 0);
608
610
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
609
611
  const sycl::range<3> block_nums(1, 1, block_num_y);
610
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
612
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
611
613
  {
612
614
 
613
615
  stream->submit([&](sycl::handler &cgh) {
@@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
615
617
  cgh.parallel_for(
616
618
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
617
619
  [=](sycl::nd_item<3> item_ct1)
618
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
620
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
619
621
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
620
622
  VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
621
623
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
631
633
  GGML_ASSERT(ncols % QK_K == 0);
632
634
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
633
635
  const sycl::range<3> block_nums(1, 1, block_num_y);
634
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
636
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
635
637
  {
636
638
 
637
639
  stream->submit([&](sycl::handler &cgh) {
@@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
639
641
  cgh.parallel_for(
640
642
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
641
643
  [=](sycl::nd_item<3> item_ct1)
642
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
644
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
643
645
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
644
646
  VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
645
647
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
655
657
  GGML_ASSERT(ncols % QK_K == 0);
656
658
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
657
659
  const sycl::range<3> block_nums(1, 1, block_num_y);
658
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
660
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
659
661
  {
660
662
 
661
663
  stream->submit([&](sycl::handler &cgh) {
@@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
663
665
  cgh.parallel_for(
664
666
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
665
667
  [=](sycl::nd_item<3> item_ct1)
666
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
668
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
667
669
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
668
670
  VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
669
671
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
679
681
  GGML_ASSERT(ncols % QK_K == 0);
680
682
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
681
683
  const sycl::range<3> block_nums(1, 1, block_num_y);
682
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
684
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
683
685
  {
684
686
 
685
687
  stream->submit([&](sycl::handler &cgh) {
@@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
687
689
  cgh.parallel_for(
688
690
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
689
691
  [=](sycl::nd_item<3> item_ct1)
690
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
692
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
691
693
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
692
694
  VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
693
695
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
703
705
  GGML_ASSERT(ncols % QK_K == 0);
704
706
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
705
707
  const sycl::range<3> block_nums(1, 1, block_num_y);
706
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
708
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
707
709
  {
708
710
 
709
711
  stream->submit([&](sycl::handler &cgh) {
@@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
711
713
  cgh.parallel_for(
712
714
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
713
715
  [=](sycl::nd_item<3> item_ct1)
714
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
716
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
715
717
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
716
718
  VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
717
719
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
728
730
  GGML_ASSERT(ncols % QK_K == 0);
729
731
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
730
732
  const sycl::range<3> block_nums(1, 1, block_num_y);
731
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
733
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
732
734
  {
733
735
  stream->submit([&](sycl::handler &cgh) {
734
736
  cgh.parallel_for(
735
737
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
736
738
  [=](sycl::nd_item<3> item_ct1)
737
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
739
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
738
740
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
739
741
  vx, vy, dst, ncols, nrows, item_ct1);
740
742
  });
@@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
749
751
  GGML_ASSERT(ncols % QK_K == 0);
750
752
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
751
753
  const sycl::range<3> block_nums(1, 1, block_num_y);
752
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
754
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
753
755
  {
754
756
 
755
757
  stream->submit([&](sycl::handler &cgh) {
@@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
759
761
  cgh.parallel_for(
760
762
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
761
763
  [=](sycl::nd_item<3> item_ct1)
762
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
764
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
763
765
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
764
766
  vx, vy, dst, ncols, nrows, item_ct1);
765
767
  });
@@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
774
776
  GGML_ASSERT(ncols % QK_K == 0);
775
777
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
776
778
  const sycl::range<3> block_nums(1, 1, block_num_y);
777
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
779
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
778
780
  {
779
781
 
780
782
  stream->submit([&](sycl::handler &cgh) {
@@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
784
786
  cgh.parallel_for(
785
787
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
786
788
  [=](sycl::nd_item<3> item_ct1)
787
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
789
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
788
790
  mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
789
791
  vx, vy, dst, ncols, nrows, item_ct1);
790
792
  });
@@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
799
801
  GGML_ASSERT(ncols % QK_K == 0);
800
802
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
801
803
  const sycl::range<3> block_nums(1, 1, block_num_y);
802
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
804
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
803
805
  {
804
806
 
805
807
  stream->submit([&](sycl::handler &cgh) {
@@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
809
811
  cgh.parallel_for(
810
812
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
811
813
  [=](sycl::nd_item<3> item_ct1)
812
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
814
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
813
815
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
814
816
  vx, vy, dst, ncols, nrows, item_ct1);
815
817
  });
@@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
824
826
  GGML_ASSERT(ncols % QK_K == 0);
825
827
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
826
828
  const sycl::range<3> block_nums(1, 1, block_num_y);
827
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
829
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
828
830
  {
829
831
 
830
832
  stream->submit([&](sycl::handler &cgh) {
@@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
833
835
  cgh.parallel_for(
834
836
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
835
837
  [=](sycl::nd_item<3> item_ct1)
836
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
838
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
837
839
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
838
840
  vx, vy, dst, ncols, nrows, item_ct1);
839
841
  });
@@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
848
850
  GGML_ASSERT(ncols % QK_K == 0);
849
851
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
850
852
  const sycl::range<3> block_nums(1, 1, block_num_y);
851
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
853
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
852
854
  {
853
855
 
854
856
  stream->submit([&](sycl::handler &cgh) {
@@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
858
860
  cgh.parallel_for(
859
861
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
860
862
  [=](sycl::nd_item<3> item_ct1)
861
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
863
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
862
864
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
863
865
  vx, vy, dst, ncols, nrows, item_ct1);
864
866
  });
@@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
873
875
  GGML_ASSERT(ncols % QK_K == 0);
874
876
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
875
877
  const sycl::range<3> block_nums(1, 1, block_num_y);
876
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
878
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
877
879
  {
878
880
  stream->submit([&](sycl::handler &cgh) {
879
881
  cgh.parallel_for(
880
882
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
881
883
  [=](sycl::nd_item<3> item_ct1)
882
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
884
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
883
885
  mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
884
886
  vx, vy, dst, ncols, nrows, item_ct1);
885
887
  });
@@ -894,14 +896,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
894
896
  GGML_ASSERT(ncols % QK4_NL == 0);
895
897
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
896
898
  const sycl::range<3> block_nums(1, 1, block_num_y);
897
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
899
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
898
900
  {
899
901
 
900
902
  stream->submit([&](sycl::handler &cgh) {
901
903
  cgh.parallel_for(
902
904
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
903
905
  [=](sycl::nd_item<3> item_ct1)
904
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
906
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
905
907
  mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
906
908
  vx, vy, dst, ncols, nrows, item_ct1);
907
909
  });
@@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
916
918
  GGML_ASSERT(ncols % QK_K == 0);
917
919
  const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
918
920
  const sycl::range<3> block_nums(1, 1, block_num_y);
919
- const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
921
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
920
922
  {
921
923
 
922
924
  stream->submit([&](sycl::handler &cgh) {
923
925
  cgh.parallel_for(
924
926
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
925
927
  [=](sycl::nd_item<3> item_ct1)
926
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
928
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
927
929
  mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
928
930
  vx, vy, dst, ncols, nrows, item_ct1);
929
931
  });
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
8
8
 
9
9
  const int nthreads = item_ct1.get_local_range(2);
10
10
  const int nwarps = nthreads / WARP_SIZE;
11
- assert(nwarps % WARP_SIZE == 0);
12
11
  sycl::float2 mean_var = sycl::float2(0.f, 0.f);
13
12
 
14
13
  for (int col = tid; col < ncols; col += block_size) {
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
55
54
  int end = start + group_size;
56
55
  const int nthreads = item_ct1.get_local_range(2);
57
56
  const int nwarps = nthreads / WARP_SIZE;
58
- assert(nwarps % WARP_SIZE == 0);
59
57
  start += item_ct1.get_local_id(2);
60
58
  int nreduce = nwarps / WARP_SIZE;
61
59
 
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
144
142
  const int tid = item_ct1.get_local_id(2);
145
143
  const int nthreads = item_ct1.get_local_range(2);
146
144
  const int nwarps = nthreads / WARP_SIZE;
147
- assert(nwarps % WARP_SIZE == 0);
148
145
  float tmp = 0.0f; // partial sum for thread in warp
149
146
 
150
147
  for (int col = tid; col < ncols; col += block_size) {
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
202
199
  }
203
200
  else {
204
201
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
202
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
205
203
  const sycl::range<3> block_dims(1, 1, work_group_size);
206
204
  /*
207
205
  DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
244
242
  }
245
243
  else {
246
244
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
245
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
247
246
  const sycl::range<3> block_dims(1, 1, work_group_size);
248
247
  /*
249
248
  DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
290
289
  }
291
290
  else {
292
291
  const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
292
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
293
293
  const sycl::range<3> block_dims(1, 1, work_group_size);
294
294
  /*
295
295
  DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
@@ -0,0 +1,56 @@
1
+ #include <sycl/sycl.hpp>
2
+ #include <oneapi/mkl.hpp>
3
+ #include "outprod.hpp"
4
+
5
+
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
+ const ggml_tensor* src1, ggml_tensor* dst) {
8
+
9
+
10
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
11
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
12
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
13
+ GGML_ASSERT(ggml_is_contiguous(src0));
14
+ GGML_ASSERT(ggml_is_contiguous(dst));
15
+
16
+ GGML_TENSOR_BINARY_OP_LOCALS
17
+
18
+ // Get SYCL queue
19
+ dpct::queue_ptr stream = ctx.stream();
20
+
21
+ // Dimension checks
22
+ GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
23
+ GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
24
+ GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
25
+
26
+ // Get data pointers
27
+ const float* src0_d = (const float*)src0->data;
28
+ const float* src1_d = (const float*)src1->data;
29
+ float* dst_d = (float*)dst->data;
30
+
31
+ // GEMM parameters
32
+ const float alpha = 1.0f;
33
+ const float beta = 0.0f;
34
+
35
+ // Handle transposition of src1
36
+ const bool src1_T = ggml_is_transposed(src1);
37
+ const oneapi::mkl::transpose src1_op =
38
+ src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
39
+ const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
40
+
41
+ try {
42
+ // Perform matrix multiplication using oneMKL GEMM
43
+ oneapi::mkl::blas::column_major::gemm(*stream,
44
+ oneapi::mkl::transpose::nontrans, src1_op,
45
+ ne0, ne1, ne01,
46
+ alpha,
47
+ src0_d, ne00,
48
+ src1_d, ldb,
49
+ beta,
50
+ dst_d, ne0);
51
+ }
52
+ catch (sycl::exception const& exc) {
53
+ std::cerr << exc.what() << std::endl;
54
+ GGML_ASSERT(false);
55
+ }
56
+ }
@@ -0,0 +1,11 @@
1
+ #ifndef GGML_SYCL_OUTPROD_HPP
2
+ #define GGML_SYCL_OUTPROD_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
+ const ggml_tensor* src1, ggml_tensor* dst);
8
+
9
+
10
+ #endif // GGML_SYCL_OUTPROD_HPP
11
+
@@ -25,6 +25,11 @@
25
25
  #define SYCL_RELU_BLOCK_SIZE 256
26
26
  #define SYCL_HARDSIGMOID_BLOCK_SIZE 256
27
27
  #define SYCL_HARDSWISH_BLOCK_SIZE 256
28
+ #define SYCL_EXP_BLOCK_SIZE 256
29
+ #define SYCL_NEG_BLOCK_SIZE 256
30
+ #define SYCL_SIGMOID_BLOCK_SIZE 256
31
+ #define SYCL_SQRT_BLOCK_SIZE 256
32
+ #define SYCL_SIN_BLOCK_SIZE 256
28
33
  #define SYCL_SQR_BLOCK_SIZE 256
29
34
  #define SYCL_CPY_BLOCK_SIZE 32
30
35
  #define SYCL_SCALE_BLOCK_SIZE 256
@@ -41,6 +46,7 @@
41
46
  #define SYCL_ACC_BLOCK_SIZE 256
42
47
  #define SYCL_IM2COL_BLOCK_SIZE 256
43
48
  #define SYCL_POOL2D_BLOCK_SIZE 256
49
+ #define SYCL_ARGMAX_BLOCK_SIZE 256
44
50
  #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
45
51
  #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
46
52