@fugood/llama.node 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +3 -1
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -0,0 +1,105 @@
1
+ #include <sycl/sycl.hpp>
2
+
3
+ #include "common.hpp"
4
+
5
+ template <u_int HEAD_SIZE>
6
+ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale,
7
+ const float * k, const float * v, const float * r, const float * td,
8
+ const float * s, float * dst) {
9
+ const u_int head_size = HEAD_SIZE;
10
+ const u_int state_size = C * head_size;
11
+ const u_int n_seq_tokens = T / B;
12
+ sycl::range<1> block_dims((C / H));
13
+ sycl::range<1> grid_dims((B * H));
14
+ stream->submit([&](sycl::handler & cgh) {
15
+ /* local memory accessors*/
16
+ auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
17
+ auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
18
+ auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
19
+
20
+ cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
21
+ u_int tid = item.get_local_id(0);
22
+ u_int bid = item.get_group(0);
23
+
24
+ u_int batch_i = bid / H;
25
+ u_int head_i = bid % H;
26
+
27
+ float state[head_size];
28
+
29
+ #pragma unroll
30
+ for (u_int i = 0; i < head_size; i++) {
31
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
32
+ }
33
+
34
+ for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
35
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
36
+
37
+ item.barrier(sycl::access::fence_space::local_space); //sync threads
38
+ _k[tid] = k[t];
39
+ _r[tid] = r[t];
40
+ _td[tid] = td[t];
41
+ item.barrier(sycl::access::fence_space::local_space); //sync threads
42
+
43
+ const float _v = v[t];
44
+ float y = 0;
45
+
46
+ for (u_int j = 0; j < head_size; j += 4) {
47
+ const sycl::float4 & k = (sycl::float4 &) (_k[j]);
48
+ const sycl::float4 & r = (sycl::float4 &) (_r[j]);
49
+ const sycl::float4 & td = (sycl::float4 &) (_td[j]);
50
+ sycl::float4 & s = (sycl::float4 &) (state[j]);
51
+ sycl::float4 kv;
52
+
53
+ kv.x() = k.x() * _v;
54
+ kv.y() = k.y() * _v;
55
+ kv.z() = k.z() * _v;
56
+ kv.w() = k.w() * _v;
57
+
58
+ s.x() = s.x() * td.x() + kv.x();
59
+ s.y() = s.y() * td.y() + kv.y();
60
+ s.z() = s.z() * td.z() + kv.z();
61
+ s.w() = s.w() * td.w() + kv.w();
62
+
63
+ y += r.x() * s.x();
64
+ y += r.y() * s.y();
65
+ y += r.z() * s.z();
66
+ y += r.w() * s.w();
67
+ }
68
+ dst[t] = y * scale;
69
+ }
70
+ #pragma unroll
71
+ for (u_int i = 0; i < head_size; i++) {
72
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
73
+ }
74
+ });
75
+ });
76
+ }
77
+
78
+ void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
79
+ const float * k_d = static_cast<const float *>(dst->src[0]->data);
80
+ const float * v_d = static_cast<const float *>(dst->src[1]->data);
81
+ const float * r_d = static_cast<const float *>(dst->src[2]->data);
82
+ const float * td_d = static_cast<const float *>(dst->src[3]->data);
83
+ const float * s_d = static_cast<const float *>(dst->src[4]->data);
84
+
85
+ const int64_t B = dst->src[4]->ne[1];
86
+ const int64_t T = dst->src[0]->ne[2];
87
+ const int64_t C = dst->ne[0];
88
+ const int64_t H = dst->src[0]->ne[1];
89
+
90
+ dpct::queue_ptr stream = ctx.stream();
91
+ GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
92
+ GGML_ASSERT(C % H == 0);
93
+ GGML_ASSERT(C / H == 64 || C / H == 128);
94
+
95
+ float scale;
96
+ memcpy(&scale, dst->op_params, sizeof(float));
97
+
98
+ float * dst_d = (float *) dst->data;
99
+
100
+ if (C / H == 64) {
101
+ gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
102
+ } else {
103
+ gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
104
+ }
105
+ }
@@ -0,0 +1,8 @@
1
+ #ifndef GGML_SYCL_GLA_HPP
2
+ #define GGML_SYCL_GLA_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7
+
8
+ #endif // GGML_SYCL_GLA_HPP
@@ -3,9 +3,9 @@
3
3
  #include "outprod.hpp"
4
4
 
5
5
 
6
- void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
- const ggml_tensor* src1, ggml_tensor* dst) {
8
-
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
7
+ const ggml_tensor *src0 = dst->src[0];
8
+ const ggml_tensor *src1 = dst->src[1];
9
9
 
10
10
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
11
11
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -3,8 +3,7 @@
3
3
 
4
4
  #include "common.hpp"
5
5
 
6
- void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
- const ggml_tensor* src1, ggml_tensor* dst);
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
8
7
 
9
8
 
10
9
  #endif // GGML_SYCL_OUTPROD_HPP
@@ -55,8 +55,9 @@ static void timestep_embedding_f32_sycl(
55
55
  });
56
56
  }
57
57
 
58
- void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
59
- const ggml_tensor *src1, ggml_tensor * dst) {
58
+ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
59
+ const ggml_tensor *src0 = dst->src[0];
60
+ const ggml_tensor *src1 = dst->src[1];
60
61
  const float * src0_d = (const float *)src0->data;
61
62
  float * dst_d = (float *)dst->data;
62
63
  dpct::queue_ptr stream = ctx.stream();
@@ -15,7 +15,6 @@
15
15
 
16
16
  #include "common.hpp"
17
17
 
18
- void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19
- const ggml_tensor *src1, ggml_tensor * dst);
18
+ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
20
19
 
21
20
  #endif // GGML_SYCL_TSEMBD_HPP
@@ -95,8 +95,10 @@ static void rwkv_wkv_f32_kernel(
95
95
  }
96
96
  }
97
97
 
98
- void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
99
- const ggml_tensor* src1, ggml_tensor* dst) {
98
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
99
+
100
+ const ggml_tensor *src0 = dst->src[0];
101
+ const ggml_tensor *src1 = dst->src[1];
100
102
 
101
103
  const float* k_d = (const float*)dst->src[0]->data;
102
104
  const float* v_d = (const float*)dst->src[1]->data;
@@ -107,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
107
109
  float* dst_d = (float*)dst->data;
108
110
 
109
111
  const int64_t B = dst->src[5]->ne[1];
110
- const int64_t T = dst->src[0]->ne[3];
112
+ const int64_t T = dst->src[0]->ne[2];
111
113
  const int64_t C = dst->ne[0];
112
- const int64_t H = dst->src[0]->ne[2];
114
+ const int64_t H = dst->src[0]->ne[1];
113
115
 
114
116
  GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
115
117
  GGML_ASSERT(C % H == 0);
@@ -131,7 +133,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
131
133
  [=](sycl::nd_item<3> item_ct1) {
132
134
  rwkv_wkv_f32_kernel(
133
135
  B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
134
- item_ct1, shared_mem_acc.get_pointer()
136
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
135
137
  );
136
138
  });
137
139
  });
@@ -3,8 +3,7 @@
3
3
 
4
4
  #include "common.hpp"
5
5
 
6
- void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
7
- const ggml_tensor *src1, ggml_tensor * dst);
6
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8
7
 
9
8
 
10
9
  #endif // GGML_SYCL_WKV6_HPP
@@ -1,5 +1,20 @@
1
+ cmake_minimum_required(VERSION 3.19)
2
+ cmake_policy(SET CMP0114 NEW)
3
+
1
4
  find_package(Vulkan COMPONENTS glslc REQUIRED)
2
5
 
6
+ function(detect_host_compiler)
7
+ if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
8
+ find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
9
+ find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
10
+ else()
11
+ find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
12
+ find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
13
+ endif()
14
+ set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
15
+ set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
16
+ endfunction()
17
+
3
18
  if (Vulkan_FOUND)
4
19
  message(STATUS "Vulkan found")
5
20
 
@@ -8,6 +23,20 @@ if (Vulkan_FOUND)
8
23
  ../../include/ggml-vulkan.h
9
24
  )
10
25
 
26
+ # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
27
+ # If it's not, there will be an error to stderr.
28
+ # If it's supported, set a define to indicate that we should compile those shaders
29
+ execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
30
+ OUTPUT_VARIABLE glslc_output
31
+ ERROR_VARIABLE glslc_error)
32
+
33
+ if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
34
+ message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
35
+ else()
36
+ message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
37
+ add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
38
+ endif()
39
+
11
40
  # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
12
41
  # If it's not, there will be an error to stderr.
13
42
  # If it's supported, set a define to indicate that we should compile those shaders
@@ -59,15 +88,56 @@ if (Vulkan_FOUND)
59
88
  add_compile_definitions(GGML_VULKAN_RUN_TESTS)
60
89
  endif()
61
90
 
62
- add_subdirectory(vulkan-shaders)
63
-
64
- set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
91
+ if (NOT CMAKE_CROSSCOMPILING)
92
+ add_subdirectory(vulkan-shaders)
93
+ if (MSVC)
94
+ foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
95
+ string(TOUPPER ${CONFIG} CONFIG)
96
+ set_target_properties(vulkan-shaders-gen PROPERTIES
97
+ RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
98
+ endforeach()
99
+ endif()
100
+ else()
101
+ if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
102
+ set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
103
+ else()
104
+ detect_host_compiler()
105
+ if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
106
+ message(FATAL_ERROR "Host compiler not found")
107
+ else()
108
+ message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
109
+ endif()
110
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
111
+ set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
112
+ endif()
113
+ message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
114
+
115
+ include(ExternalProject)
116
+ # Native build through ExternalProject_Add
117
+ ExternalProject_Add(
118
+ vulkan-shaders-gen
119
+ SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
120
+ CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
121
+ -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
122
+ BUILD_COMMAND ${CMAKE_COMMAND} --build .
123
+ INSTALL_COMMAND ${CMAKE_COMMAND} --install .
124
+ INSTALL_DIR ${CMAKE_BINARY_DIR}
125
+ )
126
+ ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
127
+ endif()
128
+ set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
129
+ set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
65
130
  set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
66
131
  set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
67
132
  set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
68
133
  set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
69
134
 
70
135
  file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
136
+ set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
137
+
138
+ if (CMAKE_CROSSCOMPILING)
139
+ set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
140
+ endif()
71
141
 
72
142
  add_custom_command(
73
143
  OUTPUT ${_ggml_vk_header}
@@ -81,7 +151,7 @@ if (Vulkan_FOUND)
81
151
  --target-cpp ${_ggml_vk_source}
82
152
  --no-clean
83
153
 
84
- DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd}
154
+ DEPENDS ${_ggml_vk_shader_deps}
85
155
  COMMENT "Generate vulkan shaders"
86
156
  )
87
157