whispercpp 1.3.0 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -0,0 +1,3031 @@
|
|
1
|
+
//
|
2
|
+
// MIT license
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
4
|
+
// SPDX-License-Identifier: MIT
|
5
|
+
//
|
6
|
+
|
7
|
+
//
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
11
|
+
//
|
12
|
+
|
13
|
+
#include "mmq.hpp"
|
14
|
+
#include "vecdotq.hpp"
|
15
|
+
|
16
|
+
typedef void (*allocate_tiles_sycl_t)(
|
17
|
+
int** x_ql,
|
18
|
+
sycl::half2** x_dm,
|
19
|
+
int** x_qh,
|
20
|
+
int** x_sc);
|
21
|
+
typedef void (*load_tiles_sycl_t)(
|
22
|
+
const void* __restrict__ vx,
|
23
|
+
int* __restrict__ x_ql,
|
24
|
+
sycl::half2* __restrict__ x_dm,
|
25
|
+
int* __restrict__ x_qh,
|
26
|
+
int* __restrict__ x_sc,
|
27
|
+
const int& i_offset,
|
28
|
+
const int& i_max,
|
29
|
+
const int& k,
|
30
|
+
const int& blocks_per_row);
|
31
|
+
typedef float (*vec_dot_q_mul_mat_sycl_t)(
|
32
|
+
const int* __restrict__ x_ql,
|
33
|
+
const sycl::half2* __restrict__ x_dm,
|
34
|
+
const int* __restrict__ x_qh,
|
35
|
+
const int* __restrict__ x_sc,
|
36
|
+
const int* __restrict__ y_qs,
|
37
|
+
const sycl::half2* __restrict__ y_ms,
|
38
|
+
const int& i,
|
39
|
+
const int& j,
|
40
|
+
const int& k);
|
41
|
+
|
42
|
+
|
43
|
+
template <int mmq_y>
|
44
|
+
static __dpct_inline__ void
|
45
|
+
allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
46
|
+
int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
|
47
|
+
(void)x_qh; (void)x_sc;
|
48
|
+
|
49
|
+
*x_ql = tile_x_qs_q4_0;
|
50
|
+
*x_dm = (sycl::half2 *)tile_x_d_q4_0;
|
51
|
+
}
|
52
|
+
|
53
|
+
template <int mmq_y, int nwarps, bool need_check>
|
54
|
+
static __dpct_inline__ void
|
55
|
+
load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
|
56
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
57
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
58
|
+
const int &k, const int &blocks_per_row) {
|
59
|
+
(void)x_qh; (void)x_sc;
|
60
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
61
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
62
|
+
GGML_SYCL_ASSUME(k >= 0);
|
63
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
64
|
+
|
65
|
+
const int kbx = k / QI4_0;
|
66
|
+
const int kqsx = k % QI4_0;
|
67
|
+
|
68
|
+
const block_q4_0 * bx0 = (const block_q4_0 *) vx;
|
69
|
+
|
70
|
+
float * x_dmf = (float *) x_dm;
|
71
|
+
|
72
|
+
#pragma unroll
|
73
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
74
|
+
int i = i0 + i_offset;
|
75
|
+
|
76
|
+
if (need_check) {
|
77
|
+
i = sycl::min(i, i_max);
|
78
|
+
}
|
79
|
+
|
80
|
+
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
|
81
|
+
|
82
|
+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
|
83
|
+
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
|
84
|
+
}
|
85
|
+
|
86
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
|
87
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
88
|
+
|
89
|
+
#pragma unroll
|
90
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
|
91
|
+
int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
|
92
|
+
|
93
|
+
if (need_check) {
|
94
|
+
i = sycl::min(i, i_max);
|
95
|
+
}
|
96
|
+
|
97
|
+
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
|
98
|
+
|
99
|
+
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
|
100
|
+
}
|
101
|
+
}
|
102
|
+
|
103
|
+
static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
|
104
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
105
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
106
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
107
|
+
const int &i, const int &j, const int &k) {
|
108
|
+
(void)x_qh; (void)x_sc;
|
109
|
+
|
110
|
+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
111
|
+
const float * x_dmf = (const float *) x_dm;
|
112
|
+
|
113
|
+
int u[2*VDR_Q4_0_Q8_1_MMQ];
|
114
|
+
|
115
|
+
#pragma unroll
|
116
|
+
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
|
117
|
+
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
118
|
+
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
|
119
|
+
}
|
120
|
+
|
121
|
+
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
122
|
+
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
|
123
|
+
y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
|
124
|
+
}
|
125
|
+
|
126
|
+
template <int mmq_y>
|
127
|
+
static __dpct_inline__ void
|
128
|
+
allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
129
|
+
int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
|
130
|
+
(void)x_qh; (void)x_sc;
|
131
|
+
|
132
|
+
*x_ql = tile_x_qs_q4_1;
|
133
|
+
*x_dm = tile_x_dm_q4_1;
|
134
|
+
}
|
135
|
+
|
136
|
+
|
137
|
+
template <int mmq_y, int nwarps, bool need_check>
|
138
|
+
static __dpct_inline__ void
|
139
|
+
load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
|
140
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
141
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
142
|
+
const int &k, const int &blocks_per_row) {
|
143
|
+
(void)x_qh; (void)x_sc;
|
144
|
+
|
145
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
146
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
147
|
+
GGML_SYCL_ASSUME(k >= 0);
|
148
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
149
|
+
|
150
|
+
const int kbx = k / QI4_1;
|
151
|
+
const int kqsx = k % QI4_1;
|
152
|
+
|
153
|
+
const block_q4_1 * bx0 = (const block_q4_1 *) vx;
|
154
|
+
|
155
|
+
#pragma unroll
|
156
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
157
|
+
int i = i0 + i_offset;
|
158
|
+
|
159
|
+
if (need_check) {
|
160
|
+
i = sycl::min(i, i_max);
|
161
|
+
}
|
162
|
+
|
163
|
+
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
|
164
|
+
|
165
|
+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
166
|
+
}
|
167
|
+
|
168
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
|
169
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
170
|
+
|
171
|
+
#pragma unroll
|
172
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
|
173
|
+
int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
|
174
|
+
|
175
|
+
if (need_check) {
|
176
|
+
i = sycl::min(i, i_max);
|
177
|
+
}
|
178
|
+
|
179
|
+
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
|
180
|
+
|
181
|
+
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
|
182
|
+
}
|
183
|
+
}
|
184
|
+
|
185
|
+
static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
|
186
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
187
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
188
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
189
|
+
const int &i, const int &j, const int &k) {
|
190
|
+
(void)x_qh; (void)x_sc;
|
191
|
+
|
192
|
+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
193
|
+
|
194
|
+
int u[2*VDR_Q4_1_Q8_1_MMQ];
|
195
|
+
|
196
|
+
#pragma unroll
|
197
|
+
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
|
198
|
+
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
199
|
+
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
|
200
|
+
}
|
201
|
+
|
202
|
+
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
203
|
+
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
|
204
|
+
y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
|
205
|
+
}
|
206
|
+
|
207
|
+
template <int mmq_y>
|
208
|
+
static __dpct_inline__ void
|
209
|
+
allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
210
|
+
int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
|
211
|
+
(void)x_qh; (void)x_sc;
|
212
|
+
|
213
|
+
*x_ql = tile_x_ql_q5_0;
|
214
|
+
*x_dm = (sycl::half2 *)tile_x_d_q5_0;
|
215
|
+
}
|
216
|
+
|
217
|
+
template <int mmq_y, int nwarps, bool need_check>
|
218
|
+
static __dpct_inline__ void
|
219
|
+
load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
|
220
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
221
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
222
|
+
const int &k, const int &blocks_per_row) {
|
223
|
+
(void)x_qh; (void)x_sc;
|
224
|
+
|
225
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
226
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
227
|
+
GGML_SYCL_ASSUME(k >= 0);
|
228
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
229
|
+
|
230
|
+
const int kbx = k / QI5_0;
|
231
|
+
const int kqsx = k % QI5_0;
|
232
|
+
|
233
|
+
const block_q5_0 * bx0 = (const block_q5_0 *) vx;
|
234
|
+
|
235
|
+
#pragma unroll
|
236
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
237
|
+
int i = i0 + i_offset;
|
238
|
+
|
239
|
+
if (need_check) {
|
240
|
+
i = sycl::min(i, i_max);
|
241
|
+
}
|
242
|
+
|
243
|
+
const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
|
244
|
+
|
245
|
+
const int ql = get_int_from_uint8(bxi->qs, kqsx);
|
246
|
+
const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
|
247
|
+
|
248
|
+
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
249
|
+
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
250
|
+
qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
|
251
|
+
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
|
252
|
+
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
|
253
|
+
qs0 = dpct::vectorized_binary<sycl::char4>(
|
254
|
+
qs0, 0x10101010, dpct::sub_sat()); // subtract 16
|
255
|
+
|
256
|
+
x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
|
257
|
+
|
258
|
+
int qs1 = (ql >> 4) & 0x0F0F0F0F;
|
259
|
+
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
|
260
|
+
qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
|
261
|
+
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
262
|
+
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
263
|
+
qs1 = dpct::vectorized_binary<sycl::char4>(
|
264
|
+
qs1, 0x10101010, dpct::sub_sat()); // subtract 16
|
265
|
+
|
266
|
+
x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
|
267
|
+
}
|
268
|
+
|
269
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
|
270
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
271
|
+
float * x_dmf = (float *) x_dm;
|
272
|
+
|
273
|
+
#pragma unroll
|
274
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
|
275
|
+
int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
|
276
|
+
|
277
|
+
if (need_check) {
|
278
|
+
i = sycl::min(i, i_max);
|
279
|
+
}
|
280
|
+
|
281
|
+
const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
|
282
|
+
|
283
|
+
x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
|
284
|
+
}
|
285
|
+
}
|
286
|
+
|
287
|
+
static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
|
288
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
289
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
290
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
291
|
+
const int &i, const int &j, const int &k) {
|
292
|
+
(void)x_qh; (void)x_sc;
|
293
|
+
|
294
|
+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
295
|
+
const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
|
296
|
+
const float * x_dmf = (const float *) x_dm;
|
297
|
+
const float * y_df = (const float *) y_ds;
|
298
|
+
|
299
|
+
int u[2*VDR_Q5_0_Q8_1_MMQ];
|
300
|
+
|
301
|
+
#pragma unroll
|
302
|
+
for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
|
303
|
+
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
304
|
+
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
|
305
|
+
}
|
306
|
+
|
307
|
+
return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
308
|
+
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
|
309
|
+
}
|
310
|
+
|
311
|
+
template <int mmq_y>
|
312
|
+
static __dpct_inline__ void
|
313
|
+
allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
314
|
+
int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
|
315
|
+
(void)x_qh; (void)x_sc;
|
316
|
+
|
317
|
+
*x_ql = tile_x_ql_q5_1;
|
318
|
+
*x_dm = tile_x_dm_q5_1;
|
319
|
+
}
|
320
|
+
|
321
|
+
template <int mmq_y, int nwarps, bool need_check>
|
322
|
+
static __dpct_inline__ void
|
323
|
+
load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
|
324
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
325
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
326
|
+
const int &k, const int &blocks_per_row) {
|
327
|
+
(void)x_qh; (void)x_sc;
|
328
|
+
|
329
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
330
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
331
|
+
GGML_SYCL_ASSUME(k >= 0);
|
332
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
333
|
+
|
334
|
+
const int kbx = k / QI5_1;
|
335
|
+
const int kqsx = k % QI5_1;
|
336
|
+
|
337
|
+
const block_q5_1 * bx0 = (const block_q5_1 *) vx;
|
338
|
+
|
339
|
+
#pragma unroll
|
340
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
341
|
+
int i = i0 + i_offset;
|
342
|
+
|
343
|
+
if (need_check) {
|
344
|
+
i = sycl::min(i, i_max);
|
345
|
+
}
|
346
|
+
|
347
|
+
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
|
348
|
+
|
349
|
+
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
350
|
+
const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
|
351
|
+
|
352
|
+
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
353
|
+
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
354
|
+
qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
|
355
|
+
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
|
356
|
+
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
|
357
|
+
|
358
|
+
x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
|
359
|
+
|
360
|
+
int qs1 = (ql >> 4) & 0x0F0F0F0F;
|
361
|
+
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
|
362
|
+
qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
|
363
|
+
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
364
|
+
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
365
|
+
|
366
|
+
x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
|
367
|
+
}
|
368
|
+
|
369
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
|
370
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
371
|
+
|
372
|
+
#pragma unroll
|
373
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
|
374
|
+
int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
|
375
|
+
|
376
|
+
if (need_check) {
|
377
|
+
i = sycl::min(i, i_max);
|
378
|
+
}
|
379
|
+
|
380
|
+
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
|
381
|
+
|
382
|
+
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
|
383
|
+
}
|
384
|
+
}
|
385
|
+
|
386
|
+
static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
|
387
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
388
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
389
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
390
|
+
const int &i, const int &j, const int &k) {
|
391
|
+
(void)x_qh; (void)x_sc;
|
392
|
+
|
393
|
+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
394
|
+
const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
|
395
|
+
|
396
|
+
int u[2*VDR_Q5_1_Q8_1_MMQ];
|
397
|
+
|
398
|
+
#pragma unroll
|
399
|
+
for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
|
400
|
+
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
|
401
|
+
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
|
402
|
+
}
|
403
|
+
|
404
|
+
return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
405
|
+
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
|
406
|
+
}
|
407
|
+
|
408
|
+
template <int mmq_y>
|
409
|
+
static __dpct_inline__ void
|
410
|
+
allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
411
|
+
int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
|
412
|
+
(void)x_qh; (void)x_sc;
|
413
|
+
|
414
|
+
*x_ql = tile_x_qs_q8_0;
|
415
|
+
*x_dm = (sycl::half2 *)tile_x_d_q8_0;
|
416
|
+
}
|
417
|
+
|
418
|
+
template <int mmq_y, int nwarps, bool need_check>
|
419
|
+
static __dpct_inline__ void
|
420
|
+
load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
|
421
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
422
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
423
|
+
const int &k, const int &blocks_per_row) {
|
424
|
+
(void)x_qh; (void)x_sc;
|
425
|
+
|
426
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
427
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
428
|
+
GGML_SYCL_ASSUME(k >= 0);
|
429
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
430
|
+
|
431
|
+
const int kbx = k / QI8_0;
|
432
|
+
const int kqsx = k % QI8_0;
|
433
|
+
float * x_dmf = (float *) x_dm;
|
434
|
+
|
435
|
+
const block_q8_0 * bx0 = (const block_q8_0 *) vx;
|
436
|
+
|
437
|
+
#pragma unroll
|
438
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
439
|
+
int i = i0 + i_offset;
|
440
|
+
|
441
|
+
if (need_check) {
|
442
|
+
i = sycl::min(i, i_max);
|
443
|
+
}
|
444
|
+
|
445
|
+
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
|
446
|
+
|
447
|
+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
|
448
|
+
}
|
449
|
+
|
450
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
|
451
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
452
|
+
|
453
|
+
#pragma unroll
|
454
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
|
455
|
+
int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
|
456
|
+
|
457
|
+
if (need_check) {
|
458
|
+
i = sycl::min(i, i_max);
|
459
|
+
}
|
460
|
+
|
461
|
+
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
|
462
|
+
|
463
|
+
x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
|
464
|
+
}
|
465
|
+
}
|
466
|
+
|
467
|
+
static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
|
468
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
469
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
470
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
471
|
+
const int &i, const int &j, const int &k) {
|
472
|
+
(void)x_qh; (void)x_sc;
|
473
|
+
|
474
|
+
const float * x_dmf = (const float *) x_dm;
|
475
|
+
const float * y_df = (const float *) y_ds;
|
476
|
+
|
477
|
+
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
|
478
|
+
(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
|
479
|
+
y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
|
480
|
+
}
|
481
|
+
|
482
|
+
template <int mmq_y>
|
483
|
+
static __dpct_inline__ void
|
484
|
+
allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
485
|
+
int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
|
486
|
+
int *tile_x_sc_q2_K) {
|
487
|
+
(void)x_qh;
|
488
|
+
|
489
|
+
*x_ql = tile_x_ql_q2_K;
|
490
|
+
*x_dm = tile_x_dm_q2_K;
|
491
|
+
*x_sc = tile_x_sc_q2_K;
|
492
|
+
}
|
493
|
+
|
494
|
+
template <int mmq_y, int nwarps, bool need_check>
|
495
|
+
static __dpct_inline__ void
|
496
|
+
load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
497
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
498
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
499
|
+
const int &k, const int &blocks_per_row) {
|
500
|
+
(void)x_qh;
|
501
|
+
|
502
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
503
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
504
|
+
GGML_SYCL_ASSUME(k >= 0);
|
505
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
506
|
+
|
507
|
+
const int kbx = k / QI2_K;
|
508
|
+
const int kqsx = k % QI2_K;
|
509
|
+
|
510
|
+
const block_q2_K * bx0 = (const block_q2_K *) vx;
|
511
|
+
|
512
|
+
#pragma unroll
|
513
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
514
|
+
int i = i0 + i_offset;
|
515
|
+
|
516
|
+
if (need_check) {
|
517
|
+
i = sycl::min(i, i_max);
|
518
|
+
}
|
519
|
+
|
520
|
+
const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
|
521
|
+
|
522
|
+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
523
|
+
}
|
524
|
+
|
525
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
|
526
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
527
|
+
|
528
|
+
#pragma unroll
|
529
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
|
530
|
+
int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
|
531
|
+
|
532
|
+
if (need_check) {
|
533
|
+
i = sycl::min(i, i_max);
|
534
|
+
}
|
535
|
+
|
536
|
+
const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
537
|
+
|
538
|
+
x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
|
539
|
+
}
|
540
|
+
|
541
|
+
#pragma unroll
|
542
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
543
|
+
int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
|
544
|
+
|
545
|
+
if (need_check) {
|
546
|
+
i = sycl::min(i, i_max);
|
547
|
+
}
|
548
|
+
|
549
|
+
const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
|
550
|
+
|
551
|
+
x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
|
552
|
+
}
|
553
|
+
}
|
554
|
+
|
555
|
+
#define VDR_Q2_K_Q8_1_MMQ 2
|
556
|
+
// contiguous u/y values
|
557
|
+
static __dpct_inline__ float
|
558
|
+
vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
|
559
|
+
const uint8_t *__restrict__ scales,
|
560
|
+
const sycl::half2 &dm2, const float &d8) {
|
561
|
+
|
562
|
+
int sumi_d = 0;
|
563
|
+
int sumi_m = 0;
|
564
|
+
|
565
|
+
#pragma unroll
|
566
|
+
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
|
567
|
+
int sumi_d_sc = 0;
|
568
|
+
|
569
|
+
const int sc = scales[i0 / (QI8_1/2)];
|
570
|
+
|
571
|
+
// fill int with 4x m
|
572
|
+
int m = sc >> 4;
|
573
|
+
m |= m << 8;
|
574
|
+
m |= m << 16;
|
575
|
+
|
576
|
+
#pragma unroll
|
577
|
+
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
578
|
+
sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
|
579
|
+
sumi_m = dpct::dp4a(m, u[i],
|
580
|
+
sumi_m); // multiply sum of q8_1 values with m
|
581
|
+
}
|
582
|
+
|
583
|
+
sumi_d += sumi_d_sc * (sc & 0xF);
|
584
|
+
}
|
585
|
+
|
586
|
+
const sycl::float2 dm2f =
|
587
|
+
dm2.convert<float, sycl::rounding_mode::automatic>();
|
588
|
+
|
589
|
+
return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
|
590
|
+
}
|
591
|
+
|
592
|
+
static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
|
593
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
594
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
595
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
596
|
+
const int &i, const int &j, const int &k) {
|
597
|
+
(void)x_qh;
|
598
|
+
|
599
|
+
const int kbx = k / QI2_K;
|
600
|
+
const int ky = (k % QI2_K) * QR2_K;
|
601
|
+
const float * y_df = (const float *) y_ds;
|
602
|
+
|
603
|
+
int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
|
604
|
+
|
605
|
+
const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
|
606
|
+
const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
|
607
|
+
|
608
|
+
#pragma unroll
|
609
|
+
for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
|
610
|
+
v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
611
|
+
}
|
612
|
+
|
613
|
+
const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
|
614
|
+
|
615
|
+
const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
|
616
|
+
return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
|
617
|
+
}
|
618
|
+
|
619
|
+
template <int mmq_y>
|
620
|
+
static __dpct_inline__ void
|
621
|
+
allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
622
|
+
int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
|
623
|
+
int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
|
624
|
+
|
625
|
+
*x_ql = tile_x_ql_q3_K;
|
626
|
+
*x_dm = tile_x_dm_q3_K;
|
627
|
+
*x_qh = tile_x_qh_q3_K;
|
628
|
+
*x_sc = tile_x_sc_q3_K;
|
629
|
+
}
|
630
|
+
|
631
|
+
template <int mmq_y, int nwarps, bool need_check>
|
632
|
+
static __dpct_inline__ void
|
633
|
+
load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
634
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
635
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
636
|
+
const int &k, const int &blocks_per_row) {
|
637
|
+
|
638
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
639
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
640
|
+
GGML_SYCL_ASSUME(k >= 0);
|
641
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
642
|
+
|
643
|
+
const int kbx = k / QI3_K;
|
644
|
+
const int kqsx = k % QI3_K;
|
645
|
+
|
646
|
+
const block_q3_K * bx0 = (const block_q3_K *) vx;
|
647
|
+
|
648
|
+
#pragma unroll
|
649
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
650
|
+
int i = i0 + i_offset;
|
651
|
+
|
652
|
+
if (need_check) {
|
653
|
+
i = sycl::min(i, i_max);
|
654
|
+
}
|
655
|
+
|
656
|
+
const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
|
657
|
+
|
658
|
+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
|
659
|
+
}
|
660
|
+
|
661
|
+
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
|
662
|
+
const int kbxd = k % blocks_per_tile_x_row;
|
663
|
+
float * x_dmf = (float *) x_dm;
|
664
|
+
|
665
|
+
#pragma unroll
|
666
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
|
667
|
+
int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
|
668
|
+
|
669
|
+
if (need_check) {
|
670
|
+
i = sycl::min(i, i_max);
|
671
|
+
}
|
672
|
+
|
673
|
+
const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
674
|
+
|
675
|
+
x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
|
676
|
+
}
|
677
|
+
|
678
|
+
#pragma unroll
|
679
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
|
680
|
+
int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
|
681
|
+
|
682
|
+
if (need_check) {
|
683
|
+
i = sycl::min(i, i_max);
|
684
|
+
}
|
685
|
+
|
686
|
+
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
|
687
|
+
|
688
|
+
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
|
689
|
+
x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
|
690
|
+
}
|
691
|
+
|
692
|
+
#pragma unroll
|
693
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
694
|
+
int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
|
695
|
+
|
696
|
+
if (need_check) {
|
697
|
+
i = sycl::min(i, i_max);
|
698
|
+
}
|
699
|
+
|
700
|
+
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
|
701
|
+
|
702
|
+
const int ksc = k % (QI3_K/4);
|
703
|
+
|
704
|
+
const int ksc_low = ksc % (QI3_K/8);
|
705
|
+
const int shift_low = 4 * (ksc / (QI3_K/8));
|
706
|
+
const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
|
707
|
+
|
708
|
+
const int ksc_high = QI3_K/8;
|
709
|
+
const int shift_high = 2 * ksc;
|
710
|
+
const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
|
711
|
+
|
712
|
+
const int sc = dpct::vectorized_binary<sycl::char4>(
|
713
|
+
sc_low | sc_high, 0x20202020, dpct::sub_sat());
|
714
|
+
|
715
|
+
x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
|
716
|
+
}
|
717
|
+
}
|
718
|
+
|
719
|
+
#define VDR_Q3_K_Q8_1_MMQ 2
|
720
|
+
// contiguous u/y values
|
721
|
+
static __dpct_inline__ float
|
722
|
+
vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
|
723
|
+
const int8_t *__restrict__ scales, const float &d3,
|
724
|
+
const float &d8) {
|
725
|
+
|
726
|
+
int sumi = 0;
|
727
|
+
|
728
|
+
#pragma unroll
|
729
|
+
for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
|
730
|
+
int sumi_sc = 0;
|
731
|
+
|
732
|
+
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
733
|
+
sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
|
734
|
+
}
|
735
|
+
|
736
|
+
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
737
|
+
}
|
738
|
+
|
739
|
+
return d3*d8 * sumi;
|
740
|
+
}
|
741
|
+
|
742
|
+
static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
|
743
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
744
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
745
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
746
|
+
const int &i, const int &j, const int &k) {
|
747
|
+
|
748
|
+
const int kbx = k / QI3_K;
|
749
|
+
const int ky = (k % QI3_K) * QR3_K;
|
750
|
+
const float * x_dmf = (const float *) x_dm;
|
751
|
+
const float * y_df = (const float *) y_ds;
|
752
|
+
|
753
|
+
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
|
754
|
+
|
755
|
+
int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
|
756
|
+
|
757
|
+
#pragma unroll
|
758
|
+
for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
|
759
|
+
const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
|
760
|
+
const int shift = 2 * ((ky % 32) / 8);
|
761
|
+
const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
762
|
+
|
763
|
+
const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
|
764
|
+
const int vlh = (vh << 2) & 0x04040404;
|
765
|
+
|
766
|
+
v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
|
767
|
+
}
|
768
|
+
|
769
|
+
const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
|
770
|
+
return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
|
771
|
+
}
|
772
|
+
|
773
|
+
template <int mmq_y>
|
774
|
+
static __dpct_inline__ void
|
775
|
+
allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
776
|
+
int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
|
777
|
+
int *tile_x_sc_q4_K) {
|
778
|
+
(void)x_qh;
|
779
|
+
|
780
|
+
*x_ql = tile_x_ql_q4_K;
|
781
|
+
*x_dm = tile_x_dm_q4_K;
|
782
|
+
*x_sc = tile_x_sc_q4_K;
|
783
|
+
}
|
784
|
+
|
785
|
+
template <int mmq_y, int nwarps, bool need_check>
|
786
|
+
static __dpct_inline__ void
|
787
|
+
load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
788
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
789
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
790
|
+
const int &k, const int &blocks_per_row) {
|
791
|
+
(void)x_qh;
|
792
|
+
|
793
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
794
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
795
|
+
GGML_SYCL_ASSUME(k >= 0);
|
796
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
797
|
+
|
798
|
+
const int kbx = k / QI4_K; // == 0 if QK_K == 256
|
799
|
+
const int kqsx = k % QI4_K; // == k if QK_K == 256
|
800
|
+
|
801
|
+
const block_q4_K * bx0 = (const block_q4_K *) vx;
|
802
|
+
|
803
|
+
#pragma unroll
|
804
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
805
|
+
int i = i0 + i_offset;
|
806
|
+
|
807
|
+
if (need_check) {
|
808
|
+
i = sycl::min(i, i_max);
|
809
|
+
}
|
810
|
+
|
811
|
+
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
|
812
|
+
|
813
|
+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
814
|
+
}
|
815
|
+
|
816
|
+
constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
|
817
|
+
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
818
|
+
|
819
|
+
#pragma unroll
|
820
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
|
821
|
+
int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
|
822
|
+
|
823
|
+
if (need_check) {
|
824
|
+
i = sycl::min(i, i_max);
|
825
|
+
}
|
826
|
+
|
827
|
+
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
828
|
+
|
829
|
+
#if QK_K == 256
|
830
|
+
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
|
831
|
+
#else
|
832
|
+
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
|
833
|
+
#endif
|
834
|
+
}
|
835
|
+
|
836
|
+
#pragma unroll
|
837
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
|
838
|
+
int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
|
839
|
+
|
840
|
+
if (need_check) {
|
841
|
+
i = sycl::min(i, i_max);
|
842
|
+
}
|
843
|
+
|
844
|
+
const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
|
845
|
+
|
846
|
+
const int * scales = (const int *) bxi->scales;
|
847
|
+
|
848
|
+
const int ksc = k % (WARP_SIZE/8);
|
849
|
+
|
850
|
+
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
|
851
|
+
int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
|
852
|
+
scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
|
853
|
+
|
854
|
+
x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
|
855
|
+
}
|
856
|
+
}
|
857
|
+
|
858
|
+
|
859
|
+
#define VDR_Q4_K_Q8_1_MMQ 8
|
860
|
+
|
861
|
+
// contiguous u/y values
|
862
|
+
static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
863
|
+
const int *__restrict__ v, const int *__restrict__ u,
|
864
|
+
const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
|
865
|
+
const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
|
866
|
+
|
867
|
+
float sumf_d = 0.0f;
|
868
|
+
float sumf_m = 0.0f;
|
869
|
+
|
870
|
+
#pragma unroll
|
871
|
+
for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
|
872
|
+
int sumi_d = 0;
|
873
|
+
|
874
|
+
#pragma unroll
|
875
|
+
for (int j = 0; j < QI8_1; ++j) {
|
876
|
+
sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
|
877
|
+
u[i * QI8_1 + j], sumi_d); // SIMD dot product
|
878
|
+
}
|
879
|
+
|
880
|
+
const sycl::float2 ds8f =
|
881
|
+
ds8[i].convert<float, sycl::rounding_mode::automatic>();
|
882
|
+
|
883
|
+
sumf_d += ds8f.x() * (sc[i] * sumi_d);
|
884
|
+
sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
|
885
|
+
}
|
886
|
+
|
887
|
+
const sycl::float2 dm4f =
|
888
|
+
dm4.convert<float, sycl::rounding_mode::automatic>();
|
889
|
+
|
890
|
+
return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
|
891
|
+
}
|
892
|
+
|
893
|
+
|
894
|
+
static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
|
895
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
896
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
897
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
898
|
+
const int &i, const int &j, const int &k) {
|
899
|
+
(void)x_qh;
|
900
|
+
|
901
|
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
|
902
|
+
|
903
|
+
const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
|
904
|
+
return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
|
905
|
+
x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
|
906
|
+
}
|
907
|
+
|
908
|
+
template <int mmq_y>
|
909
|
+
static __dpct_inline__ void
|
910
|
+
allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
911
|
+
int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
|
912
|
+
int *tile_x_sc_q5_K) {
|
913
|
+
(void)x_qh;
|
914
|
+
|
915
|
+
*x_ql = tile_x_ql_q5_K;
|
916
|
+
*x_dm = tile_x_dm_q5_K;
|
917
|
+
*x_sc = tile_x_sc_q5_K;
|
918
|
+
}
|
919
|
+
|
920
|
+
template <int mmq_y, int nwarps, bool need_check>
|
921
|
+
static __dpct_inline__ void
|
922
|
+
load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
923
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
924
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
925
|
+
const int &k, const int &blocks_per_row) {
|
926
|
+
(void)x_qh;
|
927
|
+
|
928
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
929
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
930
|
+
GGML_SYCL_ASSUME(k >= 0);
|
931
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
932
|
+
|
933
|
+
const int kbx = k / QI5_K; // == 0 if QK_K == 256
|
934
|
+
const int kqsx = k % QI5_K; // == k if QK_K == 256
|
935
|
+
|
936
|
+
const block_q5_K * bx0 = (const block_q5_K *) vx;
|
937
|
+
|
938
|
+
#pragma unroll
|
939
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
940
|
+
int i = i0 + i_offset;
|
941
|
+
|
942
|
+
if (need_check) {
|
943
|
+
i = sycl::min(i, i_max);
|
944
|
+
}
|
945
|
+
|
946
|
+
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
|
947
|
+
const int ky = QR5_K*kqsx;
|
948
|
+
|
949
|
+
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
950
|
+
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
951
|
+
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
952
|
+
|
953
|
+
const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
|
954
|
+
const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
955
|
+
const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
956
|
+
|
957
|
+
const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
|
958
|
+
const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
|
959
|
+
|
960
|
+
x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
|
961
|
+
x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
|
962
|
+
}
|
963
|
+
|
964
|
+
constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
|
965
|
+
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
966
|
+
|
967
|
+
#pragma unroll
|
968
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
|
969
|
+
int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
|
970
|
+
|
971
|
+
if (need_check) {
|
972
|
+
i = sycl::min(i, i_max);
|
973
|
+
}
|
974
|
+
|
975
|
+
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
976
|
+
|
977
|
+
#if QK_K == 256
|
978
|
+
x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
|
979
|
+
#endif
|
980
|
+
}
|
981
|
+
|
982
|
+
#pragma unroll
|
983
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
|
984
|
+
int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
|
985
|
+
|
986
|
+
if (need_check) {
|
987
|
+
i = sycl::min(i, i_max);
|
988
|
+
}
|
989
|
+
|
990
|
+
const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
|
991
|
+
|
992
|
+
const int * scales = (const int *) bxi->scales;
|
993
|
+
|
994
|
+
const int ksc = k % (WARP_SIZE/8);
|
995
|
+
|
996
|
+
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
|
997
|
+
int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
|
998
|
+
scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
|
999
|
+
|
1000
|
+
x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
|
1001
|
+
}
|
1002
|
+
}
|
1003
|
+
|
1004
|
+
#define VDR_Q5_K_Q8_1_MMQ 8
|
1005
|
+
|
1006
|
+
// contiguous u/y values
|
1007
|
+
static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
1008
|
+
const int *__restrict__ v, const int *__restrict__ u,
|
1009
|
+
const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
|
1010
|
+
const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
|
1011
|
+
|
1012
|
+
float sumf_d = 0.0f;
|
1013
|
+
float sumf_m = 0.0f;
|
1014
|
+
|
1015
|
+
#pragma unroll
|
1016
|
+
for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
|
1017
|
+
int sumi_d = 0;
|
1018
|
+
|
1019
|
+
#pragma unroll
|
1020
|
+
for (int j = 0; j < QI8_1; ++j) {
|
1021
|
+
sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
|
1022
|
+
sumi_d); // SIMD dot product
|
1023
|
+
}
|
1024
|
+
|
1025
|
+
const sycl::float2 ds8f =
|
1026
|
+
ds8[i].convert<float, sycl::rounding_mode::automatic>();
|
1027
|
+
|
1028
|
+
sumf_d += ds8f.x() * (sc[i] * sumi_d);
|
1029
|
+
sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
|
1030
|
+
}
|
1031
|
+
|
1032
|
+
const sycl::float2 dm4f =
|
1033
|
+
dm4.convert<float, sycl::rounding_mode::automatic>();
|
1034
|
+
|
1035
|
+
return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
|
1036
|
+
}
|
1037
|
+
|
1038
|
+
static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
|
1039
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
1040
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
1041
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
1042
|
+
const int &i, const int &j, const int &k) {
|
1043
|
+
(void)x_qh;
|
1044
|
+
|
1045
|
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
|
1046
|
+
|
1047
|
+
const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
|
1048
|
+
const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
|
1049
|
+
return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
|
1050
|
+
x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
|
1051
|
+
}
|
1052
|
+
|
1053
|
+
template <int mmq_y>
|
1054
|
+
static __dpct_inline__ void
|
1055
|
+
allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
|
1056
|
+
int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
|
1057
|
+
(void)x_qh;
|
1058
|
+
|
1059
|
+
*x_ql = tile_x_ql;
|
1060
|
+
*x_dm = tile_x_dm;
|
1061
|
+
*x_sc = tile_x_sc;
|
1062
|
+
}
|
1063
|
+
|
1064
|
+
template <int mmq_y, int nwarps, bool need_check>
|
1065
|
+
static __dpct_inline__ void
|
1066
|
+
load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
1067
|
+
sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
|
1068
|
+
int *__restrict__ x_sc, const int &i_offset, const int &i_max,
|
1069
|
+
const int &k, const int &blocks_per_row) {
|
1070
|
+
(void)x_qh;
|
1071
|
+
|
1072
|
+
GGML_SYCL_ASSUME(i_offset >= 0);
|
1073
|
+
GGML_SYCL_ASSUME(i_offset < nwarps);
|
1074
|
+
GGML_SYCL_ASSUME(k >= 0);
|
1075
|
+
GGML_SYCL_ASSUME(k < WARP_SIZE);
|
1076
|
+
|
1077
|
+
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
1078
|
+
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
1079
|
+
|
1080
|
+
const block_q6_K * bx0 = (const block_q6_K *) vx;
|
1081
|
+
|
1082
|
+
#pragma unroll
|
1083
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
1084
|
+
int i = i0 + i_offset;
|
1085
|
+
|
1086
|
+
if (need_check) {
|
1087
|
+
i = sycl::min(i, i_max);
|
1088
|
+
}
|
1089
|
+
|
1090
|
+
const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
|
1091
|
+
const int ky = QR6_K*kqsx;
|
1092
|
+
|
1093
|
+
const int ql = get_int_from_uint8(bxi->ql, kqsx);
|
1094
|
+
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
1095
|
+
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
1096
|
+
|
1097
|
+
const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
|
1098
|
+
const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
|
1099
|
+
const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
|
1100
|
+
|
1101
|
+
const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
|
1102
|
+
const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
|
1103
|
+
|
1104
|
+
x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
|
1105
|
+
dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
|
1106
|
+
dpct::sub_sat());
|
1107
|
+
x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
|
1108
|
+
dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
|
1109
|
+
dpct::sub_sat());
|
1110
|
+
}
|
1111
|
+
|
1112
|
+
constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
1113
|
+
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
1114
|
+
float * x_dmf = (float *) x_dm;
|
1115
|
+
|
1116
|
+
#pragma unroll
|
1117
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
|
1118
|
+
int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
|
1119
|
+
|
1120
|
+
if (need_check) {
|
1121
|
+
i = sycl::min(i, i_max);
|
1122
|
+
}
|
1123
|
+
|
1124
|
+
const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
1125
|
+
|
1126
|
+
x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
|
1127
|
+
}
|
1128
|
+
|
1129
|
+
#pragma unroll
|
1130
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
|
1131
|
+
int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
|
1132
|
+
|
1133
|
+
if (need_check) {
|
1134
|
+
i = sycl::min(i, i_max);
|
1135
|
+
}
|
1136
|
+
|
1137
|
+
const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
|
1138
|
+
|
1139
|
+
x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
|
1140
|
+
}
|
1141
|
+
}
|
1142
|
+
|
1143
|
+
#define VDR_Q6_K_Q8_1_MMQ 8
|
1144
|
+
|
1145
|
+
// contiguous u/y values
|
1146
|
+
static __dpct_inline__ float
|
1147
|
+
vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
|
1148
|
+
const int8_t *__restrict__ sc, const float &d6,
|
1149
|
+
const float *__restrict__ d8) {
|
1150
|
+
|
1151
|
+
float sumf_d = 0.0f;
|
1152
|
+
|
1153
|
+
#pragma unroll
|
1154
|
+
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
1155
|
+
sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
1156
|
+
|
1157
|
+
#pragma unroll
|
1158
|
+
for (int i = i0; i < i0 + 2; ++i) {
|
1159
|
+
sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
|
1160
|
+
sumi_d.x()); // SIMD dot product
|
1161
|
+
sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
|
1162
|
+
sumi_d.x()); // SIMD dot product
|
1163
|
+
|
1164
|
+
sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
|
1165
|
+
sumi_d.y()); // SIMD dot product
|
1166
|
+
sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
|
1167
|
+
sumi_d.y()); // SIMD dot product
|
1168
|
+
}
|
1169
|
+
|
1170
|
+
sumf_d += d8[i0 / 4] *
|
1171
|
+
(sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
|
1172
|
+
}
|
1173
|
+
|
1174
|
+
return d6 * sumf_d;
|
1175
|
+
}
|
1176
|
+
|
1177
|
+
static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
|
1178
|
+
const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
|
1179
|
+
const int *__restrict__ x_qh, const int *__restrict__ x_sc,
|
1180
|
+
const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
|
1181
|
+
const int &i, const int &j, const int &k) {
|
1182
|
+
(void)x_qh;
|
1183
|
+
|
1184
|
+
const float * x_dmf = (const float *) x_dm;
|
1185
|
+
const float * y_df = (const float *) y_ds;
|
1186
|
+
|
1187
|
+
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
|
1188
|
+
|
1189
|
+
const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
|
1190
|
+
const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
|
1191
|
+
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
1192
|
+
}
|
1193
|
+
|
1194
|
+
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
|
1195
|
+
int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
|
1196
|
+
vec_dot_q_mul_mat_sycl_t vec_dot>
|
1197
|
+
/*
|
1198
|
+
DPCT1110:8: The total declared local variable size in device function mul_mat_q
|
1199
|
+
exceeds 128 bytes and may cause high register pressure. Consult with your
|
1200
|
+
hardware vendor to find the total register size available and adjust the code,
|
1201
|
+
or use smaller sub-group size to avoid high register pressure.
|
1202
|
+
*/
|
1203
|
+
static __dpct_inline__ void
|
1204
|
+
mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
|
1205
|
+
float *__restrict__ dst, const int ncols_x, const int nrows_x,
|
1206
|
+
const int ncols_y, const int nrows_y, const int nrows_dst,
|
1207
|
+
int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
|
1208
|
+
int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
|
1209
|
+
sycl::half2 *tile_y_ds) {
|
1210
|
+
|
1211
|
+
const block_q_t * x = (const block_q_t *) vx;
|
1212
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
1213
|
+
|
1214
|
+
const int blocks_per_row_x = ncols_x / qk;
|
1215
|
+
const int blocks_per_col_y = nrows_y / QK8_1;
|
1216
|
+
const int blocks_per_warp = WARP_SIZE / qi;
|
1217
|
+
|
1218
|
+
const int & ncols_dst = ncols_y;
|
1219
|
+
|
1220
|
+
const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
|
1221
|
+
const int & row_x_0 = row_dst_0;
|
1222
|
+
|
1223
|
+
const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
|
1224
|
+
const int & col_y_0 = col_dst_0;
|
1225
|
+
|
1226
|
+
float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
|
1227
|
+
|
1228
|
+
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
1229
|
+
|
1230
|
+
load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
|
1231
|
+
tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
|
1232
|
+
nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
|
1233
|
+
blocks_per_row_x);
|
1234
|
+
|
1235
|
+
#pragma unroll
|
1236
|
+
for (int ir = 0; ir < qr; ++ir) {
|
1237
|
+
const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
|
1238
|
+
const int kbxd = kqs / QI8_1;
|
1239
|
+
|
1240
|
+
#pragma unroll
|
1241
|
+
for (int i = 0; i < mmq_x; i += nwarps) {
|
1242
|
+
const int col_y_eff = dpct::min(
|
1243
|
+
(unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
|
1244
|
+
ncols_y - 1); // to prevent out-of-bounds memory accesses
|
1245
|
+
|
1246
|
+
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
|
1247
|
+
|
1248
|
+
const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
|
1249
|
+
kqs % WARP_SIZE;
|
1250
|
+
tile_y_qs[index_y] = get_int_from_int8_aligned(
|
1251
|
+
by0->qs, item_ct1.get_local_id(2) % QI8_1);
|
1252
|
+
}
|
1253
|
+
|
1254
|
+
#pragma unroll
|
1255
|
+
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
1256
|
+
const int ids =
|
1257
|
+
(ids0 + item_ct1.get_local_id(1) * QI8_1 +
|
1258
|
+
item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
|
1259
|
+
mmq_x;
|
1260
|
+
const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
|
1261
|
+
const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
|
1262
|
+
|
1263
|
+
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
1264
|
+
const sycl::half2 *dsi_src =
|
1265
|
+
&y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
|
1266
|
+
ir * (WARP_SIZE / QI8_1) + kby]
|
1267
|
+
.ds;
|
1268
|
+
sycl::half2 *dsi_dst =
|
1269
|
+
&tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
|
1270
|
+
if (need_sum) {
|
1271
|
+
*dsi_dst = *dsi_src;
|
1272
|
+
} else {
|
1273
|
+
float * dfi_dst = (float *) dsi_dst;
|
1274
|
+
*dfi_dst = (*dsi_src)[0];
|
1275
|
+
}
|
1276
|
+
}
|
1277
|
+
|
1278
|
+
/*
|
1279
|
+
DPCT1118:9: SYCL group functions and algorithms must be encountered
|
1280
|
+
in converged control flow. You may need to adjust the code.
|
1281
|
+
*/
|
1282
|
+
/*
|
1283
|
+
DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
|
1284
|
+
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
1285
|
+
better performance if there is no access to global memory.
|
1286
|
+
*/
|
1287
|
+
item_ct1.barrier();
|
1288
|
+
|
1289
|
+
// #pragma unroll // unrolling this loop causes too much register pressure
|
1290
|
+
for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
|
1291
|
+
#pragma unroll
|
1292
|
+
for (int j = 0; j < mmq_x; j += nwarps) {
|
1293
|
+
#pragma unroll
|
1294
|
+
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
|
1295
|
+
sum[i / WARP_SIZE][j / nwarps] += vec_dot(
|
1296
|
+
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
|
1297
|
+
tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
|
1298
|
+
item_ct1.get_local_id(1) + j, k);
|
1299
|
+
}
|
1300
|
+
}
|
1301
|
+
}
|
1302
|
+
|
1303
|
+
/*
|
1304
|
+
DPCT1118:10: SYCL group functions and algorithms must be encountered
|
1305
|
+
in converged control flow. You may need to adjust the code.
|
1306
|
+
*/
|
1307
|
+
/*
|
1308
|
+
DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
|
1309
|
+
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
1310
|
+
better performance if there is no access to global memory.
|
1311
|
+
*/
|
1312
|
+
item_ct1.barrier();
|
1313
|
+
}
|
1314
|
+
}
|
1315
|
+
|
1316
|
+
#pragma unroll
|
1317
|
+
for (int j = 0; j < mmq_x; j += nwarps) {
|
1318
|
+
const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
|
1319
|
+
|
1320
|
+
if (col_dst >= ncols_dst) {
|
1321
|
+
return;
|
1322
|
+
}
|
1323
|
+
|
1324
|
+
#pragma unroll
|
1325
|
+
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
|
1326
|
+
const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
|
1327
|
+
|
1328
|
+
if (row_dst >= nrows_dst) {
|
1329
|
+
continue;
|
1330
|
+
}
|
1331
|
+
|
1332
|
+
dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
|
1333
|
+
}
|
1334
|
+
}
|
1335
|
+
}
|
1336
|
+
|
1337
|
+
#define MMQ_X_Q4_0_RDNA2 64
|
1338
|
+
#define MMQ_Y_Q4_0_RDNA2 128
|
1339
|
+
#define NWARPS_Q4_0_RDNA2 8
|
1340
|
+
#define MMQ_X_Q4_0_RDNA1 64
|
1341
|
+
#define MMQ_Y_Q4_0_RDNA1 64
|
1342
|
+
#define NWARPS_Q4_0_RDNA1 8
|
1343
|
+
#if defined(SYCL_USE_XMX)
|
1344
|
+
#define MMQ_X_Q4_0_AMPERE 4
|
1345
|
+
#define MMQ_Y_Q4_0_AMPERE 32
|
1346
|
+
#define NWARPS_Q4_0_AMPERE 4
|
1347
|
+
#else
|
1348
|
+
#define MMQ_X_Q4_0_AMPERE 64
|
1349
|
+
#define MMQ_Y_Q4_0_AMPERE 128
|
1350
|
+
#define NWARPS_Q4_0_AMPERE 4
|
1351
|
+
#endif
|
1352
|
+
#define MMQ_X_Q4_0_PASCAL 64
|
1353
|
+
#define MMQ_Y_Q4_0_PASCAL 64
|
1354
|
+
#define NWARPS_Q4_0_PASCAL 8
|
1355
|
+
|
1356
|
+
template <bool need_check> static void
|
1357
|
+
mul_mat_q4_0(
|
1358
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1359
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1360
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
|
1361
|
+
int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1362
|
+
int * tile_x_ql = nullptr;
|
1363
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1364
|
+
int * tile_x_qh = nullptr;
|
1365
|
+
int * tile_x_sc = nullptr;
|
1366
|
+
|
1367
|
+
//sycl_todo: change according to hardware
|
1368
|
+
|
1369
|
+
const int mmq_x = MMQ_X_Q4_0_AMPERE;
|
1370
|
+
const int mmq_y = MMQ_Y_Q4_0_AMPERE;
|
1371
|
+
const int nwarps = NWARPS_Q4_0_AMPERE;
|
1372
|
+
allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1373
|
+
tile_x_qs_q4_0, tile_x_d_q4_0);
|
1374
|
+
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
|
1375
|
+
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
|
1376
|
+
vec_dot_q4_0_q8_1_mul_mat>(
|
1377
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1378
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1379
|
+
}
|
1380
|
+
|
1381
|
+
#define MMQ_X_Q4_1_RDNA2 64
|
1382
|
+
#define MMQ_Y_Q4_1_RDNA2 128
|
1383
|
+
#define NWARPS_Q4_1_RDNA2 8
|
1384
|
+
#define MMQ_X_Q4_1_RDNA1 64
|
1385
|
+
#define MMQ_Y_Q4_1_RDNA1 64
|
1386
|
+
#define NWARPS_Q4_1_RDNA1 8
|
1387
|
+
#if defined(SYCL_USE_XMX)
|
1388
|
+
#define MMQ_X_Q4_1_AMPERE 4
|
1389
|
+
#define MMQ_Y_Q4_1_AMPERE 32
|
1390
|
+
#define NWARPS_Q4_1_AMPERE 4
|
1391
|
+
#else
|
1392
|
+
#define MMQ_X_Q4_1_AMPERE 64
|
1393
|
+
#define MMQ_Y_Q4_1_AMPERE 128
|
1394
|
+
#define NWARPS_Q4_1_AMPERE 4
|
1395
|
+
#endif
|
1396
|
+
#define MMQ_X_Q4_1_PASCAL 64
|
1397
|
+
#define MMQ_Y_Q4_1_PASCAL 64
|
1398
|
+
#define NWARPS_Q4_1_PASCAL 8
|
1399
|
+
|
1400
|
+
template <bool need_check> static void
|
1401
|
+
mul_mat_q4_1(
|
1402
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1403
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1404
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
|
1405
|
+
sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1406
|
+
int * tile_x_ql = nullptr;
|
1407
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1408
|
+
int * tile_x_qh = nullptr;
|
1409
|
+
int * tile_x_sc = nullptr;
|
1410
|
+
|
1411
|
+
//sycl_todo: change according to hardware
|
1412
|
+
const int mmq_x = MMQ_X_Q4_1_AMPERE;
|
1413
|
+
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
|
1414
|
+
const int nwarps = NWARPS_Q4_1_AMPERE;
|
1415
|
+
allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1416
|
+
tile_x_qs_q4_1, tile_x_dm_q4_1);
|
1417
|
+
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
|
1418
|
+
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
|
1419
|
+
vec_dot_q4_1_q8_1_mul_mat>(
|
1420
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1421
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
#define MMQ_X_Q5_0_RDNA2 64
|
1425
|
+
#define MMQ_Y_Q5_0_RDNA2 128
|
1426
|
+
#define NWARPS_Q5_0_RDNA2 8
|
1427
|
+
#define MMQ_X_Q5_0_RDNA1 64
|
1428
|
+
#define MMQ_Y_Q5_0_RDNA1 64
|
1429
|
+
#define NWARPS_Q5_0_RDNA1 8
|
1430
|
+
#if defined(SYCL_USE_XMX)
|
1431
|
+
#define MMQ_X_Q5_0_AMPERE 4
|
1432
|
+
#define MMQ_Y_Q5_0_AMPERE 32
|
1433
|
+
#define NWARPS_Q5_0_AMPERE 4
|
1434
|
+
#else
|
1435
|
+
#define MMQ_X_Q5_0_AMPERE 128
|
1436
|
+
#define MMQ_Y_Q5_0_AMPERE 64
|
1437
|
+
#define NWARPS_Q5_0_AMPERE 4
|
1438
|
+
#endif
|
1439
|
+
#define MMQ_X_Q5_0_PASCAL 64
|
1440
|
+
#define MMQ_Y_Q5_0_PASCAL 64
|
1441
|
+
#define NWARPS_Q5_0_PASCAL 8
|
1442
|
+
|
1443
|
+
template <bool need_check> static void
|
1444
|
+
mul_mat_q5_0(
|
1445
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1446
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1447
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
|
1448
|
+
int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1449
|
+
int * tile_x_ql = nullptr;
|
1450
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1451
|
+
int * tile_x_qh = nullptr;
|
1452
|
+
int * tile_x_sc = nullptr;
|
1453
|
+
|
1454
|
+
//sycl_todo: change according to hardware
|
1455
|
+
const int mmq_x = MMQ_X_Q5_0_AMPERE;
|
1456
|
+
const int mmq_y = MMQ_Y_Q5_0_AMPERE;
|
1457
|
+
const int nwarps = NWARPS_Q5_0_AMPERE;
|
1458
|
+
allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1459
|
+
tile_x_ql_q5_0, tile_x_d_q5_0);
|
1460
|
+
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
|
1461
|
+
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
|
1462
|
+
vec_dot_q5_0_q8_1_mul_mat>(
|
1463
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1464
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1465
|
+
}
|
1466
|
+
|
1467
|
+
#define MMQ_X_Q5_1_RDNA2 64
|
1468
|
+
#define MMQ_Y_Q5_1_RDNA2 128
|
1469
|
+
#define NWARPS_Q5_1_RDNA2 8
|
1470
|
+
#define MMQ_X_Q5_1_RDNA1 64
|
1471
|
+
#define MMQ_Y_Q5_1_RDNA1 64
|
1472
|
+
#define NWARPS_Q5_1_RDNA1 8
|
1473
|
+
#if defined(SYCL_USE_XMX)
|
1474
|
+
#define MMQ_X_Q5_1_AMPERE 4
|
1475
|
+
#define MMQ_Y_Q5_1_AMPERE 32
|
1476
|
+
#define NWARPS_Q5_1_AMPERE 4
|
1477
|
+
#else
|
1478
|
+
#define MMQ_X_Q5_1_AMPERE 128
|
1479
|
+
#define MMQ_Y_Q5_1_AMPERE 64
|
1480
|
+
#define NWARPS_Q5_1_AMPERE 4
|
1481
|
+
#endif
|
1482
|
+
#define MMQ_X_Q5_1_PASCAL 64
|
1483
|
+
#define MMQ_Y_Q5_1_PASCAL 64
|
1484
|
+
#define NWARPS_Q5_1_PASCAL 8
|
1485
|
+
|
1486
|
+
template <bool need_check> static void
|
1487
|
+
mul_mat_q5_1(
|
1488
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1489
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1490
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
|
1491
|
+
sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1492
|
+
int * tile_x_ql = nullptr;
|
1493
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1494
|
+
int * tile_x_qh = nullptr;
|
1495
|
+
int * tile_x_sc = nullptr;
|
1496
|
+
|
1497
|
+
//sycl_todo: change according to hardware
|
1498
|
+
const int mmq_x = MMQ_X_Q5_1_AMPERE;
|
1499
|
+
const int mmq_y = MMQ_Y_Q5_1_AMPERE;
|
1500
|
+
const int nwarps = NWARPS_Q5_1_AMPERE;
|
1501
|
+
allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1502
|
+
tile_x_ql_q5_1, tile_x_dm_q5_1);
|
1503
|
+
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
|
1504
|
+
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
|
1505
|
+
vec_dot_q5_1_q8_1_mul_mat>(
|
1506
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1507
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1508
|
+
}
|
1509
|
+
|
1510
|
+
#define MMQ_X_Q8_0_RDNA2 64
|
1511
|
+
#define MMQ_Y_Q8_0_RDNA2 128
|
1512
|
+
#define NWARPS_Q8_0_RDNA2 8
|
1513
|
+
#define MMQ_X_Q8_0_RDNA1 64
|
1514
|
+
#define MMQ_Y_Q8_0_RDNA1 64
|
1515
|
+
#define NWARPS_Q8_0_RDNA1 8
|
1516
|
+
#if defined(SYCL_USE_XMX)
|
1517
|
+
#define MMQ_X_Q8_0_AMPERE 4
|
1518
|
+
#define MMQ_Y_Q8_0_AMPERE 32
|
1519
|
+
#define NWARPS_Q8_0_AMPERE 4
|
1520
|
+
#else
|
1521
|
+
#define MMQ_X_Q8_0_AMPERE 128
|
1522
|
+
#define MMQ_Y_Q8_0_AMPERE 64
|
1523
|
+
#define NWARPS_Q8_0_AMPERE 4
|
1524
|
+
#endif
|
1525
|
+
#define MMQ_X_Q8_0_PASCAL 64
|
1526
|
+
#define MMQ_Y_Q8_0_PASCAL 64
|
1527
|
+
#define NWARPS_Q8_0_PASCAL 8
|
1528
|
+
|
1529
|
+
template <bool need_check> static void
|
1530
|
+
mul_mat_q8_0(
|
1531
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1532
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1533
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
|
1534
|
+
int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1535
|
+
int * tile_x_ql = nullptr;
|
1536
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1537
|
+
int * tile_x_qh = nullptr;
|
1538
|
+
int * tile_x_sc = nullptr;
|
1539
|
+
|
1540
|
+
//sycl_todo: change according to hardware
|
1541
|
+
const int mmq_x = MMQ_X_Q8_0_AMPERE;
|
1542
|
+
const int mmq_y = MMQ_Y_Q8_0_AMPERE;
|
1543
|
+
const int nwarps = NWARPS_Q8_0_AMPERE;
|
1544
|
+
allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1545
|
+
tile_x_qs_q8_0, tile_x_d_q8_0);
|
1546
|
+
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
|
1547
|
+
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
|
1548
|
+
vec_dot_q8_0_q8_1_mul_mat>(
|
1549
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1550
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1551
|
+
}
|
1552
|
+
|
1553
|
+
#define MMQ_X_Q2_K_RDNA2 64
|
1554
|
+
#define MMQ_Y_Q2_K_RDNA2 128
|
1555
|
+
#define NWARPS_Q2_K_RDNA2 8
|
1556
|
+
#define MMQ_X_Q2_K_RDNA1 128
|
1557
|
+
#define MMQ_Y_Q2_K_RDNA1 32
|
1558
|
+
#define NWARPS_Q2_K_RDNA1 8
|
1559
|
+
#if defined(SYCL_USE_XMX)
|
1560
|
+
#define MMQ_X_Q2_K_AMPERE 4
|
1561
|
+
#define MMQ_Y_Q2_K_AMPERE 32
|
1562
|
+
#define NWARPS_Q2_K_AMPERE 4
|
1563
|
+
#else
|
1564
|
+
#define MMQ_X_Q2_K_AMPERE 64
|
1565
|
+
#define MMQ_Y_Q2_K_AMPERE 128
|
1566
|
+
#define NWARPS_Q2_K_AMPERE 4
|
1567
|
+
#endif
|
1568
|
+
#define MMQ_X_Q2_K_PASCAL 64
|
1569
|
+
#define MMQ_Y_Q2_K_PASCAL 64
|
1570
|
+
#define NWARPS_Q2_K_PASCAL 8
|
1571
|
+
|
1572
|
+
template <bool need_check> static void
|
1573
|
+
mul_mat_q2_K(
|
1574
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1575
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1576
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
|
1577
|
+
sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
|
1578
|
+
sycl::half2 *tile_y_ds) {
|
1579
|
+
int * tile_x_ql = nullptr;
|
1580
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1581
|
+
int * tile_x_qh = nullptr;
|
1582
|
+
int * tile_x_sc = nullptr;
|
1583
|
+
|
1584
|
+
//sycl_todo: change according to hardware
|
1585
|
+
const int mmq_x = MMQ_X_Q2_K_AMPERE;
|
1586
|
+
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
|
1587
|
+
const int nwarps = NWARPS_Q2_K_AMPERE;
|
1588
|
+
allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1589
|
+
tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
|
1590
|
+
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
|
1591
|
+
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
|
1592
|
+
vec_dot_q2_K_q8_1_mul_mat>(
|
1593
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1594
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1595
|
+
}
|
1596
|
+
|
1597
|
+
#define MMQ_X_Q3_K_RDNA2 128
|
1598
|
+
#define MMQ_Y_Q3_K_RDNA2 64
|
1599
|
+
#define NWARPS_Q3_K_RDNA2 8
|
1600
|
+
#define MMQ_X_Q3_K_RDNA1 32
|
1601
|
+
#define MMQ_Y_Q3_K_RDNA1 128
|
1602
|
+
#define NWARPS_Q3_K_RDNA1 8
|
1603
|
+
#if defined(SYCL_USE_XMX)
|
1604
|
+
#define MMQ_X_Q3_K_AMPERE 4
|
1605
|
+
#define MMQ_Y_Q3_K_AMPERE 32
|
1606
|
+
#define NWARPS_Q3_K_AMPERE 4
|
1607
|
+
#else
|
1608
|
+
#define MMQ_X_Q3_K_AMPERE 128
|
1609
|
+
#define MMQ_Y_Q3_K_AMPERE 128
|
1610
|
+
#define NWARPS_Q3_K_AMPERE 4
|
1611
|
+
#endif
|
1612
|
+
#define MMQ_X_Q3_K_PASCAL 64
|
1613
|
+
#define MMQ_Y_Q3_K_PASCAL 64
|
1614
|
+
#define NWARPS_Q3_K_PASCAL 8
|
1615
|
+
|
1616
|
+
template <bool need_check> static void
|
1617
|
+
mul_mat_q3_K(
|
1618
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1619
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1620
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
|
1621
|
+
sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
|
1622
|
+
int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1623
|
+
int * tile_x_ql = nullptr;
|
1624
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1625
|
+
int * tile_x_qh = nullptr;
|
1626
|
+
int * tile_x_sc = nullptr;
|
1627
|
+
|
1628
|
+
//sycl_todo: change according to hardware
|
1629
|
+
const int mmq_x = MMQ_X_Q3_K_AMPERE;
|
1630
|
+
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
|
1631
|
+
const int nwarps = NWARPS_Q3_K_AMPERE;
|
1632
|
+
allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1633
|
+
tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
|
1634
|
+
tile_x_sc_q3_K);
|
1635
|
+
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
|
1636
|
+
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
|
1637
|
+
vec_dot_q3_K_q8_1_mul_mat>(
|
1638
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1639
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1640
|
+
}
|
1641
|
+
|
1642
|
+
#define MMQ_X_Q4_K_RDNA2 64
|
1643
|
+
#define MMQ_Y_Q4_K_RDNA2 128
|
1644
|
+
#define NWARPS_Q4_K_RDNA2 8
|
1645
|
+
#define MMQ_X_Q4_K_RDNA1 32
|
1646
|
+
#define MMQ_Y_Q4_K_RDNA1 64
|
1647
|
+
#define NWARPS_Q4_K_RDNA1 8
|
1648
|
+
#if defined(SYCL_USE_XMX)
|
1649
|
+
#define MMQ_X_Q4_K_AMPERE 4
|
1650
|
+
#define MMQ_Y_Q4_K_AMPERE 32
|
1651
|
+
#define NWARPS_Q4_K_AMPERE 4
|
1652
|
+
#else
|
1653
|
+
#define MMQ_X_Q4_K_AMPERE 64
|
1654
|
+
#define MMQ_Y_Q4_K_AMPERE 128
|
1655
|
+
#define NWARPS_Q4_K_AMPERE 4
|
1656
|
+
#endif
|
1657
|
+
#define MMQ_X_Q4_K_PASCAL 64
|
1658
|
+
#define MMQ_Y_Q4_K_PASCAL 64
|
1659
|
+
#define NWARPS_Q4_K_PASCAL 8
|
1660
|
+
|
1661
|
+
template <bool need_check> static void
|
1662
|
+
mul_mat_q4_K(
|
1663
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1664
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1665
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
|
1666
|
+
sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
|
1667
|
+
sycl::half2 *tile_y_ds) {
|
1668
|
+
int * tile_x_ql = nullptr;
|
1669
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1670
|
+
int * tile_x_qh = nullptr;
|
1671
|
+
int * tile_x_sc = nullptr;
|
1672
|
+
|
1673
|
+
//sycl_todo: change according to hardware
|
1674
|
+
const int mmq_x = MMQ_X_Q4_K_AMPERE;
|
1675
|
+
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
|
1676
|
+
const int nwarps = NWARPS_Q4_K_AMPERE;
|
1677
|
+
allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1678
|
+
tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
|
1679
|
+
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
|
1680
|
+
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
|
1681
|
+
vec_dot_q4_K_q8_1_mul_mat>(
|
1682
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1683
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1684
|
+
}
|
1685
|
+
|
1686
|
+
#define MMQ_X_Q5_K_RDNA2 64
|
1687
|
+
#define MMQ_Y_Q5_K_RDNA2 128
|
1688
|
+
#define NWARPS_Q5_K_RDNA2 8
|
1689
|
+
#define MMQ_X_Q5_K_RDNA1 32
|
1690
|
+
#define MMQ_Y_Q5_K_RDNA1 64
|
1691
|
+
#define NWARPS_Q5_K_RDNA1 8
|
1692
|
+
#if defined(SYCL_USE_XMX)
|
1693
|
+
#define MMQ_X_Q5_K_AMPERE 4
|
1694
|
+
#define MMQ_Y_Q5_K_AMPERE 32
|
1695
|
+
#define NWARPS_Q5_K_AMPERE 4
|
1696
|
+
#else
|
1697
|
+
#define MMQ_X_Q5_K_AMPERE 64
|
1698
|
+
#define MMQ_Y_Q5_K_AMPERE 128
|
1699
|
+
#define NWARPS_Q5_K_AMPERE 4
|
1700
|
+
#endif
|
1701
|
+
#define MMQ_X_Q5_K_PASCAL 64
|
1702
|
+
#define MMQ_Y_Q5_K_PASCAL 64
|
1703
|
+
#define NWARPS_Q5_K_PASCAL 8
|
1704
|
+
|
1705
|
+
template <bool need_check> static void
|
1706
|
+
mul_mat_q5_K(
|
1707
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1708
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1709
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
|
1710
|
+
sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
|
1711
|
+
sycl::half2 *tile_y_ds) {
|
1712
|
+
int * tile_x_ql = nullptr;
|
1713
|
+
sycl::half2 *tile_x_dm = nullptr;
|
1714
|
+
int * tile_x_qh = nullptr;
|
1715
|
+
int * tile_x_sc = nullptr;
|
1716
|
+
|
1717
|
+
//sycl_todo: change according to hardware
|
1718
|
+
const int mmq_x = MMQ_X_Q5_K_AMPERE;
|
1719
|
+
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
|
1720
|
+
const int nwarps = NWARPS_Q5_K_AMPERE;
|
1721
|
+
allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1722
|
+
tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
|
1723
|
+
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
|
1724
|
+
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
|
1725
|
+
vec_dot_q5_K_q8_1_mul_mat>(
|
1726
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1727
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1728
|
+
}
|
1729
|
+
|
1730
|
+
#define MMQ_X_Q6_K_RDNA2 64
|
1731
|
+
#define MMQ_Y_Q6_K_RDNA2 128
|
1732
|
+
#define NWARPS_Q6_K_RDNA2 8
|
1733
|
+
#define MMQ_X_Q6_K_RDNA1 32
|
1734
|
+
#define MMQ_Y_Q6_K_RDNA1 64
|
1735
|
+
#define NWARPS_Q6_K_RDNA1 8
|
1736
|
+
#if defined(SYCL_USE_XMX)
|
1737
|
+
#define MMQ_X_Q6_K_AMPERE 4
|
1738
|
+
#define MMQ_Y_Q6_K_AMPERE 32
|
1739
|
+
#define NWARPS_Q6_K_AMPERE 4
|
1740
|
+
#else
|
1741
|
+
#define MMQ_X_Q6_K_AMPERE 64
|
1742
|
+
#define MMQ_Y_Q6_K_AMPERE 64
|
1743
|
+
#define NWARPS_Q6_K_AMPERE 4
|
1744
|
+
#endif
|
1745
|
+
#define MMQ_X_Q6_K_PASCAL 64
|
1746
|
+
#define MMQ_Y_Q6_K_PASCAL 64
|
1747
|
+
#define NWARPS_Q6_K_PASCAL 8
|
1748
|
+
|
1749
|
+
template <bool need_check> static void
|
1750
|
+
mul_mat_q6_K(
|
1751
|
+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
1752
|
+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
|
1753
|
+
const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
|
1754
|
+
int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
|
1755
|
+
// int * tile_x_ql = nullptr;
|
1756
|
+
// sycl::half2 *tile_x_dm = nullptr;
|
1757
|
+
int * tile_x_qh = nullptr;
|
1758
|
+
// int * tile_x_sc = nullptr;
|
1759
|
+
|
1760
|
+
//sycl_todo: change according to hardware
|
1761
|
+
const int mmq_x = MMQ_X_Q6_K_AMPERE;
|
1762
|
+
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
|
1763
|
+
const int nwarps = NWARPS_Q6_K_AMPERE;
|
1764
|
+
allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
|
1765
|
+
tile_x_ql, tile_x_dm, tile_x_sc);
|
1766
|
+
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
|
1767
|
+
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
|
1768
|
+
vec_dot_q6_K_q8_1_mul_mat>(
|
1769
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
|
1770
|
+
tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
|
1771
|
+
}
|
1772
|
+
|
1773
|
+
static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
1774
|
+
float *dst, const int ncols_x,
|
1775
|
+
const int nrows_x, const int ncols_y,
|
1776
|
+
const int nrows_y, const int nrows_dst,
|
1777
|
+
dpct::queue_ptr stream) try {
|
1778
|
+
|
1779
|
+
int id;
|
1780
|
+
SYCL_CHECK(
|
1781
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
1782
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
1783
|
+
|
1784
|
+
int mmq_x, mmq_y, nwarps;
|
1785
|
+
if (compute_capability >= VER_GEN13) {
|
1786
|
+
mmq_x = MMQ_X_Q4_0_RDNA2;
|
1787
|
+
mmq_y = MMQ_Y_Q4_0_RDNA2;
|
1788
|
+
nwarps = NWARPS_Q4_0_RDNA2;
|
1789
|
+
} else if (compute_capability >= VER_GEN12) {
|
1790
|
+
mmq_x = MMQ_X_Q4_0_RDNA1;
|
1791
|
+
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
1792
|
+
nwarps = NWARPS_Q4_0_RDNA1;
|
1793
|
+
} else if (compute_capability >= VER_GEN9) {
|
1794
|
+
mmq_x = MMQ_X_Q4_0_AMPERE;
|
1795
|
+
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
1796
|
+
nwarps = NWARPS_Q4_0_AMPERE;
|
1797
|
+
} else if (compute_capability >= VER_4VEC) {
|
1798
|
+
mmq_x = MMQ_X_Q4_0_PASCAL;
|
1799
|
+
mmq_y = MMQ_Y_Q4_0_PASCAL;
|
1800
|
+
nwarps = NWARPS_Q4_0_PASCAL;
|
1801
|
+
} else {
|
1802
|
+
GGML_ABORT("fatal error");
|
1803
|
+
}
|
1804
|
+
|
1805
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
1806
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
1807
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
1808
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
1809
|
+
|
1810
|
+
if (nrows_x % mmq_y == 0) {
|
1811
|
+
const bool need_check = false;
|
1812
|
+
/*
|
1813
|
+
DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
|
1814
|
+
the limit. To get the device limit, query
|
1815
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
1816
|
+
*/
|
1817
|
+
{
|
1818
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
1819
|
+
{sycl::aspect::fp16});
|
1820
|
+
|
1821
|
+
stream->submit([&](sycl::handler &cgh) {
|
1822
|
+
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
1823
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
1824
|
+
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
1825
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
|
1826
|
+
cgh);
|
1827
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
1828
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
1829
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
1830
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
1831
|
+
|
1832
|
+
cgh.parallel_for(
|
1833
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1834
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1835
|
+
mul_mat_q4_0<need_check>(
|
1836
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
1837
|
+
nrows_dst, item_ct1,
|
1838
|
+
get_pointer(tile_x_qs_q4_0_acc_ct1),
|
1839
|
+
get_pointer(tile_x_d_q4_0_acc_ct1),
|
1840
|
+
get_pointer(tile_y_qs_acc_ct1),
|
1841
|
+
get_pointer(tile_y_ds_acc_ct1));
|
1842
|
+
});
|
1843
|
+
});
|
1844
|
+
}
|
1845
|
+
} else {
|
1846
|
+
const bool need_check = true;
|
1847
|
+
/*
|
1848
|
+
DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
|
1849
|
+
the limit. To get the device limit, query
|
1850
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
1851
|
+
*/
|
1852
|
+
{
|
1853
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
1854
|
+
{sycl::aspect::fp16});
|
1855
|
+
|
1856
|
+
stream->submit([&](sycl::handler &cgh) {
|
1857
|
+
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
1858
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
1859
|
+
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
1860
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
|
1861
|
+
cgh);
|
1862
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
1863
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
1864
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
1865
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
1866
|
+
|
1867
|
+
cgh.parallel_for(
|
1868
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1869
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1870
|
+
mul_mat_q4_0<need_check>(
|
1871
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
1872
|
+
nrows_dst, item_ct1,
|
1873
|
+
get_pointer(tile_x_qs_q4_0_acc_ct1),
|
1874
|
+
get_pointer(tile_x_d_q4_0_acc_ct1),
|
1875
|
+
get_pointer(tile_y_qs_acc_ct1),
|
1876
|
+
get_pointer(tile_y_ds_acc_ct1));
|
1877
|
+
});
|
1878
|
+
});
|
1879
|
+
}
|
1880
|
+
}
|
1881
|
+
}
|
1882
|
+
catch (sycl::exception const &exc) {
|
1883
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
1884
|
+
<< ", line:" << __LINE__ << std::endl;
|
1885
|
+
std::exit(1);
|
1886
|
+
}
|
1887
|
+
|
1888
|
+
static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
1889
|
+
float *dst, const int ncols_x,
|
1890
|
+
const int nrows_x, const int ncols_y,
|
1891
|
+
const int nrows_y, const int nrows_dst,
|
1892
|
+
dpct::queue_ptr stream) try {
|
1893
|
+
|
1894
|
+
int id;
|
1895
|
+
SYCL_CHECK(
|
1896
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
1897
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
1898
|
+
|
1899
|
+
int mmq_x, mmq_y, nwarps;
|
1900
|
+
if (compute_capability >= VER_GEN13) {
|
1901
|
+
mmq_x = MMQ_X_Q4_1_RDNA2;
|
1902
|
+
mmq_y = MMQ_Y_Q4_1_RDNA2;
|
1903
|
+
nwarps = NWARPS_Q4_1_RDNA2;
|
1904
|
+
} else if (compute_capability >= VER_GEN12) {
|
1905
|
+
mmq_x = MMQ_X_Q4_1_RDNA1;
|
1906
|
+
mmq_y = MMQ_Y_Q4_1_RDNA1;
|
1907
|
+
nwarps = NWARPS_Q4_1_RDNA1;
|
1908
|
+
} else if (compute_capability >= VER_GEN9) {
|
1909
|
+
mmq_x = MMQ_X_Q4_1_AMPERE;
|
1910
|
+
mmq_y = MMQ_Y_Q4_1_AMPERE;
|
1911
|
+
nwarps = NWARPS_Q4_1_AMPERE;
|
1912
|
+
} else if (compute_capability >= VER_4VEC) {
|
1913
|
+
mmq_x = MMQ_X_Q4_1_PASCAL;
|
1914
|
+
mmq_y = MMQ_Y_Q4_1_PASCAL;
|
1915
|
+
nwarps = NWARPS_Q4_1_PASCAL;
|
1916
|
+
} else {
|
1917
|
+
GGML_ABORT("fatal error");
|
1918
|
+
}
|
1919
|
+
|
1920
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
1921
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
1922
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
1923
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
1924
|
+
|
1925
|
+
if (nrows_x % mmq_y == 0) {
|
1926
|
+
const bool need_check = false;
|
1927
|
+
/*
|
1928
|
+
DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
|
1929
|
+
the limit. To get the device limit, query
|
1930
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
1931
|
+
*/
|
1932
|
+
{
|
1933
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
1934
|
+
{sycl::aspect::fp16});
|
1935
|
+
|
1936
|
+
stream->submit([&](sycl::handler &cgh) {
|
1937
|
+
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
1938
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
1939
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
1940
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
|
1941
|
+
cgh);
|
1942
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
1943
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
1944
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
1945
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
1946
|
+
|
1947
|
+
cgh.parallel_for(
|
1948
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1949
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1950
|
+
mul_mat_q4_1<need_check>(
|
1951
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
1952
|
+
nrows_dst, item_ct1,
|
1953
|
+
get_pointer(tile_x_qs_q4_1_acc_ct1),
|
1954
|
+
get_pointer(tile_x_dm_q4_1_acc_ct1),
|
1955
|
+
get_pointer(tile_y_qs_acc_ct1),
|
1956
|
+
get_pointer(tile_y_ds_acc_ct1));
|
1957
|
+
});
|
1958
|
+
});
|
1959
|
+
}
|
1960
|
+
} else {
|
1961
|
+
const bool need_check = true;
|
1962
|
+
/*
|
1963
|
+
DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
|
1964
|
+
the limit. To get the device limit, query
|
1965
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
1966
|
+
*/
|
1967
|
+
{
|
1968
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
1969
|
+
{sycl::aspect::fp16});
|
1970
|
+
|
1971
|
+
stream->submit([&](sycl::handler &cgh) {
|
1972
|
+
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
1973
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
1974
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
1975
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
|
1976
|
+
cgh);
|
1977
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
1978
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
1979
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
1980
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
1981
|
+
|
1982
|
+
cgh.parallel_for(
|
1983
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1984
|
+
[=](sycl::nd_item<3> item_ct1) {
|
1985
|
+
mul_mat_q4_1<need_check>(
|
1986
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
1987
|
+
nrows_dst, item_ct1,
|
1988
|
+
get_pointer(tile_x_qs_q4_1_acc_ct1),
|
1989
|
+
get_pointer(tile_x_dm_q4_1_acc_ct1),
|
1990
|
+
get_pointer(tile_y_qs_acc_ct1),
|
1991
|
+
get_pointer(tile_y_ds_acc_ct1));
|
1992
|
+
});
|
1993
|
+
});
|
1994
|
+
}
|
1995
|
+
}
|
1996
|
+
}
|
1997
|
+
catch (sycl::exception const &exc) {
|
1998
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
1999
|
+
<< ", line:" << __LINE__ << std::endl;
|
2000
|
+
std::exit(1);
|
2001
|
+
}
|
2002
|
+
|
2003
|
+
static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
2004
|
+
float *dst, const int ncols_x,
|
2005
|
+
const int nrows_x, const int ncols_y,
|
2006
|
+
const int nrows_y, const int nrows_dst,
|
2007
|
+
dpct::queue_ptr stream) try {
|
2008
|
+
|
2009
|
+
int id;
|
2010
|
+
SYCL_CHECK(
|
2011
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2012
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2013
|
+
|
2014
|
+
int mmq_x, mmq_y, nwarps;
|
2015
|
+
if (compute_capability >= VER_GEN13) {
|
2016
|
+
mmq_x = MMQ_X_Q5_0_RDNA2;
|
2017
|
+
mmq_y = MMQ_Y_Q5_0_RDNA2;
|
2018
|
+
nwarps = NWARPS_Q5_0_RDNA2;
|
2019
|
+
} else if (compute_capability >= VER_GEN12) {
|
2020
|
+
mmq_x = MMQ_X_Q5_0_RDNA1;
|
2021
|
+
mmq_y = MMQ_Y_Q5_0_RDNA1;
|
2022
|
+
nwarps = NWARPS_Q5_0_RDNA1;
|
2023
|
+
} else if (compute_capability >= VER_GEN9) {
|
2024
|
+
mmq_x = MMQ_X_Q5_0_AMPERE;
|
2025
|
+
mmq_y = MMQ_Y_Q5_0_AMPERE;
|
2026
|
+
nwarps = NWARPS_Q5_0_AMPERE;
|
2027
|
+
} else if (compute_capability >= VER_4VEC) {
|
2028
|
+
mmq_x = MMQ_X_Q5_0_PASCAL;
|
2029
|
+
mmq_y = MMQ_Y_Q5_0_PASCAL;
|
2030
|
+
nwarps = NWARPS_Q5_0_PASCAL;
|
2031
|
+
} else {
|
2032
|
+
GGML_ABORT("fatal error");
|
2033
|
+
}
|
2034
|
+
|
2035
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2036
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2037
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2038
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2039
|
+
|
2040
|
+
if (nrows_x % mmq_y == 0) {
|
2041
|
+
const bool need_check = false;
|
2042
|
+
/*
|
2043
|
+
DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
|
2044
|
+
the limit. To get the device limit, query
|
2045
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2046
|
+
*/
|
2047
|
+
{
|
2048
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2049
|
+
{sycl::aspect::fp16});
|
2050
|
+
|
2051
|
+
stream->submit([&](sycl::handler &cgh) {
|
2052
|
+
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
2053
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2054
|
+
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
2055
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
|
2056
|
+
cgh);
|
2057
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2058
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2059
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2060
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2061
|
+
|
2062
|
+
cgh.parallel_for(
|
2063
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2064
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2065
|
+
mul_mat_q5_0<need_check>(
|
2066
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2067
|
+
nrows_dst, item_ct1,
|
2068
|
+
get_pointer(tile_x_ql_q5_0_acc_ct1),
|
2069
|
+
get_pointer(tile_x_d_q5_0_acc_ct1),
|
2070
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2071
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2072
|
+
});
|
2073
|
+
});
|
2074
|
+
}
|
2075
|
+
} else {
|
2076
|
+
const bool need_check = true;
|
2077
|
+
/*
|
2078
|
+
DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
|
2079
|
+
the limit. To get the device limit, query
|
2080
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2081
|
+
*/
|
2082
|
+
{
|
2083
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2084
|
+
{sycl::aspect::fp16});
|
2085
|
+
|
2086
|
+
stream->submit([&](sycl::handler &cgh) {
|
2087
|
+
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
2088
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2089
|
+
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
2090
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
|
2091
|
+
cgh);
|
2092
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2093
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2094
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2095
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2096
|
+
|
2097
|
+
cgh.parallel_for(
|
2098
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2099
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2100
|
+
mul_mat_q5_0<need_check>(
|
2101
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2102
|
+
nrows_dst, item_ct1,
|
2103
|
+
get_pointer(tile_x_ql_q5_0_acc_ct1),
|
2104
|
+
get_pointer(tile_x_d_q5_0_acc_ct1),
|
2105
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2106
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2107
|
+
});
|
2108
|
+
});
|
2109
|
+
}
|
2110
|
+
}
|
2111
|
+
}
|
2112
|
+
catch (sycl::exception const &exc) {
|
2113
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2114
|
+
<< ", line:" << __LINE__ << std::endl;
|
2115
|
+
std::exit(1);
|
2116
|
+
}
|
2117
|
+
|
2118
|
+
static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
2119
|
+
float *dst, const int ncols_x,
|
2120
|
+
const int nrows_x, const int ncols_y,
|
2121
|
+
const int nrows_y, const int nrows_dst,
|
2122
|
+
dpct::queue_ptr stream) try {
|
2123
|
+
|
2124
|
+
int id;
|
2125
|
+
SYCL_CHECK(
|
2126
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2127
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2128
|
+
|
2129
|
+
int mmq_x, mmq_y, nwarps;
|
2130
|
+
if (compute_capability >= VER_GEN13) {
|
2131
|
+
mmq_x = MMQ_X_Q5_1_RDNA2;
|
2132
|
+
mmq_y = MMQ_Y_Q5_1_RDNA2;
|
2133
|
+
nwarps = NWARPS_Q5_1_RDNA2;
|
2134
|
+
} else if (compute_capability >= VER_GEN12) {
|
2135
|
+
mmq_x = MMQ_X_Q5_1_RDNA1;
|
2136
|
+
mmq_y = MMQ_Y_Q5_1_RDNA1;
|
2137
|
+
nwarps = NWARPS_Q5_1_RDNA1;
|
2138
|
+
} else if (compute_capability >= VER_GEN9) {
|
2139
|
+
mmq_x = MMQ_X_Q5_1_AMPERE;
|
2140
|
+
mmq_y = MMQ_Y_Q5_1_AMPERE;
|
2141
|
+
nwarps = NWARPS_Q5_1_AMPERE;
|
2142
|
+
} else if (compute_capability >= VER_4VEC) {
|
2143
|
+
mmq_x = MMQ_X_Q5_1_PASCAL;
|
2144
|
+
mmq_y = MMQ_Y_Q5_1_PASCAL;
|
2145
|
+
nwarps = NWARPS_Q5_1_PASCAL;
|
2146
|
+
} else {
|
2147
|
+
GGML_ABORT("fatal error");
|
2148
|
+
}
|
2149
|
+
|
2150
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2151
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2152
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2153
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2154
|
+
|
2155
|
+
if (nrows_x % mmq_y == 0) {
|
2156
|
+
const bool need_check = false;
|
2157
|
+
/*
|
2158
|
+
DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
|
2159
|
+
the limit. To get the device limit, query
|
2160
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2161
|
+
*/
|
2162
|
+
{
|
2163
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2164
|
+
{sycl::aspect::fp16});
|
2165
|
+
|
2166
|
+
stream->submit([&](sycl::handler &cgh) {
|
2167
|
+
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
2168
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2169
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
2170
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
|
2171
|
+
cgh);
|
2172
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2173
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2174
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2175
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2176
|
+
|
2177
|
+
cgh.parallel_for(
|
2178
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2179
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2180
|
+
mul_mat_q5_1<need_check>(
|
2181
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2182
|
+
nrows_dst, item_ct1,
|
2183
|
+
get_pointer(tile_x_ql_q5_1_acc_ct1),
|
2184
|
+
get_pointer(tile_x_dm_q5_1_acc_ct1),
|
2185
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2186
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2187
|
+
});
|
2188
|
+
});
|
2189
|
+
}
|
2190
|
+
} else {
|
2191
|
+
const bool need_check = true;
|
2192
|
+
/*
|
2193
|
+
DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
|
2194
|
+
the limit. To get the device limit, query
|
2195
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2196
|
+
*/
|
2197
|
+
{
|
2198
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2199
|
+
{sycl::aspect::fp16});
|
2200
|
+
|
2201
|
+
stream->submit([&](sycl::handler &cgh) {
|
2202
|
+
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
2203
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2204
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
2205
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
|
2206
|
+
cgh);
|
2207
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2208
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2209
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2210
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2211
|
+
|
2212
|
+
cgh.parallel_for(
|
2213
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2214
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2215
|
+
mul_mat_q5_1<need_check>(
|
2216
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2217
|
+
nrows_dst, item_ct1,
|
2218
|
+
get_pointer(tile_x_ql_q5_1_acc_ct1),
|
2219
|
+
get_pointer(tile_x_dm_q5_1_acc_ct1),
|
2220
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2221
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2222
|
+
});
|
2223
|
+
});
|
2224
|
+
}
|
2225
|
+
}
|
2226
|
+
}
|
2227
|
+
catch (sycl::exception const &exc) {
|
2228
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2229
|
+
<< ", line:" << __LINE__ << std::endl;
|
2230
|
+
std::exit(1);
|
2231
|
+
}
|
2232
|
+
|
2233
|
+
static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
2234
|
+
float *dst, const int ncols_x,
|
2235
|
+
const int nrows_x, const int ncols_y,
|
2236
|
+
const int nrows_y, const int nrows_dst,
|
2237
|
+
dpct::queue_ptr stream) try {
|
2238
|
+
|
2239
|
+
int id;
|
2240
|
+
SYCL_CHECK(
|
2241
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2242
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2243
|
+
|
2244
|
+
int mmq_x, mmq_y, nwarps;
|
2245
|
+
if (compute_capability >= VER_GEN13) {
|
2246
|
+
mmq_x = MMQ_X_Q8_0_RDNA2;
|
2247
|
+
mmq_y = MMQ_Y_Q8_0_RDNA2;
|
2248
|
+
nwarps = NWARPS_Q8_0_RDNA2;
|
2249
|
+
} else if (compute_capability >= VER_GEN12) {
|
2250
|
+
mmq_x = MMQ_X_Q8_0_RDNA1;
|
2251
|
+
mmq_y = MMQ_Y_Q8_0_RDNA1;
|
2252
|
+
nwarps = NWARPS_Q8_0_RDNA1;
|
2253
|
+
} else if (compute_capability >= VER_GEN9) {
|
2254
|
+
mmq_x = MMQ_X_Q8_0_AMPERE;
|
2255
|
+
mmq_y = MMQ_Y_Q8_0_AMPERE;
|
2256
|
+
nwarps = NWARPS_Q8_0_AMPERE;
|
2257
|
+
} else if (compute_capability >= VER_4VEC) {
|
2258
|
+
mmq_x = MMQ_X_Q8_0_PASCAL;
|
2259
|
+
mmq_y = MMQ_Y_Q8_0_PASCAL;
|
2260
|
+
nwarps = NWARPS_Q8_0_PASCAL;
|
2261
|
+
} else {
|
2262
|
+
GGML_ABORT("fatal error");
|
2263
|
+
}
|
2264
|
+
|
2265
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2266
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2267
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2268
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2269
|
+
|
2270
|
+
if (nrows_x % mmq_y == 0) {
|
2271
|
+
const bool need_check = false;
|
2272
|
+
/*
|
2273
|
+
DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
|
2274
|
+
the limit. To get the device limit, query
|
2275
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2276
|
+
*/
|
2277
|
+
{
|
2278
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2279
|
+
{sycl::aspect::fp16});
|
2280
|
+
|
2281
|
+
stream->submit([&](sycl::handler &cgh) {
|
2282
|
+
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
2283
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2284
|
+
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
2285
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
|
2286
|
+
cgh);
|
2287
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2288
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2289
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2290
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2291
|
+
|
2292
|
+
cgh.parallel_for(
|
2293
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2294
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2295
|
+
mul_mat_q8_0<need_check>(
|
2296
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2297
|
+
nrows_dst, item_ct1,
|
2298
|
+
get_pointer(tile_x_qs_q8_0_acc_ct1),
|
2299
|
+
get_pointer(tile_x_d_q8_0_acc_ct1),
|
2300
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2301
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2302
|
+
});
|
2303
|
+
});
|
2304
|
+
}
|
2305
|
+
} else {
|
2306
|
+
const bool need_check = true;
|
2307
|
+
/*
|
2308
|
+
DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
|
2309
|
+
the limit. To get the device limit, query
|
2310
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2311
|
+
*/
|
2312
|
+
{
|
2313
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2314
|
+
{sycl::aspect::fp16});
|
2315
|
+
|
2316
|
+
stream->submit([&](sycl::handler &cgh) {
|
2317
|
+
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
2318
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2319
|
+
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
2320
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
|
2321
|
+
cgh);
|
2322
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2323
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2324
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2325
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2326
|
+
|
2327
|
+
cgh.parallel_for(
|
2328
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2329
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2330
|
+
mul_mat_q8_0<need_check>(
|
2331
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2332
|
+
nrows_dst, item_ct1,
|
2333
|
+
get_pointer(tile_x_qs_q8_0_acc_ct1),
|
2334
|
+
get_pointer(tile_x_d_q8_0_acc_ct1),
|
2335
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2336
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2337
|
+
});
|
2338
|
+
});
|
2339
|
+
}
|
2340
|
+
}
|
2341
|
+
}
|
2342
|
+
catch (sycl::exception const &exc) {
|
2343
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2344
|
+
<< ", line:" << __LINE__ << std::endl;
|
2345
|
+
std::exit(1);
|
2346
|
+
}
|
2347
|
+
|
2348
|
+
static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
2349
|
+
float *dst, const int ncols_x,
|
2350
|
+
const int nrows_x, const int ncols_y,
|
2351
|
+
const int nrows_y, const int nrows_dst,
|
2352
|
+
dpct::queue_ptr stream) try {
|
2353
|
+
|
2354
|
+
int id;
|
2355
|
+
SYCL_CHECK(
|
2356
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2357
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2358
|
+
|
2359
|
+
int mmq_x, mmq_y, nwarps;
|
2360
|
+
if (compute_capability >= VER_GEN13) {
|
2361
|
+
mmq_x = MMQ_X_Q2_K_RDNA2;
|
2362
|
+
mmq_y = MMQ_Y_Q2_K_RDNA2;
|
2363
|
+
nwarps = NWARPS_Q2_K_RDNA2;
|
2364
|
+
} else if (compute_capability >= VER_GEN12) {
|
2365
|
+
mmq_x = MMQ_X_Q2_K_RDNA1;
|
2366
|
+
mmq_y = MMQ_Y_Q2_K_RDNA1;
|
2367
|
+
nwarps = NWARPS_Q2_K_RDNA1;
|
2368
|
+
} else if (compute_capability >= VER_GEN9) {
|
2369
|
+
mmq_x = MMQ_X_Q2_K_AMPERE;
|
2370
|
+
mmq_y = MMQ_Y_Q2_K_AMPERE;
|
2371
|
+
nwarps = NWARPS_Q2_K_AMPERE;
|
2372
|
+
} else if (compute_capability >= VER_4VEC) {
|
2373
|
+
mmq_x = MMQ_X_Q2_K_PASCAL;
|
2374
|
+
mmq_y = MMQ_Y_Q2_K_PASCAL;
|
2375
|
+
nwarps = NWARPS_Q2_K_PASCAL;
|
2376
|
+
} else {
|
2377
|
+
GGML_ABORT("fatal error");
|
2378
|
+
}
|
2379
|
+
|
2380
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2381
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2382
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2383
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2384
|
+
|
2385
|
+
if (nrows_x % mmq_y == 0) {
|
2386
|
+
const bool need_check = false;
|
2387
|
+
/*
|
2388
|
+
DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
|
2389
|
+
the limit. To get the device limit, query
|
2390
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2391
|
+
*/
|
2392
|
+
{
|
2393
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2394
|
+
{sycl::aspect::fp16});
|
2395
|
+
|
2396
|
+
stream->submit([&](sycl::handler &cgh) {
|
2397
|
+
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
2398
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2399
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
2400
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
|
2401
|
+
cgh);
|
2402
|
+
sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
|
2403
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
|
2404
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2405
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2406
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2407
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2408
|
+
|
2409
|
+
cgh.parallel_for(
|
2410
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2411
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2412
|
+
mul_mat_q2_K<need_check>(
|
2413
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2414
|
+
nrows_dst, item_ct1,
|
2415
|
+
get_pointer(tile_x_ql_q2_K_acc_ct1),
|
2416
|
+
get_pointer(tile_x_dm_q2_K_acc_ct1),
|
2417
|
+
get_pointer(tile_x_sc_q2_K_acc_ct1),
|
2418
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2419
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2420
|
+
});
|
2421
|
+
});
|
2422
|
+
}
|
2423
|
+
} else {
|
2424
|
+
const bool need_check = true;
|
2425
|
+
/*
|
2426
|
+
DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
|
2427
|
+
the limit. To get the device limit, query
|
2428
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2429
|
+
*/
|
2430
|
+
{
|
2431
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2432
|
+
{sycl::aspect::fp16});
|
2433
|
+
|
2434
|
+
stream->submit([&](sycl::handler &cgh) {
|
2435
|
+
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
2436
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2437
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
2438
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
|
2439
|
+
cgh);
|
2440
|
+
sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
|
2441
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
|
2442
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2443
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2444
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2445
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2446
|
+
|
2447
|
+
cgh.parallel_for(
|
2448
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2449
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2450
|
+
mul_mat_q2_K<need_check>(
|
2451
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2452
|
+
nrows_dst, item_ct1,
|
2453
|
+
get_pointer(tile_x_ql_q2_K_acc_ct1),
|
2454
|
+
get_pointer(tile_x_dm_q2_K_acc_ct1),
|
2455
|
+
get_pointer(tile_x_sc_q2_K_acc_ct1),
|
2456
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2457
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2458
|
+
});
|
2459
|
+
});
|
2460
|
+
}
|
2461
|
+
}
|
2462
|
+
}
|
2463
|
+
catch (sycl::exception const &exc) {
|
2464
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2465
|
+
<< ", line:" << __LINE__ << std::endl;
|
2466
|
+
std::exit(1);
|
2467
|
+
}
|
2468
|
+
|
2469
|
+
static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
2470
|
+
float *dst, const int ncols_x,
|
2471
|
+
const int nrows_x, const int ncols_y,
|
2472
|
+
const int nrows_y, const int nrows_dst,
|
2473
|
+
dpct::queue_ptr stream) try {
|
2474
|
+
|
2475
|
+
#if QK_K == 256
|
2476
|
+
|
2477
|
+
int id;
|
2478
|
+
SYCL_CHECK(
|
2479
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2480
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2481
|
+
|
2482
|
+
int mmq_x, mmq_y, nwarps;
|
2483
|
+
if (compute_capability >= VER_GEN13) {
|
2484
|
+
mmq_x = MMQ_X_Q3_K_RDNA2;
|
2485
|
+
mmq_y = MMQ_Y_Q3_K_RDNA2;
|
2486
|
+
nwarps = NWARPS_Q3_K_RDNA2;
|
2487
|
+
} else if (compute_capability >= VER_GEN12) {
|
2488
|
+
mmq_x = MMQ_X_Q3_K_RDNA1;
|
2489
|
+
mmq_y = MMQ_Y_Q3_K_RDNA1;
|
2490
|
+
nwarps = NWARPS_Q3_K_RDNA1;
|
2491
|
+
} else if (compute_capability >= VER_GEN9) {
|
2492
|
+
mmq_x = MMQ_X_Q3_K_AMPERE;
|
2493
|
+
mmq_y = MMQ_Y_Q3_K_AMPERE;
|
2494
|
+
nwarps = NWARPS_Q3_K_AMPERE;
|
2495
|
+
} else if (compute_capability >= VER_4VEC) {
|
2496
|
+
mmq_x = MMQ_X_Q3_K_PASCAL;
|
2497
|
+
mmq_y = MMQ_Y_Q3_K_PASCAL;
|
2498
|
+
nwarps = NWARPS_Q3_K_PASCAL;
|
2499
|
+
} else {
|
2500
|
+
GGML_ABORT("fatal error");
|
2501
|
+
}
|
2502
|
+
|
2503
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2504
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2505
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2506
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2507
|
+
|
2508
|
+
if (nrows_x % mmq_y == 0) {
|
2509
|
+
const bool need_check = false;
|
2510
|
+
/*
|
2511
|
+
DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
|
2512
|
+
the limit. To get the device limit, query
|
2513
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2514
|
+
*/
|
2515
|
+
{
|
2516
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2517
|
+
{sycl::aspect::fp16});
|
2518
|
+
|
2519
|
+
stream->submit([&](sycl::handler &cgh) {
|
2520
|
+
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
2521
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2522
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
2523
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
|
2524
|
+
cgh);
|
2525
|
+
sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
|
2526
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
|
2527
|
+
sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
|
2528
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
|
2529
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2530
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2531
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2532
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2533
|
+
|
2534
|
+
cgh.parallel_for(
|
2535
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2536
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2537
|
+
mul_mat_q3_K<need_check>(
|
2538
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2539
|
+
nrows_dst, item_ct1,
|
2540
|
+
get_pointer(tile_x_ql_q3_K_acc_ct1),
|
2541
|
+
get_pointer(tile_x_dm_q3_K_acc_ct1),
|
2542
|
+
get_pointer(tile_x_qh_q3_K_acc_ct1),
|
2543
|
+
get_pointer(tile_x_sc_q3_K_acc_ct1),
|
2544
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2545
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2546
|
+
});
|
2547
|
+
});
|
2548
|
+
}
|
2549
|
+
} else {
|
2550
|
+
const bool need_check = true;
|
2551
|
+
/*
|
2552
|
+
DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
|
2553
|
+
the limit. To get the device limit, query
|
2554
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2555
|
+
*/
|
2556
|
+
{
|
2557
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2558
|
+
{sycl::aspect::fp16});
|
2559
|
+
|
2560
|
+
stream->submit([&](sycl::handler &cgh) {
|
2561
|
+
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
2562
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2563
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
2564
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
|
2565
|
+
cgh);
|
2566
|
+
sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
|
2567
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
|
2568
|
+
sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
|
2569
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
|
2570
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2571
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2572
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2573
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2574
|
+
|
2575
|
+
cgh.parallel_for(
|
2576
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2577
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2578
|
+
mul_mat_q3_K<need_check>(
|
2579
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2580
|
+
nrows_dst, item_ct1,
|
2581
|
+
get_pointer(tile_x_ql_q3_K_acc_ct1),
|
2582
|
+
get_pointer(tile_x_dm_q3_K_acc_ct1),
|
2583
|
+
get_pointer(tile_x_qh_q3_K_acc_ct1),
|
2584
|
+
get_pointer(tile_x_sc_q3_K_acc_ct1),
|
2585
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2586
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2587
|
+
});
|
2588
|
+
});
|
2589
|
+
}
|
2590
|
+
}
|
2591
|
+
#endif
|
2592
|
+
}
|
2593
|
+
catch (sycl::exception const &exc) {
|
2594
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2595
|
+
<< ", line:" << __LINE__ << std::endl;
|
2596
|
+
std::exit(1);
|
2597
|
+
}
|
2598
|
+
|
2599
|
+
static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
2600
|
+
float *dst, const int ncols_x,
|
2601
|
+
const int nrows_x, const int ncols_y,
|
2602
|
+
const int nrows_y, const int nrows_dst,
|
2603
|
+
dpct::queue_ptr stream) try {
|
2604
|
+
|
2605
|
+
int id;
|
2606
|
+
SYCL_CHECK(
|
2607
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2608
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2609
|
+
|
2610
|
+
int mmq_x, mmq_y, nwarps;
|
2611
|
+
if (compute_capability >= VER_GEN13) {
|
2612
|
+
mmq_x = MMQ_X_Q4_K_RDNA2;
|
2613
|
+
mmq_y = MMQ_Y_Q4_K_RDNA2;
|
2614
|
+
nwarps = NWARPS_Q4_K_RDNA2;
|
2615
|
+
} else if (compute_capability >= VER_GEN12) {
|
2616
|
+
mmq_x = MMQ_X_Q4_K_RDNA1;
|
2617
|
+
mmq_y = MMQ_Y_Q4_K_RDNA1;
|
2618
|
+
nwarps = NWARPS_Q4_K_RDNA1;
|
2619
|
+
} else if (compute_capability >= VER_GEN9) {
|
2620
|
+
mmq_x = MMQ_X_Q4_K_AMPERE;
|
2621
|
+
mmq_y = MMQ_Y_Q4_K_AMPERE;
|
2622
|
+
nwarps = NWARPS_Q4_K_AMPERE;
|
2623
|
+
} else if (compute_capability >= VER_4VEC) {
|
2624
|
+
mmq_x = MMQ_X_Q4_K_PASCAL;
|
2625
|
+
mmq_y = MMQ_Y_Q4_K_PASCAL;
|
2626
|
+
nwarps = NWARPS_Q4_K_PASCAL;
|
2627
|
+
} else {
|
2628
|
+
GGML_ABORT("fatal error");
|
2629
|
+
}
|
2630
|
+
|
2631
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2632
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2633
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2634
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2635
|
+
|
2636
|
+
if (nrows_x % mmq_y == 0) {
|
2637
|
+
const bool need_check = false;
|
2638
|
+
/*
|
2639
|
+
DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
|
2640
|
+
the limit. To get the device limit, query
|
2641
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2642
|
+
*/
|
2643
|
+
{
|
2644
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2645
|
+
{sycl::aspect::fp16});
|
2646
|
+
|
2647
|
+
stream->submit([&](sycl::handler &cgh) {
|
2648
|
+
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
2649
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2650
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
2651
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
|
2652
|
+
cgh);
|
2653
|
+
sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
|
2654
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
|
2655
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2656
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2657
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2658
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2659
|
+
|
2660
|
+
cgh.parallel_for(
|
2661
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2662
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2663
|
+
mul_mat_q4_K<need_check>(
|
2664
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2665
|
+
nrows_dst, item_ct1,
|
2666
|
+
get_pointer(tile_x_ql_q4_K_acc_ct1),
|
2667
|
+
get_pointer(tile_x_dm_q4_K_acc_ct1),
|
2668
|
+
get_pointer(tile_x_sc_q4_K_acc_ct1),
|
2669
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2670
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2671
|
+
});
|
2672
|
+
});
|
2673
|
+
}
|
2674
|
+
} else {
|
2675
|
+
const bool need_check = true;
|
2676
|
+
/*
|
2677
|
+
DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
|
2678
|
+
the limit. To get the device limit, query
|
2679
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2680
|
+
*/
|
2681
|
+
{
|
2682
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2683
|
+
{sycl::aspect::fp16});
|
2684
|
+
|
2685
|
+
stream->submit([&](sycl::handler &cgh) {
|
2686
|
+
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
2687
|
+
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
2688
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
2689
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
|
2690
|
+
cgh);
|
2691
|
+
sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
|
2692
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
|
2693
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2694
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2695
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2696
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2697
|
+
|
2698
|
+
cgh.parallel_for(
|
2699
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2700
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2701
|
+
mul_mat_q4_K<need_check>(
|
2702
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2703
|
+
nrows_dst, item_ct1,
|
2704
|
+
get_pointer(tile_x_ql_q4_K_acc_ct1),
|
2705
|
+
get_pointer(tile_x_dm_q4_K_acc_ct1),
|
2706
|
+
get_pointer(tile_x_sc_q4_K_acc_ct1),
|
2707
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2708
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2709
|
+
});
|
2710
|
+
});
|
2711
|
+
}
|
2712
|
+
}
|
2713
|
+
}
|
2714
|
+
catch (sycl::exception const &exc) {
|
2715
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2716
|
+
<< ", line:" << __LINE__ << std::endl;
|
2717
|
+
std::exit(1);
|
2718
|
+
}
|
2719
|
+
|
2720
|
+
static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
2721
|
+
float *dst, const int ncols_x,
|
2722
|
+
const int nrows_x, const int ncols_y,
|
2723
|
+
const int nrows_y, const int nrows_dst,
|
2724
|
+
dpct::queue_ptr stream) try {
|
2725
|
+
|
2726
|
+
int id;
|
2727
|
+
SYCL_CHECK(
|
2728
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2729
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2730
|
+
|
2731
|
+
int mmq_x, mmq_y, nwarps;
|
2732
|
+
if (compute_capability >= VER_GEN13) {
|
2733
|
+
mmq_x = MMQ_X_Q5_K_RDNA2;
|
2734
|
+
mmq_y = MMQ_Y_Q5_K_RDNA2;
|
2735
|
+
nwarps = NWARPS_Q5_K_RDNA2;
|
2736
|
+
} else if (compute_capability >= VER_GEN12) {
|
2737
|
+
mmq_x = MMQ_X_Q5_K_RDNA1;
|
2738
|
+
mmq_y = MMQ_Y_Q5_K_RDNA1;
|
2739
|
+
nwarps = NWARPS_Q5_K_RDNA1;
|
2740
|
+
} else if (compute_capability >= VER_GEN9) {
|
2741
|
+
mmq_x = MMQ_X_Q5_K_AMPERE;
|
2742
|
+
mmq_y = MMQ_Y_Q5_K_AMPERE;
|
2743
|
+
nwarps = NWARPS_Q5_K_AMPERE;
|
2744
|
+
} else if (compute_capability >= VER_4VEC) {
|
2745
|
+
mmq_x = MMQ_X_Q5_K_PASCAL;
|
2746
|
+
mmq_y = MMQ_Y_Q5_K_PASCAL;
|
2747
|
+
nwarps = NWARPS_Q5_K_PASCAL;
|
2748
|
+
} else {
|
2749
|
+
GGML_ABORT("fatal error");
|
2750
|
+
}
|
2751
|
+
|
2752
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2753
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2754
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2755
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2756
|
+
|
2757
|
+
if (nrows_x % mmq_y == 0) {
|
2758
|
+
const bool need_check = false;
|
2759
|
+
/*
|
2760
|
+
DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
|
2761
|
+
the limit. To get the device limit, query
|
2762
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2763
|
+
*/
|
2764
|
+
{
|
2765
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2766
|
+
{sycl::aspect::fp16});
|
2767
|
+
|
2768
|
+
stream->submit([&](sycl::handler &cgh) {
|
2769
|
+
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
2770
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2771
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
2772
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
|
2773
|
+
cgh);
|
2774
|
+
sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
|
2775
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
|
2776
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2777
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2778
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2779
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2780
|
+
|
2781
|
+
cgh.parallel_for(
|
2782
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2783
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2784
|
+
mul_mat_q5_K<need_check>(
|
2785
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2786
|
+
nrows_dst, item_ct1,
|
2787
|
+
get_pointer(tile_x_ql_q5_K_acc_ct1),
|
2788
|
+
get_pointer(tile_x_dm_q5_K_acc_ct1),
|
2789
|
+
get_pointer(tile_x_sc_q5_K_acc_ct1),
|
2790
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2791
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2792
|
+
});
|
2793
|
+
});
|
2794
|
+
}
|
2795
|
+
} else {
|
2796
|
+
const bool need_check = true;
|
2797
|
+
/*
|
2798
|
+
DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
|
2799
|
+
the limit. To get the device limit, query
|
2800
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2801
|
+
*/
|
2802
|
+
{
|
2803
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2804
|
+
{sycl::aspect::fp16});
|
2805
|
+
|
2806
|
+
stream->submit([&](sycl::handler &cgh) {
|
2807
|
+
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
2808
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2809
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
2810
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
|
2811
|
+
cgh);
|
2812
|
+
sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
|
2813
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
|
2814
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2815
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2816
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2817
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2818
|
+
|
2819
|
+
cgh.parallel_for(
|
2820
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2821
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2822
|
+
mul_mat_q5_K<need_check>(
|
2823
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2824
|
+
nrows_dst, item_ct1,
|
2825
|
+
get_pointer(tile_x_ql_q5_K_acc_ct1),
|
2826
|
+
get_pointer(tile_x_dm_q5_K_acc_ct1),
|
2827
|
+
get_pointer(tile_x_sc_q5_K_acc_ct1),
|
2828
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2829
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2830
|
+
});
|
2831
|
+
});
|
2832
|
+
}
|
2833
|
+
}
|
2834
|
+
}
|
2835
|
+
catch (sycl::exception const &exc) {
|
2836
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2837
|
+
<< ", line:" << __LINE__ << std::endl;
|
2838
|
+
std::exit(1);
|
2839
|
+
}
|
2840
|
+
|
2841
|
+
static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
2842
|
+
float *dst, const int ncols_x,
|
2843
|
+
const int nrows_x, const int ncols_y,
|
2844
|
+
const int nrows_y, const int nrows_dst,
|
2845
|
+
dpct::queue_ptr stream) try {
|
2846
|
+
|
2847
|
+
int id;
|
2848
|
+
SYCL_CHECK(
|
2849
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
2850
|
+
const int compute_capability = ggml_sycl_info().devices[id].cc;
|
2851
|
+
|
2852
|
+
int mmq_x, mmq_y, nwarps;
|
2853
|
+
if (compute_capability >= VER_GEN13) {
|
2854
|
+
mmq_x = MMQ_X_Q6_K_RDNA2;
|
2855
|
+
mmq_y = MMQ_Y_Q6_K_RDNA2;
|
2856
|
+
nwarps = NWARPS_Q6_K_RDNA2;
|
2857
|
+
} else if (compute_capability >= VER_GEN12) {
|
2858
|
+
mmq_x = MMQ_X_Q6_K_RDNA1;
|
2859
|
+
mmq_y = MMQ_Y_Q6_K_RDNA1;
|
2860
|
+
nwarps = NWARPS_Q6_K_RDNA1;
|
2861
|
+
} else if (compute_capability >= VER_GEN9) {
|
2862
|
+
mmq_x = MMQ_X_Q6_K_AMPERE;
|
2863
|
+
mmq_y = MMQ_Y_Q6_K_AMPERE;
|
2864
|
+
nwarps = NWARPS_Q6_K_AMPERE;
|
2865
|
+
} else if (compute_capability >= VER_4VEC) {
|
2866
|
+
mmq_x = MMQ_X_Q6_K_PASCAL;
|
2867
|
+
mmq_y = MMQ_Y_Q6_K_PASCAL;
|
2868
|
+
nwarps = NWARPS_Q6_K_PASCAL;
|
2869
|
+
} else {
|
2870
|
+
GGML_ABORT("fatal error");
|
2871
|
+
}
|
2872
|
+
|
2873
|
+
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
2874
|
+
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
2875
|
+
const sycl::range<3> block_nums(1, block_num_y, block_num_x);
|
2876
|
+
const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
|
2877
|
+
|
2878
|
+
if (nrows_x % mmq_y == 0) {
|
2879
|
+
const bool need_check = false;
|
2880
|
+
/*
|
2881
|
+
DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
|
2882
|
+
the limit. To get the device limit, query
|
2883
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2884
|
+
*/
|
2885
|
+
{
|
2886
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2887
|
+
{sycl::aspect::fp16});
|
2888
|
+
|
2889
|
+
stream->submit([&](sycl::handler &cgh) {
|
2890
|
+
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
2891
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2892
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
2893
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
|
2894
|
+
cgh);
|
2895
|
+
sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
|
2896
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
|
2897
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2898
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2899
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2900
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2901
|
+
|
2902
|
+
cgh.parallel_for(
|
2903
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2904
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2905
|
+
mul_mat_q6_K<need_check>(
|
2906
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2907
|
+
nrows_dst, item_ct1,
|
2908
|
+
get_pointer(tile_x_ql_acc_ct1),
|
2909
|
+
get_pointer(tile_x_dm_acc_ct1),
|
2910
|
+
get_pointer(tile_x_sc_acc_ct1),
|
2911
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2912
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2913
|
+
});
|
2914
|
+
});
|
2915
|
+
}
|
2916
|
+
} else {
|
2917
|
+
const bool need_check = true;
|
2918
|
+
/*
|
2919
|
+
DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
|
2920
|
+
the limit. To get the device limit, query
|
2921
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
2922
|
+
*/
|
2923
|
+
{
|
2924
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
2925
|
+
{sycl::aspect::fp16});
|
2926
|
+
|
2927
|
+
stream->submit([&](sycl::handler &cgh) {
|
2928
|
+
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
2929
|
+
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
2930
|
+
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
2931
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
|
2932
|
+
cgh);
|
2933
|
+
sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
|
2934
|
+
sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
|
2935
|
+
sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
|
2936
|
+
sycl::range<1>(mmq_x * WARP_SIZE), cgh);
|
2937
|
+
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
2938
|
+
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
2939
|
+
|
2940
|
+
cgh.parallel_for(
|
2941
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
2942
|
+
[=](sycl::nd_item<3> item_ct1) {
|
2943
|
+
mul_mat_q6_K<need_check>(
|
2944
|
+
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
2945
|
+
nrows_dst, item_ct1,
|
2946
|
+
get_pointer(tile_x_ql_acc_ct1),
|
2947
|
+
get_pointer(tile_x_dm_acc_ct1),
|
2948
|
+
get_pointer(tile_x_sc_acc_ct1),
|
2949
|
+
get_pointer(tile_y_qs_acc_ct1),
|
2950
|
+
get_pointer(tile_y_ds_acc_ct1));
|
2951
|
+
});
|
2952
|
+
});
|
2953
|
+
}
|
2954
|
+
}
|
2955
|
+
}
|
2956
|
+
catch (sycl::exception const &exc) {
|
2957
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
2958
|
+
<< ", line:" << __LINE__ << std::endl;
|
2959
|
+
std::exit(1);
|
2960
|
+
}
|
2961
|
+
|
2962
|
+
void ggml_sycl_op_mul_mat_q(
|
2963
|
+
ggml_backend_sycl_context & ctx,
|
2964
|
+
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
2965
|
+
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
2966
|
+
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
2967
|
+
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
2968
|
+
const dpct::queue_ptr &stream) try {
|
2969
|
+
|
2970
|
+
const int64_t ne00 = src0->ne[0];
|
2971
|
+
|
2972
|
+
const int64_t ne10 = src1->ne[0];
|
2973
|
+
GGML_ASSERT(ne10 % QK8_1 == 0);
|
2974
|
+
|
2975
|
+
const int64_t ne0 = dst->ne[0];
|
2976
|
+
|
2977
|
+
const int64_t row_diff = row_high - row_low;
|
2978
|
+
|
2979
|
+
int device_id;
|
2980
|
+
SYCL_CHECK(
|
2981
|
+
CHECK_TRY_ERROR(device_id = get_current_device_id()));
|
2982
|
+
|
2983
|
+
// the main device has a larger memory buffer to hold the results from all GPUs
|
2984
|
+
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
|
2985
|
+
const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
|
2986
|
+
|
2987
|
+
switch (src0->type) {
|
2988
|
+
case GGML_TYPE_Q4_0:
|
2989
|
+
ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
2990
|
+
break;
|
2991
|
+
case GGML_TYPE_Q4_1:
|
2992
|
+
ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
2993
|
+
break;
|
2994
|
+
case GGML_TYPE_Q5_0:
|
2995
|
+
ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
2996
|
+
break;
|
2997
|
+
case GGML_TYPE_Q5_1:
|
2998
|
+
ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
2999
|
+
break;
|
3000
|
+
case GGML_TYPE_Q8_0:
|
3001
|
+
ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
3002
|
+
break;
|
3003
|
+
case GGML_TYPE_Q2_K:
|
3004
|
+
ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
3005
|
+
break;
|
3006
|
+
case GGML_TYPE_Q3_K:
|
3007
|
+
ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
3008
|
+
break;
|
3009
|
+
case GGML_TYPE_Q4_K:
|
3010
|
+
ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
3011
|
+
break;
|
3012
|
+
case GGML_TYPE_Q5_K:
|
3013
|
+
ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
3014
|
+
break;
|
3015
|
+
case GGML_TYPE_Q6_K:
|
3016
|
+
ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
|
3017
|
+
break;
|
3018
|
+
default:
|
3019
|
+
GGML_ABORT("fatal error");
|
3020
|
+
break;
|
3021
|
+
}
|
3022
|
+
|
3023
|
+
GGML_UNUSED(src1);
|
3024
|
+
GGML_UNUSED(dst);
|
3025
|
+
GGML_UNUSED(src1_ddf_i);
|
3026
|
+
}
|
3027
|
+
catch (sycl::exception const &exc) {
|
3028
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
3029
|
+
<< ", line:" << __LINE__ << std::endl;
|
3030
|
+
std::exit(1);
|
3031
|
+
}
|