@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -259,6 +259,10 @@ static std::string var_to_str(ggml_type type) {
259
259
  return ggml_type_name(type);
260
260
  }
261
261
 
262
+ static std::string var_to_str(ggml_prec prec) {
263
+ return prec == GGML_PREC_F32 ? "f32" : "def";
264
+ }
265
+
262
266
  static std::string var_to_str(ggml_op_pool pool) {
263
267
  switch (pool) {
264
268
  case GGML_OP_POOL_AVG: return "avg";
@@ -1459,11 +1463,13 @@ struct test_cpy : public test_case {
1459
1463
  const ggml_type type_src;
1460
1464
  const ggml_type type_dst;
1461
1465
  const std::array<int64_t, 4> ne;
1462
- const std::array<int64_t, 4> permute;
1466
+ const std::array<int64_t, 4> permute_src;
1467
+ const std::array<int64_t, 4> permute_dst;
1463
1468
  bool _src_use_permute;
1469
+ bool _dst_use_permute;
1464
1470
 
1465
1471
  std::string vars() override {
1466
- return VARS_TO_STR4(type_src, type_dst, ne, permute);
1472
+ return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
1467
1473
  }
1468
1474
 
1469
1475
  double max_nmse_err() override {
@@ -1476,9 +1482,11 @@ struct test_cpy : public test_case {
1476
1482
 
1477
1483
  test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
1478
1484
  std::array<int64_t, 4> ne = {10, 10, 10, 1},
1479
- std::array<int64_t, 4> permute = {0, 0, 0, 0})
1480
- : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
1481
- _src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
1485
+ std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
1486
+ std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
1487
+ : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
1488
+ _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
1489
+ _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
1482
1490
 
1483
1491
  ggml_tensor * build_graph(ggml_context * ctx) override {
1484
1492
  ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -1486,13 +1494,18 @@ struct test_cpy : public test_case {
1486
1494
  ggml_set_name(src, "src");
1487
1495
 
1488
1496
  if (_src_use_permute) {
1489
- src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
1497
+ src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
1490
1498
  ggml_set_name(src, "src_permuted");
1491
1499
  }
1492
1500
 
1493
- ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
1501
+ ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
1494
1502
  ggml_set_name(dst, "dst");
1495
1503
 
1504
+ if (_dst_use_permute) {
1505
+ dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
1506
+ ggml_set_name(dst, "dst_permuted");
1507
+ }
1508
+
1496
1509
  ggml_tensor * out = ggml_cpy(ctx, src, dst);
1497
1510
  ggml_set_name(out, "out");
1498
1511
 
@@ -1916,6 +1929,40 @@ struct test_gla : public test_case {
1916
1929
  }
1917
1930
  };
1918
1931
 
1932
+ // GGML_OP_RWKV_WKV7
1933
+ struct test_rwkv_wkv7 : public test_case {
1934
+ const ggml_type type;
1935
+
1936
+ const int64_t head_count;
1937
+ const int64_t head_size;
1938
+ const int64_t n_seq_tokens;
1939
+ const int64_t n_seqs;
1940
+
1941
+ std::string vars() override {
1942
+ return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
1943
+ }
1944
+
1945
+ test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
1946
+ int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
1947
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1948
+
1949
+ ggml_tensor * build_graph(ggml_context * ctx) override {
1950
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1951
+ ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1952
+ ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1953
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1954
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1955
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1956
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1957
+ // Outputs may become NaN with long seqlen without these normalization
1958
+ a = ggml_l2_norm(ctx, a, 1e-7F);
1959
+ b = ggml_l2_norm(ctx, b, 1e-7F);
1960
+ ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1961
+ ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
1962
+ return out;
1963
+ }
1964
+ };
1965
+
1919
1966
  // GGML_OP_MUL_MAT
1920
1967
  struct test_mul_mat : public test_case {
1921
1968
  const ggml_type type_a;
@@ -1926,9 +1973,10 @@ struct test_mul_mat : public test_case {
1926
1973
  const std::array<int64_t, 2> bs; // dims 3 and 4
1927
1974
  const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1928
1975
  const std::array<int64_t, 4> per; // permutation of dimensions
1976
+ const bool v; // whether a is a non-contiguous view
1929
1977
 
1930
1978
  std::string vars() override {
1931
- return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
1979
+ return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
1932
1980
  }
1933
1981
 
1934
1982
  double max_nmse_err() override {
@@ -1948,8 +1996,9 @@ struct test_mul_mat : public test_case {
1948
1996
  int64_t m = 32, int64_t n = 32, int64_t k = 32,
1949
1997
  std::array<int64_t, 2> bs = {10, 10},
1950
1998
  std::array<int64_t, 2> nr = {2, 2},
1951
- std::array<int64_t, 4> per = {0, 1, 2, 3})
1952
- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
1999
+ std::array<int64_t, 4> per = {0, 1, 2, 3},
2000
+ bool v = false)
2001
+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
1953
2002
 
1954
2003
  ggml_tensor * build_graph(ggml_context * ctx) override {
1955
2004
  // C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -1959,6 +2008,7 @@ struct test_mul_mat : public test_case {
1959
2008
  const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
1960
2009
  if (npermuted > 0) {
1961
2010
  GGML_ASSERT(npermuted == 2);
2011
+ GGML_ASSERT(!v); // not handled
1962
2012
  GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
1963
2013
  GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
1964
2014
 
@@ -1982,7 +2032,13 @@ struct test_mul_mat : public test_case {
1982
2032
  ggml_set_name(a, "a_permuted");
1983
2033
  ggml_set_name(b, "b_permuted");
1984
2034
  } else {
1985
- a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
2035
+
2036
+ if (v) {
2037
+ a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
2038
+ a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
2039
+ } else {
2040
+ a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
2041
+ }
1986
2042
  b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1987
2043
  if (!ggml_is_quantized(type_a)) {
1988
2044
  if (bs[1] == 1 && nr[1] == 1) {
@@ -2972,6 +3028,32 @@ struct test_group_norm : public test_case {
2972
3028
  }
2973
3029
  };
2974
3030
 
3031
+ // GGML_OP_L2_NORM
3032
+ struct test_l2_norm : public test_case {
3033
+ const ggml_type type;
3034
+ const std::array<int64_t, 4> ne;
3035
+ const float eps;
3036
+
3037
+ std::string vars() override {
3038
+ return VARS_TO_STR2(type, ne);
3039
+ }
3040
+
3041
+ test_l2_norm(ggml_type type = GGML_TYPE_F32,
3042
+ std::array<int64_t, 4> ne = {64, 64, 320, 1},
3043
+ float eps = 1e-12f)
3044
+ : type(type), ne(ne), eps(eps) {}
3045
+
3046
+ ggml_tensor * build_graph(ggml_context * ctx) override {
3047
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3048
+ ggml_set_name(a, "a");
3049
+
3050
+ ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
3051
+ ggml_set_name(out, "out");
3052
+
3053
+ return out;
3054
+ }
3055
+ };
3056
+
2975
3057
  // GGML_OP_ACC
2976
3058
  struct test_acc : public test_case {
2977
3059
  const ggml_type type;
@@ -3146,11 +3228,12 @@ struct test_flash_attn_ext : public test_case {
3146
3228
  const float max_bias; // ALiBi
3147
3229
  const float logit_softcap; // Gemma 2
3148
3230
 
3231
+ const ggml_prec prec;
3149
3232
  const ggml_type type_KV;
3150
3233
  std::array<int32_t, 4> permute;
3151
3234
 
3152
3235
  std::string vars() override {
3153
- return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
3236
+ return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3154
3237
  }
3155
3238
 
3156
3239
  double max_nmse_err() override {
@@ -3165,9 +3248,9 @@ struct test_flash_attn_ext : public test_case {
3165
3248
  }
3166
3249
 
3167
3250
  test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3168
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
3169
- std::array<int32_t, 4> permute = {0, 1, 2, 3})
3170
- : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
3251
+ bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
3252
+ ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
3253
+ : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3171
3254
 
3172
3255
  ggml_tensor * build_graph(ggml_context * ctx) override {
3173
3256
  const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3201,6 +3284,7 @@ struct test_flash_attn_ext : public test_case {
3201
3284
  }
3202
3285
 
3203
3286
  ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
3287
+ ggml_flash_attn_ext_set_prec(out, prec);
3204
3288
  ggml_set_name(out, "out");
3205
3289
 
3206
3290
  return out;
@@ -3929,14 +4013,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3929
4013
  test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
3930
4014
  }
3931
4015
 
3932
- for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4016
+ // same-type copy
4017
+ for (ggml_type type : all_types) {
4018
+ const auto nk = ggml_blck_size(type);
4019
+
4020
+ for (int k = 1; k < 4; ++k) {
4021
+ test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
4022
+ test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
4023
+ test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
4024
+ }
4025
+ }
4026
+
4027
+ for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
3933
4028
  for (ggml_type type_dst : all_types) {
3934
4029
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3935
4030
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3936
4031
  }
3937
4032
  }
3938
- for (ggml_type type_dst : {GGML_TYPE_F32}) {
3939
- for (ggml_type type_src : all_types) {
4033
+ for (ggml_type type_src : all_types) {
4034
+ for (ggml_type type_dst : {GGML_TYPE_F32}) {
3940
4035
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3941
4036
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3942
4037
  }
@@ -4006,8 +4101,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4006
4101
  test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
4007
4102
  }
4008
4103
  test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
4104
+ test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
4009
4105
  }
4010
4106
 
4107
+ test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
4108
+
4011
4109
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
4012
4110
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
4013
4111
  test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
@@ -4019,6 +4117,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4019
4117
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
4020
4118
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
4021
4119
 
4120
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
4121
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
4122
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
4123
+ test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
4124
+
4022
4125
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
4023
4126
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
4024
4127
  test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
@@ -4102,6 +4205,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4102
4205
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
4103
4206
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
4104
4207
 
4208
+ for (auto bs : {1,2,4,8}) {
4209
+ for (auto nr : {1,4}) {
4210
+ for (uint32_t m = 0; m < 2; ++m) {
4211
+ for (uint32_t k = 0; k < 2; ++k) {
4212
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4213
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4214
+ }
4215
+ }
4216
+ }
4217
+ }
4218
+
4105
4219
  // sycl backend will limit task global_range < MAX_INT
4106
4220
  // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
4107
4221
  // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
@@ -4308,11 +4422,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4308
4422
  for (int kv : { 512, 1024, }) {
4309
4423
  if (nr != 1 && kv != 512) continue;
4310
4424
  for (int nb : { 1, 3, 32, 35, }) {
4311
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4312
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
4313
- // run fewer test cases permuted
4314
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4315
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
4425
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4426
+ if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
4427
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4428
+ test_cases.emplace_back(new test_flash_attn_ext(
4429
+ hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4430
+ // run fewer test cases permuted
4431
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4432
+ test_cases.emplace_back(new test_flash_attn_ext(
4433
+ hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4434
+ }
4316
4435
  }
4317
4436
  }
4318
4437
  }
@@ -4365,6 +4484,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4365
4484
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
4366
4485
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
4367
4486
 
4487
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
4488
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
4489
+
4368
4490
  for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
4369
4491
  for (ggml_type type_a : all_types) {
4370
4492
  for (ggml_type type_b : {GGML_TYPE_F32}) {
@@ -1,143 +0,0 @@
1
- #include <sycl/sycl.hpp>
2
- #include "wkv6.hpp"
3
-
4
- constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
-
6
- // Helper function for the main kernel
7
- static void rwkv_wkv_f32_kernel(
8
- const int B, const int T, const int C, const int H,
9
- const float* k, const float* v, const float* r,
10
- const float* tf, const float* td, const float* s,
11
- float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
12
-
13
- const int tid = item_ct1.get_local_id(2);
14
- const int bid = item_ct1.get_group(2);
15
-
16
- const int head_size = WKV_BLOCK_SIZE;
17
- const int batch_i = bid / H;
18
- const int head_i = bid % H;
19
- const int state_size = C * head_size;
20
- const int n_seq_tokens = T / B;
21
-
22
- // Set up shared memory pointers
23
- float* _k = shared_mem;
24
- float* _r = _k + head_size;
25
- float* _tf = _r + head_size;
26
- float* _td = _tf + head_size;
27
-
28
- // Local state array
29
- float state[WKV_BLOCK_SIZE];
30
-
31
- // Load initial state
32
- #pragma unroll
33
- for (int i = 0; i < head_size; i++) {
34
- state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35
- }
36
-
37
- // Sync threads before shared memory operations
38
- item_ct1.barrier(sycl::access::fence_space::local_space);
39
-
40
- // Load time-mixing parameters
41
- _tf[tid] = tf[head_i * head_size + tid];
42
- item_ct1.barrier(sycl::access::fence_space::local_space);
43
-
44
- // Main sequence processing loop
45
- for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46
- t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
47
- t += C) {
48
-
49
- item_ct1.barrier(sycl::access::fence_space::local_space);
50
-
51
- // Load current timestep data to shared memory
52
- _k[tid] = k[t];
53
- _r[tid] = r[t];
54
- _td[tid] = td[t];
55
-
56
- item_ct1.barrier(sycl::access::fence_space::local_space);
57
-
58
- const float _v = v[t];
59
- float y = 0;
60
-
61
- // Process in chunks of 4 for better vectorization
62
- sycl::float4 k4, r4, tf4, td4, s4;
63
- #pragma unroll
64
- for (int j = 0; j < head_size; j += 4) {
65
- // Load data in vec4 chunks
66
- k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
67
- r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
68
- tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
69
- td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
70
- s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
71
-
72
- // Compute key-value product
73
- sycl::float4 kv4 = k4 * _v;
74
-
75
- // Accumulate weighted sum
76
- y += sycl::dot(r4, tf4 * kv4 + s4);
77
-
78
- // Update state
79
- s4 = s4 * td4 + kv4;
80
-
81
- // Store updated state
82
- state[j] = s4.x();
83
- state[j+1] = s4.y();
84
- state[j+2] = s4.z();
85
- state[j+3] = s4.w();
86
- }
87
-
88
- dst[t] = y;
89
- }
90
-
91
- // Save final state
92
- #pragma unroll
93
- for (int i = 0; i < head_size; i++) {
94
- dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
95
- }
96
- }
97
-
98
- void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
99
-
100
- const ggml_tensor *src0 = dst->src[0];
101
- const ggml_tensor *src1 = dst->src[1];
102
-
103
- const float* k_d = (const float*)dst->src[0]->data;
104
- const float* v_d = (const float*)dst->src[1]->data;
105
- const float* r_d = (const float*)dst->src[2]->data;
106
- const float* tf_d = (const float*)dst->src[3]->data;
107
- const float* td_d = (const float*)dst->src[4]->data;
108
- const float* s_d = (const float*)dst->src[5]->data;
109
- float* dst_d = (float*)dst->data;
110
-
111
- const int64_t B = dst->src[5]->ne[1];
112
- const int64_t T = dst->src[0]->ne[2];
113
- const int64_t C = dst->ne[0];
114
- const int64_t H = dst->src[0]->ne[1];
115
-
116
- GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
117
- GGML_ASSERT(C % H == 0);
118
- GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
119
-
120
- dpct::queue_ptr stream = ctx.stream();
121
-
122
- // Calculate execution configuration
123
- const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
124
- sycl::range<3> block_dims(1, 1, C / H);
125
- sycl::range<3> grid_dims(1, 1, B * H);
126
-
127
- // Submit kernel
128
- stream->submit([&](sycl::handler& cgh) {
129
- sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
130
-
131
- cgh.parallel_for(
132
- sycl::nd_range<3>(grid_dims * block_dims, block_dims),
133
- [=](sycl::nd_item<3> item_ct1) {
134
- rwkv_wkv_f32_kernel(
135
- B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
136
- item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
137
- );
138
- });
139
- });
140
-
141
- GGML_UNUSED(src0);
142
- GGML_UNUSED(src1);
143
- }
@@ -1,9 +0,0 @@
1
- #ifndef GGML_SYCL_WKV6_HPP
2
- #define GGML_SYCL_WKV6_HPP
3
-
4
- #include "common.hpp"
5
-
6
- void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7
-
8
-
9
- #endif // GGML_SYCL_WKV6_HPP