whispercpp 1.3.0 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,3031 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "mmq.hpp"
14
+ #include "vecdotq.hpp"
15
+
16
+ typedef void (*allocate_tiles_sycl_t)(
17
+ int** x_ql,
18
+ sycl::half2** x_dm,
19
+ int** x_qh,
20
+ int** x_sc);
21
+ typedef void (*load_tiles_sycl_t)(
22
+ const void* __restrict__ vx,
23
+ int* __restrict__ x_ql,
24
+ sycl::half2* __restrict__ x_dm,
25
+ int* __restrict__ x_qh,
26
+ int* __restrict__ x_sc,
27
+ const int& i_offset,
28
+ const int& i_max,
29
+ const int& k,
30
+ const int& blocks_per_row);
31
+ typedef float (*vec_dot_q_mul_mat_sycl_t)(
32
+ const int* __restrict__ x_ql,
33
+ const sycl::half2* __restrict__ x_dm,
34
+ const int* __restrict__ x_qh,
35
+ const int* __restrict__ x_sc,
36
+ const int* __restrict__ y_qs,
37
+ const sycl::half2* __restrict__ y_ms,
38
+ const int& i,
39
+ const int& j,
40
+ const int& k);
41
+
42
+
43
+ template <int mmq_y>
44
+ static __dpct_inline__ void
45
+ allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
46
+ int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
47
+ (void)x_qh; (void)x_sc;
48
+
49
+ *x_ql = tile_x_qs_q4_0;
50
+ *x_dm = (sycl::half2 *)tile_x_d_q4_0;
51
+ }
52
+
53
+ template <int mmq_y, int nwarps, bool need_check>
54
+ static __dpct_inline__ void
55
+ load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
56
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
57
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
58
+ const int &k, const int &blocks_per_row) {
59
+ (void)x_qh; (void)x_sc;
60
+ GGML_SYCL_ASSUME(i_offset >= 0);
61
+ GGML_SYCL_ASSUME(i_offset < nwarps);
62
+ GGML_SYCL_ASSUME(k >= 0);
63
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
64
+
65
+ const int kbx = k / QI4_0;
66
+ const int kqsx = k % QI4_0;
67
+
68
+ const block_q4_0 * bx0 = (const block_q4_0 *) vx;
69
+
70
+ float * x_dmf = (float *) x_dm;
71
+
72
+ #pragma unroll
73
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
74
+ int i = i0 + i_offset;
75
+
76
+ if (need_check) {
77
+ i = sycl::min(i, i_max);
78
+ }
79
+
80
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
81
+
82
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
83
+ // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
84
+ }
85
+
86
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
87
+ const int kbxd = k % blocks_per_tile_x_row;
88
+
89
+ #pragma unroll
90
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
91
+ int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
92
+
93
+ if (need_check) {
94
+ i = sycl::min(i, i_max);
95
+ }
96
+
97
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
98
+
99
+ x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
100
+ }
101
+ }
102
+
103
+ static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
104
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
105
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
106
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
107
+ const int &i, const int &j, const int &k) {
108
+ (void)x_qh; (void)x_sc;
109
+
110
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
111
+ const float * x_dmf = (const float *) x_dm;
112
+
113
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
114
+
115
+ #pragma unroll
116
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
117
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
118
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
119
+ }
120
+
121
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
122
+ (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
123
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
124
+ }
125
+
126
+ template <int mmq_y>
127
+ static __dpct_inline__ void
128
+ allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
129
+ int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
130
+ (void)x_qh; (void)x_sc;
131
+
132
+ *x_ql = tile_x_qs_q4_1;
133
+ *x_dm = tile_x_dm_q4_1;
134
+ }
135
+
136
+
137
+ template <int mmq_y, int nwarps, bool need_check>
138
+ static __dpct_inline__ void
139
+ load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
140
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
141
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
142
+ const int &k, const int &blocks_per_row) {
143
+ (void)x_qh; (void)x_sc;
144
+
145
+ GGML_SYCL_ASSUME(i_offset >= 0);
146
+ GGML_SYCL_ASSUME(i_offset < nwarps);
147
+ GGML_SYCL_ASSUME(k >= 0);
148
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
149
+
150
+ const int kbx = k / QI4_1;
151
+ const int kqsx = k % QI4_1;
152
+
153
+ const block_q4_1 * bx0 = (const block_q4_1 *) vx;
154
+
155
+ #pragma unroll
156
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
157
+ int i = i0 + i_offset;
158
+
159
+ if (need_check) {
160
+ i = sycl::min(i, i_max);
161
+ }
162
+
163
+ const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
164
+
165
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
166
+ }
167
+
168
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
169
+ const int kbxd = k % blocks_per_tile_x_row;
170
+
171
+ #pragma unroll
172
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
173
+ int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
174
+
175
+ if (need_check) {
176
+ i = sycl::min(i, i_max);
177
+ }
178
+
179
+ const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
180
+
181
+ x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
182
+ }
183
+ }
184
+
185
+ static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
186
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
187
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
188
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
189
+ const int &i, const int &j, const int &k) {
190
+ (void)x_qh; (void)x_sc;
191
+
192
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
193
+
194
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
195
+
196
+ #pragma unroll
197
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
198
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
199
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
200
+ }
201
+
202
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
203
+ (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
204
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
205
+ }
206
+
207
+ template <int mmq_y>
208
+ static __dpct_inline__ void
209
+ allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
210
+ int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
211
+ (void)x_qh; (void)x_sc;
212
+
213
+ *x_ql = tile_x_ql_q5_0;
214
+ *x_dm = (sycl::half2 *)tile_x_d_q5_0;
215
+ }
216
+
217
+ template <int mmq_y, int nwarps, bool need_check>
218
+ static __dpct_inline__ void
219
+ load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
220
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
221
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
222
+ const int &k, const int &blocks_per_row) {
223
+ (void)x_qh; (void)x_sc;
224
+
225
+ GGML_SYCL_ASSUME(i_offset >= 0);
226
+ GGML_SYCL_ASSUME(i_offset < nwarps);
227
+ GGML_SYCL_ASSUME(k >= 0);
228
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
229
+
230
+ const int kbx = k / QI5_0;
231
+ const int kqsx = k % QI5_0;
232
+
233
+ const block_q5_0 * bx0 = (const block_q5_0 *) vx;
234
+
235
+ #pragma unroll
236
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
237
+ int i = i0 + i_offset;
238
+
239
+ if (need_check) {
240
+ i = sycl::min(i, i_max);
241
+ }
242
+
243
+ const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
244
+
245
+ const int ql = get_int_from_uint8(bxi->qs, kqsx);
246
+ const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
247
+
248
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
249
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
250
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
251
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
252
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
253
+ qs0 = dpct::vectorized_binary<sycl::char4>(
254
+ qs0, 0x10101010, dpct::sub_sat()); // subtract 16
255
+
256
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
257
+
258
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
259
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
260
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
261
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
262
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
263
+ qs1 = dpct::vectorized_binary<sycl::char4>(
264
+ qs1, 0x10101010, dpct::sub_sat()); // subtract 16
265
+
266
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
267
+ }
268
+
269
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
270
+ const int kbxd = k % blocks_per_tile_x_row;
271
+ float * x_dmf = (float *) x_dm;
272
+
273
+ #pragma unroll
274
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
275
+ int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
276
+
277
+ if (need_check) {
278
+ i = sycl::min(i, i_max);
279
+ }
280
+
281
+ const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
282
+
283
+ x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
284
+ }
285
+ }
286
+
287
+ static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
288
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
289
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
290
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
291
+ const int &i, const int &j, const int &k) {
292
+ (void)x_qh; (void)x_sc;
293
+
294
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
295
+ const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
296
+ const float * x_dmf = (const float *) x_dm;
297
+ const float * y_df = (const float *) y_ds;
298
+
299
+ int u[2*VDR_Q5_0_Q8_1_MMQ];
300
+
301
+ #pragma unroll
302
+ for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
303
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
304
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
305
+ }
306
+
307
+ return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
308
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
309
+ }
310
+
311
+ template <int mmq_y>
312
+ static __dpct_inline__ void
313
+ allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
314
+ int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
315
+ (void)x_qh; (void)x_sc;
316
+
317
+ *x_ql = tile_x_ql_q5_1;
318
+ *x_dm = tile_x_dm_q5_1;
319
+ }
320
+
321
+ template <int mmq_y, int nwarps, bool need_check>
322
+ static __dpct_inline__ void
323
+ load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
324
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
325
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
326
+ const int &k, const int &blocks_per_row) {
327
+ (void)x_qh; (void)x_sc;
328
+
329
+ GGML_SYCL_ASSUME(i_offset >= 0);
330
+ GGML_SYCL_ASSUME(i_offset < nwarps);
331
+ GGML_SYCL_ASSUME(k >= 0);
332
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
333
+
334
+ const int kbx = k / QI5_1;
335
+ const int kqsx = k % QI5_1;
336
+
337
+ const block_q5_1 * bx0 = (const block_q5_1 *) vx;
338
+
339
+ #pragma unroll
340
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
341
+ int i = i0 + i_offset;
342
+
343
+ if (need_check) {
344
+ i = sycl::min(i, i_max);
345
+ }
346
+
347
+ const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
348
+
349
+ const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
350
+ const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
351
+
352
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
353
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
354
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
355
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
356
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
357
+
358
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
359
+
360
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
361
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
362
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
363
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
364
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
365
+
366
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
367
+ }
368
+
369
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
370
+ const int kbxd = k % blocks_per_tile_x_row;
371
+
372
+ #pragma unroll
373
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
374
+ int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
375
+
376
+ if (need_check) {
377
+ i = sycl::min(i, i_max);
378
+ }
379
+
380
+ const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
381
+
382
+ x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
383
+ }
384
+ }
385
+
386
+ static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
387
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
388
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
389
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
390
+ const int &i, const int &j, const int &k) {
391
+ (void)x_qh; (void)x_sc;
392
+
393
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
394
+ const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
395
+
396
+ int u[2*VDR_Q5_1_Q8_1_MMQ];
397
+
398
+ #pragma unroll
399
+ for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
400
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
401
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
402
+ }
403
+
404
+ return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
405
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
406
+ }
407
+
408
+ template <int mmq_y>
409
+ static __dpct_inline__ void
410
+ allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
411
+ int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
412
+ (void)x_qh; (void)x_sc;
413
+
414
+ *x_ql = tile_x_qs_q8_0;
415
+ *x_dm = (sycl::half2 *)tile_x_d_q8_0;
416
+ }
417
+
418
+ template <int mmq_y, int nwarps, bool need_check>
419
+ static __dpct_inline__ void
420
+ load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
421
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
422
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
423
+ const int &k, const int &blocks_per_row) {
424
+ (void)x_qh; (void)x_sc;
425
+
426
+ GGML_SYCL_ASSUME(i_offset >= 0);
427
+ GGML_SYCL_ASSUME(i_offset < nwarps);
428
+ GGML_SYCL_ASSUME(k >= 0);
429
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
430
+
431
+ const int kbx = k / QI8_0;
432
+ const int kqsx = k % QI8_0;
433
+ float * x_dmf = (float *) x_dm;
434
+
435
+ const block_q8_0 * bx0 = (const block_q8_0 *) vx;
436
+
437
+ #pragma unroll
438
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
439
+ int i = i0 + i_offset;
440
+
441
+ if (need_check) {
442
+ i = sycl::min(i, i_max);
443
+ }
444
+
445
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
446
+
447
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
448
+ }
449
+
450
+ const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
451
+ const int kbxd = k % blocks_per_tile_x_row;
452
+
453
+ #pragma unroll
454
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
455
+ int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
456
+
457
+ if (need_check) {
458
+ i = sycl::min(i, i_max);
459
+ }
460
+
461
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
462
+
463
+ x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
464
+ }
465
+ }
466
+
467
+ static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
468
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
469
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
470
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
471
+ const int &i, const int &j, const int &k) {
472
+ (void)x_qh; (void)x_sc;
473
+
474
+ const float * x_dmf = (const float *) x_dm;
475
+ const float * y_df = (const float *) y_ds;
476
+
477
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
478
+ (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
479
+ y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
480
+ }
481
+
482
+ template <int mmq_y>
483
+ static __dpct_inline__ void
484
+ allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
485
+ int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
486
+ int *tile_x_sc_q2_K) {
487
+ (void)x_qh;
488
+
489
+ *x_ql = tile_x_ql_q2_K;
490
+ *x_dm = tile_x_dm_q2_K;
491
+ *x_sc = tile_x_sc_q2_K;
492
+ }
493
+
494
+ template <int mmq_y, int nwarps, bool need_check>
495
+ static __dpct_inline__ void
496
+ load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
497
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
498
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
499
+ const int &k, const int &blocks_per_row) {
500
+ (void)x_qh;
501
+
502
+ GGML_SYCL_ASSUME(i_offset >= 0);
503
+ GGML_SYCL_ASSUME(i_offset < nwarps);
504
+ GGML_SYCL_ASSUME(k >= 0);
505
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
506
+
507
+ const int kbx = k / QI2_K;
508
+ const int kqsx = k % QI2_K;
509
+
510
+ const block_q2_K * bx0 = (const block_q2_K *) vx;
511
+
512
+ #pragma unroll
513
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
514
+ int i = i0 + i_offset;
515
+
516
+ if (need_check) {
517
+ i = sycl::min(i, i_max);
518
+ }
519
+
520
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
521
+
522
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
523
+ }
524
+
525
+ const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
526
+ const int kbxd = k % blocks_per_tile_x_row;
527
+
528
+ #pragma unroll
529
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
530
+ int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
531
+
532
+ if (need_check) {
533
+ i = sycl::min(i, i_max);
534
+ }
535
+
536
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
537
+
538
+ x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
539
+ }
540
+
541
+ #pragma unroll
542
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
543
+ int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
544
+
545
+ if (need_check) {
546
+ i = sycl::min(i, i_max);
547
+ }
548
+
549
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
550
+
551
+ x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
552
+ }
553
+ }
554
+
555
+ #define VDR_Q2_K_Q8_1_MMQ 2
556
+ // contiguous u/y values
557
+ static __dpct_inline__ float
558
+ vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
559
+ const uint8_t *__restrict__ scales,
560
+ const sycl::half2 &dm2, const float &d8) {
561
+
562
+ int sumi_d = 0;
563
+ int sumi_m = 0;
564
+
565
+ #pragma unroll
566
+ for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
567
+ int sumi_d_sc = 0;
568
+
569
+ const int sc = scales[i0 / (QI8_1/2)];
570
+
571
+ // fill int with 4x m
572
+ int m = sc >> 4;
573
+ m |= m << 8;
574
+ m |= m << 16;
575
+
576
+ #pragma unroll
577
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
578
+ sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
579
+ sumi_m = dpct::dp4a(m, u[i],
580
+ sumi_m); // multiply sum of q8_1 values with m
581
+ }
582
+
583
+ sumi_d += sumi_d_sc * (sc & 0xF);
584
+ }
585
+
586
+ const sycl::float2 dm2f =
587
+ dm2.convert<float, sycl::rounding_mode::automatic>();
588
+
589
+ return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
590
+ }
591
+
592
+ static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
593
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
594
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
595
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
596
+ const int &i, const int &j, const int &k) {
597
+ (void)x_qh;
598
+
599
+ const int kbx = k / QI2_K;
600
+ const int ky = (k % QI2_K) * QR2_K;
601
+ const float * y_df = (const float *) y_ds;
602
+
603
+ int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
604
+
605
+ const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
606
+ const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
607
+
608
+ #pragma unroll
609
+ for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
610
+ v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
611
+ }
612
+
613
+ const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
614
+
615
+ const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
616
+ return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
617
+ }
618
+
619
+ template <int mmq_y>
620
+ static __dpct_inline__ void
621
+ allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
622
+ int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
623
+ int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
624
+
625
+ *x_ql = tile_x_ql_q3_K;
626
+ *x_dm = tile_x_dm_q3_K;
627
+ *x_qh = tile_x_qh_q3_K;
628
+ *x_sc = tile_x_sc_q3_K;
629
+ }
630
+
631
+ template <int mmq_y, int nwarps, bool need_check>
632
+ static __dpct_inline__ void
633
+ load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
634
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
635
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
636
+ const int &k, const int &blocks_per_row) {
637
+
638
+ GGML_SYCL_ASSUME(i_offset >= 0);
639
+ GGML_SYCL_ASSUME(i_offset < nwarps);
640
+ GGML_SYCL_ASSUME(k >= 0);
641
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
642
+
643
+ const int kbx = k / QI3_K;
644
+ const int kqsx = k % QI3_K;
645
+
646
+ const block_q3_K * bx0 = (const block_q3_K *) vx;
647
+
648
+ #pragma unroll
649
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
650
+ int i = i0 + i_offset;
651
+
652
+ if (need_check) {
653
+ i = sycl::min(i, i_max);
654
+ }
655
+
656
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
657
+
658
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
659
+ }
660
+
661
+ const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
662
+ const int kbxd = k % blocks_per_tile_x_row;
663
+ float * x_dmf = (float *) x_dm;
664
+
665
+ #pragma unroll
666
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
667
+ int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
668
+
669
+ if (need_check) {
670
+ i = sycl::min(i, i_max);
671
+ }
672
+
673
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
674
+
675
+ x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
676
+ }
677
+
678
+ #pragma unroll
679
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
680
+ int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
681
+
682
+ if (need_check) {
683
+ i = sycl::min(i, i_max);
684
+ }
685
+
686
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
687
+
688
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
689
+ x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
690
+ }
691
+
692
+ #pragma unroll
693
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
694
+ int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
695
+
696
+ if (need_check) {
697
+ i = sycl::min(i, i_max);
698
+ }
699
+
700
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
701
+
702
+ const int ksc = k % (QI3_K/4);
703
+
704
+ const int ksc_low = ksc % (QI3_K/8);
705
+ const int shift_low = 4 * (ksc / (QI3_K/8));
706
+ const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
707
+
708
+ const int ksc_high = QI3_K/8;
709
+ const int shift_high = 2 * ksc;
710
+ const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
711
+
712
+ const int sc = dpct::vectorized_binary<sycl::char4>(
713
+ sc_low | sc_high, 0x20202020, dpct::sub_sat());
714
+
715
+ x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
716
+ }
717
+ }
718
+
719
+ #define VDR_Q3_K_Q8_1_MMQ 2
720
+ // contiguous u/y values
721
+ static __dpct_inline__ float
722
+ vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
723
+ const int8_t *__restrict__ scales, const float &d3,
724
+ const float &d8) {
725
+
726
+ int sumi = 0;
727
+
728
+ #pragma unroll
729
+ for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
730
+ int sumi_sc = 0;
731
+
732
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
733
+ sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
734
+ }
735
+
736
+ sumi += sumi_sc * scales[i0 / (QI8_1/2)];
737
+ }
738
+
739
+ return d3*d8 * sumi;
740
+ }
741
+
742
+ static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
743
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
744
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
745
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
746
+ const int &i, const int &j, const int &k) {
747
+
748
+ const int kbx = k / QI3_K;
749
+ const int ky = (k % QI3_K) * QR3_K;
750
+ const float * x_dmf = (const float *) x_dm;
751
+ const float * y_df = (const float *) y_ds;
752
+
753
+ const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
754
+
755
+ int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
756
+
757
+ #pragma unroll
758
+ for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
759
+ const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
760
+ const int shift = 2 * ((ky % 32) / 8);
761
+ const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
762
+
763
+ const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
764
+ const int vlh = (vh << 2) & 0x04040404;
765
+
766
+ v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
767
+ }
768
+
769
+ const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
770
+ return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
771
+ }
772
+
773
+ template <int mmq_y>
774
+ static __dpct_inline__ void
775
+ allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
776
+ int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
777
+ int *tile_x_sc_q4_K) {
778
+ (void)x_qh;
779
+
780
+ *x_ql = tile_x_ql_q4_K;
781
+ *x_dm = tile_x_dm_q4_K;
782
+ *x_sc = tile_x_sc_q4_K;
783
+ }
784
+
785
+ template <int mmq_y, int nwarps, bool need_check>
786
+ static __dpct_inline__ void
787
+ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
788
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
789
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
790
+ const int &k, const int &blocks_per_row) {
791
+ (void)x_qh;
792
+
793
+ GGML_SYCL_ASSUME(i_offset >= 0);
794
+ GGML_SYCL_ASSUME(i_offset < nwarps);
795
+ GGML_SYCL_ASSUME(k >= 0);
796
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
797
+
798
+ const int kbx = k / QI4_K; // == 0 if QK_K == 256
799
+ const int kqsx = k % QI4_K; // == k if QK_K == 256
800
+
801
+ const block_q4_K * bx0 = (const block_q4_K *) vx;
802
+
803
+ #pragma unroll
804
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
805
+ int i = i0 + i_offset;
806
+
807
+ if (need_check) {
808
+ i = sycl::min(i, i_max);
809
+ }
810
+
811
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
812
+
813
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
814
+ }
815
+
816
+ constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
817
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
818
+
819
+ #pragma unroll
820
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
821
+ int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
822
+
823
+ if (need_check) {
824
+ i = sycl::min(i, i_max);
825
+ }
826
+
827
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
828
+
829
+ #if QK_K == 256
830
+ x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
831
+ #else
832
+ x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
833
+ #endif
834
+ }
835
+
836
+ #pragma unroll
837
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
838
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
839
+
840
+ if (need_check) {
841
+ i = sycl::min(i, i_max);
842
+ }
843
+
844
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
845
+
846
+ const int * scales = (const int *) bxi->scales;
847
+
848
+ const int ksc = k % (WARP_SIZE/8);
849
+
850
+ // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
851
+ int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
852
+ scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
853
+
854
+ x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
855
+ }
856
+ }
857
+
858
+
859
+ #define VDR_Q4_K_Q8_1_MMQ 8
860
+
861
+ // contiguous u/y values
862
+ static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
863
+ const int *__restrict__ v, const int *__restrict__ u,
864
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
865
+ const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
866
+
867
+ float sumf_d = 0.0f;
868
+ float sumf_m = 0.0f;
869
+
870
+ #pragma unroll
871
+ for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
872
+ int sumi_d = 0;
873
+
874
+ #pragma unroll
875
+ for (int j = 0; j < QI8_1; ++j) {
876
+ sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
877
+ u[i * QI8_1 + j], sumi_d); // SIMD dot product
878
+ }
879
+
880
+ const sycl::float2 ds8f =
881
+ ds8[i].convert<float, sycl::rounding_mode::automatic>();
882
+
883
+ sumf_d += ds8f.x() * (sc[i] * sumi_d);
884
+ sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
885
+ }
886
+
887
+ const sycl::float2 dm4f =
888
+ dm4.convert<float, sycl::rounding_mode::automatic>();
889
+
890
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
891
+ }
892
+
893
+
894
+ static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
895
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
896
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
897
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
898
+ const int &i, const int &j, const int &k) {
899
+ (void)x_qh;
900
+
901
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
902
+
903
+ const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
904
+ return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
905
+ x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
906
+ }
907
+
908
+ template <int mmq_y>
909
+ static __dpct_inline__ void
910
+ allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
911
+ int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
912
+ int *tile_x_sc_q5_K) {
913
+ (void)x_qh;
914
+
915
+ *x_ql = tile_x_ql_q5_K;
916
+ *x_dm = tile_x_dm_q5_K;
917
+ *x_sc = tile_x_sc_q5_K;
918
+ }
919
+
920
+ template <int mmq_y, int nwarps, bool need_check>
921
+ static __dpct_inline__ void
922
+ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
923
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
924
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
925
+ const int &k, const int &blocks_per_row) {
926
+ (void)x_qh;
927
+
928
+ GGML_SYCL_ASSUME(i_offset >= 0);
929
+ GGML_SYCL_ASSUME(i_offset < nwarps);
930
+ GGML_SYCL_ASSUME(k >= 0);
931
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
932
+
933
+ const int kbx = k / QI5_K; // == 0 if QK_K == 256
934
+ const int kqsx = k % QI5_K; // == k if QK_K == 256
935
+
936
+ const block_q5_K * bx0 = (const block_q5_K *) vx;
937
+
938
+ #pragma unroll
939
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
940
+ int i = i0 + i_offset;
941
+
942
+ if (need_check) {
943
+ i = sycl::min(i, i_max);
944
+ }
945
+
946
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
947
+ const int ky = QR5_K*kqsx;
948
+
949
+ const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
950
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
951
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
952
+
953
+ const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
954
+ const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
955
+ const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
956
+
957
+ const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
958
+ const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
959
+
960
+ x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
961
+ x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
962
+ }
963
+
964
+ constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
965
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
966
+
967
+ #pragma unroll
968
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
969
+ int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
970
+
971
+ if (need_check) {
972
+ i = sycl::min(i, i_max);
973
+ }
974
+
975
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
976
+
977
+ #if QK_K == 256
978
+ x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
979
+ #endif
980
+ }
981
+
982
+ #pragma unroll
983
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
984
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
985
+
986
+ if (need_check) {
987
+ i = sycl::min(i, i_max);
988
+ }
989
+
990
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
991
+
992
+ const int * scales = (const int *) bxi->scales;
993
+
994
+ const int ksc = k % (WARP_SIZE/8);
995
+
996
+ // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
997
+ int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
998
+ scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
999
+
1000
+ x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
1001
+ }
1002
+ }
1003
+
1004
+ #define VDR_Q5_K_Q8_1_MMQ 8
1005
+
1006
+ // contiguous u/y values
1007
+ static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
1008
+ const int *__restrict__ v, const int *__restrict__ u,
1009
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
1010
+ const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
1011
+
1012
+ float sumf_d = 0.0f;
1013
+ float sumf_m = 0.0f;
1014
+
1015
+ #pragma unroll
1016
+ for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
1017
+ int sumi_d = 0;
1018
+
1019
+ #pragma unroll
1020
+ for (int j = 0; j < QI8_1; ++j) {
1021
+ sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
1022
+ sumi_d); // SIMD dot product
1023
+ }
1024
+
1025
+ const sycl::float2 ds8f =
1026
+ ds8[i].convert<float, sycl::rounding_mode::automatic>();
1027
+
1028
+ sumf_d += ds8f.x() * (sc[i] * sumi_d);
1029
+ sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
1030
+ }
1031
+
1032
+ const sycl::float2 dm4f =
1033
+ dm4.convert<float, sycl::rounding_mode::automatic>();
1034
+
1035
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
1036
+ }
1037
+
1038
+ static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
1039
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1040
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1041
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1042
+ const int &i, const int &j, const int &k) {
1043
+ (void)x_qh;
1044
+
1045
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
1046
+
1047
+ const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
1048
+ const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
1049
+ return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
1050
+ x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
1051
+ }
1052
+
1053
+ template <int mmq_y>
1054
+ static __dpct_inline__ void
1055
+ allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
1056
+ int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
1057
+ (void)x_qh;
1058
+
1059
+ *x_ql = tile_x_ql;
1060
+ *x_dm = tile_x_dm;
1061
+ *x_sc = tile_x_sc;
1062
+ }
1063
+
1064
+ template <int mmq_y, int nwarps, bool need_check>
1065
+ static __dpct_inline__ void
1066
+ load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
1067
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
1068
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
1069
+ const int &k, const int &blocks_per_row) {
1070
+ (void)x_qh;
1071
+
1072
+ GGML_SYCL_ASSUME(i_offset >= 0);
1073
+ GGML_SYCL_ASSUME(i_offset < nwarps);
1074
+ GGML_SYCL_ASSUME(k >= 0);
1075
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
1076
+
1077
+ const int kbx = k / QI6_K; // == 0 if QK_K == 256
1078
+ const int kqsx = k % QI6_K; // == k if QK_K == 256
1079
+
1080
+ const block_q6_K * bx0 = (const block_q6_K *) vx;
1081
+
1082
+ #pragma unroll
1083
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1084
+ int i = i0 + i_offset;
1085
+
1086
+ if (need_check) {
1087
+ i = sycl::min(i, i_max);
1088
+ }
1089
+
1090
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
1091
+ const int ky = QR6_K*kqsx;
1092
+
1093
+ const int ql = get_int_from_uint8(bxi->ql, kqsx);
1094
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1095
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1096
+
1097
+ const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1098
+ const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1099
+ const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
1100
+
1101
+ const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
1102
+ const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
1103
+
1104
+ x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
1105
+ dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
1106
+ dpct::sub_sat());
1107
+ x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
1108
+ dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
1109
+ dpct::sub_sat());
1110
+ }
1111
+
1112
+ constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1113
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
1114
+ float * x_dmf = (float *) x_dm;
1115
+
1116
+ #pragma unroll
1117
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1118
+ int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
1119
+
1120
+ if (need_check) {
1121
+ i = sycl::min(i, i_max);
1122
+ }
1123
+
1124
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
1125
+
1126
+ x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
1127
+ }
1128
+
1129
+ #pragma unroll
1130
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1131
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
1132
+
1133
+ if (need_check) {
1134
+ i = sycl::min(i, i_max);
1135
+ }
1136
+
1137
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
1138
+
1139
+ x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
1140
+ }
1141
+ }
1142
+
1143
+ #define VDR_Q6_K_Q8_1_MMQ 8
1144
+
1145
+ // contiguous u/y values
1146
+ static __dpct_inline__ float
1147
+ vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
1148
+ const int8_t *__restrict__ sc, const float &d6,
1149
+ const float *__restrict__ d8) {
1150
+
1151
+ float sumf_d = 0.0f;
1152
+
1153
+ #pragma unroll
1154
+ for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
1155
+ sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
1156
+
1157
+ #pragma unroll
1158
+ for (int i = i0; i < i0 + 2; ++i) {
1159
+ sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
1160
+ sumi_d.x()); // SIMD dot product
1161
+ sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
1162
+ sumi_d.x()); // SIMD dot product
1163
+
1164
+ sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
1165
+ sumi_d.y()); // SIMD dot product
1166
+ sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
1167
+ sumi_d.y()); // SIMD dot product
1168
+ }
1169
+
1170
+ sumf_d += d8[i0 / 4] *
1171
+ (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
1172
+ }
1173
+
1174
+ return d6 * sumf_d;
1175
+ }
1176
+
1177
+ static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
1178
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1179
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1180
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1181
+ const int &i, const int &j, const int &k) {
1182
+ (void)x_qh;
1183
+
1184
+ const float * x_dmf = (const float *) x_dm;
1185
+ const float * y_df = (const float *) y_ds;
1186
+
1187
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
1188
+
1189
+ const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
1190
+ const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
1191
+ return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
1192
+ }
1193
+
1194
+ template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
1195
+ int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
1196
+ vec_dot_q_mul_mat_sycl_t vec_dot>
1197
+ /*
1198
+ DPCT1110:8: The total declared local variable size in device function mul_mat_q
1199
+ exceeds 128 bytes and may cause high register pressure. Consult with your
1200
+ hardware vendor to find the total register size available and adjust the code,
1201
+ or use smaller sub-group size to avoid high register pressure.
1202
+ */
1203
+ static __dpct_inline__ void
1204
+ mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
1205
+ float *__restrict__ dst, const int ncols_x, const int nrows_x,
1206
+ const int ncols_y, const int nrows_y, const int nrows_dst,
1207
+ int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
1208
+ int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
1209
+ sycl::half2 *tile_y_ds) {
1210
+
1211
+ const block_q_t * x = (const block_q_t *) vx;
1212
+ const block_q8_1 * y = (const block_q8_1 *) vy;
1213
+
1214
+ const int blocks_per_row_x = ncols_x / qk;
1215
+ const int blocks_per_col_y = nrows_y / QK8_1;
1216
+ const int blocks_per_warp = WARP_SIZE / qi;
1217
+
1218
+ const int & ncols_dst = ncols_y;
1219
+
1220
+ const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
1221
+ const int & row_x_0 = row_dst_0;
1222
+
1223
+ const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
1224
+ const int & col_y_0 = col_dst_0;
1225
+
1226
+ float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
1227
+
1228
+ for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
1229
+
1230
+ load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
1231
+ tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
1232
+ nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
1233
+ blocks_per_row_x);
1234
+
1235
+ #pragma unroll
1236
+ for (int ir = 0; ir < qr; ++ir) {
1237
+ const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
1238
+ const int kbxd = kqs / QI8_1;
1239
+
1240
+ #pragma unroll
1241
+ for (int i = 0; i < mmq_x; i += nwarps) {
1242
+ const int col_y_eff = dpct::min(
1243
+ (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
1244
+ ncols_y - 1); // to prevent out-of-bounds memory accesses
1245
+
1246
+ const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
1247
+
1248
+ const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
1249
+ kqs % WARP_SIZE;
1250
+ tile_y_qs[index_y] = get_int_from_int8_aligned(
1251
+ by0->qs, item_ct1.get_local_id(2) % QI8_1);
1252
+ }
1253
+
1254
+ #pragma unroll
1255
+ for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1256
+ const int ids =
1257
+ (ids0 + item_ct1.get_local_id(1) * QI8_1 +
1258
+ item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
1259
+ mmq_x;
1260
+ const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
1261
+ const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
1262
+
1263
+ // if the sum is not needed it's faster to transform the scale to f32 ahead of time
1264
+ const sycl::half2 *dsi_src =
1265
+ &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
1266
+ ir * (WARP_SIZE / QI8_1) + kby]
1267
+ .ds;
1268
+ sycl::half2 *dsi_dst =
1269
+ &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
1270
+ if (need_sum) {
1271
+ *dsi_dst = *dsi_src;
1272
+ } else {
1273
+ float * dfi_dst = (float *) dsi_dst;
1274
+ *dfi_dst = (*dsi_src)[0];
1275
+ }
1276
+ }
1277
+
1278
+ /*
1279
+ DPCT1118:9: SYCL group functions and algorithms must be encountered
1280
+ in converged control flow. You may need to adjust the code.
1281
+ */
1282
+ /*
1283
+ DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
1284
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1285
+ better performance if there is no access to global memory.
1286
+ */
1287
+ item_ct1.barrier();
1288
+
1289
+ // #pragma unroll // unrolling this loop causes too much register pressure
1290
+ for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
1291
+ #pragma unroll
1292
+ for (int j = 0; j < mmq_x; j += nwarps) {
1293
+ #pragma unroll
1294
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1295
+ sum[i / WARP_SIZE][j / nwarps] += vec_dot(
1296
+ tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1297
+ tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
1298
+ item_ct1.get_local_id(1) + j, k);
1299
+ }
1300
+ }
1301
+ }
1302
+
1303
+ /*
1304
+ DPCT1118:10: SYCL group functions and algorithms must be encountered
1305
+ in converged control flow. You may need to adjust the code.
1306
+ */
1307
+ /*
1308
+ DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
1309
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1310
+ better performance if there is no access to global memory.
1311
+ */
1312
+ item_ct1.barrier();
1313
+ }
1314
+ }
1315
+
1316
+ #pragma unroll
1317
+ for (int j = 0; j < mmq_x; j += nwarps) {
1318
+ const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
1319
+
1320
+ if (col_dst >= ncols_dst) {
1321
+ return;
1322
+ }
1323
+
1324
+ #pragma unroll
1325
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1326
+ const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
1327
+
1328
+ if (row_dst >= nrows_dst) {
1329
+ continue;
1330
+ }
1331
+
1332
+ dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
1333
+ }
1334
+ }
1335
+ }
1336
+
1337
+ #define MMQ_X_Q4_0_RDNA2 64
1338
+ #define MMQ_Y_Q4_0_RDNA2 128
1339
+ #define NWARPS_Q4_0_RDNA2 8
1340
+ #define MMQ_X_Q4_0_RDNA1 64
1341
+ #define MMQ_Y_Q4_0_RDNA1 64
1342
+ #define NWARPS_Q4_0_RDNA1 8
1343
+ #if defined(SYCL_USE_XMX)
1344
+ #define MMQ_X_Q4_0_AMPERE 4
1345
+ #define MMQ_Y_Q4_0_AMPERE 32
1346
+ #define NWARPS_Q4_0_AMPERE 4
1347
+ #else
1348
+ #define MMQ_X_Q4_0_AMPERE 64
1349
+ #define MMQ_Y_Q4_0_AMPERE 128
1350
+ #define NWARPS_Q4_0_AMPERE 4
1351
+ #endif
1352
+ #define MMQ_X_Q4_0_PASCAL 64
1353
+ #define MMQ_Y_Q4_0_PASCAL 64
1354
+ #define NWARPS_Q4_0_PASCAL 8
1355
+
1356
+ template <bool need_check> static void
1357
+ mul_mat_q4_0(
1358
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1359
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1360
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
1361
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1362
+ int * tile_x_ql = nullptr;
1363
+ sycl::half2 *tile_x_dm = nullptr;
1364
+ int * tile_x_qh = nullptr;
1365
+ int * tile_x_sc = nullptr;
1366
+
1367
+ //sycl_todo: change according to hardware
1368
+
1369
+ const int mmq_x = MMQ_X_Q4_0_AMPERE;
1370
+ const int mmq_y = MMQ_Y_Q4_0_AMPERE;
1371
+ const int nwarps = NWARPS_Q4_0_AMPERE;
1372
+ allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1373
+ tile_x_qs_q4_0, tile_x_d_q4_0);
1374
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
1375
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
1376
+ vec_dot_q4_0_q8_1_mul_mat>(
1377
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1378
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1379
+ }
1380
+
1381
+ #define MMQ_X_Q4_1_RDNA2 64
1382
+ #define MMQ_Y_Q4_1_RDNA2 128
1383
+ #define NWARPS_Q4_1_RDNA2 8
1384
+ #define MMQ_X_Q4_1_RDNA1 64
1385
+ #define MMQ_Y_Q4_1_RDNA1 64
1386
+ #define NWARPS_Q4_1_RDNA1 8
1387
+ #if defined(SYCL_USE_XMX)
1388
+ #define MMQ_X_Q4_1_AMPERE 4
1389
+ #define MMQ_Y_Q4_1_AMPERE 32
1390
+ #define NWARPS_Q4_1_AMPERE 4
1391
+ #else
1392
+ #define MMQ_X_Q4_1_AMPERE 64
1393
+ #define MMQ_Y_Q4_1_AMPERE 128
1394
+ #define NWARPS_Q4_1_AMPERE 4
1395
+ #endif
1396
+ #define MMQ_X_Q4_1_PASCAL 64
1397
+ #define MMQ_Y_Q4_1_PASCAL 64
1398
+ #define NWARPS_Q4_1_PASCAL 8
1399
+
1400
+ template <bool need_check> static void
1401
+ mul_mat_q4_1(
1402
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1403
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1404
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
1405
+ sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1406
+ int * tile_x_ql = nullptr;
1407
+ sycl::half2 *tile_x_dm = nullptr;
1408
+ int * tile_x_qh = nullptr;
1409
+ int * tile_x_sc = nullptr;
1410
+
1411
+ //sycl_todo: change according to hardware
1412
+ const int mmq_x = MMQ_X_Q4_1_AMPERE;
1413
+ const int mmq_y = MMQ_Y_Q4_1_AMPERE;
1414
+ const int nwarps = NWARPS_Q4_1_AMPERE;
1415
+ allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1416
+ tile_x_qs_q4_1, tile_x_dm_q4_1);
1417
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
1418
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
1419
+ vec_dot_q4_1_q8_1_mul_mat>(
1420
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1421
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1422
+ }
1423
+
1424
+ #define MMQ_X_Q5_0_RDNA2 64
1425
+ #define MMQ_Y_Q5_0_RDNA2 128
1426
+ #define NWARPS_Q5_0_RDNA2 8
1427
+ #define MMQ_X_Q5_0_RDNA1 64
1428
+ #define MMQ_Y_Q5_0_RDNA1 64
1429
+ #define NWARPS_Q5_0_RDNA1 8
1430
+ #if defined(SYCL_USE_XMX)
1431
+ #define MMQ_X_Q5_0_AMPERE 4
1432
+ #define MMQ_Y_Q5_0_AMPERE 32
1433
+ #define NWARPS_Q5_0_AMPERE 4
1434
+ #else
1435
+ #define MMQ_X_Q5_0_AMPERE 128
1436
+ #define MMQ_Y_Q5_0_AMPERE 64
1437
+ #define NWARPS_Q5_0_AMPERE 4
1438
+ #endif
1439
+ #define MMQ_X_Q5_0_PASCAL 64
1440
+ #define MMQ_Y_Q5_0_PASCAL 64
1441
+ #define NWARPS_Q5_0_PASCAL 8
1442
+
1443
+ template <bool need_check> static void
1444
+ mul_mat_q5_0(
1445
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1446
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1447
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
1448
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1449
+ int * tile_x_ql = nullptr;
1450
+ sycl::half2 *tile_x_dm = nullptr;
1451
+ int * tile_x_qh = nullptr;
1452
+ int * tile_x_sc = nullptr;
1453
+
1454
+ //sycl_todo: change according to hardware
1455
+ const int mmq_x = MMQ_X_Q5_0_AMPERE;
1456
+ const int mmq_y = MMQ_Y_Q5_0_AMPERE;
1457
+ const int nwarps = NWARPS_Q5_0_AMPERE;
1458
+ allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1459
+ tile_x_ql_q5_0, tile_x_d_q5_0);
1460
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
1461
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
1462
+ vec_dot_q5_0_q8_1_mul_mat>(
1463
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1464
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1465
+ }
1466
+
1467
+ #define MMQ_X_Q5_1_RDNA2 64
1468
+ #define MMQ_Y_Q5_1_RDNA2 128
1469
+ #define NWARPS_Q5_1_RDNA2 8
1470
+ #define MMQ_X_Q5_1_RDNA1 64
1471
+ #define MMQ_Y_Q5_1_RDNA1 64
1472
+ #define NWARPS_Q5_1_RDNA1 8
1473
+ #if defined(SYCL_USE_XMX)
1474
+ #define MMQ_X_Q5_1_AMPERE 4
1475
+ #define MMQ_Y_Q5_1_AMPERE 32
1476
+ #define NWARPS_Q5_1_AMPERE 4
1477
+ #else
1478
+ #define MMQ_X_Q5_1_AMPERE 128
1479
+ #define MMQ_Y_Q5_1_AMPERE 64
1480
+ #define NWARPS_Q5_1_AMPERE 4
1481
+ #endif
1482
+ #define MMQ_X_Q5_1_PASCAL 64
1483
+ #define MMQ_Y_Q5_1_PASCAL 64
1484
+ #define NWARPS_Q5_1_PASCAL 8
1485
+
1486
+ template <bool need_check> static void
1487
+ mul_mat_q5_1(
1488
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1489
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1490
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
1491
+ sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1492
+ int * tile_x_ql = nullptr;
1493
+ sycl::half2 *tile_x_dm = nullptr;
1494
+ int * tile_x_qh = nullptr;
1495
+ int * tile_x_sc = nullptr;
1496
+
1497
+ //sycl_todo: change according to hardware
1498
+ const int mmq_x = MMQ_X_Q5_1_AMPERE;
1499
+ const int mmq_y = MMQ_Y_Q5_1_AMPERE;
1500
+ const int nwarps = NWARPS_Q5_1_AMPERE;
1501
+ allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1502
+ tile_x_ql_q5_1, tile_x_dm_q5_1);
1503
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
1504
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
1505
+ vec_dot_q5_1_q8_1_mul_mat>(
1506
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1507
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1508
+ }
1509
+
1510
+ #define MMQ_X_Q8_0_RDNA2 64
1511
+ #define MMQ_Y_Q8_0_RDNA2 128
1512
+ #define NWARPS_Q8_0_RDNA2 8
1513
+ #define MMQ_X_Q8_0_RDNA1 64
1514
+ #define MMQ_Y_Q8_0_RDNA1 64
1515
+ #define NWARPS_Q8_0_RDNA1 8
1516
+ #if defined(SYCL_USE_XMX)
1517
+ #define MMQ_X_Q8_0_AMPERE 4
1518
+ #define MMQ_Y_Q8_0_AMPERE 32
1519
+ #define NWARPS_Q8_0_AMPERE 4
1520
+ #else
1521
+ #define MMQ_X_Q8_0_AMPERE 128
1522
+ #define MMQ_Y_Q8_0_AMPERE 64
1523
+ #define NWARPS_Q8_0_AMPERE 4
1524
+ #endif
1525
+ #define MMQ_X_Q8_0_PASCAL 64
1526
+ #define MMQ_Y_Q8_0_PASCAL 64
1527
+ #define NWARPS_Q8_0_PASCAL 8
1528
+
1529
+ template <bool need_check> static void
1530
+ mul_mat_q8_0(
1531
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1532
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1533
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
1534
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1535
+ int * tile_x_ql = nullptr;
1536
+ sycl::half2 *tile_x_dm = nullptr;
1537
+ int * tile_x_qh = nullptr;
1538
+ int * tile_x_sc = nullptr;
1539
+
1540
+ //sycl_todo: change according to hardware
1541
+ const int mmq_x = MMQ_X_Q8_0_AMPERE;
1542
+ const int mmq_y = MMQ_Y_Q8_0_AMPERE;
1543
+ const int nwarps = NWARPS_Q8_0_AMPERE;
1544
+ allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1545
+ tile_x_qs_q8_0, tile_x_d_q8_0);
1546
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
1547
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
1548
+ vec_dot_q8_0_q8_1_mul_mat>(
1549
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1550
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1551
+ }
1552
+
1553
+ #define MMQ_X_Q2_K_RDNA2 64
1554
+ #define MMQ_Y_Q2_K_RDNA2 128
1555
+ #define NWARPS_Q2_K_RDNA2 8
1556
+ #define MMQ_X_Q2_K_RDNA1 128
1557
+ #define MMQ_Y_Q2_K_RDNA1 32
1558
+ #define NWARPS_Q2_K_RDNA1 8
1559
+ #if defined(SYCL_USE_XMX)
1560
+ #define MMQ_X_Q2_K_AMPERE 4
1561
+ #define MMQ_Y_Q2_K_AMPERE 32
1562
+ #define NWARPS_Q2_K_AMPERE 4
1563
+ #else
1564
+ #define MMQ_X_Q2_K_AMPERE 64
1565
+ #define MMQ_Y_Q2_K_AMPERE 128
1566
+ #define NWARPS_Q2_K_AMPERE 4
1567
+ #endif
1568
+ #define MMQ_X_Q2_K_PASCAL 64
1569
+ #define MMQ_Y_Q2_K_PASCAL 64
1570
+ #define NWARPS_Q2_K_PASCAL 8
1571
+
1572
+ template <bool need_check> static void
1573
+ mul_mat_q2_K(
1574
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1575
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1576
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
1577
+ sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
1578
+ sycl::half2 *tile_y_ds) {
1579
+ int * tile_x_ql = nullptr;
1580
+ sycl::half2 *tile_x_dm = nullptr;
1581
+ int * tile_x_qh = nullptr;
1582
+ int * tile_x_sc = nullptr;
1583
+
1584
+ //sycl_todo: change according to hardware
1585
+ const int mmq_x = MMQ_X_Q2_K_AMPERE;
1586
+ const int mmq_y = MMQ_Y_Q2_K_AMPERE;
1587
+ const int nwarps = NWARPS_Q2_K_AMPERE;
1588
+ allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1589
+ tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
1590
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
1591
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
1592
+ vec_dot_q2_K_q8_1_mul_mat>(
1593
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1594
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1595
+ }
1596
+
1597
+ #define MMQ_X_Q3_K_RDNA2 128
1598
+ #define MMQ_Y_Q3_K_RDNA2 64
1599
+ #define NWARPS_Q3_K_RDNA2 8
1600
+ #define MMQ_X_Q3_K_RDNA1 32
1601
+ #define MMQ_Y_Q3_K_RDNA1 128
1602
+ #define NWARPS_Q3_K_RDNA1 8
1603
+ #if defined(SYCL_USE_XMX)
1604
+ #define MMQ_X_Q3_K_AMPERE 4
1605
+ #define MMQ_Y_Q3_K_AMPERE 32
1606
+ #define NWARPS_Q3_K_AMPERE 4
1607
+ #else
1608
+ #define MMQ_X_Q3_K_AMPERE 128
1609
+ #define MMQ_Y_Q3_K_AMPERE 128
1610
+ #define NWARPS_Q3_K_AMPERE 4
1611
+ #endif
1612
+ #define MMQ_X_Q3_K_PASCAL 64
1613
+ #define MMQ_Y_Q3_K_PASCAL 64
1614
+ #define NWARPS_Q3_K_PASCAL 8
1615
+
1616
+ template <bool need_check> static void
1617
+ mul_mat_q3_K(
1618
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1619
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1620
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
1621
+ sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
1622
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1623
+ int * tile_x_ql = nullptr;
1624
+ sycl::half2 *tile_x_dm = nullptr;
1625
+ int * tile_x_qh = nullptr;
1626
+ int * tile_x_sc = nullptr;
1627
+
1628
+ //sycl_todo: change according to hardware
1629
+ const int mmq_x = MMQ_X_Q3_K_AMPERE;
1630
+ const int mmq_y = MMQ_Y_Q3_K_AMPERE;
1631
+ const int nwarps = NWARPS_Q3_K_AMPERE;
1632
+ allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1633
+ tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
1634
+ tile_x_sc_q3_K);
1635
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
1636
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
1637
+ vec_dot_q3_K_q8_1_mul_mat>(
1638
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1639
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1640
+ }
1641
+
1642
+ #define MMQ_X_Q4_K_RDNA2 64
1643
+ #define MMQ_Y_Q4_K_RDNA2 128
1644
+ #define NWARPS_Q4_K_RDNA2 8
1645
+ #define MMQ_X_Q4_K_RDNA1 32
1646
+ #define MMQ_Y_Q4_K_RDNA1 64
1647
+ #define NWARPS_Q4_K_RDNA1 8
1648
+ #if defined(SYCL_USE_XMX)
1649
+ #define MMQ_X_Q4_K_AMPERE 4
1650
+ #define MMQ_Y_Q4_K_AMPERE 32
1651
+ #define NWARPS_Q4_K_AMPERE 4
1652
+ #else
1653
+ #define MMQ_X_Q4_K_AMPERE 64
1654
+ #define MMQ_Y_Q4_K_AMPERE 128
1655
+ #define NWARPS_Q4_K_AMPERE 4
1656
+ #endif
1657
+ #define MMQ_X_Q4_K_PASCAL 64
1658
+ #define MMQ_Y_Q4_K_PASCAL 64
1659
+ #define NWARPS_Q4_K_PASCAL 8
1660
+
1661
+ template <bool need_check> static void
1662
+ mul_mat_q4_K(
1663
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1664
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1665
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
1666
+ sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
1667
+ sycl::half2 *tile_y_ds) {
1668
+ int * tile_x_ql = nullptr;
1669
+ sycl::half2 *tile_x_dm = nullptr;
1670
+ int * tile_x_qh = nullptr;
1671
+ int * tile_x_sc = nullptr;
1672
+
1673
+ //sycl_todo: change according to hardware
1674
+ const int mmq_x = MMQ_X_Q4_K_AMPERE;
1675
+ const int mmq_y = MMQ_Y_Q4_K_AMPERE;
1676
+ const int nwarps = NWARPS_Q4_K_AMPERE;
1677
+ allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1678
+ tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
1679
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
1680
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
1681
+ vec_dot_q4_K_q8_1_mul_mat>(
1682
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1683
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1684
+ }
1685
+
1686
+ #define MMQ_X_Q5_K_RDNA2 64
1687
+ #define MMQ_Y_Q5_K_RDNA2 128
1688
+ #define NWARPS_Q5_K_RDNA2 8
1689
+ #define MMQ_X_Q5_K_RDNA1 32
1690
+ #define MMQ_Y_Q5_K_RDNA1 64
1691
+ #define NWARPS_Q5_K_RDNA1 8
1692
+ #if defined(SYCL_USE_XMX)
1693
+ #define MMQ_X_Q5_K_AMPERE 4
1694
+ #define MMQ_Y_Q5_K_AMPERE 32
1695
+ #define NWARPS_Q5_K_AMPERE 4
1696
+ #else
1697
+ #define MMQ_X_Q5_K_AMPERE 64
1698
+ #define MMQ_Y_Q5_K_AMPERE 128
1699
+ #define NWARPS_Q5_K_AMPERE 4
1700
+ #endif
1701
+ #define MMQ_X_Q5_K_PASCAL 64
1702
+ #define MMQ_Y_Q5_K_PASCAL 64
1703
+ #define NWARPS_Q5_K_PASCAL 8
1704
+
1705
+ template <bool need_check> static void
1706
+ mul_mat_q5_K(
1707
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1708
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1709
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
1710
+ sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
1711
+ sycl::half2 *tile_y_ds) {
1712
+ int * tile_x_ql = nullptr;
1713
+ sycl::half2 *tile_x_dm = nullptr;
1714
+ int * tile_x_qh = nullptr;
1715
+ int * tile_x_sc = nullptr;
1716
+
1717
+ //sycl_todo: change according to hardware
1718
+ const int mmq_x = MMQ_X_Q5_K_AMPERE;
1719
+ const int mmq_y = MMQ_Y_Q5_K_AMPERE;
1720
+ const int nwarps = NWARPS_Q5_K_AMPERE;
1721
+ allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1722
+ tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
1723
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
1724
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
1725
+ vec_dot_q5_K_q8_1_mul_mat>(
1726
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1727
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1728
+ }
1729
+
1730
+ #define MMQ_X_Q6_K_RDNA2 64
1731
+ #define MMQ_Y_Q6_K_RDNA2 128
1732
+ #define NWARPS_Q6_K_RDNA2 8
1733
+ #define MMQ_X_Q6_K_RDNA1 32
1734
+ #define MMQ_Y_Q6_K_RDNA1 64
1735
+ #define NWARPS_Q6_K_RDNA1 8
1736
+ #if defined(SYCL_USE_XMX)
1737
+ #define MMQ_X_Q6_K_AMPERE 4
1738
+ #define MMQ_Y_Q6_K_AMPERE 32
1739
+ #define NWARPS_Q6_K_AMPERE 4
1740
+ #else
1741
+ #define MMQ_X_Q6_K_AMPERE 64
1742
+ #define MMQ_Y_Q6_K_AMPERE 64
1743
+ #define NWARPS_Q6_K_AMPERE 4
1744
+ #endif
1745
+ #define MMQ_X_Q6_K_PASCAL 64
1746
+ #define MMQ_Y_Q6_K_PASCAL 64
1747
+ #define NWARPS_Q6_K_PASCAL 8
1748
+
1749
+ template <bool need_check> static void
1750
+ mul_mat_q6_K(
1751
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1752
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1753
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
1754
+ int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1755
+ // int * tile_x_ql = nullptr;
1756
+ // sycl::half2 *tile_x_dm = nullptr;
1757
+ int * tile_x_qh = nullptr;
1758
+ // int * tile_x_sc = nullptr;
1759
+
1760
+ //sycl_todo: change according to hardware
1761
+ const int mmq_x = MMQ_X_Q6_K_AMPERE;
1762
+ const int mmq_y = MMQ_Y_Q6_K_AMPERE;
1763
+ const int nwarps = NWARPS_Q6_K_AMPERE;
1764
+ allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1765
+ tile_x_ql, tile_x_dm, tile_x_sc);
1766
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
1767
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
1768
+ vec_dot_q6_K_q8_1_mul_mat>(
1769
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1770
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1771
+ }
1772
+
1773
+ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
1774
+ float *dst, const int ncols_x,
1775
+ const int nrows_x, const int ncols_y,
1776
+ const int nrows_y, const int nrows_dst,
1777
+ dpct::queue_ptr stream) try {
1778
+
1779
+ int id;
1780
+ SYCL_CHECK(
1781
+ CHECK_TRY_ERROR(id = get_current_device_id()));
1782
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
1783
+
1784
+ int mmq_x, mmq_y, nwarps;
1785
+ if (compute_capability >= VER_GEN13) {
1786
+ mmq_x = MMQ_X_Q4_0_RDNA2;
1787
+ mmq_y = MMQ_Y_Q4_0_RDNA2;
1788
+ nwarps = NWARPS_Q4_0_RDNA2;
1789
+ } else if (compute_capability >= VER_GEN12) {
1790
+ mmq_x = MMQ_X_Q4_0_RDNA1;
1791
+ mmq_y = MMQ_Y_Q4_0_RDNA1;
1792
+ nwarps = NWARPS_Q4_0_RDNA1;
1793
+ } else if (compute_capability >= VER_GEN9) {
1794
+ mmq_x = MMQ_X_Q4_0_AMPERE;
1795
+ mmq_y = MMQ_Y_Q4_0_AMPERE;
1796
+ nwarps = NWARPS_Q4_0_AMPERE;
1797
+ } else if (compute_capability >= VER_4VEC) {
1798
+ mmq_x = MMQ_X_Q4_0_PASCAL;
1799
+ mmq_y = MMQ_Y_Q4_0_PASCAL;
1800
+ nwarps = NWARPS_Q4_0_PASCAL;
1801
+ } else {
1802
+ GGML_ABORT("fatal error");
1803
+ }
1804
+
1805
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1806
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1807
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1808
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1809
+
1810
+ if (nrows_x % mmq_y == 0) {
1811
+ const bool need_check = false;
1812
+ /*
1813
+ DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
1814
+ the limit. To get the device limit, query
1815
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1816
+ */
1817
+ {
1818
+ dpct::has_capability_or_fail(stream->get_device(),
1819
+ {sycl::aspect::fp16});
1820
+
1821
+ stream->submit([&](sycl::handler &cgh) {
1822
+ sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1823
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1824
+ sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1825
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1826
+ cgh);
1827
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1828
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1829
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1830
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1831
+
1832
+ cgh.parallel_for(
1833
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1834
+ [=](sycl::nd_item<3> item_ct1) {
1835
+ mul_mat_q4_0<need_check>(
1836
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1837
+ nrows_dst, item_ct1,
1838
+ get_pointer(tile_x_qs_q4_0_acc_ct1),
1839
+ get_pointer(tile_x_d_q4_0_acc_ct1),
1840
+ get_pointer(tile_y_qs_acc_ct1),
1841
+ get_pointer(tile_y_ds_acc_ct1));
1842
+ });
1843
+ });
1844
+ }
1845
+ } else {
1846
+ const bool need_check = true;
1847
+ /*
1848
+ DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
1849
+ the limit. To get the device limit, query
1850
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1851
+ */
1852
+ {
1853
+ dpct::has_capability_or_fail(stream->get_device(),
1854
+ {sycl::aspect::fp16});
1855
+
1856
+ stream->submit([&](sycl::handler &cgh) {
1857
+ sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1858
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1859
+ sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1860
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1861
+ cgh);
1862
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1863
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1864
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1865
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1866
+
1867
+ cgh.parallel_for(
1868
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1869
+ [=](sycl::nd_item<3> item_ct1) {
1870
+ mul_mat_q4_0<need_check>(
1871
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1872
+ nrows_dst, item_ct1,
1873
+ get_pointer(tile_x_qs_q4_0_acc_ct1),
1874
+ get_pointer(tile_x_d_q4_0_acc_ct1),
1875
+ get_pointer(tile_y_qs_acc_ct1),
1876
+ get_pointer(tile_y_ds_acc_ct1));
1877
+ });
1878
+ });
1879
+ }
1880
+ }
1881
+ }
1882
+ catch (sycl::exception const &exc) {
1883
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1884
+ << ", line:" << __LINE__ << std::endl;
1885
+ std::exit(1);
1886
+ }
1887
+
1888
+ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
1889
+ float *dst, const int ncols_x,
1890
+ const int nrows_x, const int ncols_y,
1891
+ const int nrows_y, const int nrows_dst,
1892
+ dpct::queue_ptr stream) try {
1893
+
1894
+ int id;
1895
+ SYCL_CHECK(
1896
+ CHECK_TRY_ERROR(id = get_current_device_id()));
1897
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
1898
+
1899
+ int mmq_x, mmq_y, nwarps;
1900
+ if (compute_capability >= VER_GEN13) {
1901
+ mmq_x = MMQ_X_Q4_1_RDNA2;
1902
+ mmq_y = MMQ_Y_Q4_1_RDNA2;
1903
+ nwarps = NWARPS_Q4_1_RDNA2;
1904
+ } else if (compute_capability >= VER_GEN12) {
1905
+ mmq_x = MMQ_X_Q4_1_RDNA1;
1906
+ mmq_y = MMQ_Y_Q4_1_RDNA1;
1907
+ nwarps = NWARPS_Q4_1_RDNA1;
1908
+ } else if (compute_capability >= VER_GEN9) {
1909
+ mmq_x = MMQ_X_Q4_1_AMPERE;
1910
+ mmq_y = MMQ_Y_Q4_1_AMPERE;
1911
+ nwarps = NWARPS_Q4_1_AMPERE;
1912
+ } else if (compute_capability >= VER_4VEC) {
1913
+ mmq_x = MMQ_X_Q4_1_PASCAL;
1914
+ mmq_y = MMQ_Y_Q4_1_PASCAL;
1915
+ nwarps = NWARPS_Q4_1_PASCAL;
1916
+ } else {
1917
+ GGML_ABORT("fatal error");
1918
+ }
1919
+
1920
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1921
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1922
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1923
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1924
+
1925
+ if (nrows_x % mmq_y == 0) {
1926
+ const bool need_check = false;
1927
+ /*
1928
+ DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
1929
+ the limit. To get the device limit, query
1930
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1931
+ */
1932
+ {
1933
+ dpct::has_capability_or_fail(stream->get_device(),
1934
+ {sycl::aspect::fp16});
1935
+
1936
+ stream->submit([&](sycl::handler &cgh) {
1937
+ sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1938
+ sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1939
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1940
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1941
+ cgh);
1942
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1943
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1944
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1945
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1946
+
1947
+ cgh.parallel_for(
1948
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1949
+ [=](sycl::nd_item<3> item_ct1) {
1950
+ mul_mat_q4_1<need_check>(
1951
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1952
+ nrows_dst, item_ct1,
1953
+ get_pointer(tile_x_qs_q4_1_acc_ct1),
1954
+ get_pointer(tile_x_dm_q4_1_acc_ct1),
1955
+ get_pointer(tile_y_qs_acc_ct1),
1956
+ get_pointer(tile_y_ds_acc_ct1));
1957
+ });
1958
+ });
1959
+ }
1960
+ } else {
1961
+ const bool need_check = true;
1962
+ /*
1963
+ DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
1964
+ the limit. To get the device limit, query
1965
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1966
+ */
1967
+ {
1968
+ dpct::has_capability_or_fail(stream->get_device(),
1969
+ {sycl::aspect::fp16});
1970
+
1971
+ stream->submit([&](sycl::handler &cgh) {
1972
+ sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1973
+ sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1974
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1975
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1976
+ cgh);
1977
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1978
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1979
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1980
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1981
+
1982
+ cgh.parallel_for(
1983
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1984
+ [=](sycl::nd_item<3> item_ct1) {
1985
+ mul_mat_q4_1<need_check>(
1986
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1987
+ nrows_dst, item_ct1,
1988
+ get_pointer(tile_x_qs_q4_1_acc_ct1),
1989
+ get_pointer(tile_x_dm_q4_1_acc_ct1),
1990
+ get_pointer(tile_y_qs_acc_ct1),
1991
+ get_pointer(tile_y_ds_acc_ct1));
1992
+ });
1993
+ });
1994
+ }
1995
+ }
1996
+ }
1997
+ catch (sycl::exception const &exc) {
1998
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1999
+ << ", line:" << __LINE__ << std::endl;
2000
+ std::exit(1);
2001
+ }
2002
+
2003
+ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
2004
+ float *dst, const int ncols_x,
2005
+ const int nrows_x, const int ncols_y,
2006
+ const int nrows_y, const int nrows_dst,
2007
+ dpct::queue_ptr stream) try {
2008
+
2009
+ int id;
2010
+ SYCL_CHECK(
2011
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2012
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2013
+
2014
+ int mmq_x, mmq_y, nwarps;
2015
+ if (compute_capability >= VER_GEN13) {
2016
+ mmq_x = MMQ_X_Q5_0_RDNA2;
2017
+ mmq_y = MMQ_Y_Q5_0_RDNA2;
2018
+ nwarps = NWARPS_Q5_0_RDNA2;
2019
+ } else if (compute_capability >= VER_GEN12) {
2020
+ mmq_x = MMQ_X_Q5_0_RDNA1;
2021
+ mmq_y = MMQ_Y_Q5_0_RDNA1;
2022
+ nwarps = NWARPS_Q5_0_RDNA1;
2023
+ } else if (compute_capability >= VER_GEN9) {
2024
+ mmq_x = MMQ_X_Q5_0_AMPERE;
2025
+ mmq_y = MMQ_Y_Q5_0_AMPERE;
2026
+ nwarps = NWARPS_Q5_0_AMPERE;
2027
+ } else if (compute_capability >= VER_4VEC) {
2028
+ mmq_x = MMQ_X_Q5_0_PASCAL;
2029
+ mmq_y = MMQ_Y_Q5_0_PASCAL;
2030
+ nwarps = NWARPS_Q5_0_PASCAL;
2031
+ } else {
2032
+ GGML_ABORT("fatal error");
2033
+ }
2034
+
2035
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2036
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2037
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2038
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2039
+
2040
+ if (nrows_x % mmq_y == 0) {
2041
+ const bool need_check = false;
2042
+ /*
2043
+ DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
2044
+ the limit. To get the device limit, query
2045
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2046
+ */
2047
+ {
2048
+ dpct::has_capability_or_fail(stream->get_device(),
2049
+ {sycl::aspect::fp16});
2050
+
2051
+ stream->submit([&](sycl::handler &cgh) {
2052
+ sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2053
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2054
+ sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2055
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2056
+ cgh);
2057
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2058
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2059
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2060
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2061
+
2062
+ cgh.parallel_for(
2063
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2064
+ [=](sycl::nd_item<3> item_ct1) {
2065
+ mul_mat_q5_0<need_check>(
2066
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2067
+ nrows_dst, item_ct1,
2068
+ get_pointer(tile_x_ql_q5_0_acc_ct1),
2069
+ get_pointer(tile_x_d_q5_0_acc_ct1),
2070
+ get_pointer(tile_y_qs_acc_ct1),
2071
+ get_pointer(tile_y_ds_acc_ct1));
2072
+ });
2073
+ });
2074
+ }
2075
+ } else {
2076
+ const bool need_check = true;
2077
+ /*
2078
+ DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
2079
+ the limit. To get the device limit, query
2080
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2081
+ */
2082
+ {
2083
+ dpct::has_capability_or_fail(stream->get_device(),
2084
+ {sycl::aspect::fp16});
2085
+
2086
+ stream->submit([&](sycl::handler &cgh) {
2087
+ sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2088
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2089
+ sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2090
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2091
+ cgh);
2092
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2093
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2094
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2095
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2096
+
2097
+ cgh.parallel_for(
2098
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2099
+ [=](sycl::nd_item<3> item_ct1) {
2100
+ mul_mat_q5_0<need_check>(
2101
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2102
+ nrows_dst, item_ct1,
2103
+ get_pointer(tile_x_ql_q5_0_acc_ct1),
2104
+ get_pointer(tile_x_d_q5_0_acc_ct1),
2105
+ get_pointer(tile_y_qs_acc_ct1),
2106
+ get_pointer(tile_y_ds_acc_ct1));
2107
+ });
2108
+ });
2109
+ }
2110
+ }
2111
+ }
2112
+ catch (sycl::exception const &exc) {
2113
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2114
+ << ", line:" << __LINE__ << std::endl;
2115
+ std::exit(1);
2116
+ }
2117
+
2118
+ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
2119
+ float *dst, const int ncols_x,
2120
+ const int nrows_x, const int ncols_y,
2121
+ const int nrows_y, const int nrows_dst,
2122
+ dpct::queue_ptr stream) try {
2123
+
2124
+ int id;
2125
+ SYCL_CHECK(
2126
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2127
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2128
+
2129
+ int mmq_x, mmq_y, nwarps;
2130
+ if (compute_capability >= VER_GEN13) {
2131
+ mmq_x = MMQ_X_Q5_1_RDNA2;
2132
+ mmq_y = MMQ_Y_Q5_1_RDNA2;
2133
+ nwarps = NWARPS_Q5_1_RDNA2;
2134
+ } else if (compute_capability >= VER_GEN12) {
2135
+ mmq_x = MMQ_X_Q5_1_RDNA1;
2136
+ mmq_y = MMQ_Y_Q5_1_RDNA1;
2137
+ nwarps = NWARPS_Q5_1_RDNA1;
2138
+ } else if (compute_capability >= VER_GEN9) {
2139
+ mmq_x = MMQ_X_Q5_1_AMPERE;
2140
+ mmq_y = MMQ_Y_Q5_1_AMPERE;
2141
+ nwarps = NWARPS_Q5_1_AMPERE;
2142
+ } else if (compute_capability >= VER_4VEC) {
2143
+ mmq_x = MMQ_X_Q5_1_PASCAL;
2144
+ mmq_y = MMQ_Y_Q5_1_PASCAL;
2145
+ nwarps = NWARPS_Q5_1_PASCAL;
2146
+ } else {
2147
+ GGML_ABORT("fatal error");
2148
+ }
2149
+
2150
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2151
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2152
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2153
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2154
+
2155
+ if (nrows_x % mmq_y == 0) {
2156
+ const bool need_check = false;
2157
+ /*
2158
+ DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
2159
+ the limit. To get the device limit, query
2160
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2161
+ */
2162
+ {
2163
+ dpct::has_capability_or_fail(stream->get_device(),
2164
+ {sycl::aspect::fp16});
2165
+
2166
+ stream->submit([&](sycl::handler &cgh) {
2167
+ sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2168
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2169
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2170
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2171
+ cgh);
2172
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2173
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2174
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2175
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2176
+
2177
+ cgh.parallel_for(
2178
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2179
+ [=](sycl::nd_item<3> item_ct1) {
2180
+ mul_mat_q5_1<need_check>(
2181
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2182
+ nrows_dst, item_ct1,
2183
+ get_pointer(tile_x_ql_q5_1_acc_ct1),
2184
+ get_pointer(tile_x_dm_q5_1_acc_ct1),
2185
+ get_pointer(tile_y_qs_acc_ct1),
2186
+ get_pointer(tile_y_ds_acc_ct1));
2187
+ });
2188
+ });
2189
+ }
2190
+ } else {
2191
+ const bool need_check = true;
2192
+ /*
2193
+ DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
2194
+ the limit. To get the device limit, query
2195
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2196
+ */
2197
+ {
2198
+ dpct::has_capability_or_fail(stream->get_device(),
2199
+ {sycl::aspect::fp16});
2200
+
2201
+ stream->submit([&](sycl::handler &cgh) {
2202
+ sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2203
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2204
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2205
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2206
+ cgh);
2207
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2208
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2209
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2210
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2211
+
2212
+ cgh.parallel_for(
2213
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2214
+ [=](sycl::nd_item<3> item_ct1) {
2215
+ mul_mat_q5_1<need_check>(
2216
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2217
+ nrows_dst, item_ct1,
2218
+ get_pointer(tile_x_ql_q5_1_acc_ct1),
2219
+ get_pointer(tile_x_dm_q5_1_acc_ct1),
2220
+ get_pointer(tile_y_qs_acc_ct1),
2221
+ get_pointer(tile_y_ds_acc_ct1));
2222
+ });
2223
+ });
2224
+ }
2225
+ }
2226
+ }
2227
+ catch (sycl::exception const &exc) {
2228
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2229
+ << ", line:" << __LINE__ << std::endl;
2230
+ std::exit(1);
2231
+ }
2232
+
2233
+ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
2234
+ float *dst, const int ncols_x,
2235
+ const int nrows_x, const int ncols_y,
2236
+ const int nrows_y, const int nrows_dst,
2237
+ dpct::queue_ptr stream) try {
2238
+
2239
+ int id;
2240
+ SYCL_CHECK(
2241
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2242
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2243
+
2244
+ int mmq_x, mmq_y, nwarps;
2245
+ if (compute_capability >= VER_GEN13) {
2246
+ mmq_x = MMQ_X_Q8_0_RDNA2;
2247
+ mmq_y = MMQ_Y_Q8_0_RDNA2;
2248
+ nwarps = NWARPS_Q8_0_RDNA2;
2249
+ } else if (compute_capability >= VER_GEN12) {
2250
+ mmq_x = MMQ_X_Q8_0_RDNA1;
2251
+ mmq_y = MMQ_Y_Q8_0_RDNA1;
2252
+ nwarps = NWARPS_Q8_0_RDNA1;
2253
+ } else if (compute_capability >= VER_GEN9) {
2254
+ mmq_x = MMQ_X_Q8_0_AMPERE;
2255
+ mmq_y = MMQ_Y_Q8_0_AMPERE;
2256
+ nwarps = NWARPS_Q8_0_AMPERE;
2257
+ } else if (compute_capability >= VER_4VEC) {
2258
+ mmq_x = MMQ_X_Q8_0_PASCAL;
2259
+ mmq_y = MMQ_Y_Q8_0_PASCAL;
2260
+ nwarps = NWARPS_Q8_0_PASCAL;
2261
+ } else {
2262
+ GGML_ABORT("fatal error");
2263
+ }
2264
+
2265
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2266
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2267
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2268
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2269
+
2270
+ if (nrows_x % mmq_y == 0) {
2271
+ const bool need_check = false;
2272
+ /*
2273
+ DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
2274
+ the limit. To get the device limit, query
2275
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2276
+ */
2277
+ {
2278
+ dpct::has_capability_or_fail(stream->get_device(),
2279
+ {sycl::aspect::fp16});
2280
+
2281
+ stream->submit([&](sycl::handler &cgh) {
2282
+ sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2283
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2284
+ sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2285
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2286
+ cgh);
2287
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2288
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2289
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2290
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2291
+
2292
+ cgh.parallel_for(
2293
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2294
+ [=](sycl::nd_item<3> item_ct1) {
2295
+ mul_mat_q8_0<need_check>(
2296
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2297
+ nrows_dst, item_ct1,
2298
+ get_pointer(tile_x_qs_q8_0_acc_ct1),
2299
+ get_pointer(tile_x_d_q8_0_acc_ct1),
2300
+ get_pointer(tile_y_qs_acc_ct1),
2301
+ get_pointer(tile_y_ds_acc_ct1));
2302
+ });
2303
+ });
2304
+ }
2305
+ } else {
2306
+ const bool need_check = true;
2307
+ /*
2308
+ DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
2309
+ the limit. To get the device limit, query
2310
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2311
+ */
2312
+ {
2313
+ dpct::has_capability_or_fail(stream->get_device(),
2314
+ {sycl::aspect::fp16});
2315
+
2316
+ stream->submit([&](sycl::handler &cgh) {
2317
+ sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2318
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2319
+ sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2320
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2321
+ cgh);
2322
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2323
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2324
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2325
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2326
+
2327
+ cgh.parallel_for(
2328
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2329
+ [=](sycl::nd_item<3> item_ct1) {
2330
+ mul_mat_q8_0<need_check>(
2331
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2332
+ nrows_dst, item_ct1,
2333
+ get_pointer(tile_x_qs_q8_0_acc_ct1),
2334
+ get_pointer(tile_x_d_q8_0_acc_ct1),
2335
+ get_pointer(tile_y_qs_acc_ct1),
2336
+ get_pointer(tile_y_ds_acc_ct1));
2337
+ });
2338
+ });
2339
+ }
2340
+ }
2341
+ }
2342
+ catch (sycl::exception const &exc) {
2343
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2344
+ << ", line:" << __LINE__ << std::endl;
2345
+ std::exit(1);
2346
+ }
2347
+
2348
+ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
2349
+ float *dst, const int ncols_x,
2350
+ const int nrows_x, const int ncols_y,
2351
+ const int nrows_y, const int nrows_dst,
2352
+ dpct::queue_ptr stream) try {
2353
+
2354
+ int id;
2355
+ SYCL_CHECK(
2356
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2357
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2358
+
2359
+ int mmq_x, mmq_y, nwarps;
2360
+ if (compute_capability >= VER_GEN13) {
2361
+ mmq_x = MMQ_X_Q2_K_RDNA2;
2362
+ mmq_y = MMQ_Y_Q2_K_RDNA2;
2363
+ nwarps = NWARPS_Q2_K_RDNA2;
2364
+ } else if (compute_capability >= VER_GEN12) {
2365
+ mmq_x = MMQ_X_Q2_K_RDNA1;
2366
+ mmq_y = MMQ_Y_Q2_K_RDNA1;
2367
+ nwarps = NWARPS_Q2_K_RDNA1;
2368
+ } else if (compute_capability >= VER_GEN9) {
2369
+ mmq_x = MMQ_X_Q2_K_AMPERE;
2370
+ mmq_y = MMQ_Y_Q2_K_AMPERE;
2371
+ nwarps = NWARPS_Q2_K_AMPERE;
2372
+ } else if (compute_capability >= VER_4VEC) {
2373
+ mmq_x = MMQ_X_Q2_K_PASCAL;
2374
+ mmq_y = MMQ_Y_Q2_K_PASCAL;
2375
+ nwarps = NWARPS_Q2_K_PASCAL;
2376
+ } else {
2377
+ GGML_ABORT("fatal error");
2378
+ }
2379
+
2380
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2381
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2382
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2383
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2384
+
2385
+ if (nrows_x % mmq_y == 0) {
2386
+ const bool need_check = false;
2387
+ /*
2388
+ DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
2389
+ the limit. To get the device limit, query
2390
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2391
+ */
2392
+ {
2393
+ dpct::has_capability_or_fail(stream->get_device(),
2394
+ {sycl::aspect::fp16});
2395
+
2396
+ stream->submit([&](sycl::handler &cgh) {
2397
+ sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2398
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2399
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2400
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2401
+ cgh);
2402
+ sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2403
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2404
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2405
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2406
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2407
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2408
+
2409
+ cgh.parallel_for(
2410
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2411
+ [=](sycl::nd_item<3> item_ct1) {
2412
+ mul_mat_q2_K<need_check>(
2413
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2414
+ nrows_dst, item_ct1,
2415
+ get_pointer(tile_x_ql_q2_K_acc_ct1),
2416
+ get_pointer(tile_x_dm_q2_K_acc_ct1),
2417
+ get_pointer(tile_x_sc_q2_K_acc_ct1),
2418
+ get_pointer(tile_y_qs_acc_ct1),
2419
+ get_pointer(tile_y_ds_acc_ct1));
2420
+ });
2421
+ });
2422
+ }
2423
+ } else {
2424
+ const bool need_check = true;
2425
+ /*
2426
+ DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
2427
+ the limit. To get the device limit, query
2428
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2429
+ */
2430
+ {
2431
+ dpct::has_capability_or_fail(stream->get_device(),
2432
+ {sycl::aspect::fp16});
2433
+
2434
+ stream->submit([&](sycl::handler &cgh) {
2435
+ sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2436
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2437
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2438
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2439
+ cgh);
2440
+ sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2441
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2442
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2443
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2444
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2445
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2446
+
2447
+ cgh.parallel_for(
2448
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2449
+ [=](sycl::nd_item<3> item_ct1) {
2450
+ mul_mat_q2_K<need_check>(
2451
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2452
+ nrows_dst, item_ct1,
2453
+ get_pointer(tile_x_ql_q2_K_acc_ct1),
2454
+ get_pointer(tile_x_dm_q2_K_acc_ct1),
2455
+ get_pointer(tile_x_sc_q2_K_acc_ct1),
2456
+ get_pointer(tile_y_qs_acc_ct1),
2457
+ get_pointer(tile_y_ds_acc_ct1));
2458
+ });
2459
+ });
2460
+ }
2461
+ }
2462
+ }
2463
+ catch (sycl::exception const &exc) {
2464
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2465
+ << ", line:" << __LINE__ << std::endl;
2466
+ std::exit(1);
2467
+ }
2468
+
2469
+ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
2470
+ float *dst, const int ncols_x,
2471
+ const int nrows_x, const int ncols_y,
2472
+ const int nrows_y, const int nrows_dst,
2473
+ dpct::queue_ptr stream) try {
2474
+
2475
+ #if QK_K == 256
2476
+
2477
+ int id;
2478
+ SYCL_CHECK(
2479
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2480
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2481
+
2482
+ int mmq_x, mmq_y, nwarps;
2483
+ if (compute_capability >= VER_GEN13) {
2484
+ mmq_x = MMQ_X_Q3_K_RDNA2;
2485
+ mmq_y = MMQ_Y_Q3_K_RDNA2;
2486
+ nwarps = NWARPS_Q3_K_RDNA2;
2487
+ } else if (compute_capability >= VER_GEN12) {
2488
+ mmq_x = MMQ_X_Q3_K_RDNA1;
2489
+ mmq_y = MMQ_Y_Q3_K_RDNA1;
2490
+ nwarps = NWARPS_Q3_K_RDNA1;
2491
+ } else if (compute_capability >= VER_GEN9) {
2492
+ mmq_x = MMQ_X_Q3_K_AMPERE;
2493
+ mmq_y = MMQ_Y_Q3_K_AMPERE;
2494
+ nwarps = NWARPS_Q3_K_AMPERE;
2495
+ } else if (compute_capability >= VER_4VEC) {
2496
+ mmq_x = MMQ_X_Q3_K_PASCAL;
2497
+ mmq_y = MMQ_Y_Q3_K_PASCAL;
2498
+ nwarps = NWARPS_Q3_K_PASCAL;
2499
+ } else {
2500
+ GGML_ABORT("fatal error");
2501
+ }
2502
+
2503
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2504
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2505
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2506
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2507
+
2508
+ if (nrows_x % mmq_y == 0) {
2509
+ const bool need_check = false;
2510
+ /*
2511
+ DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
2512
+ the limit. To get the device limit, query
2513
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2514
+ */
2515
+ {
2516
+ dpct::has_capability_or_fail(stream->get_device(),
2517
+ {sycl::aspect::fp16});
2518
+
2519
+ stream->submit([&](sycl::handler &cgh) {
2520
+ sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2521
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2522
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2523
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2524
+ cgh);
2525
+ sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2526
+ sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2527
+ sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2528
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2529
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2530
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2531
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2532
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2533
+
2534
+ cgh.parallel_for(
2535
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2536
+ [=](sycl::nd_item<3> item_ct1) {
2537
+ mul_mat_q3_K<need_check>(
2538
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2539
+ nrows_dst, item_ct1,
2540
+ get_pointer(tile_x_ql_q3_K_acc_ct1),
2541
+ get_pointer(tile_x_dm_q3_K_acc_ct1),
2542
+ get_pointer(tile_x_qh_q3_K_acc_ct1),
2543
+ get_pointer(tile_x_sc_q3_K_acc_ct1),
2544
+ get_pointer(tile_y_qs_acc_ct1),
2545
+ get_pointer(tile_y_ds_acc_ct1));
2546
+ });
2547
+ });
2548
+ }
2549
+ } else {
2550
+ const bool need_check = true;
2551
+ /*
2552
+ DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
2553
+ the limit. To get the device limit, query
2554
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2555
+ */
2556
+ {
2557
+ dpct::has_capability_or_fail(stream->get_device(),
2558
+ {sycl::aspect::fp16});
2559
+
2560
+ stream->submit([&](sycl::handler &cgh) {
2561
+ sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2562
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2563
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2564
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2565
+ cgh);
2566
+ sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2567
+ sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2568
+ sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2569
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2570
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2571
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2572
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2573
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2574
+
2575
+ cgh.parallel_for(
2576
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2577
+ [=](sycl::nd_item<3> item_ct1) {
2578
+ mul_mat_q3_K<need_check>(
2579
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2580
+ nrows_dst, item_ct1,
2581
+ get_pointer(tile_x_ql_q3_K_acc_ct1),
2582
+ get_pointer(tile_x_dm_q3_K_acc_ct1),
2583
+ get_pointer(tile_x_qh_q3_K_acc_ct1),
2584
+ get_pointer(tile_x_sc_q3_K_acc_ct1),
2585
+ get_pointer(tile_y_qs_acc_ct1),
2586
+ get_pointer(tile_y_ds_acc_ct1));
2587
+ });
2588
+ });
2589
+ }
2590
+ }
2591
+ #endif
2592
+ }
2593
+ catch (sycl::exception const &exc) {
2594
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2595
+ << ", line:" << __LINE__ << std::endl;
2596
+ std::exit(1);
2597
+ }
2598
+
2599
+ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
2600
+ float *dst, const int ncols_x,
2601
+ const int nrows_x, const int ncols_y,
2602
+ const int nrows_y, const int nrows_dst,
2603
+ dpct::queue_ptr stream) try {
2604
+
2605
+ int id;
2606
+ SYCL_CHECK(
2607
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2608
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2609
+
2610
+ int mmq_x, mmq_y, nwarps;
2611
+ if (compute_capability >= VER_GEN13) {
2612
+ mmq_x = MMQ_X_Q4_K_RDNA2;
2613
+ mmq_y = MMQ_Y_Q4_K_RDNA2;
2614
+ nwarps = NWARPS_Q4_K_RDNA2;
2615
+ } else if (compute_capability >= VER_GEN12) {
2616
+ mmq_x = MMQ_X_Q4_K_RDNA1;
2617
+ mmq_y = MMQ_Y_Q4_K_RDNA1;
2618
+ nwarps = NWARPS_Q4_K_RDNA1;
2619
+ } else if (compute_capability >= VER_GEN9) {
2620
+ mmq_x = MMQ_X_Q4_K_AMPERE;
2621
+ mmq_y = MMQ_Y_Q4_K_AMPERE;
2622
+ nwarps = NWARPS_Q4_K_AMPERE;
2623
+ } else if (compute_capability >= VER_4VEC) {
2624
+ mmq_x = MMQ_X_Q4_K_PASCAL;
2625
+ mmq_y = MMQ_Y_Q4_K_PASCAL;
2626
+ nwarps = NWARPS_Q4_K_PASCAL;
2627
+ } else {
2628
+ GGML_ABORT("fatal error");
2629
+ }
2630
+
2631
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2632
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2633
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2634
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2635
+
2636
+ if (nrows_x % mmq_y == 0) {
2637
+ const bool need_check = false;
2638
+ /*
2639
+ DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
2640
+ the limit. To get the device limit, query
2641
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2642
+ */
2643
+ {
2644
+ dpct::has_capability_or_fail(stream->get_device(),
2645
+ {sycl::aspect::fp16});
2646
+
2647
+ stream->submit([&](sycl::handler &cgh) {
2648
+ sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2649
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2650
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2651
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2652
+ cgh);
2653
+ sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2654
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2655
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2656
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2657
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2658
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2659
+
2660
+ cgh.parallel_for(
2661
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2662
+ [=](sycl::nd_item<3> item_ct1) {
2663
+ mul_mat_q4_K<need_check>(
2664
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2665
+ nrows_dst, item_ct1,
2666
+ get_pointer(tile_x_ql_q4_K_acc_ct1),
2667
+ get_pointer(tile_x_dm_q4_K_acc_ct1),
2668
+ get_pointer(tile_x_sc_q4_K_acc_ct1),
2669
+ get_pointer(tile_y_qs_acc_ct1),
2670
+ get_pointer(tile_y_ds_acc_ct1));
2671
+ });
2672
+ });
2673
+ }
2674
+ } else {
2675
+ const bool need_check = true;
2676
+ /*
2677
+ DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
2678
+ the limit. To get the device limit, query
2679
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2680
+ */
2681
+ {
2682
+ dpct::has_capability_or_fail(stream->get_device(),
2683
+ {sycl::aspect::fp16});
2684
+
2685
+ stream->submit([&](sycl::handler &cgh) {
2686
+ sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2687
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2688
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2689
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2690
+ cgh);
2691
+ sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2692
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2693
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2694
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2695
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2696
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2697
+
2698
+ cgh.parallel_for(
2699
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2700
+ [=](sycl::nd_item<3> item_ct1) {
2701
+ mul_mat_q4_K<need_check>(
2702
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2703
+ nrows_dst, item_ct1,
2704
+ get_pointer(tile_x_ql_q4_K_acc_ct1),
2705
+ get_pointer(tile_x_dm_q4_K_acc_ct1),
2706
+ get_pointer(tile_x_sc_q4_K_acc_ct1),
2707
+ get_pointer(tile_y_qs_acc_ct1),
2708
+ get_pointer(tile_y_ds_acc_ct1));
2709
+ });
2710
+ });
2711
+ }
2712
+ }
2713
+ }
2714
+ catch (sycl::exception const &exc) {
2715
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2716
+ << ", line:" << __LINE__ << std::endl;
2717
+ std::exit(1);
2718
+ }
2719
+
2720
+ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
2721
+ float *dst, const int ncols_x,
2722
+ const int nrows_x, const int ncols_y,
2723
+ const int nrows_y, const int nrows_dst,
2724
+ dpct::queue_ptr stream) try {
2725
+
2726
+ int id;
2727
+ SYCL_CHECK(
2728
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2729
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2730
+
2731
+ int mmq_x, mmq_y, nwarps;
2732
+ if (compute_capability >= VER_GEN13) {
2733
+ mmq_x = MMQ_X_Q5_K_RDNA2;
2734
+ mmq_y = MMQ_Y_Q5_K_RDNA2;
2735
+ nwarps = NWARPS_Q5_K_RDNA2;
2736
+ } else if (compute_capability >= VER_GEN12) {
2737
+ mmq_x = MMQ_X_Q5_K_RDNA1;
2738
+ mmq_y = MMQ_Y_Q5_K_RDNA1;
2739
+ nwarps = NWARPS_Q5_K_RDNA1;
2740
+ } else if (compute_capability >= VER_GEN9) {
2741
+ mmq_x = MMQ_X_Q5_K_AMPERE;
2742
+ mmq_y = MMQ_Y_Q5_K_AMPERE;
2743
+ nwarps = NWARPS_Q5_K_AMPERE;
2744
+ } else if (compute_capability >= VER_4VEC) {
2745
+ mmq_x = MMQ_X_Q5_K_PASCAL;
2746
+ mmq_y = MMQ_Y_Q5_K_PASCAL;
2747
+ nwarps = NWARPS_Q5_K_PASCAL;
2748
+ } else {
2749
+ GGML_ABORT("fatal error");
2750
+ }
2751
+
2752
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2753
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2754
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2755
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2756
+
2757
+ if (nrows_x % mmq_y == 0) {
2758
+ const bool need_check = false;
2759
+ /*
2760
+ DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
2761
+ the limit. To get the device limit, query
2762
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2763
+ */
2764
+ {
2765
+ dpct::has_capability_or_fail(stream->get_device(),
2766
+ {sycl::aspect::fp16});
2767
+
2768
+ stream->submit([&](sycl::handler &cgh) {
2769
+ sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2770
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2771
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2772
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2773
+ cgh);
2774
+ sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2775
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2776
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2777
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2778
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2779
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2780
+
2781
+ cgh.parallel_for(
2782
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2783
+ [=](sycl::nd_item<3> item_ct1) {
2784
+ mul_mat_q5_K<need_check>(
2785
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2786
+ nrows_dst, item_ct1,
2787
+ get_pointer(tile_x_ql_q5_K_acc_ct1),
2788
+ get_pointer(tile_x_dm_q5_K_acc_ct1),
2789
+ get_pointer(tile_x_sc_q5_K_acc_ct1),
2790
+ get_pointer(tile_y_qs_acc_ct1),
2791
+ get_pointer(tile_y_ds_acc_ct1));
2792
+ });
2793
+ });
2794
+ }
2795
+ } else {
2796
+ const bool need_check = true;
2797
+ /*
2798
+ DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
2799
+ the limit. To get the device limit, query
2800
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2801
+ */
2802
+ {
2803
+ dpct::has_capability_or_fail(stream->get_device(),
2804
+ {sycl::aspect::fp16});
2805
+
2806
+ stream->submit([&](sycl::handler &cgh) {
2807
+ sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2808
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2809
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2810
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2811
+ cgh);
2812
+ sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2813
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2814
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2815
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2816
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2817
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2818
+
2819
+ cgh.parallel_for(
2820
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2821
+ [=](sycl::nd_item<3> item_ct1) {
2822
+ mul_mat_q5_K<need_check>(
2823
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2824
+ nrows_dst, item_ct1,
2825
+ get_pointer(tile_x_ql_q5_K_acc_ct1),
2826
+ get_pointer(tile_x_dm_q5_K_acc_ct1),
2827
+ get_pointer(tile_x_sc_q5_K_acc_ct1),
2828
+ get_pointer(tile_y_qs_acc_ct1),
2829
+ get_pointer(tile_y_ds_acc_ct1));
2830
+ });
2831
+ });
2832
+ }
2833
+ }
2834
+ }
2835
+ catch (sycl::exception const &exc) {
2836
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2837
+ << ", line:" << __LINE__ << std::endl;
2838
+ std::exit(1);
2839
+ }
2840
+
2841
+ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
2842
+ float *dst, const int ncols_x,
2843
+ const int nrows_x, const int ncols_y,
2844
+ const int nrows_y, const int nrows_dst,
2845
+ dpct::queue_ptr stream) try {
2846
+
2847
+ int id;
2848
+ SYCL_CHECK(
2849
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2850
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2851
+
2852
+ int mmq_x, mmq_y, nwarps;
2853
+ if (compute_capability >= VER_GEN13) {
2854
+ mmq_x = MMQ_X_Q6_K_RDNA2;
2855
+ mmq_y = MMQ_Y_Q6_K_RDNA2;
2856
+ nwarps = NWARPS_Q6_K_RDNA2;
2857
+ } else if (compute_capability >= VER_GEN12) {
2858
+ mmq_x = MMQ_X_Q6_K_RDNA1;
2859
+ mmq_y = MMQ_Y_Q6_K_RDNA1;
2860
+ nwarps = NWARPS_Q6_K_RDNA1;
2861
+ } else if (compute_capability >= VER_GEN9) {
2862
+ mmq_x = MMQ_X_Q6_K_AMPERE;
2863
+ mmq_y = MMQ_Y_Q6_K_AMPERE;
2864
+ nwarps = NWARPS_Q6_K_AMPERE;
2865
+ } else if (compute_capability >= VER_4VEC) {
2866
+ mmq_x = MMQ_X_Q6_K_PASCAL;
2867
+ mmq_y = MMQ_Y_Q6_K_PASCAL;
2868
+ nwarps = NWARPS_Q6_K_PASCAL;
2869
+ } else {
2870
+ GGML_ABORT("fatal error");
2871
+ }
2872
+
2873
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2874
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2875
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2876
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2877
+
2878
+ if (nrows_x % mmq_y == 0) {
2879
+ const bool need_check = false;
2880
+ /*
2881
+ DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
2882
+ the limit. To get the device limit, query
2883
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2884
+ */
2885
+ {
2886
+ dpct::has_capability_or_fail(stream->get_device(),
2887
+ {sycl::aspect::fp16});
2888
+
2889
+ stream->submit([&](sycl::handler &cgh) {
2890
+ sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2891
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2892
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2893
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2894
+ cgh);
2895
+ sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2896
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2897
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2898
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2899
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2900
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2901
+
2902
+ cgh.parallel_for(
2903
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2904
+ [=](sycl::nd_item<3> item_ct1) {
2905
+ mul_mat_q6_K<need_check>(
2906
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2907
+ nrows_dst, item_ct1,
2908
+ get_pointer(tile_x_ql_acc_ct1),
2909
+ get_pointer(tile_x_dm_acc_ct1),
2910
+ get_pointer(tile_x_sc_acc_ct1),
2911
+ get_pointer(tile_y_qs_acc_ct1),
2912
+ get_pointer(tile_y_ds_acc_ct1));
2913
+ });
2914
+ });
2915
+ }
2916
+ } else {
2917
+ const bool need_check = true;
2918
+ /*
2919
+ DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
2920
+ the limit. To get the device limit, query
2921
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2922
+ */
2923
+ {
2924
+ dpct::has_capability_or_fail(stream->get_device(),
2925
+ {sycl::aspect::fp16});
2926
+
2927
+ stream->submit([&](sycl::handler &cgh) {
2928
+ sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2929
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2930
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2931
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2932
+ cgh);
2933
+ sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2934
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2935
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2936
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2937
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2938
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2939
+
2940
+ cgh.parallel_for(
2941
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2942
+ [=](sycl::nd_item<3> item_ct1) {
2943
+ mul_mat_q6_K<need_check>(
2944
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2945
+ nrows_dst, item_ct1,
2946
+ get_pointer(tile_x_ql_acc_ct1),
2947
+ get_pointer(tile_x_dm_acc_ct1),
2948
+ get_pointer(tile_x_sc_acc_ct1),
2949
+ get_pointer(tile_y_qs_acc_ct1),
2950
+ get_pointer(tile_y_ds_acc_ct1));
2951
+ });
2952
+ });
2953
+ }
2954
+ }
2955
+ }
2956
+ catch (sycl::exception const &exc) {
2957
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2958
+ << ", line:" << __LINE__ << std::endl;
2959
+ std::exit(1);
2960
+ }
2961
+
2962
+ void ggml_sycl_op_mul_mat_q(
2963
+ ggml_backend_sycl_context & ctx,
2964
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
2965
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
2966
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
2967
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
2968
+ const dpct::queue_ptr &stream) try {
2969
+
2970
+ const int64_t ne00 = src0->ne[0];
2971
+
2972
+ const int64_t ne10 = src1->ne[0];
2973
+ GGML_ASSERT(ne10 % QK8_1 == 0);
2974
+
2975
+ const int64_t ne0 = dst->ne[0];
2976
+
2977
+ const int64_t row_diff = row_high - row_low;
2978
+
2979
+ int device_id;
2980
+ SYCL_CHECK(
2981
+ CHECK_TRY_ERROR(device_id = get_current_device_id()));
2982
+
2983
+ // the main device has a larger memory buffer to hold the results from all GPUs
2984
+ // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
2985
+ const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
2986
+
2987
+ switch (src0->type) {
2988
+ case GGML_TYPE_Q4_0:
2989
+ ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2990
+ break;
2991
+ case GGML_TYPE_Q4_1:
2992
+ ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2993
+ break;
2994
+ case GGML_TYPE_Q5_0:
2995
+ ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2996
+ break;
2997
+ case GGML_TYPE_Q5_1:
2998
+ ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2999
+ break;
3000
+ case GGML_TYPE_Q8_0:
3001
+ ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3002
+ break;
3003
+ case GGML_TYPE_Q2_K:
3004
+ ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3005
+ break;
3006
+ case GGML_TYPE_Q3_K:
3007
+ ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3008
+ break;
3009
+ case GGML_TYPE_Q4_K:
3010
+ ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3011
+ break;
3012
+ case GGML_TYPE_Q5_K:
3013
+ ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3014
+ break;
3015
+ case GGML_TYPE_Q6_K:
3016
+ ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3017
+ break;
3018
+ default:
3019
+ GGML_ABORT("fatal error");
3020
+ break;
3021
+ }
3022
+
3023
+ GGML_UNUSED(src1);
3024
+ GGML_UNUSED(dst);
3025
+ GGML_UNUSED(src1_ddf_i);
3026
+ }
3027
+ catch (sycl::exception const &exc) {
3028
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3029
+ << ", line:" << __LINE__ << std::endl;
3030
+ std::exit(1);
3031
+ }