@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
968
968
  grid1[0] ^ signs[0], signs[0], std::minus<>());
969
969
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
970
970
  grid2[0] ^ signs[1], signs[1], std::minus<>());
971
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
972
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
971
+ sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
972
+ sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
973
973
  q8 += 8;
974
974
  aux32 >>= 7;
975
975
  }
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
1009
1009
  grid1[0] ^ signs0, signs0, std::minus<>());
1010
1010
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
1011
1011
  grid2[0] ^ signs1, signs1, std::minus<>());
1012
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
1013
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
1012
+ sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
1013
+ sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
1014
1014
  q8 += 8;
1015
1015
  }
1016
1016
  const float d =
@@ -0,0 +1,138 @@
1
+ #include <sycl/sycl.hpp>
2
+ #include "wkv6.hpp"
3
+
4
+ constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
+
6
+ // Helper function for the main kernel
7
+ static void rwkv_wkv_f32_kernel(
8
+ const int B, const int T, const int C, const int H,
9
+ const float* k, const float* v, const float* r,
10
+ const float* tf, const float* td, const float* s,
11
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
12
+
13
+ const int tid = item_ct1.get_local_id(2);
14
+ const int bid = item_ct1.get_group(2);
15
+
16
+ const int head_size = WKV_BLOCK_SIZE;
17
+ const int batch_i = bid / H;
18
+ const int head_i = bid % H;
19
+ const int state_size = C * head_size;
20
+ const int n_seq_tokens = T / B;
21
+
22
+ // Set up shared memory pointers
23
+ float* _k = shared_mem;
24
+ float* _r = _k + head_size;
25
+ float* _tf = _r + head_size;
26
+ float* _td = _tf + head_size;
27
+
28
+ // Local state array
29
+ float state[WKV_BLOCK_SIZE];
30
+
31
+ // Load initial state
32
+ #pragma unroll
33
+ for (int i = 0; i < head_size; i++) {
34
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35
+ }
36
+
37
+ // Sync threads before shared memory operations
38
+ item_ct1.barrier(sycl::access::fence_space::local_space);
39
+
40
+ // Load time-mixing parameters
41
+ _tf[tid] = tf[head_i * head_size + tid];
42
+ item_ct1.barrier(sycl::access::fence_space::local_space);
43
+
44
+ // Main sequence processing loop
45
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
47
+ t += C) {
48
+
49
+ item_ct1.barrier(sycl::access::fence_space::local_space);
50
+
51
+ // Load current timestep data to shared memory
52
+ _k[tid] = k[t];
53
+ _r[tid] = r[t];
54
+ _td[tid] = td[t];
55
+
56
+ item_ct1.barrier(sycl::access::fence_space::local_space);
57
+
58
+ const float _v = v[t];
59
+ float y = 0;
60
+
61
+ // Process in chunks of 4 for better vectorization
62
+ sycl::float4 k4, r4, tf4, td4, s4, kv4;
63
+ #pragma unroll
64
+ for (int j = 0; j < head_size; j += 4) {
65
+ // Load data in vec4 chunks
66
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
67
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
68
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
69
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
70
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
71
+
72
+ // Compute key-value product
73
+ sycl::float4 kv4 = k4 * _v;
74
+
75
+ // Accumulate weighted sum
76
+ y += sycl::dot(r4, tf4 * kv4 + s4);
77
+
78
+ // Update state
79
+ s4 = s4 * td4 + kv4;
80
+
81
+ // Store updated state
82
+ state[j] = s4.x();
83
+ state[j+1] = s4.y();
84
+ state[j+2] = s4.z();
85
+ state[j+3] = s4.w();
86
+ }
87
+
88
+ dst[t] = y;
89
+ }
90
+
91
+ // Save final state
92
+ #pragma unroll
93
+ for (int i = 0; i < head_size; i++) {
94
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
95
+ }
96
+ }
97
+
98
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
99
+ const ggml_tensor* src1, ggml_tensor* dst) {
100
+
101
+ const float* k_d = (const float*)dst->src[0]->data;
102
+ const float* v_d = (const float*)dst->src[1]->data;
103
+ const float* r_d = (const float*)dst->src[2]->data;
104
+ const float* tf_d = (const float*)dst->src[3]->data;
105
+ const float* td_d = (const float*)dst->src[4]->data;
106
+ const float* s_d = (const float*)dst->src[5]->data;
107
+ float* dst_d = (float*)dst->data;
108
+
109
+ const int64_t B = dst->src[5]->ne[1];
110
+ const int64_t T = dst->src[0]->ne[3];
111
+ const int64_t C = dst->ne[0];
112
+ const int64_t H = dst->src[0]->ne[2];
113
+
114
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
115
+ GGML_ASSERT(C % H == 0);
116
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
117
+
118
+ dpct::queue_ptr stream = ctx.stream();
119
+
120
+ // Calculate execution configuration
121
+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
122
+ sycl::range<3> block_dims(1, 1, C / H);
123
+ sycl::range<3> grid_dims(1, 1, B * H);
124
+
125
+ // Submit kernel
126
+ stream->submit([&](sycl::handler& cgh) {
127
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
128
+
129
+ cgh.parallel_for(
130
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
131
+ [=](sycl::nd_item<3> item_ct1) {
132
+ rwkv_wkv_f32_kernel(
133
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
134
+ item_ct1, shared_mem_acc.get_pointer()
135
+ );
136
+ });
137
+ });
138
+ }
@@ -0,0 +1,10 @@
1
+ #ifndef GGML_SYCL_WKV6_HPP
2
+ #define GGML_SYCL_WKV6_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
7
+ const ggml_tensor *src1, ggml_tensor * dst);
8
+
9
+
10
+ #endif // GGML_SYCL_WKV6_HPP
@@ -0,0 +1,12 @@
1
+ #include "ggml-threading.h"
2
+ #include <mutex>
3
+
4
+ std::mutex ggml_critical_section_mutex;
5
+
6
+ void ggml_critical_section_start() {
7
+ ggml_critical_section_mutex.lock();
8
+ }
9
+
10
+ void ggml_critical_section_end(void) {
11
+ ggml_critical_section_mutex.unlock();
12
+ }
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #ifdef __cplusplus
4
+ extern "C" {
5
+ #endif
6
+
7
+ void ggml_critical_section_start(void);
8
+ void ggml_critical_section_end(void);
9
+
10
+ #ifdef __cplusplus
11
+ }
12
+ #endif
@@ -0,0 +1,78 @@
1
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
2
+
3
+ if (Vulkan_FOUND)
4
+ message(STATUS "Vulkan found")
5
+
6
+ add_library(ggml-vulkan
7
+ ggml-vulkan.cpp
8
+ ../../include/ggml-vulkan.h
9
+ )
10
+
11
+ target_link_libraries(ggml-vulkan PRIVATE ggml-base Vulkan::Vulkan)
12
+ target_include_directories(ggml-vulkan PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
13
+
14
+ # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
15
+ # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
16
+ if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
17
+ add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
18
+ endif()
19
+
20
+ if (GGML_VULKAN_CHECK_RESULTS)
21
+ add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
22
+ endif()
23
+
24
+ if (GGML_VULKAN_DEBUG)
25
+ add_compile_definitions(GGML_VULKAN_DEBUG)
26
+ endif()
27
+
28
+ if (GGML_VULKAN_MEMORY_DEBUG)
29
+ add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
30
+ endif()
31
+
32
+ if (GGML_VULKAN_SHADER_DEBUG_INFO)
33
+ add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
34
+ endif()
35
+
36
+ if (GGML_VULKAN_PERF)
37
+ add_compile_definitions(GGML_VULKAN_PERF)
38
+ endif()
39
+
40
+ if (GGML_VULKAN_VALIDATE)
41
+ add_compile_definitions(GGML_VULKAN_VALIDATE)
42
+ endif()
43
+
44
+ if (GGML_VULKAN_RUN_TESTS)
45
+ add_compile_definitions(GGML_VULKAN_RUN_TESTS)
46
+ endif()
47
+
48
+ add_subdirectory(vulkan-shaders)
49
+
50
+ set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
51
+ set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
52
+ set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
53
+ set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
54
+ set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
55
+
56
+ file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
57
+
58
+ add_custom_command(
59
+ OUTPUT ${_ggml_vk_header}
60
+ ${_ggml_vk_source}
61
+
62
+ COMMAND ${_ggml_vk_genshaders_cmd}
63
+ --glslc ${Vulkan_GLSLC_EXECUTABLE}
64
+ --input-dir ${_ggml_vk_input_dir}
65
+ --output-dir ${_ggml_vk_output_dir}
66
+ --target-hpp ${_ggml_vk_header}
67
+ --target-cpp ${_ggml_vk_source}
68
+ --no-clean
69
+
70
+ DEPENDS ${_ggml_vk_shader_deps}
71
+ COMMENT "Generate vulkan shaders"
72
+ )
73
+
74
+ target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
75
+
76
+ else()
77
+ message(WARNING "Vulkan not found")
78
+ endif()