whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,3031 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "mmq.hpp"
14
+ #include "vecdotq.hpp"
15
+
16
+ typedef void (*allocate_tiles_sycl_t)(
17
+ int** x_ql,
18
+ sycl::half2** x_dm,
19
+ int** x_qh,
20
+ int** x_sc);
21
+ typedef void (*load_tiles_sycl_t)(
22
+ const void* __restrict__ vx,
23
+ int* __restrict__ x_ql,
24
+ sycl::half2* __restrict__ x_dm,
25
+ int* __restrict__ x_qh,
26
+ int* __restrict__ x_sc,
27
+ const int& i_offset,
28
+ const int& i_max,
29
+ const int& k,
30
+ const int& blocks_per_row);
31
+ typedef float (*vec_dot_q_mul_mat_sycl_t)(
32
+ const int* __restrict__ x_ql,
33
+ const sycl::half2* __restrict__ x_dm,
34
+ const int* __restrict__ x_qh,
35
+ const int* __restrict__ x_sc,
36
+ const int* __restrict__ y_qs,
37
+ const sycl::half2* __restrict__ y_ms,
38
+ const int& i,
39
+ const int& j,
40
+ const int& k);
41
+
42
+
43
+ template <int mmq_y>
44
+ static __dpct_inline__ void
45
+ allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
46
+ int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
47
+ (void)x_qh; (void)x_sc;
48
+
49
+ *x_ql = tile_x_qs_q4_0;
50
+ *x_dm = (sycl::half2 *)tile_x_d_q4_0;
51
+ }
52
+
53
+ template <int mmq_y, int nwarps, bool need_check>
54
+ static __dpct_inline__ void
55
+ load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
56
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
57
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
58
+ const int &k, const int &blocks_per_row) {
59
+ (void)x_qh; (void)x_sc;
60
+ GGML_SYCL_ASSUME(i_offset >= 0);
61
+ GGML_SYCL_ASSUME(i_offset < nwarps);
62
+ GGML_SYCL_ASSUME(k >= 0);
63
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
64
+
65
+ const int kbx = k / QI4_0;
66
+ const int kqsx = k % QI4_0;
67
+
68
+ const block_q4_0 * bx0 = (const block_q4_0 *) vx;
69
+
70
+ float * x_dmf = (float *) x_dm;
71
+
72
+ #pragma unroll
73
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
74
+ int i = i0 + i_offset;
75
+
76
+ if (need_check) {
77
+ i = sycl::min(i, i_max);
78
+ }
79
+
80
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
81
+
82
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
83
+ // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
84
+ }
85
+
86
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
87
+ const int kbxd = k % blocks_per_tile_x_row;
88
+
89
+ #pragma unroll
90
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
91
+ int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
92
+
93
+ if (need_check) {
94
+ i = sycl::min(i, i_max);
95
+ }
96
+
97
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
98
+
99
+ x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
100
+ }
101
+ }
102
+
103
+ static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
104
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
105
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
106
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
107
+ const int &i, const int &j, const int &k) {
108
+ (void)x_qh; (void)x_sc;
109
+
110
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
111
+ const float * x_dmf = (const float *) x_dm;
112
+
113
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
114
+
115
+ #pragma unroll
116
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
117
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
118
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
119
+ }
120
+
121
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
122
+ (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
123
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
124
+ }
125
+
126
+ template <int mmq_y>
127
+ static __dpct_inline__ void
128
+ allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
129
+ int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
130
+ (void)x_qh; (void)x_sc;
131
+
132
+ *x_ql = tile_x_qs_q4_1;
133
+ *x_dm = tile_x_dm_q4_1;
134
+ }
135
+
136
+
137
+ template <int mmq_y, int nwarps, bool need_check>
138
+ static __dpct_inline__ void
139
+ load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
140
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
141
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
142
+ const int &k, const int &blocks_per_row) {
143
+ (void)x_qh; (void)x_sc;
144
+
145
+ GGML_SYCL_ASSUME(i_offset >= 0);
146
+ GGML_SYCL_ASSUME(i_offset < nwarps);
147
+ GGML_SYCL_ASSUME(k >= 0);
148
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
149
+
150
+ const int kbx = k / QI4_1;
151
+ const int kqsx = k % QI4_1;
152
+
153
+ const block_q4_1 * bx0 = (const block_q4_1 *) vx;
154
+
155
+ #pragma unroll
156
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
157
+ int i = i0 + i_offset;
158
+
159
+ if (need_check) {
160
+ i = sycl::min(i, i_max);
161
+ }
162
+
163
+ const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
164
+
165
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
166
+ }
167
+
168
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
169
+ const int kbxd = k % blocks_per_tile_x_row;
170
+
171
+ #pragma unroll
172
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
173
+ int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
174
+
175
+ if (need_check) {
176
+ i = sycl::min(i, i_max);
177
+ }
178
+
179
+ const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
180
+
181
+ x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
182
+ }
183
+ }
184
+
185
+ static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
186
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
187
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
188
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
189
+ const int &i, const int &j, const int &k) {
190
+ (void)x_qh; (void)x_sc;
191
+
192
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
193
+
194
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
195
+
196
+ #pragma unroll
197
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
198
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
199
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
200
+ }
201
+
202
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
203
+ (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
204
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
205
+ }
206
+
207
+ template <int mmq_y>
208
+ static __dpct_inline__ void
209
+ allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
210
+ int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
211
+ (void)x_qh; (void)x_sc;
212
+
213
+ *x_ql = tile_x_ql_q5_0;
214
+ *x_dm = (sycl::half2 *)tile_x_d_q5_0;
215
+ }
216
+
217
+ template <int mmq_y, int nwarps, bool need_check>
218
+ static __dpct_inline__ void
219
+ load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
220
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
221
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
222
+ const int &k, const int &blocks_per_row) {
223
+ (void)x_qh; (void)x_sc;
224
+
225
+ GGML_SYCL_ASSUME(i_offset >= 0);
226
+ GGML_SYCL_ASSUME(i_offset < nwarps);
227
+ GGML_SYCL_ASSUME(k >= 0);
228
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
229
+
230
+ const int kbx = k / QI5_0;
231
+ const int kqsx = k % QI5_0;
232
+
233
+ const block_q5_0 * bx0 = (const block_q5_0 *) vx;
234
+
235
+ #pragma unroll
236
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
237
+ int i = i0 + i_offset;
238
+
239
+ if (need_check) {
240
+ i = sycl::min(i, i_max);
241
+ }
242
+
243
+ const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
244
+
245
+ const int ql = get_int_from_uint8(bxi->qs, kqsx);
246
+ const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
247
+
248
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
249
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
250
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
251
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
252
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
253
+ qs0 = dpct::vectorized_binary<sycl::char4>(
254
+ qs0, 0x10101010, dpct::sub_sat()); // subtract 16
255
+
256
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
257
+
258
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
259
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
260
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
261
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
262
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
263
+ qs1 = dpct::vectorized_binary<sycl::char4>(
264
+ qs1, 0x10101010, dpct::sub_sat()); // subtract 16
265
+
266
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
267
+ }
268
+
269
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
270
+ const int kbxd = k % blocks_per_tile_x_row;
271
+ float * x_dmf = (float *) x_dm;
272
+
273
+ #pragma unroll
274
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
275
+ int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
276
+
277
+ if (need_check) {
278
+ i = sycl::min(i, i_max);
279
+ }
280
+
281
+ const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
282
+
283
+ x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
284
+ }
285
+ }
286
+
287
+ static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
288
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
289
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
290
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
291
+ const int &i, const int &j, const int &k) {
292
+ (void)x_qh; (void)x_sc;
293
+
294
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
295
+ const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
296
+ const float * x_dmf = (const float *) x_dm;
297
+ const float * y_df = (const float *) y_ds;
298
+
299
+ int u[2*VDR_Q5_0_Q8_1_MMQ];
300
+
301
+ #pragma unroll
302
+ for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
303
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
304
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
305
+ }
306
+
307
+ return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
308
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
309
+ }
310
+
311
+ template <int mmq_y>
312
+ static __dpct_inline__ void
313
+ allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
314
+ int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
315
+ (void)x_qh; (void)x_sc;
316
+
317
+ *x_ql = tile_x_ql_q5_1;
318
+ *x_dm = tile_x_dm_q5_1;
319
+ }
320
+
321
+ template <int mmq_y, int nwarps, bool need_check>
322
+ static __dpct_inline__ void
323
+ load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
324
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
325
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
326
+ const int &k, const int &blocks_per_row) {
327
+ (void)x_qh; (void)x_sc;
328
+
329
+ GGML_SYCL_ASSUME(i_offset >= 0);
330
+ GGML_SYCL_ASSUME(i_offset < nwarps);
331
+ GGML_SYCL_ASSUME(k >= 0);
332
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
333
+
334
+ const int kbx = k / QI5_1;
335
+ const int kqsx = k % QI5_1;
336
+
337
+ const block_q5_1 * bx0 = (const block_q5_1 *) vx;
338
+
339
+ #pragma unroll
340
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
341
+ int i = i0 + i_offset;
342
+
343
+ if (need_check) {
344
+ i = sycl::min(i, i_max);
345
+ }
346
+
347
+ const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
348
+
349
+ const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
350
+ const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
351
+
352
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
353
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
354
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
355
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
356
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
357
+
358
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
359
+
360
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
361
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
362
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
363
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
364
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
365
+
366
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
367
+ }
368
+
369
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
370
+ const int kbxd = k % blocks_per_tile_x_row;
371
+
372
+ #pragma unroll
373
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
374
+ int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
375
+
376
+ if (need_check) {
377
+ i = sycl::min(i, i_max);
378
+ }
379
+
380
+ const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
381
+
382
+ x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
383
+ }
384
+ }
385
+
386
+ static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
387
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
388
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
389
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
390
+ const int &i, const int &j, const int &k) {
391
+ (void)x_qh; (void)x_sc;
392
+
393
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
394
+ const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
395
+
396
+ int u[2*VDR_Q5_1_Q8_1_MMQ];
397
+
398
+ #pragma unroll
399
+ for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
400
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
401
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
402
+ }
403
+
404
+ return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
405
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
406
+ }
407
+
408
+ template <int mmq_y>
409
+ static __dpct_inline__ void
410
+ allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
411
+ int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
412
+ (void)x_qh; (void)x_sc;
413
+
414
+ *x_ql = tile_x_qs_q8_0;
415
+ *x_dm = (sycl::half2 *)tile_x_d_q8_0;
416
+ }
417
+
418
+ template <int mmq_y, int nwarps, bool need_check>
419
+ static __dpct_inline__ void
420
+ load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
421
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
422
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
423
+ const int &k, const int &blocks_per_row) {
424
+ (void)x_qh; (void)x_sc;
425
+
426
+ GGML_SYCL_ASSUME(i_offset >= 0);
427
+ GGML_SYCL_ASSUME(i_offset < nwarps);
428
+ GGML_SYCL_ASSUME(k >= 0);
429
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
430
+
431
+ const int kbx = k / QI8_0;
432
+ const int kqsx = k % QI8_0;
433
+ float * x_dmf = (float *) x_dm;
434
+
435
+ const block_q8_0 * bx0 = (const block_q8_0 *) vx;
436
+
437
+ #pragma unroll
438
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
439
+ int i = i0 + i_offset;
440
+
441
+ if (need_check) {
442
+ i = sycl::min(i, i_max);
443
+ }
444
+
445
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
446
+
447
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
448
+ }
449
+
450
+ const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
451
+ const int kbxd = k % blocks_per_tile_x_row;
452
+
453
+ #pragma unroll
454
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
455
+ int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
456
+
457
+ if (need_check) {
458
+ i = sycl::min(i, i_max);
459
+ }
460
+
461
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
462
+
463
+ x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
464
+ }
465
+ }
466
+
467
+ static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
468
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
469
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
470
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
471
+ const int &i, const int &j, const int &k) {
472
+ (void)x_qh; (void)x_sc;
473
+
474
+ const float * x_dmf = (const float *) x_dm;
475
+ const float * y_df = (const float *) y_ds;
476
+
477
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
478
+ (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
479
+ y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
480
+ }
481
+
482
+ template <int mmq_y>
483
+ static __dpct_inline__ void
484
+ allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
485
+ int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
486
+ int *tile_x_sc_q2_K) {
487
+ (void)x_qh;
488
+
489
+ *x_ql = tile_x_ql_q2_K;
490
+ *x_dm = tile_x_dm_q2_K;
491
+ *x_sc = tile_x_sc_q2_K;
492
+ }
493
+
494
+ template <int mmq_y, int nwarps, bool need_check>
495
+ static __dpct_inline__ void
496
+ load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
497
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
498
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
499
+ const int &k, const int &blocks_per_row) {
500
+ (void)x_qh;
501
+
502
+ GGML_SYCL_ASSUME(i_offset >= 0);
503
+ GGML_SYCL_ASSUME(i_offset < nwarps);
504
+ GGML_SYCL_ASSUME(k >= 0);
505
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
506
+
507
+ const int kbx = k / QI2_K;
508
+ const int kqsx = k % QI2_K;
509
+
510
+ const block_q2_K * bx0 = (const block_q2_K *) vx;
511
+
512
+ #pragma unroll
513
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
514
+ int i = i0 + i_offset;
515
+
516
+ if (need_check) {
517
+ i = sycl::min(i, i_max);
518
+ }
519
+
520
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
521
+
522
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
523
+ }
524
+
525
+ const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
526
+ const int kbxd = k % blocks_per_tile_x_row;
527
+
528
+ #pragma unroll
529
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
530
+ int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
531
+
532
+ if (need_check) {
533
+ i = sycl::min(i, i_max);
534
+ }
535
+
536
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
537
+
538
+ x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
539
+ }
540
+
541
+ #pragma unroll
542
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
543
+ int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
544
+
545
+ if (need_check) {
546
+ i = sycl::min(i, i_max);
547
+ }
548
+
549
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
550
+
551
+ x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
552
+ }
553
+ }
554
+
555
+ #define VDR_Q2_K_Q8_1_MMQ 2
556
+ // contiguous u/y values
557
+ static __dpct_inline__ float
558
+ vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
559
+ const uint8_t *__restrict__ scales,
560
+ const sycl::half2 &dm2, const float &d8) {
561
+
562
+ int sumi_d = 0;
563
+ int sumi_m = 0;
564
+
565
+ #pragma unroll
566
+ for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
567
+ int sumi_d_sc = 0;
568
+
569
+ const int sc = scales[i0 / (QI8_1/2)];
570
+
571
+ // fill int with 4x m
572
+ int m = sc >> 4;
573
+ m |= m << 8;
574
+ m |= m << 16;
575
+
576
+ #pragma unroll
577
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
578
+ sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
579
+ sumi_m = dpct::dp4a(m, u[i],
580
+ sumi_m); // multiply sum of q8_1 values with m
581
+ }
582
+
583
+ sumi_d += sumi_d_sc * (sc & 0xF);
584
+ }
585
+
586
+ const sycl::float2 dm2f =
587
+ dm2.convert<float, sycl::rounding_mode::automatic>();
588
+
589
+ return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
590
+ }
591
+
592
+ static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
593
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
594
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
595
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
596
+ const int &i, const int &j, const int &k) {
597
+ (void)x_qh;
598
+
599
+ const int kbx = k / QI2_K;
600
+ const int ky = (k % QI2_K) * QR2_K;
601
+ const float * y_df = (const float *) y_ds;
602
+
603
+ int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
604
+
605
+ const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
606
+ const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
607
+
608
+ #pragma unroll
609
+ for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
610
+ v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
611
+ }
612
+
613
+ const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
614
+
615
+ const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
616
+ return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
617
+ }
618
+
619
+ template <int mmq_y>
620
+ static __dpct_inline__ void
621
+ allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
622
+ int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
623
+ int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
624
+
625
+ *x_ql = tile_x_ql_q3_K;
626
+ *x_dm = tile_x_dm_q3_K;
627
+ *x_qh = tile_x_qh_q3_K;
628
+ *x_sc = tile_x_sc_q3_K;
629
+ }
630
+
631
+ template <int mmq_y, int nwarps, bool need_check>
632
+ static __dpct_inline__ void
633
+ load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
634
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
635
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
636
+ const int &k, const int &blocks_per_row) {
637
+
638
+ GGML_SYCL_ASSUME(i_offset >= 0);
639
+ GGML_SYCL_ASSUME(i_offset < nwarps);
640
+ GGML_SYCL_ASSUME(k >= 0);
641
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
642
+
643
+ const int kbx = k / QI3_K;
644
+ const int kqsx = k % QI3_K;
645
+
646
+ const block_q3_K * bx0 = (const block_q3_K *) vx;
647
+
648
+ #pragma unroll
649
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
650
+ int i = i0 + i_offset;
651
+
652
+ if (need_check) {
653
+ i = sycl::min(i, i_max);
654
+ }
655
+
656
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
657
+
658
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
659
+ }
660
+
661
+ const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
662
+ const int kbxd = k % blocks_per_tile_x_row;
663
+ float * x_dmf = (float *) x_dm;
664
+
665
+ #pragma unroll
666
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
667
+ int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
668
+
669
+ if (need_check) {
670
+ i = sycl::min(i, i_max);
671
+ }
672
+
673
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
674
+
675
+ x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
676
+ }
677
+
678
+ #pragma unroll
679
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
680
+ int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
681
+
682
+ if (need_check) {
683
+ i = sycl::min(i, i_max);
684
+ }
685
+
686
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
687
+
688
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
689
+ x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
690
+ }
691
+
692
+ #pragma unroll
693
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
694
+ int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
695
+
696
+ if (need_check) {
697
+ i = sycl::min(i, i_max);
698
+ }
699
+
700
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
701
+
702
+ const int ksc = k % (QI3_K/4);
703
+
704
+ const int ksc_low = ksc % (QI3_K/8);
705
+ const int shift_low = 4 * (ksc / (QI3_K/8));
706
+ const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
707
+
708
+ const int ksc_high = QI3_K/8;
709
+ const int shift_high = 2 * ksc;
710
+ const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
711
+
712
+ const int sc = dpct::vectorized_binary<sycl::char4>(
713
+ sc_low | sc_high, 0x20202020, dpct::sub_sat());
714
+
715
+ x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
716
+ }
717
+ }
718
+
719
+ #define VDR_Q3_K_Q8_1_MMQ 2
720
+ // contiguous u/y values
721
+ static __dpct_inline__ float
722
+ vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
723
+ const int8_t *__restrict__ scales, const float &d3,
724
+ const float &d8) {
725
+
726
+ int sumi = 0;
727
+
728
+ #pragma unroll
729
+ for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
730
+ int sumi_sc = 0;
731
+
732
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
733
+ sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
734
+ }
735
+
736
+ sumi += sumi_sc * scales[i0 / (QI8_1/2)];
737
+ }
738
+
739
+ return d3*d8 * sumi;
740
+ }
741
+
742
+ static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
743
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
744
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
745
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
746
+ const int &i, const int &j, const int &k) {
747
+
748
+ const int kbx = k / QI3_K;
749
+ const int ky = (k % QI3_K) * QR3_K;
750
+ const float * x_dmf = (const float *) x_dm;
751
+ const float * y_df = (const float *) y_ds;
752
+
753
+ const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
754
+
755
+ int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
756
+
757
+ #pragma unroll
758
+ for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
759
+ const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
760
+ const int shift = 2 * ((ky % 32) / 8);
761
+ const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
762
+
763
+ const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
764
+ const int vlh = (vh << 2) & 0x04040404;
765
+
766
+ v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
767
+ }
768
+
769
+ const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
770
+ return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
771
+ }
772
+
773
+ template <int mmq_y>
774
+ static __dpct_inline__ void
775
+ allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
776
+ int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
777
+ int *tile_x_sc_q4_K) {
778
+ (void)x_qh;
779
+
780
+ *x_ql = tile_x_ql_q4_K;
781
+ *x_dm = tile_x_dm_q4_K;
782
+ *x_sc = tile_x_sc_q4_K;
783
+ }
784
+
785
+ template <int mmq_y, int nwarps, bool need_check>
786
+ static __dpct_inline__ void
787
+ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
788
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
789
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
790
+ const int &k, const int &blocks_per_row) {
791
+ (void)x_qh;
792
+
793
+ GGML_SYCL_ASSUME(i_offset >= 0);
794
+ GGML_SYCL_ASSUME(i_offset < nwarps);
795
+ GGML_SYCL_ASSUME(k >= 0);
796
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
797
+
798
+ const int kbx = k / QI4_K; // == 0 if QK_K == 256
799
+ const int kqsx = k % QI4_K; // == k if QK_K == 256
800
+
801
+ const block_q4_K * bx0 = (const block_q4_K *) vx;
802
+
803
+ #pragma unroll
804
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
805
+ int i = i0 + i_offset;
806
+
807
+ if (need_check) {
808
+ i = sycl::min(i, i_max);
809
+ }
810
+
811
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
812
+
813
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
814
+ }
815
+
816
+ constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
817
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
818
+
819
+ #pragma unroll
820
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
821
+ int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
822
+
823
+ if (need_check) {
824
+ i = sycl::min(i, i_max);
825
+ }
826
+
827
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
828
+
829
+ #if QK_K == 256
830
+ x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
831
+ #else
832
+ x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
833
+ #endif
834
+ }
835
+
836
+ #pragma unroll
837
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
838
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
839
+
840
+ if (need_check) {
841
+ i = sycl::min(i, i_max);
842
+ }
843
+
844
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
845
+
846
+ const int * scales = (const int *) bxi->scales;
847
+
848
+ const int ksc = k % (WARP_SIZE/8);
849
+
850
+ // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
851
+ int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
852
+ scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
853
+
854
+ x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
855
+ }
856
+ }
857
+
858
+
859
+ #define VDR_Q4_K_Q8_1_MMQ 8
860
+
861
+ // contiguous u/y values
862
+ static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
863
+ const int *__restrict__ v, const int *__restrict__ u,
864
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
865
+ const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
866
+
867
+ float sumf_d = 0.0f;
868
+ float sumf_m = 0.0f;
869
+
870
+ #pragma unroll
871
+ for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
872
+ int sumi_d = 0;
873
+
874
+ #pragma unroll
875
+ for (int j = 0; j < QI8_1; ++j) {
876
+ sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
877
+ u[i * QI8_1 + j], sumi_d); // SIMD dot product
878
+ }
879
+
880
+ const sycl::float2 ds8f =
881
+ ds8[i].convert<float, sycl::rounding_mode::automatic>();
882
+
883
+ sumf_d += ds8f.x() * (sc[i] * sumi_d);
884
+ sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
885
+ }
886
+
887
+ const sycl::float2 dm4f =
888
+ dm4.convert<float, sycl::rounding_mode::automatic>();
889
+
890
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
891
+ }
892
+
893
+
894
+ static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
895
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
896
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
897
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
898
+ const int &i, const int &j, const int &k) {
899
+ (void)x_qh;
900
+
901
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
902
+
903
+ const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
904
+ return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
905
+ x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
906
+ }
907
+
908
+ template <int mmq_y>
909
+ static __dpct_inline__ void
910
+ allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
911
+ int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
912
+ int *tile_x_sc_q5_K) {
913
+ (void)x_qh;
914
+
915
+ *x_ql = tile_x_ql_q5_K;
916
+ *x_dm = tile_x_dm_q5_K;
917
+ *x_sc = tile_x_sc_q5_K;
918
+ }
919
+
920
+ template <int mmq_y, int nwarps, bool need_check>
921
+ static __dpct_inline__ void
922
+ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
923
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
924
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
925
+ const int &k, const int &blocks_per_row) {
926
+ (void)x_qh;
927
+
928
+ GGML_SYCL_ASSUME(i_offset >= 0);
929
+ GGML_SYCL_ASSUME(i_offset < nwarps);
930
+ GGML_SYCL_ASSUME(k >= 0);
931
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
932
+
933
+ const int kbx = k / QI5_K; // == 0 if QK_K == 256
934
+ const int kqsx = k % QI5_K; // == k if QK_K == 256
935
+
936
+ const block_q5_K * bx0 = (const block_q5_K *) vx;
937
+
938
+ #pragma unroll
939
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
940
+ int i = i0 + i_offset;
941
+
942
+ if (need_check) {
943
+ i = sycl::min(i, i_max);
944
+ }
945
+
946
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
947
+ const int ky = QR5_K*kqsx;
948
+
949
+ const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
950
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
951
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
952
+
953
+ const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
954
+ const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
955
+ const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
956
+
957
+ const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
958
+ const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
959
+
960
+ x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
961
+ x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
962
+ }
963
+
964
+ constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
965
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
966
+
967
+ #pragma unroll
968
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
969
+ int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
970
+
971
+ if (need_check) {
972
+ i = sycl::min(i, i_max);
973
+ }
974
+
975
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
976
+
977
+ #if QK_K == 256
978
+ x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
979
+ #endif
980
+ }
981
+
982
+ #pragma unroll
983
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
984
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
985
+
986
+ if (need_check) {
987
+ i = sycl::min(i, i_max);
988
+ }
989
+
990
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
991
+
992
+ const int * scales = (const int *) bxi->scales;
993
+
994
+ const int ksc = k % (WARP_SIZE/8);
995
+
996
+ // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
997
+ int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
998
+ scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
999
+
1000
+ x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
1001
+ }
1002
+ }
1003
+
1004
+ #define VDR_Q5_K_Q8_1_MMQ 8
1005
+
1006
+ // contiguous u/y values
1007
+ static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
1008
+ const int *__restrict__ v, const int *__restrict__ u,
1009
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
1010
+ const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
1011
+
1012
+ float sumf_d = 0.0f;
1013
+ float sumf_m = 0.0f;
1014
+
1015
+ #pragma unroll
1016
+ for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
1017
+ int sumi_d = 0;
1018
+
1019
+ #pragma unroll
1020
+ for (int j = 0; j < QI8_1; ++j) {
1021
+ sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
1022
+ sumi_d); // SIMD dot product
1023
+ }
1024
+
1025
+ const sycl::float2 ds8f =
1026
+ ds8[i].convert<float, sycl::rounding_mode::automatic>();
1027
+
1028
+ sumf_d += ds8f.x() * (sc[i] * sumi_d);
1029
+ sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
1030
+ }
1031
+
1032
+ const sycl::float2 dm4f =
1033
+ dm4.convert<float, sycl::rounding_mode::automatic>();
1034
+
1035
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
1036
+ }
1037
+
1038
+ static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
1039
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1040
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1041
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1042
+ const int &i, const int &j, const int &k) {
1043
+ (void)x_qh;
1044
+
1045
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
1046
+
1047
+ const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
1048
+ const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
1049
+ return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
1050
+ x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
1051
+ }
1052
+
1053
+ template <int mmq_y>
1054
+ static __dpct_inline__ void
1055
+ allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
1056
+ int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
1057
+ (void)x_qh;
1058
+
1059
+ *x_ql = tile_x_ql;
1060
+ *x_dm = tile_x_dm;
1061
+ *x_sc = tile_x_sc;
1062
+ }
1063
+
1064
+ template <int mmq_y, int nwarps, bool need_check>
1065
+ static __dpct_inline__ void
1066
+ load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
1067
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
1068
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
1069
+ const int &k, const int &blocks_per_row) {
1070
+ (void)x_qh;
1071
+
1072
+ GGML_SYCL_ASSUME(i_offset >= 0);
1073
+ GGML_SYCL_ASSUME(i_offset < nwarps);
1074
+ GGML_SYCL_ASSUME(k >= 0);
1075
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
1076
+
1077
+ const int kbx = k / QI6_K; // == 0 if QK_K == 256
1078
+ const int kqsx = k % QI6_K; // == k if QK_K == 256
1079
+
1080
+ const block_q6_K * bx0 = (const block_q6_K *) vx;
1081
+
1082
+ #pragma unroll
1083
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1084
+ int i = i0 + i_offset;
1085
+
1086
+ if (need_check) {
1087
+ i = sycl::min(i, i_max);
1088
+ }
1089
+
1090
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
1091
+ const int ky = QR6_K*kqsx;
1092
+
1093
+ const int ql = get_int_from_uint8(bxi->ql, kqsx);
1094
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1095
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1096
+
1097
+ const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1098
+ const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1099
+ const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
1100
+
1101
+ const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
1102
+ const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
1103
+
1104
+ x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
1105
+ dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
1106
+ dpct::sub_sat());
1107
+ x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
1108
+ dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
1109
+ dpct::sub_sat());
1110
+ }
1111
+
1112
+ constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1113
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
1114
+ float * x_dmf = (float *) x_dm;
1115
+
1116
+ #pragma unroll
1117
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1118
+ int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
1119
+
1120
+ if (need_check) {
1121
+ i = sycl::min(i, i_max);
1122
+ }
1123
+
1124
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
1125
+
1126
+ x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
1127
+ }
1128
+
1129
+ #pragma unroll
1130
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1131
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
1132
+
1133
+ if (need_check) {
1134
+ i = sycl::min(i, i_max);
1135
+ }
1136
+
1137
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
1138
+
1139
+ x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
1140
+ }
1141
+ }
1142
+
1143
+ #define VDR_Q6_K_Q8_1_MMQ 8
1144
+
1145
+ // contiguous u/y values
1146
+ static __dpct_inline__ float
1147
+ vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
1148
+ const int8_t *__restrict__ sc, const float &d6,
1149
+ const float *__restrict__ d8) {
1150
+
1151
+ float sumf_d = 0.0f;
1152
+
1153
+ #pragma unroll
1154
+ for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
1155
+ sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
1156
+
1157
+ #pragma unroll
1158
+ for (int i = i0; i < i0 + 2; ++i) {
1159
+ sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
1160
+ sumi_d.x()); // SIMD dot product
1161
+ sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
1162
+ sumi_d.x()); // SIMD dot product
1163
+
1164
+ sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
1165
+ sumi_d.y()); // SIMD dot product
1166
+ sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
1167
+ sumi_d.y()); // SIMD dot product
1168
+ }
1169
+
1170
+ sumf_d += d8[i0 / 4] *
1171
+ (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
1172
+ }
1173
+
1174
+ return d6 * sumf_d;
1175
+ }
1176
+
1177
+ static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
1178
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
1179
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
1180
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
1181
+ const int &i, const int &j, const int &k) {
1182
+ (void)x_qh;
1183
+
1184
+ const float * x_dmf = (const float *) x_dm;
1185
+ const float * y_df = (const float *) y_ds;
1186
+
1187
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
1188
+
1189
+ const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
1190
+ const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
1191
+ return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
1192
+ }
1193
+
1194
+ template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
1195
+ int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
1196
+ vec_dot_q_mul_mat_sycl_t vec_dot>
1197
+ /*
1198
+ DPCT1110:8: The total declared local variable size in device function mul_mat_q
1199
+ exceeds 128 bytes and may cause high register pressure. Consult with your
1200
+ hardware vendor to find the total register size available and adjust the code,
1201
+ or use smaller sub-group size to avoid high register pressure.
1202
+ */
1203
+ static __dpct_inline__ void
1204
+ mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
1205
+ float *__restrict__ dst, const int ncols_x, const int nrows_x,
1206
+ const int ncols_y, const int nrows_y, const int nrows_dst,
1207
+ int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
1208
+ int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
1209
+ sycl::half2 *tile_y_ds) {
1210
+
1211
+ const block_q_t * x = (const block_q_t *) vx;
1212
+ const block_q8_1 * y = (const block_q8_1 *) vy;
1213
+
1214
+ const int blocks_per_row_x = ncols_x / qk;
1215
+ const int blocks_per_col_y = nrows_y / QK8_1;
1216
+ const int blocks_per_warp = WARP_SIZE / qi;
1217
+
1218
+ const int & ncols_dst = ncols_y;
1219
+
1220
+ const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
1221
+ const int & row_x_0 = row_dst_0;
1222
+
1223
+ const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
1224
+ const int & col_y_0 = col_dst_0;
1225
+
1226
+ float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
1227
+
1228
+ for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
1229
+
1230
+ load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
1231
+ tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
1232
+ nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
1233
+ blocks_per_row_x);
1234
+
1235
+ #pragma unroll
1236
+ for (int ir = 0; ir < qr; ++ir) {
1237
+ const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
1238
+ const int kbxd = kqs / QI8_1;
1239
+
1240
+ #pragma unroll
1241
+ for (int i = 0; i < mmq_x; i += nwarps) {
1242
+ const int col_y_eff = dpct::min(
1243
+ (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
1244
+ ncols_y - 1); // to prevent out-of-bounds memory accesses
1245
+
1246
+ const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
1247
+
1248
+ const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
1249
+ kqs % WARP_SIZE;
1250
+ tile_y_qs[index_y] = get_int_from_int8_aligned(
1251
+ by0->qs, item_ct1.get_local_id(2) % QI8_1);
1252
+ }
1253
+
1254
+ #pragma unroll
1255
+ for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1256
+ const int ids =
1257
+ (ids0 + item_ct1.get_local_id(1) * QI8_1 +
1258
+ item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
1259
+ mmq_x;
1260
+ const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
1261
+ const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
1262
+
1263
+ // if the sum is not needed it's faster to transform the scale to f32 ahead of time
1264
+ const sycl::half2 *dsi_src =
1265
+ &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
1266
+ ir * (WARP_SIZE / QI8_1) + kby]
1267
+ .ds;
1268
+ sycl::half2 *dsi_dst =
1269
+ &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
1270
+ if (need_sum) {
1271
+ *dsi_dst = *dsi_src;
1272
+ } else {
1273
+ float * dfi_dst = (float *) dsi_dst;
1274
+ *dfi_dst = (*dsi_src)[0];
1275
+ }
1276
+ }
1277
+
1278
+ /*
1279
+ DPCT1118:9: SYCL group functions and algorithms must be encountered
1280
+ in converged control flow. You may need to adjust the code.
1281
+ */
1282
+ /*
1283
+ DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
1284
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1285
+ better performance if there is no access to global memory.
1286
+ */
1287
+ item_ct1.barrier();
1288
+
1289
+ // #pragma unroll // unrolling this loop causes too much register pressure
1290
+ for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
1291
+ #pragma unroll
1292
+ for (int j = 0; j < mmq_x; j += nwarps) {
1293
+ #pragma unroll
1294
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1295
+ sum[i / WARP_SIZE][j / nwarps] += vec_dot(
1296
+ tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1297
+ tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
1298
+ item_ct1.get_local_id(1) + j, k);
1299
+ }
1300
+ }
1301
+ }
1302
+
1303
+ /*
1304
+ DPCT1118:10: SYCL group functions and algorithms must be encountered
1305
+ in converged control flow. You may need to adjust the code.
1306
+ */
1307
+ /*
1308
+ DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
1309
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
1310
+ better performance if there is no access to global memory.
1311
+ */
1312
+ item_ct1.barrier();
1313
+ }
1314
+ }
1315
+
1316
+ #pragma unroll
1317
+ for (int j = 0; j < mmq_x; j += nwarps) {
1318
+ const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
1319
+
1320
+ if (col_dst >= ncols_dst) {
1321
+ return;
1322
+ }
1323
+
1324
+ #pragma unroll
1325
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
1326
+ const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
1327
+
1328
+ if (row_dst >= nrows_dst) {
1329
+ continue;
1330
+ }
1331
+
1332
+ dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
1333
+ }
1334
+ }
1335
+ }
1336
+
1337
+ #define MMQ_X_Q4_0_RDNA2 64
1338
+ #define MMQ_Y_Q4_0_RDNA2 128
1339
+ #define NWARPS_Q4_0_RDNA2 8
1340
+ #define MMQ_X_Q4_0_RDNA1 64
1341
+ #define MMQ_Y_Q4_0_RDNA1 64
1342
+ #define NWARPS_Q4_0_RDNA1 8
1343
+ #if defined(SYCL_USE_XMX)
1344
+ #define MMQ_X_Q4_0_AMPERE 4
1345
+ #define MMQ_Y_Q4_0_AMPERE 32
1346
+ #define NWARPS_Q4_0_AMPERE 4
1347
+ #else
1348
+ #define MMQ_X_Q4_0_AMPERE 64
1349
+ #define MMQ_Y_Q4_0_AMPERE 128
1350
+ #define NWARPS_Q4_0_AMPERE 4
1351
+ #endif
1352
+ #define MMQ_X_Q4_0_PASCAL 64
1353
+ #define MMQ_Y_Q4_0_PASCAL 64
1354
+ #define NWARPS_Q4_0_PASCAL 8
1355
+
1356
+ template <bool need_check> static void
1357
+ mul_mat_q4_0(
1358
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1359
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1360
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
1361
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1362
+ int * tile_x_ql = nullptr;
1363
+ sycl::half2 *tile_x_dm = nullptr;
1364
+ int * tile_x_qh = nullptr;
1365
+ int * tile_x_sc = nullptr;
1366
+
1367
+ //sycl_todo: change according to hardware
1368
+
1369
+ const int mmq_x = MMQ_X_Q4_0_AMPERE;
1370
+ const int mmq_y = MMQ_Y_Q4_0_AMPERE;
1371
+ const int nwarps = NWARPS_Q4_0_AMPERE;
1372
+ allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1373
+ tile_x_qs_q4_0, tile_x_d_q4_0);
1374
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
1375
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
1376
+ vec_dot_q4_0_q8_1_mul_mat>(
1377
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1378
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1379
+ }
1380
+
1381
+ #define MMQ_X_Q4_1_RDNA2 64
1382
+ #define MMQ_Y_Q4_1_RDNA2 128
1383
+ #define NWARPS_Q4_1_RDNA2 8
1384
+ #define MMQ_X_Q4_1_RDNA1 64
1385
+ #define MMQ_Y_Q4_1_RDNA1 64
1386
+ #define NWARPS_Q4_1_RDNA1 8
1387
+ #if defined(SYCL_USE_XMX)
1388
+ #define MMQ_X_Q4_1_AMPERE 4
1389
+ #define MMQ_Y_Q4_1_AMPERE 32
1390
+ #define NWARPS_Q4_1_AMPERE 4
1391
+ #else
1392
+ #define MMQ_X_Q4_1_AMPERE 64
1393
+ #define MMQ_Y_Q4_1_AMPERE 128
1394
+ #define NWARPS_Q4_1_AMPERE 4
1395
+ #endif
1396
+ #define MMQ_X_Q4_1_PASCAL 64
1397
+ #define MMQ_Y_Q4_1_PASCAL 64
1398
+ #define NWARPS_Q4_1_PASCAL 8
1399
+
1400
+ template <bool need_check> static void
1401
+ mul_mat_q4_1(
1402
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1403
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1404
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
1405
+ sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1406
+ int * tile_x_ql = nullptr;
1407
+ sycl::half2 *tile_x_dm = nullptr;
1408
+ int * tile_x_qh = nullptr;
1409
+ int * tile_x_sc = nullptr;
1410
+
1411
+ //sycl_todo: change according to hardware
1412
+ const int mmq_x = MMQ_X_Q4_1_AMPERE;
1413
+ const int mmq_y = MMQ_Y_Q4_1_AMPERE;
1414
+ const int nwarps = NWARPS_Q4_1_AMPERE;
1415
+ allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1416
+ tile_x_qs_q4_1, tile_x_dm_q4_1);
1417
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
1418
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
1419
+ vec_dot_q4_1_q8_1_mul_mat>(
1420
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1421
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1422
+ }
1423
+
1424
+ #define MMQ_X_Q5_0_RDNA2 64
1425
+ #define MMQ_Y_Q5_0_RDNA2 128
1426
+ #define NWARPS_Q5_0_RDNA2 8
1427
+ #define MMQ_X_Q5_0_RDNA1 64
1428
+ #define MMQ_Y_Q5_0_RDNA1 64
1429
+ #define NWARPS_Q5_0_RDNA1 8
1430
+ #if defined(SYCL_USE_XMX)
1431
+ #define MMQ_X_Q5_0_AMPERE 4
1432
+ #define MMQ_Y_Q5_0_AMPERE 32
1433
+ #define NWARPS_Q5_0_AMPERE 4
1434
+ #else
1435
+ #define MMQ_X_Q5_0_AMPERE 128
1436
+ #define MMQ_Y_Q5_0_AMPERE 64
1437
+ #define NWARPS_Q5_0_AMPERE 4
1438
+ #endif
1439
+ #define MMQ_X_Q5_0_PASCAL 64
1440
+ #define MMQ_Y_Q5_0_PASCAL 64
1441
+ #define NWARPS_Q5_0_PASCAL 8
1442
+
1443
+ template <bool need_check> static void
1444
+ mul_mat_q5_0(
1445
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1446
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1447
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
1448
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1449
+ int * tile_x_ql = nullptr;
1450
+ sycl::half2 *tile_x_dm = nullptr;
1451
+ int * tile_x_qh = nullptr;
1452
+ int * tile_x_sc = nullptr;
1453
+
1454
+ //sycl_todo: change according to hardware
1455
+ const int mmq_x = MMQ_X_Q5_0_AMPERE;
1456
+ const int mmq_y = MMQ_Y_Q5_0_AMPERE;
1457
+ const int nwarps = NWARPS_Q5_0_AMPERE;
1458
+ allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1459
+ tile_x_ql_q5_0, tile_x_d_q5_0);
1460
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
1461
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
1462
+ vec_dot_q5_0_q8_1_mul_mat>(
1463
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1464
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1465
+ }
1466
+
1467
+ #define MMQ_X_Q5_1_RDNA2 64
1468
+ #define MMQ_Y_Q5_1_RDNA2 128
1469
+ #define NWARPS_Q5_1_RDNA2 8
1470
+ #define MMQ_X_Q5_1_RDNA1 64
1471
+ #define MMQ_Y_Q5_1_RDNA1 64
1472
+ #define NWARPS_Q5_1_RDNA1 8
1473
+ #if defined(SYCL_USE_XMX)
1474
+ #define MMQ_X_Q5_1_AMPERE 4
1475
+ #define MMQ_Y_Q5_1_AMPERE 32
1476
+ #define NWARPS_Q5_1_AMPERE 4
1477
+ #else
1478
+ #define MMQ_X_Q5_1_AMPERE 128
1479
+ #define MMQ_Y_Q5_1_AMPERE 64
1480
+ #define NWARPS_Q5_1_AMPERE 4
1481
+ #endif
1482
+ #define MMQ_X_Q5_1_PASCAL 64
1483
+ #define MMQ_Y_Q5_1_PASCAL 64
1484
+ #define NWARPS_Q5_1_PASCAL 8
1485
+
1486
+ template <bool need_check> static void
1487
+ mul_mat_q5_1(
1488
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1489
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1490
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
1491
+ sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1492
+ int * tile_x_ql = nullptr;
1493
+ sycl::half2 *tile_x_dm = nullptr;
1494
+ int * tile_x_qh = nullptr;
1495
+ int * tile_x_sc = nullptr;
1496
+
1497
+ //sycl_todo: change according to hardware
1498
+ const int mmq_x = MMQ_X_Q5_1_AMPERE;
1499
+ const int mmq_y = MMQ_Y_Q5_1_AMPERE;
1500
+ const int nwarps = NWARPS_Q5_1_AMPERE;
1501
+ allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1502
+ tile_x_ql_q5_1, tile_x_dm_q5_1);
1503
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
1504
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
1505
+ vec_dot_q5_1_q8_1_mul_mat>(
1506
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1507
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1508
+ }
1509
+
1510
+ #define MMQ_X_Q8_0_RDNA2 64
1511
+ #define MMQ_Y_Q8_0_RDNA2 128
1512
+ #define NWARPS_Q8_0_RDNA2 8
1513
+ #define MMQ_X_Q8_0_RDNA1 64
1514
+ #define MMQ_Y_Q8_0_RDNA1 64
1515
+ #define NWARPS_Q8_0_RDNA1 8
1516
+ #if defined(SYCL_USE_XMX)
1517
+ #define MMQ_X_Q8_0_AMPERE 4
1518
+ #define MMQ_Y_Q8_0_AMPERE 32
1519
+ #define NWARPS_Q8_0_AMPERE 4
1520
+ #else
1521
+ #define MMQ_X_Q8_0_AMPERE 128
1522
+ #define MMQ_Y_Q8_0_AMPERE 64
1523
+ #define NWARPS_Q8_0_AMPERE 4
1524
+ #endif
1525
+ #define MMQ_X_Q8_0_PASCAL 64
1526
+ #define MMQ_Y_Q8_0_PASCAL 64
1527
+ #define NWARPS_Q8_0_PASCAL 8
1528
+
1529
+ template <bool need_check> static void
1530
+ mul_mat_q8_0(
1531
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1532
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1533
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
1534
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1535
+ int * tile_x_ql = nullptr;
1536
+ sycl::half2 *tile_x_dm = nullptr;
1537
+ int * tile_x_qh = nullptr;
1538
+ int * tile_x_sc = nullptr;
1539
+
1540
+ //sycl_todo: change according to hardware
1541
+ const int mmq_x = MMQ_X_Q8_0_AMPERE;
1542
+ const int mmq_y = MMQ_Y_Q8_0_AMPERE;
1543
+ const int nwarps = NWARPS_Q8_0_AMPERE;
1544
+ allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1545
+ tile_x_qs_q8_0, tile_x_d_q8_0);
1546
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
1547
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
1548
+ vec_dot_q8_0_q8_1_mul_mat>(
1549
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1550
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1551
+ }
1552
+
1553
+ #define MMQ_X_Q2_K_RDNA2 64
1554
+ #define MMQ_Y_Q2_K_RDNA2 128
1555
+ #define NWARPS_Q2_K_RDNA2 8
1556
+ #define MMQ_X_Q2_K_RDNA1 128
1557
+ #define MMQ_Y_Q2_K_RDNA1 32
1558
+ #define NWARPS_Q2_K_RDNA1 8
1559
+ #if defined(SYCL_USE_XMX)
1560
+ #define MMQ_X_Q2_K_AMPERE 4
1561
+ #define MMQ_Y_Q2_K_AMPERE 32
1562
+ #define NWARPS_Q2_K_AMPERE 4
1563
+ #else
1564
+ #define MMQ_X_Q2_K_AMPERE 64
1565
+ #define MMQ_Y_Q2_K_AMPERE 128
1566
+ #define NWARPS_Q2_K_AMPERE 4
1567
+ #endif
1568
+ #define MMQ_X_Q2_K_PASCAL 64
1569
+ #define MMQ_Y_Q2_K_PASCAL 64
1570
+ #define NWARPS_Q2_K_PASCAL 8
1571
+
1572
+ template <bool need_check> static void
1573
+ mul_mat_q2_K(
1574
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1575
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1576
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
1577
+ sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
1578
+ sycl::half2 *tile_y_ds) {
1579
+ int * tile_x_ql = nullptr;
1580
+ sycl::half2 *tile_x_dm = nullptr;
1581
+ int * tile_x_qh = nullptr;
1582
+ int * tile_x_sc = nullptr;
1583
+
1584
+ //sycl_todo: change according to hardware
1585
+ const int mmq_x = MMQ_X_Q2_K_AMPERE;
1586
+ const int mmq_y = MMQ_Y_Q2_K_AMPERE;
1587
+ const int nwarps = NWARPS_Q2_K_AMPERE;
1588
+ allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1589
+ tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
1590
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
1591
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
1592
+ vec_dot_q2_K_q8_1_mul_mat>(
1593
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1594
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1595
+ }
1596
+
1597
+ #define MMQ_X_Q3_K_RDNA2 128
1598
+ #define MMQ_Y_Q3_K_RDNA2 64
1599
+ #define NWARPS_Q3_K_RDNA2 8
1600
+ #define MMQ_X_Q3_K_RDNA1 32
1601
+ #define MMQ_Y_Q3_K_RDNA1 128
1602
+ #define NWARPS_Q3_K_RDNA1 8
1603
+ #if defined(SYCL_USE_XMX)
1604
+ #define MMQ_X_Q3_K_AMPERE 4
1605
+ #define MMQ_Y_Q3_K_AMPERE 32
1606
+ #define NWARPS_Q3_K_AMPERE 4
1607
+ #else
1608
+ #define MMQ_X_Q3_K_AMPERE 128
1609
+ #define MMQ_Y_Q3_K_AMPERE 128
1610
+ #define NWARPS_Q3_K_AMPERE 4
1611
+ #endif
1612
+ #define MMQ_X_Q3_K_PASCAL 64
1613
+ #define MMQ_Y_Q3_K_PASCAL 64
1614
+ #define NWARPS_Q3_K_PASCAL 8
1615
+
1616
+ template <bool need_check> static void
1617
+ mul_mat_q3_K(
1618
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1619
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1620
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
1621
+ sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
1622
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
1623
+ int * tile_x_ql = nullptr;
1624
+ sycl::half2 *tile_x_dm = nullptr;
1625
+ int * tile_x_qh = nullptr;
1626
+ int * tile_x_sc = nullptr;
1627
+
1628
+ //sycl_todo: change according to hardware
1629
+ const int mmq_x = MMQ_X_Q3_K_AMPERE;
1630
+ const int mmq_y = MMQ_Y_Q3_K_AMPERE;
1631
+ const int nwarps = NWARPS_Q3_K_AMPERE;
1632
+ allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1633
+ tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
1634
+ tile_x_sc_q3_K);
1635
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
1636
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
1637
+ vec_dot_q3_K_q8_1_mul_mat>(
1638
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1639
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1640
+ }
1641
+
1642
+ #define MMQ_X_Q4_K_RDNA2 64
1643
+ #define MMQ_Y_Q4_K_RDNA2 128
1644
+ #define NWARPS_Q4_K_RDNA2 8
1645
+ #define MMQ_X_Q4_K_RDNA1 32
1646
+ #define MMQ_Y_Q4_K_RDNA1 64
1647
+ #define NWARPS_Q4_K_RDNA1 8
1648
+ #if defined(SYCL_USE_XMX)
1649
+ #define MMQ_X_Q4_K_AMPERE 4
1650
+ #define MMQ_Y_Q4_K_AMPERE 32
1651
+ #define NWARPS_Q4_K_AMPERE 4
1652
+ #else
1653
+ #define MMQ_X_Q4_K_AMPERE 64
1654
+ #define MMQ_Y_Q4_K_AMPERE 128
1655
+ #define NWARPS_Q4_K_AMPERE 4
1656
+ #endif
1657
+ #define MMQ_X_Q4_K_PASCAL 64
1658
+ #define MMQ_Y_Q4_K_PASCAL 64
1659
+ #define NWARPS_Q4_K_PASCAL 8
1660
+
1661
+ template <bool need_check> static void
1662
+ mul_mat_q4_K(
1663
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1664
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1665
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
1666
+ sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
1667
+ sycl::half2 *tile_y_ds) {
1668
+ int * tile_x_ql = nullptr;
1669
+ sycl::half2 *tile_x_dm = nullptr;
1670
+ int * tile_x_qh = nullptr;
1671
+ int * tile_x_sc = nullptr;
1672
+
1673
+ //sycl_todo: change according to hardware
1674
+ const int mmq_x = MMQ_X_Q4_K_AMPERE;
1675
+ const int mmq_y = MMQ_Y_Q4_K_AMPERE;
1676
+ const int nwarps = NWARPS_Q4_K_AMPERE;
1677
+ allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1678
+ tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
1679
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
1680
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
1681
+ vec_dot_q4_K_q8_1_mul_mat>(
1682
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1683
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1684
+ }
1685
+
1686
+ #define MMQ_X_Q5_K_RDNA2 64
1687
+ #define MMQ_Y_Q5_K_RDNA2 128
1688
+ #define NWARPS_Q5_K_RDNA2 8
1689
+ #define MMQ_X_Q5_K_RDNA1 32
1690
+ #define MMQ_Y_Q5_K_RDNA1 64
1691
+ #define NWARPS_Q5_K_RDNA1 8
1692
+ #if defined(SYCL_USE_XMX)
1693
+ #define MMQ_X_Q5_K_AMPERE 4
1694
+ #define MMQ_Y_Q5_K_AMPERE 32
1695
+ #define NWARPS_Q5_K_AMPERE 4
1696
+ #else
1697
+ #define MMQ_X_Q5_K_AMPERE 64
1698
+ #define MMQ_Y_Q5_K_AMPERE 128
1699
+ #define NWARPS_Q5_K_AMPERE 4
1700
+ #endif
1701
+ #define MMQ_X_Q5_K_PASCAL 64
1702
+ #define MMQ_Y_Q5_K_PASCAL 64
1703
+ #define NWARPS_Q5_K_PASCAL 8
1704
+
1705
+ template <bool need_check> static void
1706
+ mul_mat_q5_K(
1707
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1708
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1709
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
1710
+ sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
1711
+ sycl::half2 *tile_y_ds) {
1712
+ int * tile_x_ql = nullptr;
1713
+ sycl::half2 *tile_x_dm = nullptr;
1714
+ int * tile_x_qh = nullptr;
1715
+ int * tile_x_sc = nullptr;
1716
+
1717
+ //sycl_todo: change according to hardware
1718
+ const int mmq_x = MMQ_X_Q5_K_AMPERE;
1719
+ const int mmq_y = MMQ_Y_Q5_K_AMPERE;
1720
+ const int nwarps = NWARPS_Q5_K_AMPERE;
1721
+ allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1722
+ tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
1723
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
1724
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
1725
+ vec_dot_q5_K_q8_1_mul_mat>(
1726
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1727
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1728
+ }
1729
+
1730
+ #define MMQ_X_Q6_K_RDNA2 64
1731
+ #define MMQ_Y_Q6_K_RDNA2 128
1732
+ #define NWARPS_Q6_K_RDNA2 8
1733
+ #define MMQ_X_Q6_K_RDNA1 32
1734
+ #define MMQ_Y_Q6_K_RDNA1 64
1735
+ #define NWARPS_Q6_K_RDNA1 8
1736
+ #if defined(SYCL_USE_XMX)
1737
+ #define MMQ_X_Q6_K_AMPERE 4
1738
+ #define MMQ_Y_Q6_K_AMPERE 32
1739
+ #define NWARPS_Q6_K_AMPERE 4
1740
+ #else
1741
+ #define MMQ_X_Q6_K_AMPERE 64
1742
+ #define MMQ_Y_Q6_K_AMPERE 64
1743
+ #define NWARPS_Q6_K_AMPERE 4
1744
+ #endif
1745
+ #define MMQ_X_Q6_K_PASCAL 64
1746
+ #define MMQ_Y_Q6_K_PASCAL 64
1747
+ #define NWARPS_Q6_K_PASCAL 8
1748
+
1749
+ template <bool need_check> static void
1750
+ mul_mat_q6_K(
1751
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1752
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
1753
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
1754
+ int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
1755
+ // int * tile_x_ql = nullptr;
1756
+ // sycl::half2 *tile_x_dm = nullptr;
1757
+ int * tile_x_qh = nullptr;
1758
+ // int * tile_x_sc = nullptr;
1759
+
1760
+ //sycl_todo: change according to hardware
1761
+ const int mmq_x = MMQ_X_Q6_K_AMPERE;
1762
+ const int mmq_y = MMQ_Y_Q6_K_AMPERE;
1763
+ const int nwarps = NWARPS_Q6_K_AMPERE;
1764
+ allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
1765
+ tile_x_ql, tile_x_dm, tile_x_sc);
1766
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
1767
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
1768
+ vec_dot_q6_K_q8_1_mul_mat>(
1769
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
1770
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
1771
+ }
1772
+
1773
+ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
1774
+ float *dst, const int ncols_x,
1775
+ const int nrows_x, const int ncols_y,
1776
+ const int nrows_y, const int nrows_dst,
1777
+ dpct::queue_ptr stream) try {
1778
+
1779
+ int id;
1780
+ SYCL_CHECK(
1781
+ CHECK_TRY_ERROR(id = get_current_device_id()));
1782
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
1783
+
1784
+ int mmq_x, mmq_y, nwarps;
1785
+ if (compute_capability >= VER_GEN13) {
1786
+ mmq_x = MMQ_X_Q4_0_RDNA2;
1787
+ mmq_y = MMQ_Y_Q4_0_RDNA2;
1788
+ nwarps = NWARPS_Q4_0_RDNA2;
1789
+ } else if (compute_capability >= VER_GEN12) {
1790
+ mmq_x = MMQ_X_Q4_0_RDNA1;
1791
+ mmq_y = MMQ_Y_Q4_0_RDNA1;
1792
+ nwarps = NWARPS_Q4_0_RDNA1;
1793
+ } else if (compute_capability >= VER_GEN9) {
1794
+ mmq_x = MMQ_X_Q4_0_AMPERE;
1795
+ mmq_y = MMQ_Y_Q4_0_AMPERE;
1796
+ nwarps = NWARPS_Q4_0_AMPERE;
1797
+ } else if (compute_capability >= VER_4VEC) {
1798
+ mmq_x = MMQ_X_Q4_0_PASCAL;
1799
+ mmq_y = MMQ_Y_Q4_0_PASCAL;
1800
+ nwarps = NWARPS_Q4_0_PASCAL;
1801
+ } else {
1802
+ GGML_ABORT("fatal error");
1803
+ }
1804
+
1805
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1806
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1807
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1808
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1809
+
1810
+ if (nrows_x % mmq_y == 0) {
1811
+ const bool need_check = false;
1812
+ /*
1813
+ DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
1814
+ the limit. To get the device limit, query
1815
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1816
+ */
1817
+ {
1818
+ dpct::has_capability_or_fail(stream->get_device(),
1819
+ {sycl::aspect::fp16});
1820
+
1821
+ stream->submit([&](sycl::handler &cgh) {
1822
+ sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1823
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1824
+ sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1825
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1826
+ cgh);
1827
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1828
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1829
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1830
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1831
+
1832
+ cgh.parallel_for(
1833
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1834
+ [=](sycl::nd_item<3> item_ct1) {
1835
+ mul_mat_q4_0<need_check>(
1836
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1837
+ nrows_dst, item_ct1,
1838
+ get_pointer(tile_x_qs_q4_0_acc_ct1),
1839
+ get_pointer(tile_x_d_q4_0_acc_ct1),
1840
+ get_pointer(tile_y_qs_acc_ct1),
1841
+ get_pointer(tile_y_ds_acc_ct1));
1842
+ });
1843
+ });
1844
+ }
1845
+ } else {
1846
+ const bool need_check = true;
1847
+ /*
1848
+ DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
1849
+ the limit. To get the device limit, query
1850
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1851
+ */
1852
+ {
1853
+ dpct::has_capability_or_fail(stream->get_device(),
1854
+ {sycl::aspect::fp16});
1855
+
1856
+ stream->submit([&](sycl::handler &cgh) {
1857
+ sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
1858
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
1859
+ sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
1860
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
1861
+ cgh);
1862
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1863
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1864
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1865
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1866
+
1867
+ cgh.parallel_for(
1868
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1869
+ [=](sycl::nd_item<3> item_ct1) {
1870
+ mul_mat_q4_0<need_check>(
1871
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1872
+ nrows_dst, item_ct1,
1873
+ get_pointer(tile_x_qs_q4_0_acc_ct1),
1874
+ get_pointer(tile_x_d_q4_0_acc_ct1),
1875
+ get_pointer(tile_y_qs_acc_ct1),
1876
+ get_pointer(tile_y_ds_acc_ct1));
1877
+ });
1878
+ });
1879
+ }
1880
+ }
1881
+ }
1882
+ catch (sycl::exception const &exc) {
1883
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1884
+ << ", line:" << __LINE__ << std::endl;
1885
+ std::exit(1);
1886
+ }
1887
+
1888
+ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
1889
+ float *dst, const int ncols_x,
1890
+ const int nrows_x, const int ncols_y,
1891
+ const int nrows_y, const int nrows_dst,
1892
+ dpct::queue_ptr stream) try {
1893
+
1894
+ int id;
1895
+ SYCL_CHECK(
1896
+ CHECK_TRY_ERROR(id = get_current_device_id()));
1897
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
1898
+
1899
+ int mmq_x, mmq_y, nwarps;
1900
+ if (compute_capability >= VER_GEN13) {
1901
+ mmq_x = MMQ_X_Q4_1_RDNA2;
1902
+ mmq_y = MMQ_Y_Q4_1_RDNA2;
1903
+ nwarps = NWARPS_Q4_1_RDNA2;
1904
+ } else if (compute_capability >= VER_GEN12) {
1905
+ mmq_x = MMQ_X_Q4_1_RDNA1;
1906
+ mmq_y = MMQ_Y_Q4_1_RDNA1;
1907
+ nwarps = NWARPS_Q4_1_RDNA1;
1908
+ } else if (compute_capability >= VER_GEN9) {
1909
+ mmq_x = MMQ_X_Q4_1_AMPERE;
1910
+ mmq_y = MMQ_Y_Q4_1_AMPERE;
1911
+ nwarps = NWARPS_Q4_1_AMPERE;
1912
+ } else if (compute_capability >= VER_4VEC) {
1913
+ mmq_x = MMQ_X_Q4_1_PASCAL;
1914
+ mmq_y = MMQ_Y_Q4_1_PASCAL;
1915
+ nwarps = NWARPS_Q4_1_PASCAL;
1916
+ } else {
1917
+ GGML_ABORT("fatal error");
1918
+ }
1919
+
1920
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
1921
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
1922
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
1923
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
1924
+
1925
+ if (nrows_x % mmq_y == 0) {
1926
+ const bool need_check = false;
1927
+ /*
1928
+ DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
1929
+ the limit. To get the device limit, query
1930
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1931
+ */
1932
+ {
1933
+ dpct::has_capability_or_fail(stream->get_device(),
1934
+ {sycl::aspect::fp16});
1935
+
1936
+ stream->submit([&](sycl::handler &cgh) {
1937
+ sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1938
+ sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1939
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1940
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1941
+ cgh);
1942
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1943
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1944
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1945
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1946
+
1947
+ cgh.parallel_for(
1948
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1949
+ [=](sycl::nd_item<3> item_ct1) {
1950
+ mul_mat_q4_1<need_check>(
1951
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1952
+ nrows_dst, item_ct1,
1953
+ get_pointer(tile_x_qs_q4_1_acc_ct1),
1954
+ get_pointer(tile_x_dm_q4_1_acc_ct1),
1955
+ get_pointer(tile_y_qs_acc_ct1),
1956
+ get_pointer(tile_y_ds_acc_ct1));
1957
+ });
1958
+ });
1959
+ }
1960
+ } else {
1961
+ const bool need_check = true;
1962
+ /*
1963
+ DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
1964
+ the limit. To get the device limit, query
1965
+ info::device::max_work_group_size. Adjust the work-group size if needed.
1966
+ */
1967
+ {
1968
+ dpct::has_capability_or_fail(stream->get_device(),
1969
+ {sycl::aspect::fp16});
1970
+
1971
+ stream->submit([&](sycl::handler &cgh) {
1972
+ sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
1973
+ sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
1974
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
1975
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
1976
+ cgh);
1977
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
1978
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
1979
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
1980
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
1981
+
1982
+ cgh.parallel_for(
1983
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1984
+ [=](sycl::nd_item<3> item_ct1) {
1985
+ mul_mat_q4_1<need_check>(
1986
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
1987
+ nrows_dst, item_ct1,
1988
+ get_pointer(tile_x_qs_q4_1_acc_ct1),
1989
+ get_pointer(tile_x_dm_q4_1_acc_ct1),
1990
+ get_pointer(tile_y_qs_acc_ct1),
1991
+ get_pointer(tile_y_ds_acc_ct1));
1992
+ });
1993
+ });
1994
+ }
1995
+ }
1996
+ }
1997
+ catch (sycl::exception const &exc) {
1998
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
1999
+ << ", line:" << __LINE__ << std::endl;
2000
+ std::exit(1);
2001
+ }
2002
+
2003
+ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
2004
+ float *dst, const int ncols_x,
2005
+ const int nrows_x, const int ncols_y,
2006
+ const int nrows_y, const int nrows_dst,
2007
+ dpct::queue_ptr stream) try {
2008
+
2009
+ int id;
2010
+ SYCL_CHECK(
2011
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2012
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2013
+
2014
+ int mmq_x, mmq_y, nwarps;
2015
+ if (compute_capability >= VER_GEN13) {
2016
+ mmq_x = MMQ_X_Q5_0_RDNA2;
2017
+ mmq_y = MMQ_Y_Q5_0_RDNA2;
2018
+ nwarps = NWARPS_Q5_0_RDNA2;
2019
+ } else if (compute_capability >= VER_GEN12) {
2020
+ mmq_x = MMQ_X_Q5_0_RDNA1;
2021
+ mmq_y = MMQ_Y_Q5_0_RDNA1;
2022
+ nwarps = NWARPS_Q5_0_RDNA1;
2023
+ } else if (compute_capability >= VER_GEN9) {
2024
+ mmq_x = MMQ_X_Q5_0_AMPERE;
2025
+ mmq_y = MMQ_Y_Q5_0_AMPERE;
2026
+ nwarps = NWARPS_Q5_0_AMPERE;
2027
+ } else if (compute_capability >= VER_4VEC) {
2028
+ mmq_x = MMQ_X_Q5_0_PASCAL;
2029
+ mmq_y = MMQ_Y_Q5_0_PASCAL;
2030
+ nwarps = NWARPS_Q5_0_PASCAL;
2031
+ } else {
2032
+ GGML_ABORT("fatal error");
2033
+ }
2034
+
2035
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2036
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2037
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2038
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2039
+
2040
+ if (nrows_x % mmq_y == 0) {
2041
+ const bool need_check = false;
2042
+ /*
2043
+ DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
2044
+ the limit. To get the device limit, query
2045
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2046
+ */
2047
+ {
2048
+ dpct::has_capability_or_fail(stream->get_device(),
2049
+ {sycl::aspect::fp16});
2050
+
2051
+ stream->submit([&](sycl::handler &cgh) {
2052
+ sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2053
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2054
+ sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2055
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2056
+ cgh);
2057
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2058
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2059
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2060
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2061
+
2062
+ cgh.parallel_for(
2063
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2064
+ [=](sycl::nd_item<3> item_ct1) {
2065
+ mul_mat_q5_0<need_check>(
2066
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2067
+ nrows_dst, item_ct1,
2068
+ get_pointer(tile_x_ql_q5_0_acc_ct1),
2069
+ get_pointer(tile_x_d_q5_0_acc_ct1),
2070
+ get_pointer(tile_y_qs_acc_ct1),
2071
+ get_pointer(tile_y_ds_acc_ct1));
2072
+ });
2073
+ });
2074
+ }
2075
+ } else {
2076
+ const bool need_check = true;
2077
+ /*
2078
+ DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
2079
+ the limit. To get the device limit, query
2080
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2081
+ */
2082
+ {
2083
+ dpct::has_capability_or_fail(stream->get_device(),
2084
+ {sycl::aspect::fp16});
2085
+
2086
+ stream->submit([&](sycl::handler &cgh) {
2087
+ sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
2088
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2089
+ sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
2090
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
2091
+ cgh);
2092
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2093
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2094
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2095
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2096
+
2097
+ cgh.parallel_for(
2098
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2099
+ [=](sycl::nd_item<3> item_ct1) {
2100
+ mul_mat_q5_0<need_check>(
2101
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2102
+ nrows_dst, item_ct1,
2103
+ get_pointer(tile_x_ql_q5_0_acc_ct1),
2104
+ get_pointer(tile_x_d_q5_0_acc_ct1),
2105
+ get_pointer(tile_y_qs_acc_ct1),
2106
+ get_pointer(tile_y_ds_acc_ct1));
2107
+ });
2108
+ });
2109
+ }
2110
+ }
2111
+ }
2112
+ catch (sycl::exception const &exc) {
2113
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2114
+ << ", line:" << __LINE__ << std::endl;
2115
+ std::exit(1);
2116
+ }
2117
+
2118
+ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
2119
+ float *dst, const int ncols_x,
2120
+ const int nrows_x, const int ncols_y,
2121
+ const int nrows_y, const int nrows_dst,
2122
+ dpct::queue_ptr stream) try {
2123
+
2124
+ int id;
2125
+ SYCL_CHECK(
2126
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2127
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2128
+
2129
+ int mmq_x, mmq_y, nwarps;
2130
+ if (compute_capability >= VER_GEN13) {
2131
+ mmq_x = MMQ_X_Q5_1_RDNA2;
2132
+ mmq_y = MMQ_Y_Q5_1_RDNA2;
2133
+ nwarps = NWARPS_Q5_1_RDNA2;
2134
+ } else if (compute_capability >= VER_GEN12) {
2135
+ mmq_x = MMQ_X_Q5_1_RDNA1;
2136
+ mmq_y = MMQ_Y_Q5_1_RDNA1;
2137
+ nwarps = NWARPS_Q5_1_RDNA1;
2138
+ } else if (compute_capability >= VER_GEN9) {
2139
+ mmq_x = MMQ_X_Q5_1_AMPERE;
2140
+ mmq_y = MMQ_Y_Q5_1_AMPERE;
2141
+ nwarps = NWARPS_Q5_1_AMPERE;
2142
+ } else if (compute_capability >= VER_4VEC) {
2143
+ mmq_x = MMQ_X_Q5_1_PASCAL;
2144
+ mmq_y = MMQ_Y_Q5_1_PASCAL;
2145
+ nwarps = NWARPS_Q5_1_PASCAL;
2146
+ } else {
2147
+ GGML_ABORT("fatal error");
2148
+ }
2149
+
2150
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2151
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2152
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2153
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2154
+
2155
+ if (nrows_x % mmq_y == 0) {
2156
+ const bool need_check = false;
2157
+ /*
2158
+ DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
2159
+ the limit. To get the device limit, query
2160
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2161
+ */
2162
+ {
2163
+ dpct::has_capability_or_fail(stream->get_device(),
2164
+ {sycl::aspect::fp16});
2165
+
2166
+ stream->submit([&](sycl::handler &cgh) {
2167
+ sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2168
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2169
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2170
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2171
+ cgh);
2172
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2173
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2174
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2175
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2176
+
2177
+ cgh.parallel_for(
2178
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2179
+ [=](sycl::nd_item<3> item_ct1) {
2180
+ mul_mat_q5_1<need_check>(
2181
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2182
+ nrows_dst, item_ct1,
2183
+ get_pointer(tile_x_ql_q5_1_acc_ct1),
2184
+ get_pointer(tile_x_dm_q5_1_acc_ct1),
2185
+ get_pointer(tile_y_qs_acc_ct1),
2186
+ get_pointer(tile_y_ds_acc_ct1));
2187
+ });
2188
+ });
2189
+ }
2190
+ } else {
2191
+ const bool need_check = true;
2192
+ /*
2193
+ DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
2194
+ the limit. To get the device limit, query
2195
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2196
+ */
2197
+ {
2198
+ dpct::has_capability_or_fail(stream->get_device(),
2199
+ {sycl::aspect::fp16});
2200
+
2201
+ stream->submit([&](sycl::handler &cgh) {
2202
+ sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
2203
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2204
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
2205
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
2206
+ cgh);
2207
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2208
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2209
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2210
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2211
+
2212
+ cgh.parallel_for(
2213
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2214
+ [=](sycl::nd_item<3> item_ct1) {
2215
+ mul_mat_q5_1<need_check>(
2216
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2217
+ nrows_dst, item_ct1,
2218
+ get_pointer(tile_x_ql_q5_1_acc_ct1),
2219
+ get_pointer(tile_x_dm_q5_1_acc_ct1),
2220
+ get_pointer(tile_y_qs_acc_ct1),
2221
+ get_pointer(tile_y_ds_acc_ct1));
2222
+ });
2223
+ });
2224
+ }
2225
+ }
2226
+ }
2227
+ catch (sycl::exception const &exc) {
2228
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2229
+ << ", line:" << __LINE__ << std::endl;
2230
+ std::exit(1);
2231
+ }
2232
+
2233
+ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
2234
+ float *dst, const int ncols_x,
2235
+ const int nrows_x, const int ncols_y,
2236
+ const int nrows_y, const int nrows_dst,
2237
+ dpct::queue_ptr stream) try {
2238
+
2239
+ int id;
2240
+ SYCL_CHECK(
2241
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2242
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2243
+
2244
+ int mmq_x, mmq_y, nwarps;
2245
+ if (compute_capability >= VER_GEN13) {
2246
+ mmq_x = MMQ_X_Q8_0_RDNA2;
2247
+ mmq_y = MMQ_Y_Q8_0_RDNA2;
2248
+ nwarps = NWARPS_Q8_0_RDNA2;
2249
+ } else if (compute_capability >= VER_GEN12) {
2250
+ mmq_x = MMQ_X_Q8_0_RDNA1;
2251
+ mmq_y = MMQ_Y_Q8_0_RDNA1;
2252
+ nwarps = NWARPS_Q8_0_RDNA1;
2253
+ } else if (compute_capability >= VER_GEN9) {
2254
+ mmq_x = MMQ_X_Q8_0_AMPERE;
2255
+ mmq_y = MMQ_Y_Q8_0_AMPERE;
2256
+ nwarps = NWARPS_Q8_0_AMPERE;
2257
+ } else if (compute_capability >= VER_4VEC) {
2258
+ mmq_x = MMQ_X_Q8_0_PASCAL;
2259
+ mmq_y = MMQ_Y_Q8_0_PASCAL;
2260
+ nwarps = NWARPS_Q8_0_PASCAL;
2261
+ } else {
2262
+ GGML_ABORT("fatal error");
2263
+ }
2264
+
2265
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2266
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2267
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2268
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2269
+
2270
+ if (nrows_x % mmq_y == 0) {
2271
+ const bool need_check = false;
2272
+ /*
2273
+ DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
2274
+ the limit. To get the device limit, query
2275
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2276
+ */
2277
+ {
2278
+ dpct::has_capability_or_fail(stream->get_device(),
2279
+ {sycl::aspect::fp16});
2280
+
2281
+ stream->submit([&](sycl::handler &cgh) {
2282
+ sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2283
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2284
+ sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2285
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2286
+ cgh);
2287
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2288
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2289
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2290
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2291
+
2292
+ cgh.parallel_for(
2293
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2294
+ [=](sycl::nd_item<3> item_ct1) {
2295
+ mul_mat_q8_0<need_check>(
2296
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2297
+ nrows_dst, item_ct1,
2298
+ get_pointer(tile_x_qs_q8_0_acc_ct1),
2299
+ get_pointer(tile_x_d_q8_0_acc_ct1),
2300
+ get_pointer(tile_y_qs_acc_ct1),
2301
+ get_pointer(tile_y_ds_acc_ct1));
2302
+ });
2303
+ });
2304
+ }
2305
+ } else {
2306
+ const bool need_check = true;
2307
+ /*
2308
+ DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
2309
+ the limit. To get the device limit, query
2310
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2311
+ */
2312
+ {
2313
+ dpct::has_capability_or_fail(stream->get_device(),
2314
+ {sycl::aspect::fp16});
2315
+
2316
+ stream->submit([&](sycl::handler &cgh) {
2317
+ sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
2318
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2319
+ sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
2320
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
2321
+ cgh);
2322
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2323
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2324
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2325
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2326
+
2327
+ cgh.parallel_for(
2328
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2329
+ [=](sycl::nd_item<3> item_ct1) {
2330
+ mul_mat_q8_0<need_check>(
2331
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2332
+ nrows_dst, item_ct1,
2333
+ get_pointer(tile_x_qs_q8_0_acc_ct1),
2334
+ get_pointer(tile_x_d_q8_0_acc_ct1),
2335
+ get_pointer(tile_y_qs_acc_ct1),
2336
+ get_pointer(tile_y_ds_acc_ct1));
2337
+ });
2338
+ });
2339
+ }
2340
+ }
2341
+ }
2342
+ catch (sycl::exception const &exc) {
2343
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2344
+ << ", line:" << __LINE__ << std::endl;
2345
+ std::exit(1);
2346
+ }
2347
+
2348
+ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
2349
+ float *dst, const int ncols_x,
2350
+ const int nrows_x, const int ncols_y,
2351
+ const int nrows_y, const int nrows_dst,
2352
+ dpct::queue_ptr stream) try {
2353
+
2354
+ int id;
2355
+ SYCL_CHECK(
2356
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2357
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2358
+
2359
+ int mmq_x, mmq_y, nwarps;
2360
+ if (compute_capability >= VER_GEN13) {
2361
+ mmq_x = MMQ_X_Q2_K_RDNA2;
2362
+ mmq_y = MMQ_Y_Q2_K_RDNA2;
2363
+ nwarps = NWARPS_Q2_K_RDNA2;
2364
+ } else if (compute_capability >= VER_GEN12) {
2365
+ mmq_x = MMQ_X_Q2_K_RDNA1;
2366
+ mmq_y = MMQ_Y_Q2_K_RDNA1;
2367
+ nwarps = NWARPS_Q2_K_RDNA1;
2368
+ } else if (compute_capability >= VER_GEN9) {
2369
+ mmq_x = MMQ_X_Q2_K_AMPERE;
2370
+ mmq_y = MMQ_Y_Q2_K_AMPERE;
2371
+ nwarps = NWARPS_Q2_K_AMPERE;
2372
+ } else if (compute_capability >= VER_4VEC) {
2373
+ mmq_x = MMQ_X_Q2_K_PASCAL;
2374
+ mmq_y = MMQ_Y_Q2_K_PASCAL;
2375
+ nwarps = NWARPS_Q2_K_PASCAL;
2376
+ } else {
2377
+ GGML_ABORT("fatal error");
2378
+ }
2379
+
2380
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2381
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2382
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2383
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2384
+
2385
+ if (nrows_x % mmq_y == 0) {
2386
+ const bool need_check = false;
2387
+ /*
2388
+ DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
2389
+ the limit. To get the device limit, query
2390
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2391
+ */
2392
+ {
2393
+ dpct::has_capability_or_fail(stream->get_device(),
2394
+ {sycl::aspect::fp16});
2395
+
2396
+ stream->submit([&](sycl::handler &cgh) {
2397
+ sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2398
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2399
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2400
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2401
+ cgh);
2402
+ sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2403
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2404
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2405
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2406
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2407
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2408
+
2409
+ cgh.parallel_for(
2410
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2411
+ [=](sycl::nd_item<3> item_ct1) {
2412
+ mul_mat_q2_K<need_check>(
2413
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2414
+ nrows_dst, item_ct1,
2415
+ get_pointer(tile_x_ql_q2_K_acc_ct1),
2416
+ get_pointer(tile_x_dm_q2_K_acc_ct1),
2417
+ get_pointer(tile_x_sc_q2_K_acc_ct1),
2418
+ get_pointer(tile_y_qs_acc_ct1),
2419
+ get_pointer(tile_y_ds_acc_ct1));
2420
+ });
2421
+ });
2422
+ }
2423
+ } else {
2424
+ const bool need_check = true;
2425
+ /*
2426
+ DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
2427
+ the limit. To get the device limit, query
2428
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2429
+ */
2430
+ {
2431
+ dpct::has_capability_or_fail(stream->get_device(),
2432
+ {sycl::aspect::fp16});
2433
+
2434
+ stream->submit([&](sycl::handler &cgh) {
2435
+ sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
2436
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2437
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
2438
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
2439
+ cgh);
2440
+ sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
2441
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2442
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2443
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2444
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2445
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2446
+
2447
+ cgh.parallel_for(
2448
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2449
+ [=](sycl::nd_item<3> item_ct1) {
2450
+ mul_mat_q2_K<need_check>(
2451
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2452
+ nrows_dst, item_ct1,
2453
+ get_pointer(tile_x_ql_q2_K_acc_ct1),
2454
+ get_pointer(tile_x_dm_q2_K_acc_ct1),
2455
+ get_pointer(tile_x_sc_q2_K_acc_ct1),
2456
+ get_pointer(tile_y_qs_acc_ct1),
2457
+ get_pointer(tile_y_ds_acc_ct1));
2458
+ });
2459
+ });
2460
+ }
2461
+ }
2462
+ }
2463
+ catch (sycl::exception const &exc) {
2464
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2465
+ << ", line:" << __LINE__ << std::endl;
2466
+ std::exit(1);
2467
+ }
2468
+
2469
+ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
2470
+ float *dst, const int ncols_x,
2471
+ const int nrows_x, const int ncols_y,
2472
+ const int nrows_y, const int nrows_dst,
2473
+ dpct::queue_ptr stream) try {
2474
+
2475
+ #if QK_K == 256
2476
+
2477
+ int id;
2478
+ SYCL_CHECK(
2479
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2480
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2481
+
2482
+ int mmq_x, mmq_y, nwarps;
2483
+ if (compute_capability >= VER_GEN13) {
2484
+ mmq_x = MMQ_X_Q3_K_RDNA2;
2485
+ mmq_y = MMQ_Y_Q3_K_RDNA2;
2486
+ nwarps = NWARPS_Q3_K_RDNA2;
2487
+ } else if (compute_capability >= VER_GEN12) {
2488
+ mmq_x = MMQ_X_Q3_K_RDNA1;
2489
+ mmq_y = MMQ_Y_Q3_K_RDNA1;
2490
+ nwarps = NWARPS_Q3_K_RDNA1;
2491
+ } else if (compute_capability >= VER_GEN9) {
2492
+ mmq_x = MMQ_X_Q3_K_AMPERE;
2493
+ mmq_y = MMQ_Y_Q3_K_AMPERE;
2494
+ nwarps = NWARPS_Q3_K_AMPERE;
2495
+ } else if (compute_capability >= VER_4VEC) {
2496
+ mmq_x = MMQ_X_Q3_K_PASCAL;
2497
+ mmq_y = MMQ_Y_Q3_K_PASCAL;
2498
+ nwarps = NWARPS_Q3_K_PASCAL;
2499
+ } else {
2500
+ GGML_ABORT("fatal error");
2501
+ }
2502
+
2503
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2504
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2505
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2506
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2507
+
2508
+ if (nrows_x % mmq_y == 0) {
2509
+ const bool need_check = false;
2510
+ /*
2511
+ DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
2512
+ the limit. To get the device limit, query
2513
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2514
+ */
2515
+ {
2516
+ dpct::has_capability_or_fail(stream->get_device(),
2517
+ {sycl::aspect::fp16});
2518
+
2519
+ stream->submit([&](sycl::handler &cgh) {
2520
+ sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2521
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2522
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2523
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2524
+ cgh);
2525
+ sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2526
+ sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2527
+ sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2528
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2529
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2530
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2531
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2532
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2533
+
2534
+ cgh.parallel_for(
2535
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2536
+ [=](sycl::nd_item<3> item_ct1) {
2537
+ mul_mat_q3_K<need_check>(
2538
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2539
+ nrows_dst, item_ct1,
2540
+ get_pointer(tile_x_ql_q3_K_acc_ct1),
2541
+ get_pointer(tile_x_dm_q3_K_acc_ct1),
2542
+ get_pointer(tile_x_qh_q3_K_acc_ct1),
2543
+ get_pointer(tile_x_sc_q3_K_acc_ct1),
2544
+ get_pointer(tile_y_qs_acc_ct1),
2545
+ get_pointer(tile_y_ds_acc_ct1));
2546
+ });
2547
+ });
2548
+ }
2549
+ } else {
2550
+ const bool need_check = true;
2551
+ /*
2552
+ DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
2553
+ the limit. To get the device limit, query
2554
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2555
+ */
2556
+ {
2557
+ dpct::has_capability_or_fail(stream->get_device(),
2558
+ {sycl::aspect::fp16});
2559
+
2560
+ stream->submit([&](sycl::handler &cgh) {
2561
+ sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
2562
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2563
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
2564
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
2565
+ cgh);
2566
+ sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
2567
+ sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
2568
+ sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
2569
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
2570
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2571
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2572
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2573
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2574
+
2575
+ cgh.parallel_for(
2576
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2577
+ [=](sycl::nd_item<3> item_ct1) {
2578
+ mul_mat_q3_K<need_check>(
2579
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2580
+ nrows_dst, item_ct1,
2581
+ get_pointer(tile_x_ql_q3_K_acc_ct1),
2582
+ get_pointer(tile_x_dm_q3_K_acc_ct1),
2583
+ get_pointer(tile_x_qh_q3_K_acc_ct1),
2584
+ get_pointer(tile_x_sc_q3_K_acc_ct1),
2585
+ get_pointer(tile_y_qs_acc_ct1),
2586
+ get_pointer(tile_y_ds_acc_ct1));
2587
+ });
2588
+ });
2589
+ }
2590
+ }
2591
+ #endif
2592
+ }
2593
+ catch (sycl::exception const &exc) {
2594
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2595
+ << ", line:" << __LINE__ << std::endl;
2596
+ std::exit(1);
2597
+ }
2598
+
2599
+ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
2600
+ float *dst, const int ncols_x,
2601
+ const int nrows_x, const int ncols_y,
2602
+ const int nrows_y, const int nrows_dst,
2603
+ dpct::queue_ptr stream) try {
2604
+
2605
+ int id;
2606
+ SYCL_CHECK(
2607
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2608
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2609
+
2610
+ int mmq_x, mmq_y, nwarps;
2611
+ if (compute_capability >= VER_GEN13) {
2612
+ mmq_x = MMQ_X_Q4_K_RDNA2;
2613
+ mmq_y = MMQ_Y_Q4_K_RDNA2;
2614
+ nwarps = NWARPS_Q4_K_RDNA2;
2615
+ } else if (compute_capability >= VER_GEN12) {
2616
+ mmq_x = MMQ_X_Q4_K_RDNA1;
2617
+ mmq_y = MMQ_Y_Q4_K_RDNA1;
2618
+ nwarps = NWARPS_Q4_K_RDNA1;
2619
+ } else if (compute_capability >= VER_GEN9) {
2620
+ mmq_x = MMQ_X_Q4_K_AMPERE;
2621
+ mmq_y = MMQ_Y_Q4_K_AMPERE;
2622
+ nwarps = NWARPS_Q4_K_AMPERE;
2623
+ } else if (compute_capability >= VER_4VEC) {
2624
+ mmq_x = MMQ_X_Q4_K_PASCAL;
2625
+ mmq_y = MMQ_Y_Q4_K_PASCAL;
2626
+ nwarps = NWARPS_Q4_K_PASCAL;
2627
+ } else {
2628
+ GGML_ABORT("fatal error");
2629
+ }
2630
+
2631
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2632
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2633
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2634
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2635
+
2636
+ if (nrows_x % mmq_y == 0) {
2637
+ const bool need_check = false;
2638
+ /*
2639
+ DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
2640
+ the limit. To get the device limit, query
2641
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2642
+ */
2643
+ {
2644
+ dpct::has_capability_or_fail(stream->get_device(),
2645
+ {sycl::aspect::fp16});
2646
+
2647
+ stream->submit([&](sycl::handler &cgh) {
2648
+ sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2649
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2650
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2651
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2652
+ cgh);
2653
+ sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2654
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2655
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2656
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2657
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2658
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2659
+
2660
+ cgh.parallel_for(
2661
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2662
+ [=](sycl::nd_item<3> item_ct1) {
2663
+ mul_mat_q4_K<need_check>(
2664
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2665
+ nrows_dst, item_ct1,
2666
+ get_pointer(tile_x_ql_q4_K_acc_ct1),
2667
+ get_pointer(tile_x_dm_q4_K_acc_ct1),
2668
+ get_pointer(tile_x_sc_q4_K_acc_ct1),
2669
+ get_pointer(tile_y_qs_acc_ct1),
2670
+ get_pointer(tile_y_ds_acc_ct1));
2671
+ });
2672
+ });
2673
+ }
2674
+ } else {
2675
+ const bool need_check = true;
2676
+ /*
2677
+ DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
2678
+ the limit. To get the device limit, query
2679
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2680
+ */
2681
+ {
2682
+ dpct::has_capability_or_fail(stream->get_device(),
2683
+ {sycl::aspect::fp16});
2684
+
2685
+ stream->submit([&](sycl::handler &cgh) {
2686
+ sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
2687
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
2688
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
2689
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
2690
+ cgh);
2691
+ sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
2692
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2693
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2694
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2695
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2696
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2697
+
2698
+ cgh.parallel_for(
2699
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2700
+ [=](sycl::nd_item<3> item_ct1) {
2701
+ mul_mat_q4_K<need_check>(
2702
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2703
+ nrows_dst, item_ct1,
2704
+ get_pointer(tile_x_ql_q4_K_acc_ct1),
2705
+ get_pointer(tile_x_dm_q4_K_acc_ct1),
2706
+ get_pointer(tile_x_sc_q4_K_acc_ct1),
2707
+ get_pointer(tile_y_qs_acc_ct1),
2708
+ get_pointer(tile_y_ds_acc_ct1));
2709
+ });
2710
+ });
2711
+ }
2712
+ }
2713
+ }
2714
+ catch (sycl::exception const &exc) {
2715
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2716
+ << ", line:" << __LINE__ << std::endl;
2717
+ std::exit(1);
2718
+ }
2719
+
2720
+ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
2721
+ float *dst, const int ncols_x,
2722
+ const int nrows_x, const int ncols_y,
2723
+ const int nrows_y, const int nrows_dst,
2724
+ dpct::queue_ptr stream) try {
2725
+
2726
+ int id;
2727
+ SYCL_CHECK(
2728
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2729
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2730
+
2731
+ int mmq_x, mmq_y, nwarps;
2732
+ if (compute_capability >= VER_GEN13) {
2733
+ mmq_x = MMQ_X_Q5_K_RDNA2;
2734
+ mmq_y = MMQ_Y_Q5_K_RDNA2;
2735
+ nwarps = NWARPS_Q5_K_RDNA2;
2736
+ } else if (compute_capability >= VER_GEN12) {
2737
+ mmq_x = MMQ_X_Q5_K_RDNA1;
2738
+ mmq_y = MMQ_Y_Q5_K_RDNA1;
2739
+ nwarps = NWARPS_Q5_K_RDNA1;
2740
+ } else if (compute_capability >= VER_GEN9) {
2741
+ mmq_x = MMQ_X_Q5_K_AMPERE;
2742
+ mmq_y = MMQ_Y_Q5_K_AMPERE;
2743
+ nwarps = NWARPS_Q5_K_AMPERE;
2744
+ } else if (compute_capability >= VER_4VEC) {
2745
+ mmq_x = MMQ_X_Q5_K_PASCAL;
2746
+ mmq_y = MMQ_Y_Q5_K_PASCAL;
2747
+ nwarps = NWARPS_Q5_K_PASCAL;
2748
+ } else {
2749
+ GGML_ABORT("fatal error");
2750
+ }
2751
+
2752
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2753
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2754
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2755
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2756
+
2757
+ if (nrows_x % mmq_y == 0) {
2758
+ const bool need_check = false;
2759
+ /*
2760
+ DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
2761
+ the limit. To get the device limit, query
2762
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2763
+ */
2764
+ {
2765
+ dpct::has_capability_or_fail(stream->get_device(),
2766
+ {sycl::aspect::fp16});
2767
+
2768
+ stream->submit([&](sycl::handler &cgh) {
2769
+ sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2770
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2771
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2772
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2773
+ cgh);
2774
+ sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2775
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2776
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2777
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2778
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2779
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2780
+
2781
+ cgh.parallel_for(
2782
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2783
+ [=](sycl::nd_item<3> item_ct1) {
2784
+ mul_mat_q5_K<need_check>(
2785
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2786
+ nrows_dst, item_ct1,
2787
+ get_pointer(tile_x_ql_q5_K_acc_ct1),
2788
+ get_pointer(tile_x_dm_q5_K_acc_ct1),
2789
+ get_pointer(tile_x_sc_q5_K_acc_ct1),
2790
+ get_pointer(tile_y_qs_acc_ct1),
2791
+ get_pointer(tile_y_ds_acc_ct1));
2792
+ });
2793
+ });
2794
+ }
2795
+ } else {
2796
+ const bool need_check = true;
2797
+ /*
2798
+ DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
2799
+ the limit. To get the device limit, query
2800
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2801
+ */
2802
+ {
2803
+ dpct::has_capability_or_fail(stream->get_device(),
2804
+ {sycl::aspect::fp16});
2805
+
2806
+ stream->submit([&](sycl::handler &cgh) {
2807
+ sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
2808
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2809
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
2810
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
2811
+ cgh);
2812
+ sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
2813
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2814
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2815
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2816
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2817
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2818
+
2819
+ cgh.parallel_for(
2820
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2821
+ [=](sycl::nd_item<3> item_ct1) {
2822
+ mul_mat_q5_K<need_check>(
2823
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2824
+ nrows_dst, item_ct1,
2825
+ get_pointer(tile_x_ql_q5_K_acc_ct1),
2826
+ get_pointer(tile_x_dm_q5_K_acc_ct1),
2827
+ get_pointer(tile_x_sc_q5_K_acc_ct1),
2828
+ get_pointer(tile_y_qs_acc_ct1),
2829
+ get_pointer(tile_y_ds_acc_ct1));
2830
+ });
2831
+ });
2832
+ }
2833
+ }
2834
+ }
2835
+ catch (sycl::exception const &exc) {
2836
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2837
+ << ", line:" << __LINE__ << std::endl;
2838
+ std::exit(1);
2839
+ }
2840
+
2841
+ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
2842
+ float *dst, const int ncols_x,
2843
+ const int nrows_x, const int ncols_y,
2844
+ const int nrows_y, const int nrows_dst,
2845
+ dpct::queue_ptr stream) try {
2846
+
2847
+ int id;
2848
+ SYCL_CHECK(
2849
+ CHECK_TRY_ERROR(id = get_current_device_id()));
2850
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
2851
+
2852
+ int mmq_x, mmq_y, nwarps;
2853
+ if (compute_capability >= VER_GEN13) {
2854
+ mmq_x = MMQ_X_Q6_K_RDNA2;
2855
+ mmq_y = MMQ_Y_Q6_K_RDNA2;
2856
+ nwarps = NWARPS_Q6_K_RDNA2;
2857
+ } else if (compute_capability >= VER_GEN12) {
2858
+ mmq_x = MMQ_X_Q6_K_RDNA1;
2859
+ mmq_y = MMQ_Y_Q6_K_RDNA1;
2860
+ nwarps = NWARPS_Q6_K_RDNA1;
2861
+ } else if (compute_capability >= VER_GEN9) {
2862
+ mmq_x = MMQ_X_Q6_K_AMPERE;
2863
+ mmq_y = MMQ_Y_Q6_K_AMPERE;
2864
+ nwarps = NWARPS_Q6_K_AMPERE;
2865
+ } else if (compute_capability >= VER_4VEC) {
2866
+ mmq_x = MMQ_X_Q6_K_PASCAL;
2867
+ mmq_y = MMQ_Y_Q6_K_PASCAL;
2868
+ nwarps = NWARPS_Q6_K_PASCAL;
2869
+ } else {
2870
+ GGML_ABORT("fatal error");
2871
+ }
2872
+
2873
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
2874
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
2875
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
2876
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
2877
+
2878
+ if (nrows_x % mmq_y == 0) {
2879
+ const bool need_check = false;
2880
+ /*
2881
+ DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
2882
+ the limit. To get the device limit, query
2883
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2884
+ */
2885
+ {
2886
+ dpct::has_capability_or_fail(stream->get_device(),
2887
+ {sycl::aspect::fp16});
2888
+
2889
+ stream->submit([&](sycl::handler &cgh) {
2890
+ sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2891
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2892
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2893
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2894
+ cgh);
2895
+ sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2896
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2897
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2898
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2899
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2900
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2901
+
2902
+ cgh.parallel_for(
2903
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2904
+ [=](sycl::nd_item<3> item_ct1) {
2905
+ mul_mat_q6_K<need_check>(
2906
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2907
+ nrows_dst, item_ct1,
2908
+ get_pointer(tile_x_ql_acc_ct1),
2909
+ get_pointer(tile_x_dm_acc_ct1),
2910
+ get_pointer(tile_x_sc_acc_ct1),
2911
+ get_pointer(tile_y_qs_acc_ct1),
2912
+ get_pointer(tile_y_ds_acc_ct1));
2913
+ });
2914
+ });
2915
+ }
2916
+ } else {
2917
+ const bool need_check = true;
2918
+ /*
2919
+ DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
2920
+ the limit. To get the device limit, query
2921
+ info::device::max_work_group_size. Adjust the work-group size if needed.
2922
+ */
2923
+ {
2924
+ dpct::has_capability_or_fail(stream->get_device(),
2925
+ {sycl::aspect::fp16});
2926
+
2927
+ stream->submit([&](sycl::handler &cgh) {
2928
+ sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
2929
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
2930
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
2931
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
2932
+ cgh);
2933
+ sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
2934
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
2935
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
2936
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
2937
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
2938
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
2939
+
2940
+ cgh.parallel_for(
2941
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2942
+ [=](sycl::nd_item<3> item_ct1) {
2943
+ mul_mat_q6_K<need_check>(
2944
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
2945
+ nrows_dst, item_ct1,
2946
+ get_pointer(tile_x_ql_acc_ct1),
2947
+ get_pointer(tile_x_dm_acc_ct1),
2948
+ get_pointer(tile_x_sc_acc_ct1),
2949
+ get_pointer(tile_y_qs_acc_ct1),
2950
+ get_pointer(tile_y_ds_acc_ct1));
2951
+ });
2952
+ });
2953
+ }
2954
+ }
2955
+ }
2956
+ catch (sycl::exception const &exc) {
2957
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2958
+ << ", line:" << __LINE__ << std::endl;
2959
+ std::exit(1);
2960
+ }
2961
+
2962
+ void ggml_sycl_op_mul_mat_q(
2963
+ ggml_backend_sycl_context & ctx,
2964
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
2965
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
2966
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
2967
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
2968
+ const dpct::queue_ptr &stream) try {
2969
+
2970
+ const int64_t ne00 = src0->ne[0];
2971
+
2972
+ const int64_t ne10 = src1->ne[0];
2973
+ GGML_ASSERT(ne10 % QK8_1 == 0);
2974
+
2975
+ const int64_t ne0 = dst->ne[0];
2976
+
2977
+ const int64_t row_diff = row_high - row_low;
2978
+
2979
+ int device_id;
2980
+ SYCL_CHECK(
2981
+ CHECK_TRY_ERROR(device_id = get_current_device_id()));
2982
+
2983
+ // the main device has a larger memory buffer to hold the results from all GPUs
2984
+ // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
2985
+ const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
2986
+
2987
+ switch (src0->type) {
2988
+ case GGML_TYPE_Q4_0:
2989
+ ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2990
+ break;
2991
+ case GGML_TYPE_Q4_1:
2992
+ ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2993
+ break;
2994
+ case GGML_TYPE_Q5_0:
2995
+ ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2996
+ break;
2997
+ case GGML_TYPE_Q5_1:
2998
+ ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
2999
+ break;
3000
+ case GGML_TYPE_Q8_0:
3001
+ ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3002
+ break;
3003
+ case GGML_TYPE_Q2_K:
3004
+ ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3005
+ break;
3006
+ case GGML_TYPE_Q3_K:
3007
+ ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3008
+ break;
3009
+ case GGML_TYPE_Q4_K:
3010
+ ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3011
+ break;
3012
+ case GGML_TYPE_Q5_K:
3013
+ ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3014
+ break;
3015
+ case GGML_TYPE_Q6_K:
3016
+ ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
3017
+ break;
3018
+ default:
3019
+ GGML_ABORT("fatal error");
3020
+ break;
3021
+ }
3022
+
3023
+ GGML_UNUSED(src1);
3024
+ GGML_UNUSED(dst);
3025
+ GGML_UNUSED(src1_ddf_i);
3026
+ }
3027
+ catch (sycl::exception const &exc) {
3028
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3029
+ << ", line:" << __LINE__ << std::endl;
3030
+ std::exit(1);
3031
+ }