@fugood/llama.node 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +2 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +8 -9
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +43 -9
- package/src/llama.cpp/.github/workflows/docker.yml +3 -0
- package/src/llama.cpp/CMakeLists.txt +7 -4
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +0 -2
- package/src/llama.cpp/common/arg.cpp +642 -607
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +79 -281
- package/src/llama.cpp/common/common.h +130 -100
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +116 -108
- package/src/llama.cpp/common/sampling.h +20 -20
- package/src/llama.cpp/docs/build.md +37 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +14 -14
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
- package/src/llama.cpp/examples/infill/infill.cpp +40 -86
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/clip.cpp +1 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +37 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
- package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
- package/src/llama.cpp/examples/main/main.cpp +64 -109
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
- package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
- package/src/llama.cpp/examples/server/server.cpp +553 -691
- package/src/llama.cpp/examples/server/utils.hpp +312 -25
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +128 -96
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +53 -393
- package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
- package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
- package/src/llama.cpp/include/llama.h +67 -33
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +745 -105
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +49 -9
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +2636 -2406
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
- package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +1 -0
- package/src/llama.cpp/tests/test-sampling.cpp +162 -137
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
- /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
|
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
|
|
|
968
968
|
grid1[0] ^ signs[0], signs[0], std::minus<>());
|
|
969
969
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
970
970
|
grid2[0] ^ signs[1], signs[1], std::minus<>());
|
|
971
|
-
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
972
|
-
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
971
|
+
sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
|
|
972
|
+
sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
|
|
973
973
|
q8 += 8;
|
|
974
974
|
aux32 >>= 7;
|
|
975
975
|
}
|
|
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
|
|
1009
1009
|
grid1[0] ^ signs0, signs0, std::minus<>());
|
|
1010
1010
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
1011
1011
|
grid2[0] ^ signs1, signs1, std::minus<>());
|
|
1012
|
-
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
1013
|
-
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
1012
|
+
sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
|
|
1013
|
+
sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
|
|
1014
1014
|
q8 += 8;
|
|
1015
1015
|
}
|
|
1016
1016
|
const float d =
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
#include "wkv6.hpp"
|
|
3
|
+
|
|
4
|
+
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
|
5
|
+
|
|
6
|
+
// Helper function for the main kernel
|
|
7
|
+
static void rwkv_wkv_f32_kernel(
|
|
8
|
+
const int B, const int T, const int C, const int H,
|
|
9
|
+
const float* k, const float* v, const float* r,
|
|
10
|
+
const float* tf, const float* td, const float* s,
|
|
11
|
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
12
|
+
|
|
13
|
+
const int tid = item_ct1.get_local_id(2);
|
|
14
|
+
const int bid = item_ct1.get_group(2);
|
|
15
|
+
|
|
16
|
+
const int head_size = WKV_BLOCK_SIZE;
|
|
17
|
+
const int batch_i = bid / H;
|
|
18
|
+
const int head_i = bid % H;
|
|
19
|
+
const int state_size = C * head_size;
|
|
20
|
+
const int n_seq_tokens = T / B;
|
|
21
|
+
|
|
22
|
+
// Set up shared memory pointers
|
|
23
|
+
float* _k = shared_mem;
|
|
24
|
+
float* _r = _k + head_size;
|
|
25
|
+
float* _tf = _r + head_size;
|
|
26
|
+
float* _td = _tf + head_size;
|
|
27
|
+
|
|
28
|
+
// Local state array
|
|
29
|
+
float state[WKV_BLOCK_SIZE];
|
|
30
|
+
|
|
31
|
+
// Load initial state
|
|
32
|
+
#pragma unroll
|
|
33
|
+
for (int i = 0; i < head_size; i++) {
|
|
34
|
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// Sync threads before shared memory operations
|
|
38
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
39
|
+
|
|
40
|
+
// Load time-mixing parameters
|
|
41
|
+
_tf[tid] = tf[head_i * head_size + tid];
|
|
42
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
43
|
+
|
|
44
|
+
// Main sequence processing loop
|
|
45
|
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
46
|
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
47
|
+
t += C) {
|
|
48
|
+
|
|
49
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
50
|
+
|
|
51
|
+
// Load current timestep data to shared memory
|
|
52
|
+
_k[tid] = k[t];
|
|
53
|
+
_r[tid] = r[t];
|
|
54
|
+
_td[tid] = td[t];
|
|
55
|
+
|
|
56
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
57
|
+
|
|
58
|
+
const float _v = v[t];
|
|
59
|
+
float y = 0;
|
|
60
|
+
|
|
61
|
+
// Process in chunks of 4 for better vectorization
|
|
62
|
+
sycl::float4 k4, r4, tf4, td4, s4, kv4;
|
|
63
|
+
#pragma unroll
|
|
64
|
+
for (int j = 0; j < head_size; j += 4) {
|
|
65
|
+
// Load data in vec4 chunks
|
|
66
|
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
67
|
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
68
|
+
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
69
|
+
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
70
|
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
71
|
+
|
|
72
|
+
// Compute key-value product
|
|
73
|
+
sycl::float4 kv4 = k4 * _v;
|
|
74
|
+
|
|
75
|
+
// Accumulate weighted sum
|
|
76
|
+
y += sycl::dot(r4, tf4 * kv4 + s4);
|
|
77
|
+
|
|
78
|
+
// Update state
|
|
79
|
+
s4 = s4 * td4 + kv4;
|
|
80
|
+
|
|
81
|
+
// Store updated state
|
|
82
|
+
state[j] = s4.x();
|
|
83
|
+
state[j+1] = s4.y();
|
|
84
|
+
state[j+2] = s4.z();
|
|
85
|
+
state[j+3] = s4.w();
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
dst[t] = y;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// Save final state
|
|
92
|
+
#pragma unroll
|
|
93
|
+
for (int i = 0; i < head_size; i++) {
|
|
94
|
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
99
|
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
|
100
|
+
|
|
101
|
+
const float* k_d = (const float*)dst->src[0]->data;
|
|
102
|
+
const float* v_d = (const float*)dst->src[1]->data;
|
|
103
|
+
const float* r_d = (const float*)dst->src[2]->data;
|
|
104
|
+
const float* tf_d = (const float*)dst->src[3]->data;
|
|
105
|
+
const float* td_d = (const float*)dst->src[4]->data;
|
|
106
|
+
const float* s_d = (const float*)dst->src[5]->data;
|
|
107
|
+
float* dst_d = (float*)dst->data;
|
|
108
|
+
|
|
109
|
+
const int64_t B = dst->src[5]->ne[1];
|
|
110
|
+
const int64_t T = dst->src[0]->ne[3];
|
|
111
|
+
const int64_t C = dst->ne[0];
|
|
112
|
+
const int64_t H = dst->src[0]->ne[2];
|
|
113
|
+
|
|
114
|
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
115
|
+
GGML_ASSERT(C % H == 0);
|
|
116
|
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
117
|
+
|
|
118
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
119
|
+
|
|
120
|
+
// Calculate execution configuration
|
|
121
|
+
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
|
122
|
+
sycl::range<3> block_dims(1, 1, C / H);
|
|
123
|
+
sycl::range<3> grid_dims(1, 1, B * H);
|
|
124
|
+
|
|
125
|
+
// Submit kernel
|
|
126
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
127
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
128
|
+
|
|
129
|
+
cgh.parallel_for(
|
|
130
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
131
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
132
|
+
rwkv_wkv_f32_kernel(
|
|
133
|
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
134
|
+
item_ct1, shared_mem_acc.get_pointer()
|
|
135
|
+
);
|
|
136
|
+
});
|
|
137
|
+
});
|
|
138
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#include "ggml-threading.h"
|
|
2
|
+
#include <mutex>
|
|
3
|
+
|
|
4
|
+
std::mutex ggml_critical_section_mutex;
|
|
5
|
+
|
|
6
|
+
void ggml_critical_section_start() {
|
|
7
|
+
ggml_critical_section_mutex.lock();
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
void ggml_critical_section_end(void) {
|
|
11
|
+
ggml_critical_section_mutex.unlock();
|
|
12
|
+
}
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
|
2
|
+
|
|
3
|
+
if (Vulkan_FOUND)
|
|
4
|
+
message(STATUS "Vulkan found")
|
|
5
|
+
|
|
6
|
+
add_library(ggml-vulkan
|
|
7
|
+
ggml-vulkan.cpp
|
|
8
|
+
../../include/ggml-vulkan.h
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
target_link_libraries(ggml-vulkan PRIVATE ggml-base Vulkan::Vulkan)
|
|
12
|
+
target_include_directories(ggml-vulkan PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
|
|
13
|
+
|
|
14
|
+
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
|
|
15
|
+
# Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
|
|
16
|
+
if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|
17
|
+
add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
|
|
18
|
+
endif()
|
|
19
|
+
|
|
20
|
+
if (GGML_VULKAN_CHECK_RESULTS)
|
|
21
|
+
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
|
22
|
+
endif()
|
|
23
|
+
|
|
24
|
+
if (GGML_VULKAN_DEBUG)
|
|
25
|
+
add_compile_definitions(GGML_VULKAN_DEBUG)
|
|
26
|
+
endif()
|
|
27
|
+
|
|
28
|
+
if (GGML_VULKAN_MEMORY_DEBUG)
|
|
29
|
+
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
|
|
30
|
+
endif()
|
|
31
|
+
|
|
32
|
+
if (GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
33
|
+
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
34
|
+
endif()
|
|
35
|
+
|
|
36
|
+
if (GGML_VULKAN_PERF)
|
|
37
|
+
add_compile_definitions(GGML_VULKAN_PERF)
|
|
38
|
+
endif()
|
|
39
|
+
|
|
40
|
+
if (GGML_VULKAN_VALIDATE)
|
|
41
|
+
add_compile_definitions(GGML_VULKAN_VALIDATE)
|
|
42
|
+
endif()
|
|
43
|
+
|
|
44
|
+
if (GGML_VULKAN_RUN_TESTS)
|
|
45
|
+
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
|
46
|
+
endif()
|
|
47
|
+
|
|
48
|
+
add_subdirectory(vulkan-shaders)
|
|
49
|
+
|
|
50
|
+
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
|
|
51
|
+
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
|
52
|
+
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
|
|
53
|
+
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
|
|
54
|
+
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
|
|
55
|
+
|
|
56
|
+
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
|
57
|
+
|
|
58
|
+
add_custom_command(
|
|
59
|
+
OUTPUT ${_ggml_vk_header}
|
|
60
|
+
${_ggml_vk_source}
|
|
61
|
+
|
|
62
|
+
COMMAND ${_ggml_vk_genshaders_cmd}
|
|
63
|
+
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
|
64
|
+
--input-dir ${_ggml_vk_input_dir}
|
|
65
|
+
--output-dir ${_ggml_vk_output_dir}
|
|
66
|
+
--target-hpp ${_ggml_vk_header}
|
|
67
|
+
--target-cpp ${_ggml_vk_source}
|
|
68
|
+
--no-clean
|
|
69
|
+
|
|
70
|
+
DEPENDS ${_ggml_vk_shader_deps}
|
|
71
|
+
COMMENT "Generate vulkan shaders"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
|
|
75
|
+
|
|
76
|
+
else()
|
|
77
|
+
message(WARNING "Vulkan not found")
|
|
78
|
+
endif()
|