whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,1023 @@
1
+ #include "convert.hpp"
2
+ #include "dmmv.hpp"
3
+ #include "dequantize.hpp"
4
+ #include "presets.hpp"
5
+
6
+
7
+ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
8
+ const sycl::half *x = (const sycl::half *)vx;
9
+
10
+ // automatic half -> float type cast if dfloat == float
11
+ v.x() = x[ib + iqs + 0];
12
+ v.y() = x[ib + iqs + 1];
13
+ }
14
+
15
+ static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
16
+ const float * x = (const float *) vx;
17
+
18
+ // automatic half -> float type cast if dfloat == float
19
+ v.x() = x[ib + iqs + 0];
20
+ v.y() = x[ib + iqs + 1];
21
+ }
22
+
23
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
24
+ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
25
+ const sycl::nd_item<3> &item_ct1) {
26
+ // qk = quantized weights per x block
27
+ // qr = number of quantized weights per data value in x block
28
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
29
+ item_ct1.get_local_id(1);
30
+
31
+ if (row >= nrows) {
32
+ return;
33
+ }
34
+
35
+ const int tid = item_ct1.get_local_id(2);
36
+
37
+ const int iter_stride = 2*GGML_SYCL_DMMV_X;
38
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
39
+ const int y_offset = qr == 1 ? 1 : qk/2;
40
+
41
+ // partial sum for each thread
42
+ #ifdef GGML_SYCL_F16
43
+ sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
44
+ #else
45
+ float tmp = 0.0f;
46
+ #endif // GGML_SYCL_F16
47
+
48
+ for (int i = 0; i < ncols; i += iter_stride) {
49
+ const int col = i + vals_per_iter*tid;
50
+ const int ib = (row*ncols + col)/qk; // x block index
51
+ const int iqs = (col%qk)/qr; // x quant index
52
+ const int iybs = col - col%qk; // y block start index
53
+
54
+ // processing >2 values per i iter is faster for fast GPUs
55
+ #pragma unroll
56
+ for (int j = 0; j < vals_per_iter; j += 2) {
57
+ // process 2 vals per j iter
58
+
59
+ // dequantize
60
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
61
+ dfloat2 v;
62
+ dequantize_kernel(vx, ib, iqs + j/qr, v);
63
+
64
+ // matrix multiplication
65
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
66
+ #ifdef GGML_SYCL_F16
67
+ dfloat2 t1{y[iybs + iqs + j / qr + 0],
68
+ y[iybs + iqs + j / qr + y_offset]};
69
+
70
+ tmp += v * t1;
71
+ #else
72
+ tmp += v.x() * y[iybs + iqs + j / qr + 0];
73
+ tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
74
+ #endif // GGML_SYCL_F16
75
+ }
76
+ }
77
+
78
+ // sum up partial sums and write back result
79
+ const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
80
+ for (int mask = mask_start; mask > 0; mask >>= 1) {
81
+ tmp +=
82
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
83
+ }
84
+
85
+ if (tid == 0) {
86
+ #ifdef GGML_SYCL_F16
87
+ dst[row] = tmp.x() + tmp.y();
88
+ #else
89
+ dst[row] = tmp;
90
+ #endif // GGML_SYCL_F16
91
+ }
92
+ }
93
+
94
+ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
95
+ float *dst, const int ncols,
96
+ const int nrows,
97
+ dpct::queue_ptr stream) {
98
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
99
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
100
+ const sycl::range<3> block_nums(1, 1, block_num_y);
101
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
102
+ {
103
+ dpct::has_capability_or_fail(stream->get_device(),
104
+ {sycl::aspect::fp16});
105
+
106
+ stream->parallel_for(
107
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
108
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
109
+ dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
110
+ nrows, item_ct1);
111
+ });
112
+ }
113
+ }
114
+
115
+ /*
116
+ DPCT1110:4: The total declared local variable size in device function
117
+ dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
118
+ pressure. Consult with your hardware vendor to find the total register size
119
+ available and adjust the code, or use smaller sub-group size to avoid high
120
+ register pressure.
121
+ */
122
+ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
123
+ const float *__restrict__ yy,
124
+ float *__restrict__ dst,
125
+ const int ncols, int nrows,
126
+ const sycl::nd_item<3> &item_ct1) {
127
+
128
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
129
+
130
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
131
+ item_ct1.get_local_id(1);
132
+ if (row > nrows) return;
133
+
134
+ const int num_blocks_per_row = ncols / QK_K;
135
+ const int ib0 = row*num_blocks_per_row;
136
+
137
+ const block_q2_K * x = (const block_q2_K *)vx + ib0;
138
+
139
+ float tmp = 0; // partial sum for thread in warp
140
+
141
+ #if QK_K == 256
142
+ const int tid =
143
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
144
+ const int ix =
145
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
146
+
147
+ const int step = 16/K_QUANTS_PER_ITERATION;
148
+
149
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
150
+ const int in = tid - step*im; // 0...15 or 0...7
151
+
152
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
153
+ const int q_offset = 32*im + l0;
154
+ const int s_offset = 8*im;
155
+ const int y_offset = 128*im + l0;
156
+
157
+ uint32_t aux[4];
158
+ const uint8_t * d = (const uint8_t *)aux;
159
+ const uint8_t * m = (const uint8_t *)(aux + 2);
160
+
161
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
162
+
163
+ const float * y = yy + i * QK_K + y_offset;
164
+ const uint8_t * q = x[i].qs + q_offset;
165
+
166
+ const float dall = x[i].dm[0];
167
+ const float dmin = x[i].dm[1];
168
+
169
+ const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
170
+ aux[0] = a[0] & 0x0f0f0f0f;
171
+ aux[1] = a[1] & 0x0f0f0f0f;
172
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
173
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
174
+
175
+ float sum1 = 0, sum2 = 0;
176
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
177
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
178
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
179
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
180
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
181
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
182
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
183
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
184
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
185
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
186
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
187
+
188
+ }
189
+ tmp += dall * sum1 - dmin * sum2;
190
+
191
+ }
192
+ #else
193
+ const int tid = item_ct1.get_local_id(2) /
194
+ (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
195
+ const int ix = item_ct1.get_local_id(2) %
196
+ (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
197
+ const int offset = tid * K_QUANTS_PER_ITERATION;
198
+
199
+ uint32_t uaux[2];
200
+ const uint8_t * d = (const uint8_t *)uaux;
201
+
202
+
203
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
204
+
205
+ const float * y = yy + i * QK_K + offset;
206
+ const uint8_t * q = x[i].qs + offset;
207
+ const uint32_t * s = (const uint32_t *)x[i].scales;
208
+
209
+ uaux[0] = s[0] & 0x0f0f0f0f;
210
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
211
+
212
+ const sycl::float2 dall =
213
+ x[i].dm.convert<float, sycl::rounding_mode::automatic>();
214
+
215
+ float sum1 = 0, sum2 = 0;
216
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
217
+ const uint8_t ql = q[l];
218
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
219
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
220
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
221
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
222
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
223
+ }
224
+ tmp += dall.x() * sum1 - dall.y() * sum2;
225
+ }
226
+
227
+ #endif
228
+
229
+ // sum up partial sums and write back result
230
+ #pragma unroll
231
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
232
+ tmp +=
233
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
234
+ }
235
+
236
+ if (item_ct1.get_local_id(2) == 0) {
237
+ dst[row] = tmp;
238
+ }
239
+ }
240
+
241
+ /*
242
+ DPCT1110:5: The total declared local variable size in device function
243
+ dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register
244
+ pressure. Consult with your hardware vendor to find the total register size
245
+ available and adjust the code, or use smaller sub-group size to avoid high
246
+ register pressure.
247
+ */
248
+ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
249
+ const float *__restrict__ yy,
250
+ float *__restrict__ dst,
251
+ const int ncols, int nrows,
252
+ const sycl::nd_item<3> &item_ct1) {
253
+
254
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
255
+ item_ct1.get_local_id(1);
256
+ if (row > nrows) return;
257
+
258
+ const int num_blocks_per_row = ncols / QK_K;
259
+ const int ib0 = row*num_blocks_per_row;
260
+
261
+ const block_q3_K * x = (const block_q3_K *)vx + ib0;
262
+
263
+ float tmp = 0; // partial sum for thread in warp
264
+
265
+ #if QK_K == 256
266
+
267
+ const uint16_t kmask1 = 0x0303;
268
+ const uint16_t kmask2 = 0x0f0f;
269
+
270
+ const int tid =
271
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
272
+ const int ix =
273
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
274
+
275
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
276
+ const int step = 16/K_QUANTS_PER_ITERATION;
277
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
278
+ const int in = tid - step*im; // 0....15 or 0...7
279
+
280
+ const uint8_t m = 1 << (4*im);
281
+
282
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
283
+ const int q_offset = 32*im + l0;
284
+ const int y_offset = 128*im + l0;
285
+
286
+ uint16_t utmp[4];
287
+ const int8_t * s = (const int8_t *)utmp;
288
+
289
+ const uint16_t s_shift = 4*im;
290
+
291
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
292
+
293
+ const float * y = yy + i * QK_K + y_offset;
294
+ const uint8_t * q = x[i].qs + q_offset;
295
+ const uint8_t * h = x[i].hmask + l0;
296
+
297
+ const uint16_t * a = (const uint16_t *)x[i].scales;
298
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
299
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
300
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
301
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
302
+
303
+ const float d = x[i].d;
304
+
305
+ float sum = 0;
306
+ for (int l = 0; l < n; ++l) {
307
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
308
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
309
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
310
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
311
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
312
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
313
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
314
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
315
+ }
316
+ tmp += d * sum;
317
+
318
+ }
319
+ #else
320
+
321
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
322
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
323
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
324
+ const int in = offset/8; // 0 or 1
325
+ const int im = offset%8; // 0...7
326
+
327
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
328
+
329
+ const float * y = yy + i * QK_K + offset;
330
+ const uint8_t * q = x[i].qs + offset;
331
+ const uint8_t * s = x[i].scales;
332
+
333
+ const float dall = (float)x[i].d;
334
+
335
+ float sum = 0;
336
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
337
+ const uint8_t hl = x[i].hmask[im+l] >> in;
338
+ const uint8_t ql = q[l];
339
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
340
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
341
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
342
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
343
+ }
344
+ tmp += sum;
345
+ }
346
+ #endif
347
+
348
+ // sum up partial sums and write back result
349
+ #pragma unroll
350
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
351
+ tmp +=
352
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
353
+ }
354
+
355
+ if (item_ct1.get_local_id(2) == 0) {
356
+ dst[row] = tmp;
357
+ }
358
+ }
359
+
360
+ /*
361
+ DPCT1110:6: The total declared local variable size in device function
362
+ dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
363
+ pressure. Consult with your hardware vendor to find the total register size
364
+ available and adjust the code, or use smaller sub-group size to avoid high
365
+ register pressure.
366
+ */
367
+ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
368
+ const float *__restrict__ yy,
369
+ float *__restrict__ dst,
370
+ const int ncols, int nrows,
371
+ const sycl::nd_item<3> &item_ct1) {
372
+
373
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
374
+ item_ct1.get_local_id(1);
375
+ if (row > nrows) return;
376
+ const int num_blocks_per_row = ncols / QK_K;
377
+ const int ib0 = row*num_blocks_per_row;
378
+
379
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
380
+
381
+ #if QK_K == 256
382
+ const uint16_t kmask1 = 0x3f3f;
383
+ const uint16_t kmask2 = 0x0f0f;
384
+ const uint16_t kmask3 = 0xc0c0;
385
+
386
+ const int tid =
387
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
388
+ const int ix =
389
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
390
+
391
+ const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
392
+
393
+ const int il = tid/step; // 0...3
394
+ const int ir = tid - step*il; // 0...7 or 0...3
395
+ const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
396
+
397
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
398
+ const int in = il%2;
399
+
400
+ const int l0 = n*(2*ir + in);
401
+ const int q_offset = 32*im + l0;
402
+ const int y_offset = 64*im + l0;
403
+
404
+ uint16_t aux[4];
405
+ const uint8_t * sc = (const uint8_t *)aux;
406
+
407
+ #if K_QUANTS_PER_ITERATION == 2
408
+ uint32_t q32[4];
409
+ const uint8_t * q4 = (const uint8_t *)q32;
410
+ #else
411
+ uint16_t q16[4];
412
+ const uint8_t * q4 = (const uint8_t *)q16;
413
+ #endif
414
+
415
+ float tmp = 0; // partial sum for thread in warp
416
+
417
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
418
+
419
+ const float * y1 = yy + i*QK_K + y_offset;
420
+ const float * y2 = y1 + 128;
421
+
422
+ const float dall = x[i].dm[0];
423
+ const float dmin = x[i].dm[1];
424
+
425
+ const uint16_t * a = (const uint16_t *)x[i].scales;
426
+ aux[0] = a[im+0] & kmask1;
427
+ aux[1] = a[im+2] & kmask1;
428
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
429
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
430
+
431
+ #if K_QUANTS_PER_ITERATION == 2
432
+ const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
433
+ const uint32_t * q2 = q1 + 16;
434
+
435
+ q32[0] = q1[0] & 0x0f0f0f0f;
436
+ q32[1] = q1[0] & 0xf0f0f0f0;
437
+ q32[2] = q2[0] & 0x0f0f0f0f;
438
+ q32[3] = q2[0] & 0xf0f0f0f0;
439
+
440
+ sycl::float4 s = {0.f, 0.f, 0.f, 0.f};
441
+ float smin = 0;
442
+ for (int l = 0; l < 4; ++l) {
443
+ s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];
444
+ s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];
445
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
446
+ }
447
+ tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
448
+ s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
449
+ dmin * smin;
450
+ #else
451
+ const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
452
+ const uint16_t * q2 = q1 + 32;
453
+
454
+ q16[0] = q1[0] & 0x0f0f;
455
+ q16[1] = q1[0] & 0xf0f0;
456
+ q16[2] = q2[0] & 0x0f0f;
457
+ q16[3] = q2[0] & 0xf0f0;
458
+
459
+ float4 s = {0.f, 0.f, 0.f, 0.f};
460
+ float smin = 0;
461
+ for (int l = 0; l < 2; ++l) {
462
+ s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
463
+ s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
464
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
465
+ }
466
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
467
+ #endif
468
+
469
+ }
470
+ #else
471
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
472
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
473
+
474
+ const int step = tid * K_QUANTS_PER_ITERATION;
475
+
476
+ uint16_t aux16[2];
477
+ const uint8_t * s = (const uint8_t *)aux16;
478
+
479
+ float tmp = 0;
480
+
481
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
482
+ const uint8_t * q = x[i].qs + step;
483
+ const float * y = yy + i*QK_K + step;
484
+ const uint16_t * a = (const uint16_t *)x[i].scales;
485
+ aux16[0] = a[0] & 0x0f0f;
486
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
487
+ const float d = (float)x[i].dm[0];
488
+ const float m = (float)x[i].dm[1];
489
+ float sum = 0.f;
490
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
491
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
492
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
493
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
494
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
495
+ }
496
+ tmp += sum;
497
+ }
498
+
499
+ #endif
500
+
501
+ // sum up partial sums and write back result
502
+ #pragma unroll
503
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
504
+ tmp +=
505
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
506
+ }
507
+
508
+ if (tid == 0) {
509
+ dst[row] = tmp;
510
+ }
511
+ }
512
+
513
+ /*
514
+ DPCT1110:7: The total declared local variable size in device function
515
+ dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register
516
+ pressure. Consult with your hardware vendor to find the total register size
517
+ available and adjust the code, or use smaller sub-group size to avoid high
518
+ register pressure.
519
+ */
520
+ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
521
+ const float *__restrict__ yy,
522
+ float *__restrict__ dst,
523
+ const int ncols,
524
+ const sycl::nd_item<3> &item_ct1) {
525
+
526
+ const int row = item_ct1.get_group(2);
527
+ const int num_blocks_per_row = ncols / QK_K;
528
+ const int ib0 = row*num_blocks_per_row;
529
+
530
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
531
+
532
+ float tmp = 0; // partial sum for thread in warp
533
+
534
+ #if QK_K == 256
535
+ const uint16_t kmask1 = 0x3f3f;
536
+ const uint16_t kmask2 = 0x0f0f;
537
+ const uint16_t kmask3 = 0xc0c0;
538
+
539
+ const int tid = item_ct1.get_local_id(2) / 2; // 0...15
540
+ const int ix = item_ct1.get_local_id(2) % 2;
541
+
542
+ const int il = tid/4; // 0...3
543
+ const int ir = tid - 4*il;// 0...3
544
+ const int n = 2;
545
+
546
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
547
+ const int in = il%2;
548
+
549
+ const int l0 = n*(2*ir + in);
550
+ const int q_offset = 32*im + l0;
551
+ const int y_offset = 64*im + l0;
552
+
553
+ const uint8_t hm1 = 1 << (2*im);
554
+ const uint8_t hm2 = hm1 << 4;
555
+
556
+ uint16_t aux[4];
557
+ const uint8_t * sc = (const uint8_t *)aux;
558
+
559
+ uint16_t q16[8];
560
+ const uint8_t * q4 = (const uint8_t *)q16;
561
+
562
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
563
+
564
+ const uint8_t * ql1 = x[i].qs + q_offset;
565
+ const uint8_t * qh = x[i].qh + l0;
566
+ const float * y1 = yy + i*QK_K + y_offset;
567
+ const float * y2 = y1 + 128;
568
+
569
+ const float dall = x[i].dm[0];
570
+ const float dmin = x[i].dm[1];
571
+
572
+ const uint16_t * a = (const uint16_t *)x[i].scales;
573
+ aux[0] = a[im+0] & kmask1;
574
+ aux[1] = a[im+2] & kmask1;
575
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
576
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
577
+
578
+ sycl::float4 sum = {0.f, 0.f, 0.f, 0.f};
579
+ float smin = 0;
580
+ const uint16_t * q1 = (const uint16_t *)ql1;
581
+ const uint16_t * q2 = q1 + 32;
582
+ q16[0] = q1[0] & 0x0f0f;
583
+ q16[1] = q1[8] & 0x0f0f;
584
+ q16[2] = (q1[0] >> 4) & 0x0f0f;
585
+ q16[3] = (q1[8] >> 4) & 0x0f0f;
586
+ q16[4] = q2[0] & 0x0f0f;
587
+ q16[5] = q2[8] & 0x0f0f;
588
+ q16[6] = (q2[0] >> 4) & 0x0f0f;
589
+ q16[7] = (q2[8] >> 4) & 0x0f0f;
590
+ for (int l = 0; l < n; ++l) {
591
+ sum.x() +=
592
+ y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) +
593
+ y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0));
594
+ sum.y() +=
595
+ y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) +
596
+ y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0));
597
+ sum.z() +=
598
+ y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) +
599
+ y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0));
600
+ sum.w() +=
601
+ y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) +
602
+ y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0));
603
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
604
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
605
+ }
606
+ tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] +
607
+ sum.w() * sc[5]) -
608
+ dmin * smin;
609
+ }
610
+
611
+ #else
612
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
613
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
614
+ const int step = tid * K_QUANTS_PER_ITERATION;
615
+ const int im = step/8;
616
+ const int in = step%8;
617
+
618
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
619
+ const uint8_t * q = x[i].qs + step;
620
+ const int8_t * s = x[i].scales;
621
+ const float * y = yy + i*QK_K + step;
622
+ const float d = x[i].d;
623
+ float sum = 0.f;
624
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
625
+ const uint8_t h = x[i].qh[in+j] >> im;
626
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
627
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
628
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
629
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
630
+ }
631
+ tmp += sum;
632
+ }
633
+ #endif
634
+
635
+ // sum up partial sums and write back result
636
+ #pragma unroll
637
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
638
+ tmp +=
639
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
640
+ }
641
+
642
+ if (item_ct1.get_local_id(2) == 0) {
643
+ dst[row] = tmp;
644
+ }
645
+ }
646
+
647
+ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
648
+ const sycl::nd_item<3> &item_ct1) {
649
+
650
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
651
+
652
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
653
+ item_ct1.get_local_id(1);
654
+ if (row > nrows) return;
655
+
656
+ const int num_blocks_per_row = ncols / QK_K;
657
+ const int ib0 = row*num_blocks_per_row;
658
+
659
+ const block_q6_K * x = (const block_q6_K *)vx + ib0;
660
+
661
+ #if QK_K == 256
662
+
663
+ const int tid =
664
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
665
+ const int ix =
666
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
667
+
668
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
669
+
670
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
671
+ const int in = tid - step*im; // 0...15 or 0...7
672
+
673
+ #if K_QUANTS_PER_ITERATION == 1
674
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
675
+ const int is = 0;
676
+ #else
677
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
678
+ const int is = in / 4;
679
+ #endif
680
+ const int ql_offset = 64*im + l0;
681
+ const int qh_offset = 32*im + l0;
682
+ const int s_offset = 8*im + is;
683
+ const int y_offset = 128*im + l0;
684
+
685
+ float tmp = 0; // partial sum for thread in warp
686
+
687
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
688
+
689
+ const float * y = yy + i * QK_K + y_offset;
690
+ const uint8_t * ql = x[i].ql + ql_offset;
691
+ const uint8_t * qh = x[i].qh + qh_offset;
692
+ const int8_t * s = x[i].scales + s_offset;
693
+
694
+ const float d = x[i].d;
695
+
696
+ #if K_QUANTS_PER_ITERATION == 1
697
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
698
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
699
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
700
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
701
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
702
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
703
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
704
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
705
+ tmp += sum;
706
+ #else
707
+ float sum = 0;
708
+ for (int l = 0; l < 4; ++l) {
709
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
710
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
711
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
712
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
713
+ }
714
+ tmp += sum;
715
+ #endif
716
+
717
+ }
718
+
719
+ #else
720
+
721
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7
722
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3
723
+
724
+ const int step = tid * K_QUANTS_PER_ITERATION;
725
+
726
+ float tmp = 0; // partial sum for thread in warp
727
+
728
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
729
+
730
+ const float * y = yy + i * QK_K + step;
731
+ const uint8_t * ql = x[i].ql + step;
732
+ const uint8_t * qh = x[i].qh + step;
733
+ const int8_t * s = x[i].scales;
734
+
735
+ const float d = x[i+0].d;
736
+
737
+ float sum = 0;
738
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
739
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
740
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
741
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
742
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
743
+ }
744
+ tmp += sum;
745
+
746
+ }
747
+
748
+ #endif
749
+
750
+ // sum up partial sums and write back result
751
+ #pragma unroll
752
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
753
+ tmp +=
754
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
755
+ }
756
+
757
+ if (tid == 0) {
758
+ dst[row] = tmp;
759
+ }
760
+ }
761
+
762
+
763
+ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
764
+ float *dst, const int ncols,
765
+ const int nrows,
766
+ dpct::queue_ptr stream) {
767
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
768
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
769
+ // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
770
+ const sycl::range<3> block_nums(1, 1, block_num_y);
771
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
772
+ {
773
+ dpct::has_capability_or_fail(stream->get_device(),
774
+ {sycl::aspect::fp16});
775
+
776
+ stream->parallel_for(
777
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
778
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
779
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
780
+ vx, y, dst, ncols, nrows, item_ct1);
781
+ });
782
+ }
783
+ }
784
+
785
+ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
786
+ float *dst, const int ncols,
787
+ const int nrows,
788
+ dpct::queue_ptr stream) {
789
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
790
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
791
+ const sycl::range<3> block_nums(1, 1, block_num_y);
792
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
793
+ {
794
+ dpct::has_capability_or_fail(stream->get_device(),
795
+ {sycl::aspect::fp16});
796
+
797
+ stream->parallel_for(
798
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
799
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
800
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
801
+ vx, y, dst, ncols, nrows, item_ct1);
802
+ });
803
+ }
804
+ }
805
+
806
+ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
807
+ float *dst, const int ncols,
808
+ const int nrows,
809
+ dpct::queue_ptr stream) {
810
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
811
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
812
+ const sycl::range<3> block_nums(1, 1, block_num_y);
813
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
814
+ {
815
+ dpct::has_capability_or_fail(stream->get_device(),
816
+ {sycl::aspect::fp16});
817
+
818
+ stream->parallel_for(
819
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
820
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
821
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
822
+ vx, y, dst, ncols, nrows, item_ct1);
823
+ });
824
+ }
825
+ }
826
+
827
+ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
828
+ float *dst, const int ncols,
829
+ const int nrows,
830
+ dpct::queue_ptr stream) {
831
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
832
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
833
+ const sycl::range<3> block_nums(1, 1, block_num_y);
834
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
835
+ {
836
+ dpct::has_capability_or_fail(stream->get_device(),
837
+ {sycl::aspect::fp16});
838
+
839
+ stream->parallel_for(
840
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
841
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
842
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
843
+ vx, y, dst, ncols, nrows, item_ct1);
844
+ });
845
+ }
846
+ }
847
+
848
+ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
849
+ float *dst, const int ncols,
850
+ const int nrows,
851
+ dpct::queue_ptr stream) {
852
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
853
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
854
+ const sycl::range<3> block_nums(1, 1, block_num_y);
855
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
856
+ {
857
+ dpct::has_capability_or_fail(stream->get_device(),
858
+ {sycl::aspect::fp16});
859
+
860
+ stream->parallel_for(
861
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
862
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
863
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
864
+ vx, y, dst, ncols, nrows, item_ct1);
865
+ });
866
+ }
867
+ }
868
+
869
+ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
870
+ float *dst, const int ncols,
871
+ const int nrows,
872
+ dpct::queue_ptr stream) {
873
+ GGML_ASSERT(ncols % QK_K == 0);
874
+ const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
875
+ const int block_num_y = (nrows + ny - 1) / ny;
876
+ const sycl::range<3> block_nums(1, 1, block_num_y);
877
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
878
+ stream->parallel_for(
879
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
880
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
881
+ dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
882
+ });
883
+ }
884
+
885
+ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
886
+ float *dst, const int ncols,
887
+ const int nrows,
888
+ dpct::queue_ptr stream) {
889
+ GGML_ASSERT(ncols % QK_K == 0);
890
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
891
+ const int block_num_y = (nrows + ny - 1) / ny;
892
+ const sycl::range<3> block_nums(1, 1, block_num_y);
893
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
894
+ stream->parallel_for(
895
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
896
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
897
+ dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
898
+ });
899
+ }
900
+
901
+ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
902
+ float *dst, const int ncols,
903
+ const int nrows,
904
+ dpct::queue_ptr stream) {
905
+ GGML_ASSERT(ncols % QK_K == 0);
906
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
907
+ const int block_num_y = (nrows + ny - 1) / ny;
908
+ const sycl::range<3> block_nums(1, 1, block_num_y);
909
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
910
+ stream->parallel_for(
911
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
913
+ dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
914
+ });
915
+ }
916
+
917
+ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
918
+ float *dst, const int ncols,
919
+ const int nrows,
920
+ dpct::queue_ptr stream) {
921
+ GGML_ASSERT(ncols % QK_K == 0);
922
+ const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
923
+ stream->parallel_for(
924
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
925
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
926
+ dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
927
+ });
928
+ }
929
+
930
+ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
931
+ float *dst, const int ncols,
932
+ const int nrows,
933
+ dpct::queue_ptr stream) {
934
+ GGML_ASSERT(ncols % QK_K == 0);
935
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
936
+ const int block_num_y = (nrows + ny - 1) / ny;
937
+ const sycl::range<3> block_nums(1, 1, block_num_y);
938
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
939
+ stream->parallel_for(
940
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
941
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
942
+ dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
943
+ });
944
+ }
945
+
946
+ void ggml_sycl_op_dequantize_mul_mat_vec(
947
+ ggml_backend_sycl_context & ctx,
948
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
949
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
950
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
951
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
952
+ const dpct::queue_ptr &stream) {
953
+
954
+ const int64_t ne00 = src0->ne[0];
955
+ const int64_t row_diff = row_high - row_low;
956
+
957
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
958
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
959
+ #ifdef GGML_SYCL_F16
960
+ ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool());
961
+ sycl::half *src1_dfloat = nullptr; // dfloat == half
962
+
963
+ bool src1_convert_f16 =
964
+ src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
965
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
966
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
967
+
968
+ if (src1_convert_f16) {
969
+ src1_dfloat = src1_dfloat_a.alloc(ne00);
970
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
971
+ GGML_ASSERT(to_fp16_sycl != nullptr);
972
+ to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
973
+ }
974
+ #else
975
+ const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
976
+ #endif // GGML_SYCL_F16
977
+
978
+ switch (src0->type) {
979
+ case GGML_TYPE_Q4_0:
980
+ dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
981
+ break;
982
+ case GGML_TYPE_Q4_1:
983
+ dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
984
+ break;
985
+ case GGML_TYPE_Q5_0:
986
+ dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
987
+ break;
988
+ case GGML_TYPE_Q5_1:
989
+ dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
990
+ break;
991
+ case GGML_TYPE_Q8_0:
992
+ dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
993
+ break;
994
+ case GGML_TYPE_Q2_K:
995
+ dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
996
+ break;
997
+ case GGML_TYPE_Q3_K:
998
+ dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
999
+ break;
1000
+ case GGML_TYPE_Q4_K:
1001
+ dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1002
+ break;
1003
+ case GGML_TYPE_Q5_K:
1004
+ dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1005
+ break;
1006
+ case GGML_TYPE_Q6_K:
1007
+ dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1008
+ break;
1009
+ case GGML_TYPE_F16:
1010
+ convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1011
+ break;
1012
+ default:
1013
+ printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
1014
+ GGML_ABORT("fatal error");
1015
+ break;
1016
+ }
1017
+
1018
+ GGML_UNUSED(src1);
1019
+ GGML_UNUSED(dst);
1020
+ GGML_UNUSED(src1_ddq_i);
1021
+ GGML_UNUSED(src1_ncols);
1022
+ GGML_UNUSED(src1_padded_row_size);
1023
+ }