@fugood/llama.node 0.3.6 → 0.3.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -2
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +3 -1
- package/lib/index.js +16 -1
- package/lib/index.ts +16 -0
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +4 -3
- package/src/LlamaCompletionWorker.cpp +4 -2
- package/src/LlamaContext.cpp +61 -6
- package/src/LlamaContext.h +1 -0
- package/src/common.hpp +6 -11
- package/src/llama.cpp/.github/workflows/build.yml +19 -17
- package/src/llama.cpp/.github/workflows/docker.yml +77 -30
- package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +22 -3
- package/src/llama.cpp/CMakeLists.txt +49 -24
- package/src/llama.cpp/common/arg.cpp +82 -26
- package/src/llama.cpp/common/arg.h +3 -0
- package/src/llama.cpp/common/common.cpp +192 -72
- package/src/llama.cpp/common/common.h +51 -18
- package/src/llama.cpp/common/ngram-cache.cpp +12 -12
- package/src/llama.cpp/common/ngram-cache.h +2 -2
- package/src/llama.cpp/common/sampling.cpp +11 -6
- package/src/llama.cpp/common/speculative.cpp +18 -15
- package/src/llama.cpp/docs/build.md +2 -0
- package/src/llama.cpp/examples/batched/batched.cpp +9 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
- package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
- package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
- package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
- package/src/llama.cpp/examples/infill/infill.cpp +23 -24
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
- package/src/llama.cpp/examples/llava/clip.cpp +4 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
- package/src/llama.cpp/examples/llava/llava.cpp +2 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
- package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
- package/src/llama.cpp/examples/main/main.cpp +51 -29
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
- package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
- package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
- package/src/llama.cpp/examples/run/run.cpp +175 -61
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
- package/src/llama.cpp/examples/server/httplib.h +1295 -409
- package/src/llama.cpp/examples/server/server.cpp +387 -181
- package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
- package/src/llama.cpp/examples/server/utils.hpp +170 -58
- package/src/llama.cpp/examples/simple/simple.cpp +9 -8
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
- package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
- package/src/llama.cpp/examples/tts/tts.cpp +64 -23
- package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
- package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +36 -145
- package/src/llama.cpp/ggml/include/gguf.h +202 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml.c +117 -1327
- package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
- package/src/llama.cpp/include/llama-cpp.h +6 -1
- package/src/llama.cpp/include/llama.h +138 -75
- package/src/llama.cpp/src/CMakeLists.txt +13 -1
- package/src/llama.cpp/src/llama-adapter.cpp +347 -0
- package/src/llama.cpp/src/llama-adapter.h +74 -0
- package/src/llama.cpp/src/llama-arch.cpp +1487 -0
- package/src/llama.cpp/src/llama-arch.h +400 -0
- package/src/llama.cpp/src/llama-batch.cpp +368 -0
- package/src/llama.cpp/src/llama-batch.h +88 -0
- package/src/llama.cpp/src/llama-chat.cpp +578 -0
- package/src/llama.cpp/src/llama-chat.h +52 -0
- package/src/llama.cpp/src/llama-context.cpp +1775 -0
- package/src/llama.cpp/src/llama-context.h +128 -0
- package/src/llama.cpp/src/llama-cparams.cpp +1 -0
- package/src/llama.cpp/src/llama-cparams.h +37 -0
- package/src/llama.cpp/src/llama-grammar.cpp +5 -4
- package/src/llama.cpp/src/llama-grammar.h +3 -1
- package/src/llama.cpp/src/llama-hparams.cpp +71 -0
- package/src/llama.cpp/src/llama-hparams.h +139 -0
- package/src/llama.cpp/src/llama-impl.cpp +167 -0
- package/src/llama.cpp/src/llama-impl.h +16 -136
- package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
- package/src/llama.cpp/src/llama-kv-cache.h +218 -0
- package/src/llama.cpp/src/llama-mmap.cpp +589 -0
- package/src/llama.cpp/src/llama-mmap.h +67 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
- package/src/llama.cpp/src/llama-model-loader.h +167 -0
- package/src/llama.cpp/src/llama-model.cpp +3953 -0
- package/src/llama.cpp/src/llama-model.h +370 -0
- package/src/llama.cpp/src/llama-quant.cpp +934 -0
- package/src/llama.cpp/src/llama-quant.h +1 -0
- package/src/llama.cpp/src/llama-sampling.cpp +147 -32
- package/src/llama.cpp/src/llama-sampling.h +3 -19
- package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
- package/src/llama.cpp/src/llama-vocab.h +97 -142
- package/src/llama.cpp/src/llama.cpp +7160 -20314
- package/src/llama.cpp/src/unicode.cpp +8 -3
- package/src/llama.cpp/tests/CMakeLists.txt +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
- package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
- package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
- package/src/llama.cpp/tests/test-gguf.cpp +222 -187
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +0 -1
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
|
|
3
|
+
#include "common.hpp"
|
|
4
|
+
|
|
5
|
+
template <u_int HEAD_SIZE>
|
|
6
|
+
static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale,
|
|
7
|
+
const float * k, const float * v, const float * r, const float * td,
|
|
8
|
+
const float * s, float * dst) {
|
|
9
|
+
const u_int head_size = HEAD_SIZE;
|
|
10
|
+
const u_int state_size = C * head_size;
|
|
11
|
+
const u_int n_seq_tokens = T / B;
|
|
12
|
+
sycl::range<1> block_dims((C / H));
|
|
13
|
+
sycl::range<1> grid_dims((B * H));
|
|
14
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
15
|
+
/* local memory accessors*/
|
|
16
|
+
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
|
17
|
+
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
|
18
|
+
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
|
19
|
+
|
|
20
|
+
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
|
|
21
|
+
u_int tid = item.get_local_id(0);
|
|
22
|
+
u_int bid = item.get_group(0);
|
|
23
|
+
|
|
24
|
+
u_int batch_i = bid / H;
|
|
25
|
+
u_int head_i = bid % H;
|
|
26
|
+
|
|
27
|
+
float state[head_size];
|
|
28
|
+
|
|
29
|
+
#pragma unroll
|
|
30
|
+
for (u_int i = 0; i < head_size; i++) {
|
|
31
|
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
35
|
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
|
36
|
+
|
|
37
|
+
item.barrier(sycl::access::fence_space::local_space); //sync threads
|
|
38
|
+
_k[tid] = k[t];
|
|
39
|
+
_r[tid] = r[t];
|
|
40
|
+
_td[tid] = td[t];
|
|
41
|
+
item.barrier(sycl::access::fence_space::local_space); //sync threads
|
|
42
|
+
|
|
43
|
+
const float _v = v[t];
|
|
44
|
+
float y = 0;
|
|
45
|
+
|
|
46
|
+
for (u_int j = 0; j < head_size; j += 4) {
|
|
47
|
+
const sycl::float4 & k = (sycl::float4 &) (_k[j]);
|
|
48
|
+
const sycl::float4 & r = (sycl::float4 &) (_r[j]);
|
|
49
|
+
const sycl::float4 & td = (sycl::float4 &) (_td[j]);
|
|
50
|
+
sycl::float4 & s = (sycl::float4 &) (state[j]);
|
|
51
|
+
sycl::float4 kv;
|
|
52
|
+
|
|
53
|
+
kv.x() = k.x() * _v;
|
|
54
|
+
kv.y() = k.y() * _v;
|
|
55
|
+
kv.z() = k.z() * _v;
|
|
56
|
+
kv.w() = k.w() * _v;
|
|
57
|
+
|
|
58
|
+
s.x() = s.x() * td.x() + kv.x();
|
|
59
|
+
s.y() = s.y() * td.y() + kv.y();
|
|
60
|
+
s.z() = s.z() * td.z() + kv.z();
|
|
61
|
+
s.w() = s.w() * td.w() + kv.w();
|
|
62
|
+
|
|
63
|
+
y += r.x() * s.x();
|
|
64
|
+
y += r.y() * s.y();
|
|
65
|
+
y += r.z() * s.z();
|
|
66
|
+
y += r.w() * s.w();
|
|
67
|
+
}
|
|
68
|
+
dst[t] = y * scale;
|
|
69
|
+
}
|
|
70
|
+
#pragma unroll
|
|
71
|
+
for (u_int i = 0; i < head_size; i++) {
|
|
72
|
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
73
|
+
}
|
|
74
|
+
});
|
|
75
|
+
});
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
79
|
+
const float * k_d = static_cast<const float *>(dst->src[0]->data);
|
|
80
|
+
const float * v_d = static_cast<const float *>(dst->src[1]->data);
|
|
81
|
+
const float * r_d = static_cast<const float *>(dst->src[2]->data);
|
|
82
|
+
const float * td_d = static_cast<const float *>(dst->src[3]->data);
|
|
83
|
+
const float * s_d = static_cast<const float *>(dst->src[4]->data);
|
|
84
|
+
|
|
85
|
+
const int64_t B = dst->src[4]->ne[1];
|
|
86
|
+
const int64_t T = dst->src[0]->ne[2];
|
|
87
|
+
const int64_t C = dst->ne[0];
|
|
88
|
+
const int64_t H = dst->src[0]->ne[1];
|
|
89
|
+
|
|
90
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
91
|
+
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
|
|
92
|
+
GGML_ASSERT(C % H == 0);
|
|
93
|
+
GGML_ASSERT(C / H == 64 || C / H == 128);
|
|
94
|
+
|
|
95
|
+
float scale;
|
|
96
|
+
memcpy(&scale, dst->op_params, sizeof(float));
|
|
97
|
+
|
|
98
|
+
float * dst_d = (float *) dst->data;
|
|
99
|
+
|
|
100
|
+
if (C / H == 64) {
|
|
101
|
+
gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
|
|
102
|
+
} else {
|
|
103
|
+
gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
|
|
104
|
+
}
|
|
105
|
+
}
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
#include "outprod.hpp"
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx,
|
|
7
|
-
const ggml_tensor*
|
|
8
|
-
|
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
7
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
8
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
9
9
|
|
|
10
10
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
11
11
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
@@ -3,8 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
#include "common.hpp"
|
|
5
5
|
|
|
6
|
-
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx,
|
|
7
|
-
const ggml_tensor* src1, ggml_tensor* dst);
|
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
#endif // GGML_SYCL_OUTPROD_HPP
|
|
@@ -55,8 +55,9 @@ static void timestep_embedding_f32_sycl(
|
|
|
55
55
|
});
|
|
56
56
|
}
|
|
57
57
|
|
|
58
|
-
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx,
|
|
59
|
-
const ggml_tensor *
|
|
58
|
+
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
59
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
60
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
60
61
|
const float * src0_d = (const float *)src0->data;
|
|
61
62
|
float * dst_d = (float *)dst->data;
|
|
62
63
|
dpct::queue_ptr stream = ctx.stream();
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
|
-
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx,
|
|
19
|
-
const ggml_tensor *src1, ggml_tensor * dst);
|
|
18
|
+
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
20
19
|
|
|
21
20
|
#endif // GGML_SYCL_TSEMBD_HPP
|
|
@@ -95,8 +95,10 @@ static void rwkv_wkv_f32_kernel(
|
|
|
95
95
|
}
|
|
96
96
|
}
|
|
97
97
|
|
|
98
|
-
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx,
|
|
99
|
-
|
|
98
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
99
|
+
|
|
100
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
101
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
100
102
|
|
|
101
103
|
const float* k_d = (const float*)dst->src[0]->data;
|
|
102
104
|
const float* v_d = (const float*)dst->src[1]->data;
|
|
@@ -107,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
|
|
|
107
109
|
float* dst_d = (float*)dst->data;
|
|
108
110
|
|
|
109
111
|
const int64_t B = dst->src[5]->ne[1];
|
|
110
|
-
const int64_t T = dst->src[0]->ne[
|
|
112
|
+
const int64_t T = dst->src[0]->ne[2];
|
|
111
113
|
const int64_t C = dst->ne[0];
|
|
112
|
-
const int64_t H = dst->src[0]->ne[
|
|
114
|
+
const int64_t H = dst->src[0]->ne[1];
|
|
113
115
|
|
|
114
116
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
115
117
|
GGML_ASSERT(C % H == 0);
|
|
@@ -131,7 +133,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
|
|
|
131
133
|
[=](sycl::nd_item<3> item_ct1) {
|
|
132
134
|
rwkv_wkv_f32_kernel(
|
|
133
135
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
134
|
-
item_ct1, shared_mem_acc.
|
|
136
|
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
135
137
|
);
|
|
136
138
|
});
|
|
137
139
|
});
|
|
@@ -3,8 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
#include "common.hpp"
|
|
5
5
|
|
|
6
|
-
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx,
|
|
7
|
-
const ggml_tensor *src1, ggml_tensor * dst);
|
|
6
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
#endif // GGML_SYCL_WKV6_HPP
|
|
@@ -1,5 +1,20 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.19)
|
|
2
|
+
cmake_policy(SET CMP0114 NEW)
|
|
3
|
+
|
|
1
4
|
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
|
2
5
|
|
|
6
|
+
function(detect_host_compiler)
|
|
7
|
+
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
|
8
|
+
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
|
|
9
|
+
find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
|
|
10
|
+
else()
|
|
11
|
+
find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
|
|
12
|
+
find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
|
|
13
|
+
endif()
|
|
14
|
+
set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
|
|
15
|
+
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
|
|
16
|
+
endfunction()
|
|
17
|
+
|
|
3
18
|
if (Vulkan_FOUND)
|
|
4
19
|
message(STATUS "Vulkan found")
|
|
5
20
|
|
|
@@ -8,6 +23,20 @@ if (Vulkan_FOUND)
|
|
|
8
23
|
../../include/ggml-vulkan.h
|
|
9
24
|
)
|
|
10
25
|
|
|
26
|
+
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
|
|
27
|
+
# If it's not, there will be an error to stderr.
|
|
28
|
+
# If it's supported, set a define to indicate that we should compile those shaders
|
|
29
|
+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
|
30
|
+
OUTPUT_VARIABLE glslc_output
|
|
31
|
+
ERROR_VARIABLE glslc_error)
|
|
32
|
+
|
|
33
|
+
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
|
|
34
|
+
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
|
|
35
|
+
else()
|
|
36
|
+
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
|
|
37
|
+
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
38
|
+
endif()
|
|
39
|
+
|
|
11
40
|
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
|
|
12
41
|
# If it's not, there will be an error to stderr.
|
|
13
42
|
# If it's supported, set a define to indicate that we should compile those shaders
|
|
@@ -59,15 +88,56 @@ if (Vulkan_FOUND)
|
|
|
59
88
|
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
|
60
89
|
endif()
|
|
61
90
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
91
|
+
if (NOT CMAKE_CROSSCOMPILING)
|
|
92
|
+
add_subdirectory(vulkan-shaders)
|
|
93
|
+
if (MSVC)
|
|
94
|
+
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
|
95
|
+
string(TOUPPER ${CONFIG} CONFIG)
|
|
96
|
+
set_target_properties(vulkan-shaders-gen PROPERTIES
|
|
97
|
+
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
|
98
|
+
endforeach()
|
|
99
|
+
endif()
|
|
100
|
+
else()
|
|
101
|
+
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
|
|
102
|
+
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
|
|
103
|
+
else()
|
|
104
|
+
detect_host_compiler()
|
|
105
|
+
if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
|
|
106
|
+
message(FATAL_ERROR "Host compiler not found")
|
|
107
|
+
else()
|
|
108
|
+
message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
|
|
109
|
+
endif()
|
|
110
|
+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
|
|
111
|
+
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
|
|
112
|
+
endif()
|
|
113
|
+
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
|
114
|
+
|
|
115
|
+
include(ExternalProject)
|
|
116
|
+
# Native build through ExternalProject_Add
|
|
117
|
+
ExternalProject_Add(
|
|
118
|
+
vulkan-shaders-gen
|
|
119
|
+
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
|
120
|
+
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
|
|
121
|
+
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
|
122
|
+
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
|
123
|
+
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
|
124
|
+
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
|
125
|
+
)
|
|
126
|
+
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
|
127
|
+
endif()
|
|
128
|
+
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
|
|
129
|
+
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
|
|
65
130
|
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
|
66
131
|
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
|
|
67
132
|
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
|
|
68
133
|
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
|
|
69
134
|
|
|
70
135
|
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
|
136
|
+
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
|
|
137
|
+
|
|
138
|
+
if (CMAKE_CROSSCOMPILING)
|
|
139
|
+
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
|
140
|
+
endif()
|
|
71
141
|
|
|
72
142
|
add_custom_command(
|
|
73
143
|
OUTPUT ${_ggml_vk_header}
|
|
@@ -81,7 +151,7 @@ if (Vulkan_FOUND)
|
|
|
81
151
|
--target-cpp ${_ggml_vk_source}
|
|
82
152
|
--no-clean
|
|
83
153
|
|
|
84
|
-
DEPENDS ${_ggml_vk_shader_deps}
|
|
154
|
+
DEPENDS ${_ggml_vk_shader_deps}
|
|
85
155
|
COMMENT "Generate vulkan shaders"
|
|
86
156
|
)
|
|
87
157
|
|