whispercpp 1.2.0.2 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +46 -86
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -7
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/ggml/include/ggml.h +2285 -0
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/include/whisper.h +672 -0
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1608 -159
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/src/whisper.cpp +7393 -0
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -8616
  133. data/ext/ggml.h +0 -748
  134. data/ext/whisper.cpp +0 -4829
  135. data/ext/whisper.h +0 -402
@@ -0,0 +1,1015 @@
1
+ #include "mmvq.hpp"
2
+ #include "vecdotq.hpp"
3
+ #include <cassert>
4
+
5
+ template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
6
+ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
7
+ const sycl::nd_item<3> &item_ct1) {
8
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
9
+ item_ct1.get_local_id(1);
10
+
11
+ if (row >= nrows) {
12
+ return;
13
+ }
14
+
15
+ const int blocks_per_row = ncols / qk;
16
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
17
+ assert(blocks_per_warp>0);
18
+
19
+ // partial sum for each thread
20
+ float tmp = 0.0f;
21
+
22
+ const block_q_t * x = (const block_q_t *) vx;
23
+ const block_q8_1 * y = (const block_q8_1 *) vy;
24
+
25
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
26
+ i += blocks_per_warp) {
27
+ const int ibx = row*blocks_per_row + i; // x block index
28
+
29
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
30
+
31
+ const int iqs =
32
+ vdr *
33
+ (item_ct1.get_local_id(2) %
34
+ (qi / vdr)); // x block quant index when casting the quants to int
35
+
36
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
37
+ }
38
+
39
+ // sum up partial sums and write back result
40
+ #pragma unroll
41
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
42
+ tmp +=
43
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
44
+ }
45
+
46
+ if (item_ct1.get_local_id(2) == 0) {
47
+ dst[row] = tmp;
48
+ }
49
+ }
50
+
51
+ template <int qk, int qi, typename block_q_t, int vdr>
52
+ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
53
+ const void *__restrict__ vy,
54
+ float *__restrict__ dst, const int ncols,
55
+ const int nrows,
56
+ const sycl::nd_item<3> &item_ct1) {
57
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
58
+ item_ct1.get_local_id(1);
59
+
60
+ if (row >= nrows) {
61
+ return;
62
+ }
63
+
64
+ const int blocks_per_row = ncols / qk;
65
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
66
+ assert(blocks_per_warp>0);
67
+
68
+ // partial sum for each thread
69
+ float tmp = 0.0f;
70
+
71
+ const block_q_t * x = (const block_q_t *) vx;
72
+ const block_q8_1 * y = (const block_q8_1 *) vy;
73
+
74
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
75
+ i += blocks_per_warp) {
76
+ const int ibx = row*blocks_per_row + i; // x block index
77
+
78
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
79
+
80
+ const int iqs =
81
+ vdr *
82
+ (item_ct1.get_local_id(2) %
83
+ (qi / vdr)); // x block quant index when casting the quants to int
84
+
85
+ tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
86
+ }
87
+
88
+ // sum up partial sums and write back result
89
+ #pragma unroll
90
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
91
+ tmp +=
92
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
93
+ }
94
+
95
+ if (item_ct1.get_local_id(2) == 0) {
96
+ dst[row] = tmp;
97
+ }
98
+ }
99
+
100
+ template <int qk, int qi, typename block_q_t, int vdr>
101
+ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
102
+ const void *__restrict__ vy,
103
+ float *__restrict__ dst, const int ncols,
104
+ const int nrows,
105
+ const sycl::nd_item<3> &item_ct1) {
106
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
107
+ item_ct1.get_local_id(1);
108
+
109
+ if (row >= nrows) {
110
+ return;
111
+ }
112
+
113
+ const int blocks_per_row = ncols / qk;
114
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
115
+ assert(blocks_per_warp>0);
116
+ // partial sum for each thread
117
+ float tmp = 0.0f;
118
+
119
+ const block_q_t * x = (const block_q_t *) vx;
120
+ const block_q8_1 * y = (const block_q8_1 *) vy;
121
+
122
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
123
+ i += blocks_per_warp) {
124
+ const int ibx = row*blocks_per_row + i; // x block index
125
+
126
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
127
+
128
+ const int iqs =
129
+ vdr *
130
+ (item_ct1.get_local_id(2) %
131
+ (qi / vdr)); // x block quant index when casting the quants to int
132
+
133
+ tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
134
+ }
135
+
136
+ // sum up partial sums and write back result
137
+ #pragma unroll
138
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
139
+ tmp +=
140
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
141
+ }
142
+
143
+ if (item_ct1.get_local_id(2) == 0) {
144
+ dst[row] = tmp;
145
+ }
146
+ }
147
+
148
+ template <int qk, int qi, typename block_q_t, int vdr>
149
+ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
150
+ const void *__restrict__ vy,
151
+ float *__restrict__ dst, const int ncols,
152
+ const int nrows,
153
+ const sycl::nd_item<3> &item_ct1) {
154
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
155
+ item_ct1.get_local_id(1);
156
+
157
+ if (row >= nrows) {
158
+ return;
159
+ }
160
+
161
+ const int blocks_per_row = ncols / qk;
162
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
163
+ assert(blocks_per_warp>0);
164
+ // partial sum for each thread
165
+ float tmp = 0.0f;
166
+
167
+ const block_q_t * x = (const block_q_t *) vx;
168
+ const block_q8_1 * y = (const block_q8_1 *) vy;
169
+
170
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
171
+ i += blocks_per_warp) {
172
+ const int ibx = row*blocks_per_row + i; // x block index
173
+
174
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
175
+
176
+ const int iqs =
177
+ vdr *
178
+ (item_ct1.get_local_id(2) %
179
+ (qi / vdr)); // x block quant index when casting the quants to int
180
+
181
+ tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
182
+ }
183
+
184
+ // sum up partial sums and write back result
185
+ #pragma unroll
186
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
187
+ tmp +=
188
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
189
+ }
190
+
191
+ if (item_ct1.get_local_id(2) == 0) {
192
+ dst[row] = tmp;
193
+ }
194
+ }
195
+
196
+ template <int qk, int qi, typename block_q_t, int vdr>
197
+ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
198
+ const void *__restrict__ vy,
199
+ float *__restrict__ dst, const int ncols,
200
+ const int nrows,
201
+ const sycl::nd_item<3> &item_ct1) {
202
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
203
+ item_ct1.get_local_id(1);
204
+
205
+ if (row >= nrows) {
206
+ return;
207
+ }
208
+
209
+ const int blocks_per_row = ncols / qk;
210
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
211
+ assert(blocks_per_warp>0);
212
+ // partial sum for each thread
213
+ float tmp = 0.0f;
214
+
215
+ const block_q_t * x = (const block_q_t *) vx;
216
+ const block_q8_1 * y = (const block_q8_1 *) vy;
217
+
218
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
219
+ i += blocks_per_warp) {
220
+ const int ibx = row*blocks_per_row + i; // x block index
221
+
222
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
223
+
224
+ const int iqs =
225
+ vdr *
226
+ (item_ct1.get_local_id(2) %
227
+ (qi / vdr)); // x block quant index when casting the quants to int
228
+
229
+ tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
230
+ }
231
+
232
+ // sum up partial sums and write back result
233
+ #pragma unroll
234
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
235
+ tmp +=
236
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
237
+ }
238
+
239
+ if (item_ct1.get_local_id(2) == 0) {
240
+ dst[row] = tmp;
241
+ }
242
+ }
243
+
244
+ template <int qk, int qi, typename block_q_t, int vdr>
245
+ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
246
+ const void *__restrict__ vy,
247
+ float *__restrict__ dst, const int ncols,
248
+ const int nrows,
249
+ const sycl::nd_item<3> &item_ct1) {
250
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
251
+ item_ct1.get_local_id(1);
252
+
253
+ if (row >= nrows) {
254
+ return;
255
+ }
256
+
257
+ const int blocks_per_row = ncols / qk;
258
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
259
+ assert(blocks_per_warp>0);
260
+ // partial sum for each thread
261
+ float tmp = 0.0f;
262
+
263
+ const block_q_t * x = (const block_q_t *) vx;
264
+ const block_q8_1 * y = (const block_q8_1 *) vy;
265
+
266
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
267
+ i += blocks_per_warp) {
268
+ const int ibx = row*blocks_per_row + i; // x block index
269
+
270
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
271
+
272
+ const int iqs =
273
+ vdr *
274
+ (item_ct1.get_local_id(2) %
275
+ (qi / vdr)); // x block quant index when casting the quants to int
276
+
277
+ tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
278
+ }
279
+
280
+ // sum up partial sums and write back result
281
+ #pragma unroll
282
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
283
+ tmp +=
284
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
285
+ }
286
+
287
+ if (item_ct1.get_local_id(2) == 0) {
288
+ dst[row] = tmp;
289
+ }
290
+ }
291
+
292
+ template <int qk, int qi, typename block_q_t, int vdr>
293
+ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
294
+ const void *__restrict__ vy,
295
+ float *__restrict__ dst, const int ncols,
296
+ const int nrows,
297
+ const sycl::nd_item<3> &item_ct1) {
298
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
299
+ item_ct1.get_local_id(1);
300
+
301
+ if (row >= nrows) {
302
+ return;
303
+ }
304
+
305
+ const int blocks_per_row = ncols / qk;
306
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
307
+ assert(blocks_per_warp>0);
308
+ // partial sum for each thread
309
+ float tmp = 0.0f;
310
+
311
+ const block_q_t * x = (const block_q_t *) vx;
312
+ const block_q8_1 * y = (const block_q8_1 *) vy;
313
+
314
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
315
+ i += blocks_per_warp) {
316
+ const int ibx = row*blocks_per_row + i; // x block index
317
+
318
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
319
+
320
+ const int iqs =
321
+ vdr *
322
+ (item_ct1.get_local_id(2) %
323
+ (qi / vdr)); // x block quant index when casting the quants to int
324
+
325
+ tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
326
+ }
327
+
328
+ // sum up partial sums and write back result
329
+ #pragma unroll
330
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
331
+ tmp +=
332
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
333
+ }
334
+
335
+ if (item_ct1.get_local_id(2) == 0) {
336
+ dst[row] = tmp;
337
+ }
338
+ }
339
+
340
+ template <int qk, int qi, typename block_q_t, int vdr>
341
+ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
342
+ const void *__restrict__ vy,
343
+ float *__restrict__ dst, const int ncols,
344
+ const int nrows,
345
+ const sycl::nd_item<3> &item_ct1) {
346
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
347
+ item_ct1.get_local_id(1);
348
+
349
+ if (row >= nrows) {
350
+ return;
351
+ }
352
+
353
+ const int blocks_per_row = ncols / qk;
354
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
355
+ assert(blocks_per_warp>0);
356
+ // partial sum for each thread
357
+ float tmp = 0.0f;
358
+
359
+ const block_q_t * x = (const block_q_t *) vx;
360
+ const block_q8_1 * y = (const block_q8_1 *) vy;
361
+
362
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
363
+ i += blocks_per_warp) {
364
+ const int ibx = row*blocks_per_row + i; // x block index
365
+
366
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
367
+
368
+ const int iqs =
369
+ vdr *
370
+ (item_ct1.get_local_id(2) %
371
+ (qi / vdr)); // x block quant index when casting the quants to int
372
+
373
+ tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
374
+ }
375
+
376
+ // sum up partial sums and write back result
377
+ #pragma unroll
378
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
379
+ tmp +=
380
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
381
+ }
382
+
383
+ if (item_ct1.get_local_id(2) == 0) {
384
+ dst[row] = tmp;
385
+ }
386
+ }
387
+
388
+ template <int qk, int qi, typename block_q_t, int vdr>
389
+ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
390
+ const void *__restrict__ vy,
391
+ float *__restrict__ dst, const int ncols,
392
+ const int nrows,
393
+ const sycl::nd_item<3> &item_ct1) {
394
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
395
+ item_ct1.get_local_id(1);
396
+
397
+ if (row >= nrows) {
398
+ return;
399
+ }
400
+
401
+ const int blocks_per_row = ncols / qk;
402
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
403
+ assert(blocks_per_warp>0);
404
+ // partial sum for each thread
405
+ float tmp = 0.0f;
406
+
407
+ const block_q_t * x = (const block_q_t *) vx;
408
+ const block_q8_1 * y = (const block_q8_1 *) vy;
409
+
410
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
411
+ i += blocks_per_warp) {
412
+ const int ibx = row*blocks_per_row + i; // x block index
413
+
414
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
415
+
416
+ const int iqs =
417
+ vdr *
418
+ (item_ct1.get_local_id(2) %
419
+ (qi / vdr)); // x block quant index when casting the quants to int
420
+
421
+ tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
422
+ }
423
+
424
+ // sum up partial sums and write back result
425
+ #pragma unroll
426
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
427
+ tmp +=
428
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
429
+ }
430
+
431
+ if (item_ct1.get_local_id(2) == 0) {
432
+ dst[row] = tmp;
433
+ }
434
+ }
435
+
436
+
437
+ template <int qk, int qi, typename block_q_t, int vdr>
438
+ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
439
+ const void *__restrict__ vy,
440
+ float *__restrict__ dst, const int ncols,
441
+ const int nrows,
442
+ const sycl::nd_item<3> &item_ct1) {
443
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
444
+ item_ct1.get_local_id(1);
445
+
446
+ if (row >= nrows) {
447
+ return;
448
+ }
449
+
450
+ const int blocks_per_row = ncols / qk;
451
+ const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
452
+ assert(blocks_per_warp>0);
453
+ // partial sum for each thread
454
+ float tmp = 0.0f;
455
+
456
+ const block_q_t * x = (const block_q_t *) vx;
457
+ const block_q8_1 * y = (const block_q8_1 *) vy;
458
+
459
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
460
+ i += blocks_per_warp) {
461
+ const int ibx = row*blocks_per_row + i; // x block index
462
+
463
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
464
+
465
+ const int iqs =
466
+ vdr *
467
+ (item_ct1.get_local_id(2) %
468
+ (qi / vdr)); // x block quant index when casting the quants to int
469
+
470
+ tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
471
+ }
472
+
473
+ // sum up partial sums and write back result
474
+ #pragma unroll
475
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
476
+ tmp +=
477
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
478
+ }
479
+
480
+ if (item_ct1.get_local_id(2) == 0) {
481
+ dst[row] = tmp;
482
+ }
483
+ }
484
+
485
+ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
486
+ float *dst, const int ncols,
487
+ const int nrows,
488
+ dpct::queue_ptr stream) {
489
+ GGML_ASSERT(ncols % QK4_0 == 0);
490
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
491
+ const sycl::range<3> block_nums(1, 1, block_num_y);
492
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
493
+ {
494
+
495
+ stream->submit([&](sycl::handler &cgh) {
496
+
497
+ cgh.parallel_for(
498
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
499
+ [=](sycl::nd_item<3> item_ct1)
500
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
501
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
502
+ VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
503
+ vx, vy, dst, ncols, nrows, item_ct1);
504
+ });
505
+ });
506
+ }
507
+ }
508
+
509
+ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
510
+ float *dst, const int ncols,
511
+ const int nrows,
512
+ dpct::queue_ptr stream) {
513
+ GGML_ASSERT(ncols % QK4_1 == 0);
514
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
515
+ const sycl::range<3> block_nums(1, 1, block_num_y);
516
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
517
+ {
518
+
519
+ stream->submit([&](sycl::handler &cgh) {
520
+
521
+ cgh.parallel_for(
522
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
523
+ [=](sycl::nd_item<3> item_ct1)
524
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
525
+ mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
526
+ VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
527
+ vx, vy, dst, ncols, nrows, item_ct1);
528
+ });
529
+ });
530
+ }
531
+ }
532
+
533
+ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
534
+ float *dst, const int ncols,
535
+ const int nrows,
536
+ dpct::queue_ptr stream) {
537
+ GGML_ASSERT(ncols % QK5_0 == 0);
538
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
539
+ const sycl::range<3> block_nums(1, 1, block_num_y);
540
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
541
+ {
542
+
543
+ stream->submit([&](sycl::handler &cgh) {
544
+
545
+ cgh.parallel_for(
546
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
547
+ [=](sycl::nd_item<3> item_ct1)
548
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
549
+ mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
550
+ VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
551
+ vx, vy, dst, ncols, nrows, item_ct1);
552
+ });
553
+ });
554
+ }
555
+ }
556
+
557
+ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
558
+ float *dst, const int ncols,
559
+ const int nrows,
560
+ dpct::queue_ptr stream) {
561
+ GGML_ASSERT(ncols % QK5_1 == 0);
562
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
563
+ const sycl::range<3> block_nums(1, 1, block_num_y);
564
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
565
+ {
566
+
567
+ stream->submit([&](sycl::handler &cgh) {
568
+
569
+ cgh.parallel_for(
570
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
571
+ [=](sycl::nd_item<3> item_ct1)
572
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
573
+ mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
574
+ VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
575
+ vx, vy, dst, ncols, nrows, item_ct1);
576
+ });
577
+ });
578
+ }
579
+ }
580
+
581
+ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
582
+ float *dst, const int ncols,
583
+ const int nrows,
584
+ dpct::queue_ptr stream) {
585
+ GGML_ASSERT(ncols % QK8_0 == 0);
586
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
587
+ const sycl::range<3> block_nums(1, 1, block_num_y);
588
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
589
+ {
590
+
591
+ stream->submit([&](sycl::handler &cgh) {
592
+
593
+ cgh.parallel_for(
594
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
595
+ [=](sycl::nd_item<3> item_ct1)
596
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
597
+ mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
598
+ VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
599
+ vx, vy, dst, ncols, nrows, item_ct1);
600
+ });
601
+ });
602
+ }
603
+ }
604
+
605
+ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
606
+ float *dst, const int ncols,
607
+ const int nrows,
608
+ dpct::queue_ptr stream) {
609
+ GGML_ASSERT(ncols % QK_K == 0);
610
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
611
+ const sycl::range<3> block_nums(1, 1, block_num_y);
612
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
613
+ {
614
+
615
+ stream->submit([&](sycl::handler &cgh) {
616
+
617
+ cgh.parallel_for(
618
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
619
+ [=](sycl::nd_item<3> item_ct1)
620
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
621
+ mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
622
+ VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
623
+ vx, vy, dst, ncols, nrows, item_ct1);
624
+ });
625
+ });
626
+ }
627
+ }
628
+
629
+ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
630
+ float *dst, const int ncols,
631
+ const int nrows,
632
+ dpct::queue_ptr stream) {
633
+ GGML_ASSERT(ncols % QK_K == 0);
634
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
635
+ const sycl::range<3> block_nums(1, 1, block_num_y);
636
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
637
+ {
638
+
639
+ stream->submit([&](sycl::handler &cgh) {
640
+
641
+ cgh.parallel_for(
642
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
643
+ [=](sycl::nd_item<3> item_ct1)
644
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
645
+ mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
646
+ VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
647
+ vx, vy, dst, ncols, nrows, item_ct1);
648
+ });
649
+ });
650
+ }
651
+ }
652
+
653
+ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
654
+ float *dst, const int ncols,
655
+ const int nrows,
656
+ dpct::queue_ptr stream) {
657
+ GGML_ASSERT(ncols % QK_K == 0);
658
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
659
+ const sycl::range<3> block_nums(1, 1, block_num_y);
660
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
661
+ {
662
+
663
+ stream->submit([&](sycl::handler &cgh) {
664
+
665
+ cgh.parallel_for(
666
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
667
+ [=](sycl::nd_item<3> item_ct1)
668
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
669
+ mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
670
+ VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
671
+ vx, vy, dst, ncols, nrows, item_ct1);
672
+ });
673
+ });
674
+ }
675
+ }
676
+
677
+ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
678
+ float *dst, const int ncols,
679
+ const int nrows,
680
+ dpct::queue_ptr stream) {
681
+ GGML_ASSERT(ncols % QK_K == 0);
682
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
683
+ const sycl::range<3> block_nums(1, 1, block_num_y);
684
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
685
+ {
686
+
687
+ stream->submit([&](sycl::handler &cgh) {
688
+
689
+ cgh.parallel_for(
690
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
691
+ [=](sycl::nd_item<3> item_ct1)
692
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
693
+ mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
694
+ VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
695
+ vx, vy, dst, ncols, nrows, item_ct1);
696
+ });
697
+ });
698
+ }
699
+ }
700
+
701
+ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
702
+ float *dst, const int ncols,
703
+ const int nrows,
704
+ dpct::queue_ptr stream) {
705
+ GGML_ASSERT(ncols % QK_K == 0);
706
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
707
+ const sycl::range<3> block_nums(1, 1, block_num_y);
708
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
709
+ {
710
+
711
+ stream->submit([&](sycl::handler &cgh) {
712
+
713
+ cgh.parallel_for(
714
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
715
+ [=](sycl::nd_item<3> item_ct1)
716
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
717
+ mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
718
+ VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
719
+ vx, vy, dst, ncols, nrows, item_ct1);
720
+ });
721
+ });
722
+ }
723
+ }
724
+
725
+
726
+ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
727
+ float *dst, const int ncols,
728
+ const int nrows,
729
+ dpct::queue_ptr stream) {
730
+ GGML_ASSERT(ncols % QK_K == 0);
731
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
732
+ const sycl::range<3> block_nums(1, 1, block_num_y);
733
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
734
+ {
735
+ stream->submit([&](sycl::handler &cgh) {
736
+ cgh.parallel_for(
737
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
738
+ [=](sycl::nd_item<3> item_ct1)
739
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
740
+ mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
741
+ vx, vy, dst, ncols, nrows, item_ct1);
742
+ });
743
+ });
744
+ }
745
+ }
746
+
747
+ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
748
+ float *dst, const int ncols,
749
+ const int nrows,
750
+ dpct::queue_ptr stream) {
751
+ GGML_ASSERT(ncols % QK_K == 0);
752
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
753
+ const sycl::range<3> block_nums(1, 1, block_num_y);
754
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
755
+ {
756
+ stream->submit([&](sycl::handler & cgh) {
757
+ cgh.parallel_for(
758
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
759
+ [=](sycl::nd_item<3> item_ct1)
760
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
761
+ mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
762
+ vx, vy, dst, ncols, nrows, item_ct1);
763
+ });
764
+ });
765
+ }
766
+ }
767
+
768
+ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
769
+ float *dst, const int ncols,
770
+ const int nrows,
771
+ dpct::queue_ptr stream) {
772
+ GGML_ASSERT(ncols % QK_K == 0);
773
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
774
+ const sycl::range<3> block_nums(1, 1, block_num_y);
775
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
776
+ {
777
+
778
+ stream->submit([&](sycl::handler &cgh) {
779
+ cgh.parallel_for(
780
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
781
+ [=](sycl::nd_item<3> item_ct1)
782
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
783
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
784
+ vx, vy, dst, ncols, nrows, item_ct1);
785
+ });
786
+ });
787
+ }
788
+ }
789
+
790
+ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
791
+ float *dst, const int ncols,
792
+ const int nrows,
793
+ dpct::queue_ptr stream) {
794
+ GGML_ASSERT(ncols % QK_K == 0);
795
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
796
+ const sycl::range<3> block_nums(1, 1, block_num_y);
797
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
798
+ {
799
+
800
+ stream->submit([&](sycl::handler &cgh) {
801
+ cgh.parallel_for(
802
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
803
+ [=](sycl::nd_item<3> item_ct1)
804
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
805
+ mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
806
+ vx, vy, dst, ncols, nrows, item_ct1);
807
+ });
808
+ });
809
+ }
810
+ }
811
+
812
+ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
813
+ float *dst, const int ncols,
814
+ const int nrows,
815
+ dpct::queue_ptr stream) {
816
+ GGML_ASSERT(ncols % QK_K == 0);
817
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
818
+ const sycl::range<3> block_nums(1, 1, block_num_y);
819
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
820
+ {
821
+
822
+ stream->submit([&](sycl::handler &cgh) {
823
+ cgh.parallel_for(
824
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
825
+ [=](sycl::nd_item<3> item_ct1)
826
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
827
+ mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
828
+ vx, vy, dst, ncols, nrows, item_ct1);
829
+ });
830
+ });
831
+ }
832
+ }
833
+
834
+ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
835
+ float *dst, const int ncols,
836
+ const int nrows,
837
+ dpct::queue_ptr stream) {
838
+ GGML_ASSERT(ncols % QK_K == 0);
839
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
840
+ const sycl::range<3> block_nums(1, 1, block_num_y);
841
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
842
+ {
843
+
844
+ stream->submit([&](sycl::handler &cgh) {
845
+ cgh.parallel_for(
846
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
847
+ [=](sycl::nd_item<3> item_ct1)
848
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
849
+ mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
850
+ vx, vy, dst, ncols, nrows, item_ct1);
851
+ });
852
+ });
853
+ }
854
+ }
855
+
856
+ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
857
+ float *dst, const int ncols,
858
+ const int nrows,
859
+ dpct::queue_ptr stream) {
860
+ GGML_ASSERT(ncols % QK_K == 0);
861
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
862
+ const sycl::range<3> block_nums(1, 1, block_num_y);
863
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
864
+ {
865
+ stream->submit([&](sycl::handler &cgh) {
866
+ cgh.parallel_for(
867
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
868
+ [=](sycl::nd_item<3> item_ct1)
869
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
870
+ mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
871
+ vx, vy, dst, ncols, nrows, item_ct1);
872
+ });
873
+ });
874
+ }
875
+ }
876
+
877
+ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
878
+ float *dst, const int ncols,
879
+ const int nrows,
880
+ dpct::queue_ptr stream) {
881
+ GGML_ASSERT(ncols % QK4_NL == 0);
882
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
883
+ const sycl::range<3> block_nums(1, 1, block_num_y);
884
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
885
+ {
886
+
887
+ stream->submit([&](sycl::handler &cgh) {
888
+ cgh.parallel_for(
889
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
890
+ [=](sycl::nd_item<3> item_ct1)
891
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
892
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
893
+ vx, vy, dst, ncols, nrows, item_ct1);
894
+ });
895
+ });
896
+ }
897
+ }
898
+
899
+ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
900
+ float *dst, const int ncols,
901
+ const int nrows,
902
+ dpct::queue_ptr stream) {
903
+ GGML_ASSERT(ncols % QK_K == 0);
904
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
905
+ const sycl::range<3> block_nums(1, 1, block_num_y);
906
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
907
+ {
908
+
909
+ stream->submit([&](sycl::handler &cgh) {
910
+ cgh.parallel_for(
911
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
+ [=](sycl::nd_item<3> item_ct1)
913
+ [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
914
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
915
+ vx, vy, dst, ncols, nrows, item_ct1);
916
+ });
917
+ });
918
+ }
919
+ }
920
+
921
+ void ggml_sycl_op_mul_mat_vec_q(
922
+ ggml_backend_sycl_context & ctx,
923
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
924
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
925
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
926
+ const int64_t src1_ncols, const int64_t src1_padded_col_size,
927
+ const dpct::queue_ptr &stream) {
928
+
929
+ const int64_t ne10 = src1->ne[0];
930
+ GGML_ASSERT(ne10 % QK8_1 == 0);
931
+
932
+ const int64_t ne00 = src0->ne[0];
933
+ const int64_t row_diff = row_high - row_low;
934
+
935
+ int id;
936
+ SYCL_CHECK(
937
+ CHECK_TRY_ERROR(id = get_current_device_id()));
938
+ const size_t q8_1_ts = sizeof(block_q8_1);
939
+ const size_t q8_1_bs = QK8_1;
940
+ // the main device has a larger memory buffer to hold the results from all GPUs
941
+ // nrows_dst == nrows of the matrix that the kernel writes into
942
+
943
+ for (int i = 0; i < src1_ncols; i++)
944
+ {
945
+ const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
946
+ const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
947
+ float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
948
+ switch (src0->type) {
949
+ case GGML_TYPE_Q4_0:
950
+ mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
951
+ break;
952
+ case GGML_TYPE_Q4_1:
953
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
954
+ break;
955
+ case GGML_TYPE_Q5_0:
956
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
957
+ break;
958
+ case GGML_TYPE_Q5_1:
959
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
960
+ break;
961
+ case GGML_TYPE_Q8_0:
962
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
963
+ break;
964
+ case GGML_TYPE_Q2_K:
965
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
966
+ break;
967
+ case GGML_TYPE_Q3_K:
968
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
969
+ break;
970
+ case GGML_TYPE_Q4_K:
971
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
972
+ break;
973
+ case GGML_TYPE_Q5_K:
974
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
975
+ break;
976
+ case GGML_TYPE_Q6_K:
977
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
978
+ break;
979
+ case GGML_TYPE_IQ1_S:
980
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
981
+ break;
982
+ case GGML_TYPE_IQ1_M:
983
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
984
+ break;
985
+ case GGML_TYPE_IQ2_XXS:
986
+ mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
987
+ break;
988
+ case GGML_TYPE_IQ2_XS:
989
+ mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
990
+ break;
991
+ case GGML_TYPE_IQ2_S:
992
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
993
+ break;
994
+ case GGML_TYPE_IQ3_XXS:
995
+ mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
996
+ break;
997
+ case GGML_TYPE_IQ3_S:
998
+ mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
999
+ break;
1000
+ case GGML_TYPE_IQ4_NL:
1001
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1002
+ break;
1003
+ case GGML_TYPE_IQ4_XS:
1004
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1005
+ break;
1006
+ default:
1007
+ GGML_ABORT("fatal error");
1008
+ break;
1009
+ }
1010
+ }
1011
+ GGML_UNUSED(src1);
1012
+ GGML_UNUSED(dst);
1013
+ GGML_UNUSED(src1_ddf_i);
1014
+ GGML_UNUSED(ctx);
1015
+ }