@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#ifndef GGML_SYCL_CPY_HPP
|
|
2
|
+
#define GGML_SYCL_CPY_HPP
|
|
3
|
+
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
|
|
6
|
+
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
|
7
|
+
|
|
8
|
+
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);
|
|
9
|
+
void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
10
|
+
|
|
11
|
+
#endif // GGML_SYCL_CPY_HPP
|
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
18
|
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
|
19
|
+
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
|
|
20
|
+
const int iqs, dfloat2 &v);
|
|
19
21
|
|
|
20
22
|
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
|
21
23
|
const int iqs, dfloat2 &v) {
|
|
@@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
|
|
40
42
|
#endif // GGML_SYCL_F16
|
|
41
43
|
}
|
|
42
44
|
|
|
45
|
+
static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
|
|
46
|
+
const int iqs, dfloat2 &v) {
|
|
47
|
+
// const block_q4_0 * x = (const block_q4_0 *) vx;
|
|
48
|
+
|
|
49
|
+
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
|
|
50
|
+
|
|
51
|
+
const int vui = *((const uint8_t *)qs+iqs);
|
|
52
|
+
|
|
53
|
+
v.x() = vui & 0xF;
|
|
54
|
+
v.y() = vui >> 4;
|
|
55
|
+
|
|
56
|
+
#ifdef GGML_SYCL_F16
|
|
57
|
+
// v = v - {8.0f, 8.0f};
|
|
58
|
+
// v = v * {d, d};
|
|
59
|
+
v.s0() = (v.s0() - 8.0f) * d;
|
|
60
|
+
v.s1() = (v.s1() - 8.0f) * d;
|
|
61
|
+
|
|
62
|
+
#else
|
|
63
|
+
v.x() = (v.x() - 8.0f) * d;
|
|
64
|
+
v.y() = (v.y() - 8.0f) * d;
|
|
65
|
+
#endif // GGML_SYCL_F16
|
|
66
|
+
}
|
|
67
|
+
|
|
43
68
|
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
|
|
44
69
|
const int iqs, dfloat2 &v) {
|
|
45
70
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
|
@@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
|
|
|
167
192
|
}
|
|
168
193
|
}
|
|
169
194
|
|
|
195
|
+
template<typename dst_t>
|
|
196
|
+
static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
|
197
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
198
|
+
|
|
199
|
+
const int64_t i = item_ct1.get_group(2);
|
|
200
|
+
auto k=nb32;
|
|
201
|
+
// assume 32 threads
|
|
202
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
203
|
+
const int lane_ib = i * WARP_SIZE + tid;
|
|
204
|
+
|
|
205
|
+
if (lane_ib >= k / QK4_0) {
|
|
206
|
+
return;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
dst_t * y_ptr = yy + lane_ib * QK4_0;
|
|
210
|
+
|
|
211
|
+
auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
|
|
212
|
+
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
|
|
213
|
+
|
|
214
|
+
const float d = float(*s_ptr);
|
|
215
|
+
|
|
216
|
+
#pragma unroll
|
|
217
|
+
for (int l = 0; l < QK4_0 / 2; ++l) {
|
|
218
|
+
int vq = qs[l];
|
|
219
|
+
y_ptr[l + 0] = d * ((vq & 0xF) - 8);
|
|
220
|
+
y_ptr[l + 16] = d * ((vq >> 4) - 8);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
}
|
|
224
|
+
|
|
170
225
|
template<typename dst_t>
|
|
171
226
|
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
|
172
227
|
const sycl::nd_item<3> &item_ct1) {
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
#include "dequantize.hpp"
|
|
4
4
|
#include "presets.hpp"
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
|
8
7
|
const sycl::half *x = (const sycl::half *)vx;
|
|
9
8
|
|
|
@@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
|
|
91
90
|
}
|
|
92
91
|
}
|
|
93
92
|
|
|
93
|
+
template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>
|
|
94
|
+
static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
|
|
95
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
96
|
+
// qk = quantized weights per x block
|
|
97
|
+
// qr = number of quantized weights per data value in x block
|
|
98
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
|
99
|
+
item_ct1.get_local_id(1);
|
|
100
|
+
|
|
101
|
+
if (row >= nrows) {
|
|
102
|
+
return;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
const int tid = item_ct1.get_local_id(2);
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
const int ncols_left = ncols % (QK4_0*WARP_SIZE);
|
|
109
|
+
const int ncols_align = ncols - ncols_left;
|
|
110
|
+
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
|
|
111
|
+
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
|
|
112
|
+
const int y_offset = qr == 1 ? 1 : qk/2;
|
|
113
|
+
|
|
114
|
+
// partial sum for each thread
|
|
115
|
+
#ifdef GGML_SYCL_F16
|
|
116
|
+
sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
|
|
117
|
+
#else
|
|
118
|
+
float tmp = 0.0f;
|
|
119
|
+
#endif // GGML_SYCL_F16
|
|
120
|
+
const char *d_ptr = (const char*)vx+ncols*nrows/2;
|
|
121
|
+
int i=0;
|
|
122
|
+
for (i = 0; i < ncols_align; i += iter_stride) {
|
|
123
|
+
const int col = i + vals_per_iter*tid;
|
|
124
|
+
const int ib = (row*ncols + col)/qk; // x block index
|
|
125
|
+
const int iqs = (col%qk)/qr; // x quant index
|
|
126
|
+
const int iybs = col - col%qk; // y block start index
|
|
127
|
+
|
|
128
|
+
// processing >2 values per i iter is faster for fast GPUs
|
|
129
|
+
#pragma unroll
|
|
130
|
+
for (int j = 0; j < vals_per_iter; j += 2) {
|
|
131
|
+
// process 2 vals per j iter
|
|
132
|
+
|
|
133
|
+
// dequantize
|
|
134
|
+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
|
135
|
+
dfloat2 v;
|
|
136
|
+
dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
|
|
137
|
+
|
|
138
|
+
// matrix multiplication
|
|
139
|
+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
|
140
|
+
#ifdef GGML_SYCL_F16
|
|
141
|
+
dfloat2 t1{y[iybs + iqs + j / qr + 0],
|
|
142
|
+
y[iybs + iqs + j / qr + y_offset]};
|
|
143
|
+
|
|
144
|
+
tmp += v * t1;
|
|
145
|
+
#else
|
|
146
|
+
tmp += v.x() * y[iybs + iqs + j / qr + 0];
|
|
147
|
+
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
|
|
148
|
+
#endif // GGML_SYCL_F16
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
for (; i < ncols; i += iter_stride) {
|
|
153
|
+
if (tid>=ncols_left/QK4_0) continue;
|
|
154
|
+
const int col = i + vals_per_iter*tid;
|
|
155
|
+
const int ib = (row*ncols + col)/qk; // x block index
|
|
156
|
+
const int iqs = (col%qk)/qr; // x quant index
|
|
157
|
+
const int iybs = col - col%qk; // y block start index
|
|
158
|
+
|
|
159
|
+
// processing >2 values per i iter is faster for fast GPUs
|
|
160
|
+
#pragma unroll
|
|
161
|
+
for (int j = 0; j < vals_per_iter; j += 2) {
|
|
162
|
+
// process 2 vals per j iter
|
|
163
|
+
|
|
164
|
+
// dequantize
|
|
165
|
+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
|
166
|
+
dfloat2 v;
|
|
167
|
+
dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
|
|
168
|
+
|
|
169
|
+
// matrix multiplication
|
|
170
|
+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
|
171
|
+
#ifdef GGML_SYCL_F16
|
|
172
|
+
dfloat2 t1{y[iybs + iqs + j / qr + 0],
|
|
173
|
+
y[iybs + iqs + j / qr + y_offset]};
|
|
174
|
+
|
|
175
|
+
tmp += v * t1;
|
|
176
|
+
#else
|
|
177
|
+
tmp += v.x() * y[iybs + iqs + j / qr + 0];
|
|
178
|
+
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
|
|
179
|
+
#endif // GGML_SYCL_F16
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// sum up partial sums and write back result
|
|
184
|
+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
|
185
|
+
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
|
186
|
+
tmp +=
|
|
187
|
+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
if (tid == 0) {
|
|
191
|
+
#ifdef GGML_SYCL_F16
|
|
192
|
+
dst[row] = tmp.x() + tmp.y();
|
|
193
|
+
#else
|
|
194
|
+
dst[row] = tmp;
|
|
195
|
+
#endif // GGML_SYCL_F16
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
94
199
|
static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
|
95
200
|
float *dst, const int ncols,
|
|
96
201
|
const int nrows,
|
|
@@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
|
|
759
864
|
}
|
|
760
865
|
}
|
|
761
866
|
|
|
867
|
+
static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
|
|
868
|
+
float *dst, const int ncols,
|
|
869
|
+
const int nrows,
|
|
870
|
+
dpct::queue_ptr stream) {
|
|
871
|
+
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
|
|
872
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
873
|
+
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
|
874
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
875
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
876
|
+
{
|
|
877
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
|
878
|
+
{sycl::aspect::fp16});
|
|
879
|
+
|
|
880
|
+
stream->parallel_for(
|
|
881
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
882
|
+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
|
883
|
+
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
|
884
|
+
vx, y, dst, ncols, nrows, item_ct1);
|
|
885
|
+
});
|
|
886
|
+
}
|
|
887
|
+
}
|
|
888
|
+
|
|
762
889
|
|
|
763
890
|
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|
764
891
|
float *dst, const int ncols,
|
|
@@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
953
1080
|
|
|
954
1081
|
const int64_t ne00 = src0->ne[0];
|
|
955
1082
|
const int64_t row_diff = row_high - row_low;
|
|
956
|
-
|
|
957
1083
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
958
1084
|
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
|
959
1085
|
#ifdef GGML_SYCL_F16
|
|
@@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
967
1093
|
|
|
968
1094
|
if (src1_convert_f16) {
|
|
969
1095
|
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
|
970
|
-
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
|
|
1096
|
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
|
971
1097
|
GGML_ASSERT(to_fp16_sycl != nullptr);
|
|
972
1098
|
to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
|
|
973
1099
|
}
|
|
@@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
977
1103
|
|
|
978
1104
|
switch (src0->type) {
|
|
979
1105
|
case GGML_TYPE_Q4_0:
|
|
980
|
-
|
|
1106
|
+
if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
|
|
1107
|
+
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1108
|
+
dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
|
1109
|
+
} else {
|
|
1110
|
+
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
|
1111
|
+
}
|
|
981
1112
|
break;
|
|
982
1113
|
case GGML_TYPE_Q4_1:
|
|
983
1114
|
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
|
@@ -1020,4 +1151,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
1020
1151
|
GGML_UNUSED(src1_ddq_i);
|
|
1021
1152
|
GGML_UNUSED(src1_ncols);
|
|
1022
1153
|
GGML_UNUSED(src1_padded_row_size);
|
|
1154
|
+
GGML_UNUSED(ctx);
|
|
1023
1155
|
}
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#include "ggml-impl.h"
|
|
14
|
+
#include "common.hpp"
|
|
15
|
+
#include "dequantize.hpp"
|
|
16
|
+
#include "getrows.hpp"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
|
20
|
+
static void k_get_rows(
|
|
21
|
+
const void * src0, const int32_t * src1, dst_t * dst,
|
|
22
|
+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
|
23
|
+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
|
24
|
+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
|
25
|
+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
|
26
|
+
size_t s10, size_t s11, size_t s12,
|
|
27
|
+
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
|
28
|
+
|
|
29
|
+
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
30
|
+
item_ct1.get_local_id(2)) *
|
|
31
|
+
2;
|
|
32
|
+
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
33
|
+
item_ct1.get_local_id(1);
|
|
34
|
+
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
|
35
|
+
item_ct1.get_local_id(0)) /
|
|
36
|
+
ne12;
|
|
37
|
+
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
|
38
|
+
item_ct1.get_local_id(0)) %
|
|
39
|
+
ne12;
|
|
40
|
+
|
|
41
|
+
if (i00 >= ne00) {
|
|
42
|
+
return;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
|
46
|
+
|
|
47
|
+
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
|
48
|
+
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
|
49
|
+
|
|
50
|
+
const int ib = i00/qk; // block index
|
|
51
|
+
const int iqs = (i00%qk)/qr; // quant index
|
|
52
|
+
const int iybs = i00 - i00%qk; // dst block start index
|
|
53
|
+
const int y_offset = qr == 1 ? 1 : qk/2;
|
|
54
|
+
|
|
55
|
+
// dequantize
|
|
56
|
+
dfloat2 v;
|
|
57
|
+
dequantize_kernel(src0_row, ib, iqs, v);
|
|
58
|
+
|
|
59
|
+
dst_row[iybs + iqs + 0] = v.x();
|
|
60
|
+
dst_row[iybs + iqs + y_offset] = v.y();
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
template<int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder, typename dst_t>
|
|
64
|
+
static void k_get_rows_reorder(
|
|
65
|
+
const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst,
|
|
66
|
+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
|
67
|
+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
|
68
|
+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
|
69
|
+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
|
70
|
+
size_t s10, size_t s11, size_t s12,
|
|
71
|
+
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
|
72
|
+
|
|
73
|
+
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
74
|
+
item_ct1.get_local_id(2)) *
|
|
75
|
+
2;
|
|
76
|
+
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
77
|
+
item_ct1.get_local_id(1);
|
|
78
|
+
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
|
79
|
+
item_ct1.get_local_id(0)) /
|
|
80
|
+
ne12;
|
|
81
|
+
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
|
82
|
+
item_ct1.get_local_id(0)) %
|
|
83
|
+
ne12;
|
|
84
|
+
|
|
85
|
+
if (i00 >= ne00) {
|
|
86
|
+
return;
|
|
87
|
+
}
|
|
88
|
+
auto ncols = ne00;
|
|
89
|
+
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
|
90
|
+
|
|
91
|
+
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
|
92
|
+
|
|
93
|
+
const int src0_off = i01 * ncols + i00;
|
|
94
|
+
const int ib = src0_off / QK4_0; // block index
|
|
95
|
+
const int iqs = (i00%qk)/qr; // x quant index
|
|
96
|
+
const int iybs = i00 - i00%qk; // dst block start index
|
|
97
|
+
const int y_offset = qr == 1 ? 1 : qk/2;
|
|
98
|
+
|
|
99
|
+
// dequantize
|
|
100
|
+
dfloat2 v;
|
|
101
|
+
dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v);
|
|
102
|
+
|
|
103
|
+
dst_row[iybs + iqs + 0] = v.x();
|
|
104
|
+
dst_row[iybs + iqs + y_offset] = v.y();
|
|
105
|
+
|
|
106
|
+
GGML_UNUSED(nb01);
|
|
107
|
+
GGML_UNUSED(nb02);
|
|
108
|
+
GGML_UNUSED(nb03);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
template<typename src0_t, typename dst_t>
|
|
112
|
+
static void k_get_rows_float(
|
|
113
|
+
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
|
114
|
+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
|
115
|
+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
|
116
|
+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
|
117
|
+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
|
118
|
+
size_t s10, size_t s11, size_t s12,
|
|
119
|
+
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
|
120
|
+
|
|
121
|
+
const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
122
|
+
item_ct1.get_local_id(2);
|
|
123
|
+
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
124
|
+
item_ct1.get_local_id(1);
|
|
125
|
+
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
|
126
|
+
item_ct1.get_local_id(0)) /
|
|
127
|
+
ne12;
|
|
128
|
+
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
|
129
|
+
item_ct1.get_local_id(0)) %
|
|
130
|
+
ne12;
|
|
131
|
+
|
|
132
|
+
if (i00 >= ne00) {
|
|
133
|
+
return;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
|
137
|
+
|
|
138
|
+
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
|
139
|
+
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
|
140
|
+
|
|
141
|
+
dst_row[i00] = src0_row[i00];
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
template <int qk, int qr, dequantize_kernel_t dq>
|
|
145
|
+
static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
|
146
|
+
ggml_tensor *dst, const void *src0_dd,
|
|
147
|
+
const int32_t *src1_dd, float *dst_dd,
|
|
148
|
+
queue_ptr stream) {
|
|
149
|
+
|
|
150
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
151
|
+
|
|
152
|
+
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
|
153
|
+
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
|
|
154
|
+
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
|
155
|
+
|
|
156
|
+
// strides in elements
|
|
157
|
+
//const size_t s0 = nb0 / ggml_element_size(dst);
|
|
158
|
+
const size_t s1 = nb1 / ggml_element_size(dst);
|
|
159
|
+
const size_t s2 = nb2 / ggml_element_size(dst);
|
|
160
|
+
const size_t s3 = nb3 / ggml_element_size(dst);
|
|
161
|
+
|
|
162
|
+
const size_t s10 = nb10 / ggml_element_size(src1);
|
|
163
|
+
const size_t s11 = nb11 / ggml_element_size(src1);
|
|
164
|
+
const size_t s12 = nb12 / ggml_element_size(src1);
|
|
165
|
+
//const size_t s13 = nb13 / ggml_element_size(src1);
|
|
166
|
+
|
|
167
|
+
GGML_ASSERT(ne00 % 2 == 0);
|
|
168
|
+
|
|
169
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
170
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
171
|
+
k_get_rows<qk, qr, dq>(
|
|
172
|
+
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
|
173
|
+
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
|
174
|
+
});
|
|
175
|
+
|
|
176
|
+
GGML_UNUSED(dst);
|
|
177
|
+
GGML_UNUSED(ctx);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
template <int qk, int qr, dequantize_kernel_t_reorder dq_reorder>
|
|
181
|
+
static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
|
182
|
+
ggml_tensor *dst, const void *src0_dd,
|
|
183
|
+
const int32_t *src1_dd, float *dst_dd,
|
|
184
|
+
queue_ptr stream) {
|
|
185
|
+
|
|
186
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
187
|
+
|
|
188
|
+
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
|
189
|
+
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
|
|
190
|
+
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
|
191
|
+
|
|
192
|
+
// strides in elements
|
|
193
|
+
//const size_t s0 = nb0 / ggml_element_size(dst);
|
|
194
|
+
const size_t s1 = nb1 / ggml_element_size(dst);
|
|
195
|
+
const size_t s2 = nb2 / ggml_element_size(dst);
|
|
196
|
+
const size_t s3 = nb3 / ggml_element_size(dst);
|
|
197
|
+
|
|
198
|
+
const size_t s10 = nb10 / ggml_element_size(src1);
|
|
199
|
+
const size_t s11 = nb11 / ggml_element_size(src1);
|
|
200
|
+
const size_t s12 = nb12 / ggml_element_size(src1);
|
|
201
|
+
//const size_t s13 = nb13 / ggml_element_size(src1);
|
|
202
|
+
|
|
203
|
+
GGML_ASSERT(ne00 % 2 == 0);
|
|
204
|
+
|
|
205
|
+
const uint8_t* src0_q = (const uint8_t*)src0_dd;
|
|
206
|
+
const size_t ncols = ne00;
|
|
207
|
+
const size_t nrows = ne01;
|
|
208
|
+
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
|
209
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
210
|
+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
|
211
|
+
k_get_rows_reorder<qk, qr, dq_reorder>(
|
|
212
|
+
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
|
213
|
+
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
|
214
|
+
});
|
|
215
|
+
|
|
216
|
+
GGML_UNUSED(dst);
|
|
217
|
+
GGML_UNUSED(ctx);
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
template <typename src0_t>
|
|
222
|
+
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
223
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
|
224
|
+
const src0_t *src0_dd, const int32_t *src1_dd,
|
|
225
|
+
float *dst_dd, queue_ptr stream) {
|
|
226
|
+
|
|
227
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
228
|
+
|
|
229
|
+
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
|
230
|
+
const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
|
|
231
|
+
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
|
232
|
+
|
|
233
|
+
// strides in elements
|
|
234
|
+
//const size_t s0 = nb0 / ggml_element_size(dst);
|
|
235
|
+
const size_t s1 = nb1 / ggml_element_size(dst);
|
|
236
|
+
const size_t s2 = nb2 / ggml_element_size(dst);
|
|
237
|
+
const size_t s3 = nb3 / ggml_element_size(dst);
|
|
238
|
+
|
|
239
|
+
const size_t s10 = nb10 / ggml_element_size(src1);
|
|
240
|
+
const size_t s11 = nb11 / ggml_element_size(src1);
|
|
241
|
+
const size_t s12 = nb12 / ggml_element_size(src1);
|
|
242
|
+
//const size_t s13 = nb13 / ggml_element_size(src1);
|
|
243
|
+
|
|
244
|
+
{
|
|
245
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
|
246
|
+
{sycl::aspect::fp16});
|
|
247
|
+
|
|
248
|
+
stream->parallel_for(
|
|
249
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
250
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
251
|
+
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
|
252
|
+
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
|
253
|
+
});
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
GGML_UNUSED(dst);
|
|
257
|
+
GGML_UNUSED(ctx);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
261
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
|
262
|
+
const float *src0_d, const float *src1_d,
|
|
263
|
+
float *dst_d, const queue_ptr &stream) {
|
|
264
|
+
|
|
265
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
266
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
267
|
+
|
|
268
|
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
269
|
+
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
|
270
|
+
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
|
271
|
+
|
|
272
|
+
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
|
273
|
+
|
|
274
|
+
switch (src0->type) {
|
|
275
|
+
case GGML_TYPE_F16:
|
|
276
|
+
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
|
|
277
|
+
src1_i32, dst_d, stream);
|
|
278
|
+
break;
|
|
279
|
+
case GGML_TYPE_F32:
|
|
280
|
+
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
281
|
+
break;
|
|
282
|
+
case GGML_TYPE_Q4_0:
|
|
283
|
+
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
|
|
284
|
+
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
285
|
+
} else {
|
|
286
|
+
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
287
|
+
}
|
|
288
|
+
break;
|
|
289
|
+
case GGML_TYPE_Q4_1:
|
|
290
|
+
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
291
|
+
break;
|
|
292
|
+
case GGML_TYPE_Q5_0:
|
|
293
|
+
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
294
|
+
break;
|
|
295
|
+
case GGML_TYPE_Q5_1:
|
|
296
|
+
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
297
|
+
break;
|
|
298
|
+
case GGML_TYPE_Q8_0:
|
|
299
|
+
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
|
300
|
+
break;
|
|
301
|
+
default:
|
|
302
|
+
// TODO: k-quants
|
|
303
|
+
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
|
304
|
+
GGML_ABORT("fatal error");
|
|
305
|
+
break;
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#ifndef GGML_SYCL_GETROWS_HPP
|
|
14
|
+
#define GGML_SYCL_GETROWS_HPP
|
|
15
|
+
|
|
16
|
+
#include "common.hpp"
|
|
17
|
+
|
|
18
|
+
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
19
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
|
20
|
+
const float *src0_d, const float *src1_d,
|
|
21
|
+
float *dst_d, const queue_ptr &stream);
|
|
22
|
+
|
|
23
|
+
#endif // GGML_SYCL_GETROWS_HPP
|