@fugood/llama.node 0.3.14 → 0.3.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/llama.cpp/.github/workflows/build.yml +30 -1
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/arg.cpp +20 -2
- package/src/llama.cpp/common/common.cpp +6 -3
- package/src/llama.cpp/common/speculative.cpp +4 -4
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +6 -6
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/run.cpp +91 -46
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +32 -15
- package/src/llama.cpp/examples/server/utils.hpp +3 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/tts/tts.cpp +12 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +24 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -27
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +253 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +66 -26
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +103 -34
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +352 -146
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml.c +85 -2
- package/src/llama.cpp/include/llama.h +86 -22
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +102 -16
- package/src/llama.cpp/src/llama-arch.h +18 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -110
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-model.cpp +8207 -163
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama.cpp +51 -9984
- package/src/llama.cpp/tests/test-backend-ops.cpp +88 -9
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
|
@@ -259,6 +259,10 @@ static std::string var_to_str(ggml_type type) {
|
|
|
259
259
|
return ggml_type_name(type);
|
|
260
260
|
}
|
|
261
261
|
|
|
262
|
+
static std::string var_to_str(ggml_prec prec) {
|
|
263
|
+
return prec == GGML_PREC_F32 ? "f32" : "def";
|
|
264
|
+
}
|
|
265
|
+
|
|
262
266
|
static std::string var_to_str(ggml_op_pool pool) {
|
|
263
267
|
switch (pool) {
|
|
264
268
|
case GGML_OP_POOL_AVG: return "avg";
|
|
@@ -1916,6 +1920,40 @@ struct test_gla : public test_case {
|
|
|
1916
1920
|
}
|
|
1917
1921
|
};
|
|
1918
1922
|
|
|
1923
|
+
// GGML_OP_RWKV_WKV7
|
|
1924
|
+
struct test_rwkv_wkv7 : public test_case {
|
|
1925
|
+
const ggml_type type;
|
|
1926
|
+
|
|
1927
|
+
const int64_t head_count;
|
|
1928
|
+
const int64_t head_size;
|
|
1929
|
+
const int64_t n_seq_tokens;
|
|
1930
|
+
const int64_t n_seqs;
|
|
1931
|
+
|
|
1932
|
+
std::string vars() override {
|
|
1933
|
+
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
|
1934
|
+
}
|
|
1935
|
+
|
|
1936
|
+
test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
|
|
1937
|
+
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
|
1938
|
+
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
|
1939
|
+
|
|
1940
|
+
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1941
|
+
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
|
1942
|
+
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
|
1943
|
+
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
|
1944
|
+
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
|
1945
|
+
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
|
1946
|
+
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
|
1947
|
+
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
|
1948
|
+
// Outputs may become NaN with long seqlen without these normalization
|
|
1949
|
+
a = ggml_l2_norm(ctx, a, 1e-7F);
|
|
1950
|
+
b = ggml_l2_norm(ctx, b, 1e-7F);
|
|
1951
|
+
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
|
1952
|
+
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
|
|
1953
|
+
return out;
|
|
1954
|
+
}
|
|
1955
|
+
};
|
|
1956
|
+
|
|
1919
1957
|
// GGML_OP_MUL_MAT
|
|
1920
1958
|
struct test_mul_mat : public test_case {
|
|
1921
1959
|
const ggml_type type_a;
|
|
@@ -2972,6 +3010,32 @@ struct test_group_norm : public test_case {
|
|
|
2972
3010
|
}
|
|
2973
3011
|
};
|
|
2974
3012
|
|
|
3013
|
+
// GGML_OP_L2_NORM
|
|
3014
|
+
struct test_l2_norm : public test_case {
|
|
3015
|
+
const ggml_type type;
|
|
3016
|
+
const std::array<int64_t, 4> ne;
|
|
3017
|
+
const float eps;
|
|
3018
|
+
|
|
3019
|
+
std::string vars() override {
|
|
3020
|
+
return VARS_TO_STR2(type, ne);
|
|
3021
|
+
}
|
|
3022
|
+
|
|
3023
|
+
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
|
3024
|
+
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
|
3025
|
+
float eps = 1e-12f)
|
|
3026
|
+
: type(type), ne(ne), eps(eps) {}
|
|
3027
|
+
|
|
3028
|
+
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
3029
|
+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
3030
|
+
ggml_set_name(a, "a");
|
|
3031
|
+
|
|
3032
|
+
ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
|
|
3033
|
+
ggml_set_name(out, "out");
|
|
3034
|
+
|
|
3035
|
+
return out;
|
|
3036
|
+
}
|
|
3037
|
+
};
|
|
3038
|
+
|
|
2975
3039
|
// GGML_OP_ACC
|
|
2976
3040
|
struct test_acc : public test_case {
|
|
2977
3041
|
const ggml_type type;
|
|
@@ -3146,11 +3210,12 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3146
3210
|
const float max_bias; // ALiBi
|
|
3147
3211
|
const float logit_softcap; // Gemma 2
|
|
3148
3212
|
|
|
3213
|
+
const ggml_prec prec;
|
|
3149
3214
|
const ggml_type type_KV;
|
|
3150
3215
|
std::array<int32_t, 4> permute;
|
|
3151
3216
|
|
|
3152
3217
|
std::string vars() override {
|
|
3153
|
-
return
|
|
3218
|
+
return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
|
3154
3219
|
}
|
|
3155
3220
|
|
|
3156
3221
|
double max_nmse_err() override {
|
|
@@ -3165,9 +3230,9 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3165
3230
|
}
|
|
3166
3231
|
|
|
3167
3232
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
|
3168
|
-
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f,
|
|
3169
|
-
std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
|
3170
|
-
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
|
3233
|
+
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
|
3234
|
+
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
|
3235
|
+
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
|
3171
3236
|
|
|
3172
3237
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
3173
3238
|
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
|
@@ -3201,6 +3266,7 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3201
3266
|
}
|
|
3202
3267
|
|
|
3203
3268
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
|
|
3269
|
+
ggml_flash_attn_ext_set_prec(out, prec);
|
|
3204
3270
|
ggml_set_name(out, "out");
|
|
3205
3271
|
|
|
3206
3272
|
return out;
|
|
@@ -4006,8 +4072,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4006
4072
|
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
|
4007
4073
|
}
|
|
4008
4074
|
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
|
4075
|
+
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
|
4009
4076
|
}
|
|
4010
4077
|
|
|
4078
|
+
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
|
4079
|
+
|
|
4011
4080
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
|
4012
4081
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
|
4013
4082
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
|
@@ -4019,6 +4088,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4019
4088
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
|
4020
4089
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
|
4021
4090
|
|
|
4091
|
+
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
|
|
4092
|
+
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
|
|
4093
|
+
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
|
|
4094
|
+
test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
|
|
4095
|
+
|
|
4022
4096
|
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
|
|
4023
4097
|
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
|
|
4024
4098
|
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
|
@@ -4308,11 +4382,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4308
4382
|
for (int kv : { 512, 1024, }) {
|
|
4309
4383
|
if (nr != 1 && kv != 512) continue;
|
|
4310
4384
|
for (int nb : { 1, 3, 32, 35, }) {
|
|
4311
|
-
for (
|
|
4312
|
-
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
|
|
4385
|
+
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
|
4386
|
+
if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
|
|
4387
|
+
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
|
4388
|
+
test_cases.emplace_back(new test_flash_attn_ext(
|
|
4389
|
+
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
|
4390
|
+
// run fewer test cases permuted
|
|
4391
|
+
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
|
4392
|
+
test_cases.emplace_back(new test_flash_attn_ext(
|
|
4393
|
+
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
|
4394
|
+
}
|
|
4316
4395
|
}
|
|
4317
4396
|
}
|
|
4318
4397
|
}
|
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
#include <sycl/sycl.hpp>
|
|
2
|
-
#include "wkv6.hpp"
|
|
3
|
-
|
|
4
|
-
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
|
5
|
-
|
|
6
|
-
// Helper function for the main kernel
|
|
7
|
-
static void rwkv_wkv_f32_kernel(
|
|
8
|
-
const int B, const int T, const int C, const int H,
|
|
9
|
-
const float* k, const float* v, const float* r,
|
|
10
|
-
const float* tf, const float* td, const float* s,
|
|
11
|
-
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
12
|
-
|
|
13
|
-
const int tid = item_ct1.get_local_id(2);
|
|
14
|
-
const int bid = item_ct1.get_group(2);
|
|
15
|
-
|
|
16
|
-
const int head_size = WKV_BLOCK_SIZE;
|
|
17
|
-
const int batch_i = bid / H;
|
|
18
|
-
const int head_i = bid % H;
|
|
19
|
-
const int state_size = C * head_size;
|
|
20
|
-
const int n_seq_tokens = T / B;
|
|
21
|
-
|
|
22
|
-
// Set up shared memory pointers
|
|
23
|
-
float* _k = shared_mem;
|
|
24
|
-
float* _r = _k + head_size;
|
|
25
|
-
float* _tf = _r + head_size;
|
|
26
|
-
float* _td = _tf + head_size;
|
|
27
|
-
|
|
28
|
-
// Local state array
|
|
29
|
-
float state[WKV_BLOCK_SIZE];
|
|
30
|
-
|
|
31
|
-
// Load initial state
|
|
32
|
-
#pragma unroll
|
|
33
|
-
for (int i = 0; i < head_size; i++) {
|
|
34
|
-
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
// Sync threads before shared memory operations
|
|
38
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
39
|
-
|
|
40
|
-
// Load time-mixing parameters
|
|
41
|
-
_tf[tid] = tf[head_i * head_size + tid];
|
|
42
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
43
|
-
|
|
44
|
-
// Main sequence processing loop
|
|
45
|
-
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
46
|
-
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
47
|
-
t += C) {
|
|
48
|
-
|
|
49
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
50
|
-
|
|
51
|
-
// Load current timestep data to shared memory
|
|
52
|
-
_k[tid] = k[t];
|
|
53
|
-
_r[tid] = r[t];
|
|
54
|
-
_td[tid] = td[t];
|
|
55
|
-
|
|
56
|
-
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
57
|
-
|
|
58
|
-
const float _v = v[t];
|
|
59
|
-
float y = 0;
|
|
60
|
-
|
|
61
|
-
// Process in chunks of 4 for better vectorization
|
|
62
|
-
sycl::float4 k4, r4, tf4, td4, s4;
|
|
63
|
-
#pragma unroll
|
|
64
|
-
for (int j = 0; j < head_size; j += 4) {
|
|
65
|
-
// Load data in vec4 chunks
|
|
66
|
-
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
67
|
-
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
68
|
-
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
69
|
-
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
70
|
-
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
71
|
-
|
|
72
|
-
// Compute key-value product
|
|
73
|
-
sycl::float4 kv4 = k4 * _v;
|
|
74
|
-
|
|
75
|
-
// Accumulate weighted sum
|
|
76
|
-
y += sycl::dot(r4, tf4 * kv4 + s4);
|
|
77
|
-
|
|
78
|
-
// Update state
|
|
79
|
-
s4 = s4 * td4 + kv4;
|
|
80
|
-
|
|
81
|
-
// Store updated state
|
|
82
|
-
state[j] = s4.x();
|
|
83
|
-
state[j+1] = s4.y();
|
|
84
|
-
state[j+2] = s4.z();
|
|
85
|
-
state[j+3] = s4.w();
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
dst[t] = y;
|
|
89
|
-
}
|
|
90
|
-
|
|
91
|
-
// Save final state
|
|
92
|
-
#pragma unroll
|
|
93
|
-
for (int i = 0; i < head_size; i++) {
|
|
94
|
-
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
95
|
-
}
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
99
|
-
|
|
100
|
-
const ggml_tensor *src0 = dst->src[0];
|
|
101
|
-
const ggml_tensor *src1 = dst->src[1];
|
|
102
|
-
|
|
103
|
-
const float* k_d = (const float*)dst->src[0]->data;
|
|
104
|
-
const float* v_d = (const float*)dst->src[1]->data;
|
|
105
|
-
const float* r_d = (const float*)dst->src[2]->data;
|
|
106
|
-
const float* tf_d = (const float*)dst->src[3]->data;
|
|
107
|
-
const float* td_d = (const float*)dst->src[4]->data;
|
|
108
|
-
const float* s_d = (const float*)dst->src[5]->data;
|
|
109
|
-
float* dst_d = (float*)dst->data;
|
|
110
|
-
|
|
111
|
-
const int64_t B = dst->src[5]->ne[1];
|
|
112
|
-
const int64_t T = dst->src[0]->ne[2];
|
|
113
|
-
const int64_t C = dst->ne[0];
|
|
114
|
-
const int64_t H = dst->src[0]->ne[1];
|
|
115
|
-
|
|
116
|
-
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
117
|
-
GGML_ASSERT(C % H == 0);
|
|
118
|
-
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
119
|
-
|
|
120
|
-
dpct::queue_ptr stream = ctx.stream();
|
|
121
|
-
|
|
122
|
-
// Calculate execution configuration
|
|
123
|
-
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
|
124
|
-
sycl::range<3> block_dims(1, 1, C / H);
|
|
125
|
-
sycl::range<3> grid_dims(1, 1, B * H);
|
|
126
|
-
|
|
127
|
-
// Submit kernel
|
|
128
|
-
stream->submit([&](sycl::handler& cgh) {
|
|
129
|
-
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
130
|
-
|
|
131
|
-
cgh.parallel_for(
|
|
132
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
133
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
134
|
-
rwkv_wkv_f32_kernel(
|
|
135
|
-
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
136
|
-
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
137
|
-
);
|
|
138
|
-
});
|
|
139
|
-
});
|
|
140
|
-
|
|
141
|
-
GGML_UNUSED(src0);
|
|
142
|
-
GGML_UNUSED(src1);
|
|
143
|
-
}
|