@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
180
180
  }
181
181
  }
182
182
 
183
+ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
184
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
185
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
186
+ item_ct1.get_local_id(1);
187
+ const int tid = item_ct1.get_local_id(2);
188
+ const int nthreads = item_ct1.get_local_range(2);
189
+ const int nwarps = nthreads / WARP_SIZE;
190
+ float tmp = 0.0f; // partial sum for thread in warp
191
+
192
+ for (int col = tid; col < ncols; col += block_size) {
193
+ const float xi = x[row * ncols + col];
194
+ tmp += xi * xi;
195
+ }
196
+
197
+ // sum up partial sums
198
+ tmp = warp_reduce_sum(tmp, item_ct1);
199
+ if (block_size > WARP_SIZE) {
200
+
201
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
202
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
203
+ if (lane_id == 0) {
204
+ s_sum[warp_id] = tmp;
205
+ }
206
+ /*
207
+ DPCT1118:3: SYCL group functions and algorithms must be encountered in
208
+ converged control flow. You may need to adjust the code.
209
+ */
210
+ item_ct1.barrier(sycl::access::fence_space::local_space);
211
+ size_t nreduce = nwarps / WARP_SIZE;
212
+ tmp = 0.f;
213
+ for (size_t i = 0; i < nreduce; i += 1)
214
+ {
215
+ tmp += s_sum[lane_id + i * WARP_SIZE];
216
+ }
217
+ tmp = warp_reduce_sum(tmp, item_ct1);
218
+ }
219
+
220
+ const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
221
+
222
+ for (int col = tid; col < ncols; col += block_size) {
223
+ dst[row * ncols + col] = scale * x[row * ncols + col];
224
+ }
225
+ }
226
+
183
227
  static void norm_f32_sycl(const float* x, float* dst, const int ncols,
184
228
  const int nrows, const float eps,
185
229
  queue_ptr stream, int device) {
@@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
191
235
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
192
236
  block_dims),
193
237
  [=](sycl::nd_item<3> item_ct1)
194
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
238
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
195
239
  norm_f32(x, dst, ncols, eps, item_ct1,
196
240
  nullptr, WARP_SIZE);
197
241
  });
@@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
214
258
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
215
259
  block_dims),
216
260
  [=](sycl::nd_item<3> item_ct1)
217
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
261
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
218
262
  norm_f32(x, dst, ncols, eps, item_ct1,
219
263
  get_pointer(s_sum_acc_ct1), work_group_size);
220
264
  });
@@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
233
277
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
234
278
  block_dims),
235
279
  [=](sycl::nd_item<3> item_ct1)
236
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
280
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
237
281
  group_norm_f32(
238
282
  x, dst, group_size, ne_elements, eps_ct4, item_ct1,
239
283
  nullptr, WARP_SIZE);
@@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
260
304
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
261
305
  block_dims),
262
306
  [=](sycl::nd_item<3> item_ct1)
263
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
307
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
264
308
  group_norm_f32(x, dst, group_size, ne_elements,
265
309
  eps_ct4, item_ct1,
266
310
  get_pointer(s_sum_acc_ct1), work_group_size);
@@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
281
325
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
282
326
  block_dims),
283
327
  [=](sycl::nd_item<3> item_ct1)
284
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
328
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
285
329
  rms_norm_f32(x, dst, ncols, eps, item_ct1,
286
330
  nullptr, WARP_SIZE);
287
331
  });
@@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
303
347
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
304
348
  block_dims),
305
349
  [=](sycl::nd_item<3> item_ct1)
306
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
350
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
307
351
  rms_norm_f32(x, dst, ncols, eps, item_ct1,
308
352
  get_pointer(s_sum_acc_ct1), work_group_size);
309
353
  });
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
311
355
  }
312
356
  }
313
357
 
358
+ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
359
+ const int nrows, const float eps,
360
+ queue_ptr stream, int device) {
361
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
362
+ // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
363
+ if (ncols < 1024) {
364
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
365
+ stream->submit([&](sycl::handler& cgh) {
366
+ cgh.parallel_for(
367
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
368
+ block_dims),
369
+ [=](sycl::nd_item<3> item_ct1)
370
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
371
+ l2_norm_f32(x, dst, ncols, eps, item_ct1,
372
+ nullptr, WARP_SIZE);
373
+ });
374
+ });
375
+ }
376
+ else {
377
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
378
+ assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
379
+ const sycl::range<3> block_dims(1, 1, work_group_size);
380
+ /*
381
+ DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
382
+ the limit. To get the device limit, query
383
+ info::device::max_work_group_size. Adjust the work-group size if needed.
384
+ */
385
+ stream->submit([&](sycl::handler& cgh) {
386
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
387
+ cgh);
388
+ cgh.parallel_for(
389
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
390
+ block_dims),
391
+ [=](sycl::nd_item<3> item_ct1)
392
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
393
+ l2_norm_f32(x, dst, ncols, eps, item_ct1,
394
+ get_pointer(s_sum_acc_ct1), work_group_size);
395
+ });
396
+ });
397
+ }
398
+ }
399
+
314
400
  void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
315
401
  ggml_tensor* dst, const float* src0_dd,
316
402
  const float* src1_dd, float* dst_dd,
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
376
462
  (void)dst;
377
463
  (void)src1_dd;
378
464
  }
465
+
466
+ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
467
+ const ggml_tensor* src1, ggml_tensor* dst,
468
+ const float* src0_dd, const float* src1_dd,
469
+ float* dst_dd,
470
+ const queue_ptr& main_stream) {
471
+
472
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
473
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
474
+
475
+ const int64_t ne00 = src0->ne[0];
476
+ const int64_t nrows = ggml_nrows(src0);
477
+
478
+ float eps;
479
+ memcpy(&eps, dst->op_params, sizeof(float));
480
+
481
+ l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
482
+
483
+ (void)src1;
484
+ (void)dst;
485
+ (void)src1_dd;
486
+ }
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
32
32
  float* dst_dd,
33
33
  const queue_ptr& main_stream);
34
34
 
35
+ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
36
+ const ggml_tensor* src1, ggml_tensor* dst,
37
+ const float* src0_dd, const float* src1_dd,
38
+ float* dst_dd,
39
+ const queue_ptr& main_stream);
40
+
35
41
  #endif // GGML_SYCL_NORM_HPP
@@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
132
132
 
133
133
  cgh.parallel_for(
134
134
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
135
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
135
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
136
136
  soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137
137
  nrows_y, scale, max_bias, m0,
138
138
  m1, n_head_log2, item_ct1,
@@ -0,0 +1,305 @@
1
+ #include <sycl/sycl.hpp>
2
+ #include "wkv.hpp"
3
+
4
+ constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
+
6
+ // Helper function for the main kernel
7
+ template <int block_size>
8
+ static void rwkv_wkv6_f32_kernel(
9
+ const int B, const int T, const int C, const int H,
10
+ const float* k, const float* v, const float* r,
11
+ const float* tf, const float* td, const float* s,
12
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
13
+
14
+ const int tid = item_ct1.get_local_id(2);
15
+ const int bid = item_ct1.get_group(2);
16
+
17
+ const int head_size = block_size;
18
+ const int batch_i = bid / H;
19
+ const int head_i = bid % H;
20
+ const int state_size = C * head_size;
21
+ const int n_seq_tokens = T / B;
22
+
23
+ // Set up shared memory pointers
24
+ float* _k = shared_mem;
25
+ float* _r = _k + head_size;
26
+ float* _tf = _r + head_size;
27
+ float* _td = _tf + head_size;
28
+
29
+ // Local state array
30
+ float state[block_size];
31
+
32
+ // Load initial state
33
+ #pragma unroll
34
+ for (int i = 0; i < head_size; i++) {
35
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
36
+ }
37
+
38
+ // Sync threads before shared memory operations
39
+ item_ct1.barrier(sycl::access::fence_space::local_space);
40
+
41
+ // Load time-mixing parameters
42
+ _tf[tid] = tf[head_i * head_size + tid];
43
+ item_ct1.barrier(sycl::access::fence_space::local_space);
44
+
45
+ // Main sequence processing loop
46
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
47
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
48
+ t += C) {
49
+
50
+ item_ct1.barrier(sycl::access::fence_space::local_space);
51
+
52
+ // Load current timestep data to shared memory
53
+ _k[tid] = k[t];
54
+ _r[tid] = r[t];
55
+ _td[tid] = td[t];
56
+
57
+ item_ct1.barrier(sycl::access::fence_space::local_space);
58
+
59
+ const float _v = v[t];
60
+ float y = 0;
61
+
62
+ // Process in chunks of 4 for better vectorization
63
+ sycl::float4 k4, r4, tf4, td4, s4;
64
+ #pragma unroll
65
+ for (int j = 0; j < head_size; j += 4) {
66
+ // Load data in vec4 chunks
67
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
68
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
69
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
70
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
71
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
72
+
73
+ // Compute key-value product
74
+ sycl::float4 kv4 = k4 * _v;
75
+
76
+ // Accumulate weighted sum
77
+ y += sycl::dot(r4, tf4 * kv4 + s4);
78
+
79
+ // Update state
80
+ s4 = s4 * td4 + kv4;
81
+
82
+ // Store updated state
83
+ state[j] = s4.x();
84
+ state[j+1] = s4.y();
85
+ state[j+2] = s4.z();
86
+ state[j+3] = s4.w();
87
+ }
88
+
89
+ dst[t] = y;
90
+ }
91
+
92
+ // Save final state
93
+ #pragma unroll
94
+ for (int i = 0; i < head_size; i++) {
95
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
96
+ }
97
+ }
98
+
99
+ template <int block_size>
100
+ static void rwkv_wkv7_f32_kernel(
101
+ const int B, const int T, const int C, const int H,
102
+ const float* r, const float* w, const float* k, const float* v,
103
+ const float* a, const float* b, const float* s,
104
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
105
+
106
+ const int tid = item_ct1.get_local_id(2);
107
+ const int bid = item_ct1.get_group(2);
108
+
109
+ const int head_size = block_size;
110
+ const int batch_i = bid / H;
111
+ const int head_i = bid % H;
112
+ const int state_size = C * head_size;
113
+ const int n_seq_tokens = T / B;
114
+
115
+ float* _r = shared_mem;
116
+ float* _w = _r + head_size;
117
+ float* _k = _w + head_size;
118
+ float* _a = _k + head_size;
119
+ float* _b = _a + head_size;
120
+
121
+ float state[block_size];
122
+
123
+ #pragma unroll
124
+ for (int i = 0; i < head_size; i++) {
125
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
126
+ }
127
+
128
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
129
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
130
+ t += C) {
131
+
132
+ item_ct1.barrier(sycl::access::fence_space::local_space);
133
+
134
+ _r[tid] = r[t];
135
+ _w[tid] = w[t];
136
+ _k[tid] = k[t];
137
+ _a[tid] = a[t];
138
+ _b[tid] = b[t];
139
+
140
+ item_ct1.barrier(sycl::access::fence_space::local_space);
141
+
142
+ const float _v = v[t];
143
+ float y = 0, sa = 0;
144
+ sycl::float4 a4, s4;
145
+
146
+ #pragma unroll
147
+ for (int j = 0; j < head_size; j += 4) {
148
+ a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
149
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
150
+ sa += sycl::dot(a4, s4);
151
+ }
152
+
153
+ sycl::float4 r4, w4, k4, b4;
154
+ #pragma unroll
155
+ for (int j = 0; j < head_size; j += 4) {
156
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
157
+ w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
158
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
159
+ b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
160
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
161
+
162
+ sycl::float4 kv4 = k4 * _v;
163
+
164
+ s4 = s4 * w4 + kv4 + sa * b4;
165
+ y += sycl::dot(r4, s4);
166
+
167
+ state[j] = s4.x();
168
+ state[j+1] = s4.y();
169
+ state[j+2] = s4.z();
170
+ state[j+3] = s4.w();
171
+ }
172
+
173
+ dst[t] = y;
174
+ }
175
+
176
+ #pragma unroll
177
+ for (int i = 0; i < head_size; i++) {
178
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
179
+ }
180
+ }
181
+
182
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
183
+
184
+ const ggml_tensor *src0 = dst->src[0];
185
+ const ggml_tensor *src1 = dst->src[1];
186
+
187
+ const float* k_d = (const float*)dst->src[0]->data;
188
+ const float* v_d = (const float*)dst->src[1]->data;
189
+ const float* r_d = (const float*)dst->src[2]->data;
190
+ const float* tf_d = (const float*)dst->src[3]->data;
191
+ const float* td_d = (const float*)dst->src[4]->data;
192
+ const float* s_d = (const float*)dst->src[5]->data;
193
+ float* dst_d = (float*)dst->data;
194
+
195
+ const int64_t B = dst->src[5]->ne[1];
196
+ const int64_t T = dst->src[0]->ne[2];
197
+ const int64_t C = dst->ne[0];
198
+ const int64_t H = dst->src[0]->ne[1];
199
+
200
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
201
+ GGML_ASSERT(C % H == 0);
202
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
203
+
204
+ dpct::queue_ptr stream = ctx.stream();
205
+
206
+ // Calculate execution configuration
207
+ const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
208
+ sycl::range<3> block_dims(1, 1, C / H);
209
+ sycl::range<3> grid_dims(1, 1, B * H);
210
+
211
+ // Submit kernel
212
+ if (C / H == WKV_BLOCK_SIZE) {
213
+ stream->submit([&](sycl::handler& cgh) {
214
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
215
+
216
+ cgh.parallel_for(
217
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
218
+ [=](sycl::nd_item<3> item_ct1) {
219
+ rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
220
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
221
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
222
+ );
223
+ });
224
+ });
225
+ } else {
226
+ stream->submit([&](sycl::handler& cgh) {
227
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
228
+
229
+ cgh.parallel_for(
230
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
231
+ [=](sycl::nd_item<3> item_ct1) {
232
+ rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
233
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
234
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
235
+ );
236
+ });
237
+ });
238
+ }
239
+
240
+ GGML_UNUSED(src0);
241
+ GGML_UNUSED(src1);
242
+ }
243
+
244
+ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
245
+
246
+ const ggml_tensor *src0 = dst->src[0];
247
+ const ggml_tensor *src1 = dst->src[1];
248
+
249
+ const float* r_d = (const float*)dst->src[0]->data;
250
+ const float* w_d = (const float*)dst->src[1]->data;
251
+ const float* k_d = (const float*)dst->src[2]->data;
252
+ const float* v_d = (const float*)dst->src[3]->data;
253
+ const float* a_d = (const float*)dst->src[4]->data;
254
+ const float* b_d = (const float*)dst->src[5]->data;
255
+ const float* s_d = (const float*)dst->src[6]->data;
256
+ float* dst_d = (float*)dst->data;
257
+
258
+ const int64_t B = dst->src[6]->ne[1];
259
+ const int64_t T = dst->src[0]->ne[2];
260
+ const int64_t C = dst->ne[0];
261
+ const int64_t H = dst->src[0]->ne[1];
262
+
263
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
264
+ GGML_ASSERT(C % H == 0);
265
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
266
+
267
+ dpct::queue_ptr stream = ctx.stream();
268
+
269
+ // Calculate execution configuration
270
+ const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
271
+ sycl::range<3> block_dims(1, 1, C / H);
272
+ sycl::range<3> grid_dims(1, 1, B * H);
273
+
274
+ // Submit kernel
275
+ if (C / H == WKV_BLOCK_SIZE) {
276
+ stream->submit([&](sycl::handler& cgh) {
277
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
278
+
279
+ cgh.parallel_for(
280
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
281
+ [=](sycl::nd_item<3> item_ct1) {
282
+ rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
283
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
284
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
285
+ );
286
+ });
287
+ });
288
+ } else {
289
+ stream->submit([&](sycl::handler& cgh) {
290
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
291
+
292
+ cgh.parallel_for(
293
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
294
+ [=](sycl::nd_item<3> item_ct1) {
295
+ rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
296
+ B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
297
+ item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
298
+ );
299
+ });
300
+ });
301
+ }
302
+
303
+ GGML_UNUSED(src0);
304
+ GGML_UNUSED(src1);
305
+ }
@@ -0,0 +1,10 @@
1
+ #ifndef GGML_SYCL_WKV_HPP
2
+ #define GGML_SYCL_WKV_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7
+
8
+ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
9
+
10
+ #endif // GGML_SYCL_WKV_HPP