@fugood/llama.node 0.3.14 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +32 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +12 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -27
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +253 -2
  58. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  60. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  61. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  62. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  63. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -0
  64. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  65. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +66 -26
  66. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  67. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  68. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  69. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  70. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +103 -34
  71. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  72. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  73. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  74. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  75. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  76. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  78. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +352 -146
  79. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
  81. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  82. package/src/llama.cpp/include/llama.h +86 -22
  83. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  84. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  85. package/src/llama.cpp/src/llama-adapter.h +11 -9
  86. package/src/llama.cpp/src/llama-arch.cpp +102 -16
  87. package/src/llama.cpp/src/llama-arch.h +18 -0
  88. package/src/llama.cpp/src/llama-batch.h +2 -2
  89. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  90. package/src/llama.cpp/src/llama-context.h +214 -77
  91. package/src/llama.cpp/src/llama-cparams.h +1 -0
  92. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  93. package/src/llama.cpp/src/llama-graph.h +574 -0
  94. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  95. package/src/llama.cpp/src/llama-hparams.h +9 -0
  96. package/src/llama.cpp/src/llama-io.cpp +15 -0
  97. package/src/llama.cpp/src/llama-io.h +35 -0
  98. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  99. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  100. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  101. package/src/llama.cpp/src/llama-memory.h +21 -0
  102. package/src/llama.cpp/src/llama-model.cpp +8207 -163
  103. package/src/llama.cpp/src/llama-model.h +34 -1
  104. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  105. package/src/llama.cpp/src/llama.cpp +51 -9984
  106. package/src/llama.cpp/tests/test-backend-ops.cpp +88 -9
  107. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  108. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -259,6 +259,10 @@ static std::string var_to_str(ggml_type type) {
259
259
  return ggml_type_name(type);
260
260
  }
261
261
 
262
+ static std::string var_to_str(ggml_prec prec) {
263
+ return prec == GGML_PREC_F32 ? "f32" : "def";
264
+ }
265
+
262
266
  static std::string var_to_str(ggml_op_pool pool) {
263
267
  switch (pool) {
264
268
  case GGML_OP_POOL_AVG: return "avg";
@@ -1916,6 +1920,40 @@ struct test_gla : public test_case {
1916
1920
  }
1917
1921
  };
1918
1922
 
1923
+ // GGML_OP_RWKV_WKV7
1924
+ struct test_rwkv_wkv7 : public test_case {
1925
+ const ggml_type type;
1926
+
1927
+ const int64_t head_count;
1928
+ const int64_t head_size;
1929
+ const int64_t n_seq_tokens;
1930
+ const int64_t n_seqs;
1931
+
1932
+ std::string vars() override {
1933
+ return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1934
+ }
1935
+
1936
+ test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
1937
+ int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1938
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1939
+
1940
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1941
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1942
+ ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1943
+ ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1944
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1945
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1946
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1947
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1948
+ // Outputs may become NaN with long seqlen without these normalization
1949
+ a = ggml_l2_norm(ctx, a, 1e-7F);
1950
+ b = ggml_l2_norm(ctx, b, 1e-7F);
1951
+ ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1952
+ ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
1953
+ return out;
1954
+ }
1955
+ };
1956
+
1919
1957
  // GGML_OP_MUL_MAT
1920
1958
  struct test_mul_mat : public test_case {
1921
1959
  const ggml_type type_a;
@@ -2972,6 +3010,32 @@ struct test_group_norm : public test_case {
2972
3010
  }
2973
3011
  };
2974
3012
 
3013
+ // GGML_OP_L2_NORM
3014
+ struct test_l2_norm : public test_case {
3015
+ const ggml_type type;
3016
+ const std::array<int64_t, 4> ne;
3017
+ const float eps;
3018
+
3019
+ std::string vars() override {
3020
+ return VARS_TO_STR2(type, ne);
3021
+ }
3022
+
3023
+ test_l2_norm(ggml_type type = GGML_TYPE_F32,
3024
+ std::array<int64_t, 4> ne = {64, 64, 320, 1},
3025
+ float eps = 1e-12f)
3026
+ : type(type), ne(ne), eps(eps) {}
3027
+
3028
+ ggml_tensor * build_graph(ggml_context * ctx) override {
3029
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3030
+ ggml_set_name(a, "a");
3031
+
3032
+ ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
3033
+ ggml_set_name(out, "out");
3034
+
3035
+ return out;
3036
+ }
3037
+ };
3038
+
2975
3039
  // GGML_OP_ACC
2976
3040
  struct test_acc : public test_case {
2977
3041
  const ggml_type type;
@@ -3146,11 +3210,12 @@ struct test_flash_attn_ext : public test_case {
3146
3210
  const float max_bias; // ALiBi
3147
3211
  const float logit_softcap; // Gemma 2
3148
3212
 
3213
+ const ggml_prec prec;
3149
3214
  const ggml_type type_KV;
3150
3215
  std::array<int32_t, 4> permute;
3151
3216
 
3152
3217
  std::string vars() override {
3153
- return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3218
+ return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3154
3219
  }
3155
3220
 
3156
3221
  double max_nmse_err() override {
@@ -3165,9 +3230,9 @@ struct test_flash_attn_ext : public test_case {
3165
3230
  }
3166
3231
 
3167
3232
  test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3168
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
3169
- std::array<int32_t, 4> permute = {0, 1, 2, 3})
3170
- : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3233
+ bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
3234
+ ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
3235
+ : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3171
3236
 
3172
3237
  ggml_tensor * build_graph(ggml_context * ctx) override {
3173
3238
  const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3201,6 +3266,7 @@ struct test_flash_attn_ext : public test_case {
3201
3266
  }
3202
3267
 
3203
3268
  ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
3269
+ ggml_flash_attn_ext_set_prec(out, prec);
3204
3270
  ggml_set_name(out, "out");
3205
3271
 
3206
3272
  return out;
@@ -4006,8 +4072,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4006
4072
  test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
4007
4073
  }
4008
4074
  test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
4075
+ test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
4009
4076
  }
4010
4077
 
4078
+ test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
4079
+
4011
4080
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
4012
4081
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
4013
4082
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
@@ -4019,6 +4088,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4019
4088
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
4020
4089
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
4021
4090
 
4091
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
4092
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
4093
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
4094
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
4095
+
4022
4096
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
4023
4097
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
4024
4098
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
@@ -4308,11 +4382,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4308
4382
  for (int kv : { 512, 1024, }) {
4309
4383
  if (nr != 1 && kv != 512) continue;
4310
4384
  for (int nb : { 1, 3, 32, 35, }) {
4311
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4312
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
4313
- // run fewer test cases permuted
4314
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4315
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4385
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4386
+ if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
4387
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4388
+ test_cases.emplace_back(new test_flash_attn_ext(
4389
+ hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4390
+ // run fewer test cases permuted
4391
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4392
+ test_cases.emplace_back(new test_flash_attn_ext(
4393
+ hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4394
+ }
4316
4395
  }
4317
4396
  }
4318
4397
  }
@@ -1,143 +0,0 @@
1
- #include <sycl/sycl.hpp>
2
- #include "wkv6.hpp"
3
-
4
- constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
-
6
- // Helper function for the main kernel
7
- static void rwkv_wkv_f32_kernel(
8
- const int B, const int T, const int C, const int H,
9
- const float* k, const float* v, const float* r,
10
- const float* tf, const float* td, const float* s,
11
- float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
12
-
13
- const int tid = item_ct1.get_local_id(2);
14
- const int bid = item_ct1.get_group(2);
15
-
16
- const int head_size = WKV_BLOCK_SIZE;
17
- const int batch_i = bid / H;
18
- const int head_i = bid % H;
19
- const int state_size = C * head_size;
20
- const int n_seq_tokens = T / B;
21
-
22
- // Set up shared memory pointers
23
- float* _k = shared_mem;
24
- float* _r = _k + head_size;
25
- float* _tf = _r + head_size;
26
- float* _td = _tf + head_size;
27
-
28
- // Local state array
29
- float state[WKV_BLOCK_SIZE];
30
-
31
- // Load initial state
32
- #pragma unroll
33
- for (int i = 0; i < head_size; i++) {
34
- state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35
- }
36
-
37
- // Sync threads before shared memory operations
38
- item_ct1.barrier(sycl::access::fence_space::local_space);
39
-
40
- // Load time-mixing parameters
41
- _tf[tid] = tf[head_i * head_size + tid];
42
- item_ct1.barrier(sycl::access::fence_space::local_space);
43
-
44
- // Main sequence processing loop
45
- for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46
- t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
47
- t += C) {
48
-
49
- item_ct1.barrier(sycl::access::fence_space::local_space);
50
-
51
- // Load current timestep data to shared memory
52
- _k[tid] = k[t];
53
- _r[tid] = r[t];
54
- _td[tid] = td[t];
55
-
56
- item_ct1.barrier(sycl::access::fence_space::local_space);
57
-
58
- const float _v = v[t];
59
- float y = 0;
60
-
61
- // Process in chunks of 4 for better vectorization
62
- sycl::float4 k4, r4, tf4, td4, s4;
63
- #pragma unroll
64
- for (int j = 0; j < head_size; j += 4) {
65
- // Load data in vec4 chunks
66
- k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
67
- r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
68
- tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
69
- td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
70
- s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
71
-
72
- // Compute key-value product
73
- sycl::float4 kv4 = k4 * _v;
74
-
75
- // Accumulate weighted sum
76
- y += sycl::dot(r4, tf4 * kv4 + s4);
77
-
78
- // Update state
79
- s4 = s4 * td4 + kv4;
80
-
81
- // Store updated state
82
- state[j] = s4.x();
83
- state[j+1] = s4.y();
84
- state[j+2] = s4.z();
85
- state[j+3] = s4.w();
86
- }
87
-
88
- dst[t] = y;
89
- }
90
-
91
- // Save final state
92
- #pragma unroll
93
- for (int i = 0; i < head_size; i++) {
94
- dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
95
- }
96
- }
97
-
98
- void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
99
-
100
- const ggml_tensor *src0 = dst->src[0];
101
- const ggml_tensor *src1 = dst->src[1];
102
-
103
- const float* k_d = (const float*)dst->src[0]->data;
104
- const float* v_d = (const float*)dst->src[1]->data;
105
- const float* r_d = (const float*)dst->src[2]->data;
106
- const float* tf_d = (const float*)dst->src[3]->data;
107
- const float* td_d = (const float*)dst->src[4]->data;
108
- const float* s_d = (const float*)dst->src[5]->data;
109
- float* dst_d = (float*)dst->data;
110
-
111
- const int64_t B = dst->src[5]->ne[1];
112
- const int64_t T = dst->src[0]->ne[2];
113
- const int64_t C = dst->ne[0];
114
- const int64_t H = dst->src[0]->ne[1];
115
-
116
- GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
117
- GGML_ASSERT(C % H == 0);
118
- GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
119
-
120
- dpct::queue_ptr stream = ctx.stream();
121
-
122
- // Calculate execution configuration
123
- const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
124
- sycl::range<3> block_dims(1, 1, C / H);
125
- sycl::range<3> grid_dims(1, 1, B * H);
126
-
127
- // Submit kernel
128
- stream->submit([&](sycl::handler& cgh) {
129
- sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
130
-
131
- cgh.parallel_for(
132
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
133
- [=](sycl::nd_item<3> item_ct1) {
134
- rwkv_wkv_f32_kernel(
135
- B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
136
- item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
137
- );
138
- });
139
- });
140
-
141
- GGML_UNUSED(src0);
142
- GGML_UNUSED(src1);
143
- }
@@ -1,9 +0,0 @@
1
- #ifndef GGML_SYCL_WKV6_HPP
2
- #define GGML_SYCL_WKV6_HPP
3
-
4
- #include "common.hpp"
5
-
6
- void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7
-
8
-
9
- #endif // GGML_SYCL_WKV6_HPP