@fugood/llama.node 0.3.13 → 0.3.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +1 -1
- package/package.json +1 -1
- package/src/LlamaContext.cpp +98 -76
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +89 -10
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +3 -3
- package/src/llama.cpp/common/arg.cpp +132 -13
- package/src/llama.cpp/common/chat.cpp +960 -266
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +33 -174
- package/src/llama.cpp/common/common.h +27 -67
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +45 -7
- package/src/llama.cpp/common/speculative.cpp +10 -9
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +45 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +79 -34
- package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +196 -108
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +113 -101
- package/src/llama.cpp/examples/server/utils.hpp +94 -105
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +263 -151
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
- package/src/llama.cpp/ggml/include/ggml.h +29 -1
- package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
- package/src/llama.cpp/ggml/src/ggml.c +93 -5
- package/src/llama.cpp/include/llama.h +105 -27
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +123 -16
- package/src/llama.cpp/src/llama-arch.h +19 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +182 -182
- package/src/llama.cpp/src/llama-grammar.h +12 -3
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -109
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-model.cpp +8230 -122
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama-sampling.cpp +43 -10
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +51 -9837
- package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +593 -395
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -55
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
- /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#ifndef GGML_SYCL_CPY_HPP
|
|
2
|
+
#define GGML_SYCL_CPY_HPP
|
|
3
|
+
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
|
|
6
|
+
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
|
7
|
+
|
|
8
|
+
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);
|
|
9
|
+
void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
10
|
+
|
|
11
|
+
#endif // GGML_SYCL_CPY_HPP
|
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
18
|
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
|
19
|
+
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
|
|
20
|
+
const int iqs, dfloat2 &v);
|
|
19
21
|
|
|
20
22
|
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
|
21
23
|
const int iqs, dfloat2 &v) {
|
|
@@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
|
|
40
42
|
#endif // GGML_SYCL_F16
|
|
41
43
|
}
|
|
42
44
|
|
|
45
|
+
static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
|
|
46
|
+
const int iqs, dfloat2 &v) {
|
|
47
|
+
// const block_q4_0 * x = (const block_q4_0 *) vx;
|
|
48
|
+
|
|
49
|
+
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
|
|
50
|
+
|
|
51
|
+
const int vui = *((const uint8_t *)qs+iqs);
|
|
52
|
+
|
|
53
|
+
v.x() = vui & 0xF;
|
|
54
|
+
v.y() = vui >> 4;
|
|
55
|
+
|
|
56
|
+
#ifdef GGML_SYCL_F16
|
|
57
|
+
// v = v - {8.0f, 8.0f};
|
|
58
|
+
// v = v * {d, d};
|
|
59
|
+
v.s0() = (v.s0() - 8.0f) * d;
|
|
60
|
+
v.s1() = (v.s1() - 8.0f) * d;
|
|
61
|
+
|
|
62
|
+
#else
|
|
63
|
+
v.x() = (v.x() - 8.0f) * d;
|
|
64
|
+
v.y() = (v.y() - 8.0f) * d;
|
|
65
|
+
#endif // GGML_SYCL_F16
|
|
66
|
+
}
|
|
67
|
+
|
|
43
68
|
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
|
|
44
69
|
const int iqs, dfloat2 &v) {
|
|
45
70
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
|
@@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
|
|
|
167
192
|
}
|
|
168
193
|
}
|
|
169
194
|
|
|
195
|
+
template<typename dst_t>
|
|
196
|
+
static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
|
197
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
198
|
+
|
|
199
|
+
const int64_t i = item_ct1.get_group(2);
|
|
200
|
+
auto k=nb32;
|
|
201
|
+
// assume 32 threads
|
|
202
|
+
const int64_t tid = item_ct1.get_local_id(2);
|
|
203
|
+
const int lane_ib = i * WARP_SIZE + tid;
|
|
204
|
+
|
|
205
|
+
if (lane_ib >= k / QK4_0) {
|
|
206
|
+
return;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
dst_t * y_ptr = yy + lane_ib * QK4_0;
|
|
210
|
+
|
|
211
|
+
auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
|
|
212
|
+
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
|
|
213
|
+
|
|
214
|
+
const float d = float(*s_ptr);
|
|
215
|
+
|
|
216
|
+
#pragma unroll
|
|
217
|
+
for (int l = 0; l < QK4_0 / 2; ++l) {
|
|
218
|
+
int vq = qs[l];
|
|
219
|
+
y_ptr[l + 0] = d * ((vq & 0xF) - 8);
|
|
220
|
+
y_ptr[l + 16] = d * ((vq >> 4) - 8);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
}
|
|
224
|
+
|
|
170
225
|
template<typename dst_t>
|
|
171
226
|
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
|
172
227
|
const sycl::nd_item<3> &item_ct1) {
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
#include "dequantize.hpp"
|
|
4
4
|
#include "presets.hpp"
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
|
8
7
|
const sycl::half *x = (const sycl::half *)vx;
|
|
9
8
|
|
|
@@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
|
|
91
90
|
}
|
|
92
91
|
}
|
|
93
92
|
|
|
93
|
+
template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>
|
|
94
|
+
static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
|
|
95
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
96
|
+
// qk = quantized weights per x block
|
|
97
|
+
// qr = number of quantized weights per data value in x block
|
|
98
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
|
99
|
+
item_ct1.get_local_id(1);
|
|
100
|
+
|
|
101
|
+
if (row >= nrows) {
|
|
102
|
+
return;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
const int tid = item_ct1.get_local_id(2);
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
const int ncols_left = ncols % (QK4_0*WARP_SIZE);
|
|
109
|
+
const int ncols_align = ncols - ncols_left;
|
|
110
|
+
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
|
|
111
|
+
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
|
|
112
|
+
const int y_offset = qr == 1 ? 1 : qk/2;
|
|
113
|
+
|
|
114
|
+
// partial sum for each thread
|
|
115
|
+
#ifdef GGML_SYCL_F16
|
|
116
|
+
sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
|
|
117
|
+
#else
|
|
118
|
+
float tmp = 0.0f;
|
|
119
|
+
#endif // GGML_SYCL_F16
|
|
120
|
+
const char *d_ptr = (const char*)vx+ncols*nrows/2;
|
|
121
|
+
int i=0;
|
|
122
|
+
for (i = 0; i < ncols_align; i += iter_stride) {
|
|
123
|
+
const int col = i + vals_per_iter*tid;
|
|
124
|
+
const int ib = (row*ncols + col)/qk; // x block index
|
|
125
|
+
const int iqs = (col%qk)/qr; // x quant index
|
|
126
|
+
const int iybs = col - col%qk; // y block start index
|
|
127
|
+
|
|
128
|
+
// processing >2 values per i iter is faster for fast GPUs
|
|
129
|
+
#pragma unroll
|
|
130
|
+
for (int j = 0; j < vals_per_iter; j += 2) {
|
|
131
|
+
// process 2 vals per j iter
|
|
132
|
+
|
|
133
|
+
// dequantize
|
|
134
|
+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
|
135
|
+
dfloat2 v;
|
|
136
|
+
dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
|
|
137
|
+
|
|
138
|
+
// matrix multiplication
|
|
139
|
+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
|
140
|
+
#ifdef GGML_SYCL_F16
|
|
141
|
+
dfloat2 t1{y[iybs + iqs + j / qr + 0],
|
|
142
|
+
y[iybs + iqs + j / qr + y_offset]};
|
|
143
|
+
|
|
144
|
+
tmp += v * t1;
|
|
145
|
+
#else
|
|
146
|
+
tmp += v.x() * y[iybs + iqs + j / qr + 0];
|
|
147
|
+
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
|
|
148
|
+
#endif // GGML_SYCL_F16
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
for (; i < ncols; i += iter_stride) {
|
|
153
|
+
if (tid>=ncols_left/QK4_0) continue;
|
|
154
|
+
const int col = i + vals_per_iter*tid;
|
|
155
|
+
const int ib = (row*ncols + col)/qk; // x block index
|
|
156
|
+
const int iqs = (col%qk)/qr; // x quant index
|
|
157
|
+
const int iybs = col - col%qk; // y block start index
|
|
158
|
+
|
|
159
|
+
// processing >2 values per i iter is faster for fast GPUs
|
|
160
|
+
#pragma unroll
|
|
161
|
+
for (int j = 0; j < vals_per_iter; j += 2) {
|
|
162
|
+
// process 2 vals per j iter
|
|
163
|
+
|
|
164
|
+
// dequantize
|
|
165
|
+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
|
166
|
+
dfloat2 v;
|
|
167
|
+
dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
|
|
168
|
+
|
|
169
|
+
// matrix multiplication
|
|
170
|
+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
|
171
|
+
#ifdef GGML_SYCL_F16
|
|
172
|
+
dfloat2 t1{y[iybs + iqs + j / qr + 0],
|
|
173
|
+
y[iybs + iqs + j / qr + y_offset]};
|
|
174
|
+
|
|
175
|
+
tmp += v * t1;
|
|
176
|
+
#else
|
|
177
|
+
tmp += v.x() * y[iybs + iqs + j / qr + 0];
|
|
178
|
+
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
|
|
179
|
+
#endif // GGML_SYCL_F16
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// sum up partial sums and write back result
|
|
184
|
+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
|
185
|
+
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
|
186
|
+
tmp +=
|
|
187
|
+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
if (tid == 0) {
|
|
191
|
+
#ifdef GGML_SYCL_F16
|
|
192
|
+
dst[row] = tmp.x() + tmp.y();
|
|
193
|
+
#else
|
|
194
|
+
dst[row] = tmp;
|
|
195
|
+
#endif // GGML_SYCL_F16
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
94
199
|
static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
|
95
200
|
float *dst, const int ncols,
|
|
96
201
|
const int nrows,
|
|
@@ -105,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
|
|
105
210
|
|
|
106
211
|
stream->parallel_for(
|
|
107
212
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
108
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
213
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
109
214
|
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
|
110
215
|
nrows, item_ct1);
|
|
111
216
|
});
|
|
@@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
|
|
759
864
|
}
|
|
760
865
|
}
|
|
761
866
|
|
|
867
|
+
static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
|
|
868
|
+
float *dst, const int ncols,
|
|
869
|
+
const int nrows,
|
|
870
|
+
dpct::queue_ptr stream) {
|
|
871
|
+
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
|
|
872
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
873
|
+
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
|
874
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
875
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
876
|
+
{
|
|
877
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
|
878
|
+
{sycl::aspect::fp16});
|
|
879
|
+
|
|
880
|
+
stream->parallel_for(
|
|
881
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
882
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
883
|
+
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
|
884
|
+
vx, y, dst, ncols, nrows, item_ct1);
|
|
885
|
+
});
|
|
886
|
+
}
|
|
887
|
+
}
|
|
888
|
+
|
|
762
889
|
|
|
763
890
|
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|
764
891
|
float *dst, const int ncols,
|
|
@@ -775,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|
|
775
902
|
|
|
776
903
|
stream->parallel_for(
|
|
777
904
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
778
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
905
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
779
906
|
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
|
780
907
|
vx, y, dst, ncols, nrows, item_ct1);
|
|
781
908
|
});
|
|
@@ -796,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
|
|
796
923
|
|
|
797
924
|
stream->parallel_for(
|
|
798
925
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
799
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
926
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
800
927
|
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
|
801
928
|
vx, y, dst, ncols, nrows, item_ct1);
|
|
802
929
|
});
|
|
@@ -817,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
|
|
817
944
|
|
|
818
945
|
stream->parallel_for(
|
|
819
946
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
820
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
947
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
821
948
|
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
|
822
949
|
vx, y, dst, ncols, nrows, item_ct1);
|
|
823
950
|
});
|
|
@@ -838,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
|
|
838
965
|
|
|
839
966
|
stream->parallel_for(
|
|
840
967
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
841
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
968
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
842
969
|
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
|
843
970
|
vx, y, dst, ncols, nrows, item_ct1);
|
|
844
971
|
});
|
|
@@ -859,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
|
|
859
986
|
|
|
860
987
|
stream->parallel_for(
|
|
861
988
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
862
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
989
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
863
990
|
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
|
864
991
|
vx, y, dst, ncols, nrows, item_ct1);
|
|
865
992
|
});
|
|
@@ -877,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
|
|
877
1004
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
|
878
1005
|
stream->parallel_for(
|
|
879
1006
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
880
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
1007
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
|
881
1008
|
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
|
882
1009
|
});
|
|
883
1010
|
}
|
|
@@ -893,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
|
|
893
1020
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
|
894
1021
|
stream->parallel_for(
|
|
895
1022
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
896
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
1023
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
|
897
1024
|
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
|
898
1025
|
});
|
|
899
1026
|
}
|
|
@@ -909,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
|
|
909
1036
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
|
910
1037
|
stream->parallel_for(
|
|
911
1038
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
912
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
1039
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
|
913
1040
|
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
|
914
1041
|
});
|
|
915
1042
|
}
|
|
@@ -922,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
|
|
922
1049
|
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
|
923
1050
|
stream->parallel_for(
|
|
924
1051
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
|
925
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
1052
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
|
926
1053
|
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
|
927
1054
|
});
|
|
928
1055
|
}
|
|
@@ -938,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
|
|
938
1065
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
|
939
1066
|
stream->parallel_for(
|
|
940
1067
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
941
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
1068
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
|
942
1069
|
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
|
943
1070
|
});
|
|
944
1071
|
}
|
|
@@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
953
1080
|
|
|
954
1081
|
const int64_t ne00 = src0->ne[0];
|
|
955
1082
|
const int64_t row_diff = row_high - row_low;
|
|
956
|
-
|
|
957
1083
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
958
1084
|
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
|
959
1085
|
#ifdef GGML_SYCL_F16
|
|
@@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
967
1093
|
|
|
968
1094
|
if (src1_convert_f16) {
|
|
969
1095
|
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
|
970
|
-
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
|
|
1096
|
+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
|
971
1097
|
GGML_ASSERT(to_fp16_sycl != nullptr);
|
|
972
1098
|
to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
|
|
973
1099
|
}
|
|
@@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
977
1103
|
|
|
978
1104
|
switch (src0->type) {
|
|
979
1105
|
case GGML_TYPE_Q4_0:
|
|
980
|
-
|
|
1106
|
+
if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
|
|
1107
|
+
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1108
|
+
dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
|
1109
|
+
} else {
|
|
1110
|
+
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
|
1111
|
+
}
|
|
981
1112
|
break;
|
|
982
1113
|
case GGML_TYPE_Q4_1:
|
|
983
1114
|
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
|
@@ -1012,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
1012
1143
|
default:
|
|
1013
1144
|
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
|
1014
1145
|
GGML_ABORT("fatal error");
|
|
1015
|
-
break;
|
|
1016
1146
|
}
|
|
1017
1147
|
|
|
1018
1148
|
GGML_UNUSED(src1);
|
|
@@ -1020,4 +1150,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
|
1020
1150
|
GGML_UNUSED(src1_ddq_i);
|
|
1021
1151
|
GGML_UNUSED(src1_ncols);
|
|
1022
1152
|
GGML_UNUSED(src1_padded_row_size);
|
|
1153
|
+
GGML_UNUSED(ctx);
|
|
1023
1154
|
}
|