@fugood/llama.node 0.3.14 → 0.3.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/llama.cpp/.github/workflows/build.yml +30 -1
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/arg.cpp +20 -2
- package/src/llama.cpp/common/common.cpp +6 -3
- package/src/llama.cpp/common/speculative.cpp +4 -4
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +6 -6
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/run.cpp +91 -46
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +37 -15
- package/src/llama.cpp/examples/server/utils.hpp +3 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/tts/tts.cpp +20 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +24 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
- package/src/llama.cpp/ggml/src/ggml.c +85 -2
- package/src/llama.cpp/include/llama.h +86 -22
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +103 -16
- package/src/llama.cpp/src/llama-arch.h +18 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -110
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-model.cpp +8244 -173
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama.cpp +51 -9984
- package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
|
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|
|
180
180
|
}
|
|
181
181
|
}
|
|
182
182
|
|
|
183
|
+
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
|
184
|
+
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
|
185
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
|
186
|
+
item_ct1.get_local_id(1);
|
|
187
|
+
const int tid = item_ct1.get_local_id(2);
|
|
188
|
+
const int nthreads = item_ct1.get_local_range(2);
|
|
189
|
+
const int nwarps = nthreads / WARP_SIZE;
|
|
190
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
|
191
|
+
|
|
192
|
+
for (int col = tid; col < ncols; col += block_size) {
|
|
193
|
+
const float xi = x[row * ncols + col];
|
|
194
|
+
tmp += xi * xi;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
// sum up partial sums
|
|
198
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
|
199
|
+
if (block_size > WARP_SIZE) {
|
|
200
|
+
|
|
201
|
+
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
|
202
|
+
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
|
203
|
+
if (lane_id == 0) {
|
|
204
|
+
s_sum[warp_id] = tmp;
|
|
205
|
+
}
|
|
206
|
+
/*
|
|
207
|
+
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
|
208
|
+
converged control flow. You may need to adjust the code.
|
|
209
|
+
*/
|
|
210
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
211
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
|
212
|
+
tmp = 0.f;
|
|
213
|
+
for (size_t i = 0; i < nreduce; i += 1)
|
|
214
|
+
{
|
|
215
|
+
tmp += s_sum[lane_id + i * WARP_SIZE];
|
|
216
|
+
}
|
|
217
|
+
tmp = warp_reduce_sum(tmp, item_ct1);
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
|
221
|
+
|
|
222
|
+
for (int col = tid; col < ncols; col += block_size) {
|
|
223
|
+
dst[row * ncols + col] = scale * x[row * ncols + col];
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
183
227
|
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
184
228
|
const int nrows, const float eps,
|
|
185
229
|
queue_ptr stream, int device) {
|
|
@@ -191,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
191
235
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
192
236
|
block_dims),
|
|
193
237
|
[=](sycl::nd_item<3> item_ct1)
|
|
194
|
-
[[
|
|
238
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
195
239
|
norm_f32(x, dst, ncols, eps, item_ct1,
|
|
196
240
|
nullptr, WARP_SIZE);
|
|
197
241
|
});
|
|
@@ -214,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
214
258
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
215
259
|
block_dims),
|
|
216
260
|
[=](sycl::nd_item<3> item_ct1)
|
|
217
|
-
[[
|
|
261
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
218
262
|
norm_f32(x, dst, ncols, eps, item_ct1,
|
|
219
263
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
220
264
|
});
|
|
@@ -233,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
|
233
277
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
|
234
278
|
block_dims),
|
|
235
279
|
[=](sycl::nd_item<3> item_ct1)
|
|
236
|
-
[[
|
|
280
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
237
281
|
group_norm_f32(
|
|
238
282
|
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
|
239
283
|
nullptr, WARP_SIZE);
|
|
@@ -260,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
|
260
304
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
|
261
305
|
block_dims),
|
|
262
306
|
[=](sycl::nd_item<3> item_ct1)
|
|
263
|
-
[[
|
|
307
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
264
308
|
group_norm_f32(x, dst, group_size, ne_elements,
|
|
265
309
|
eps_ct4, item_ct1,
|
|
266
310
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
@@ -281,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
281
325
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
282
326
|
block_dims),
|
|
283
327
|
[=](sycl::nd_item<3> item_ct1)
|
|
284
|
-
[[
|
|
328
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
285
329
|
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
286
330
|
nullptr, WARP_SIZE);
|
|
287
331
|
});
|
|
@@ -303,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
303
347
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
304
348
|
block_dims),
|
|
305
349
|
[=](sycl::nd_item<3> item_ct1)
|
|
306
|
-
[[
|
|
350
|
+
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
307
351
|
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
308
352
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
309
353
|
});
|
|
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
311
355
|
}
|
|
312
356
|
}
|
|
313
357
|
|
|
358
|
+
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
359
|
+
const int nrows, const float eps,
|
|
360
|
+
queue_ptr stream, int device) {
|
|
361
|
+
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
|
362
|
+
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
|
363
|
+
if (ncols < 1024) {
|
|
364
|
+
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
365
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
366
|
+
cgh.parallel_for(
|
|
367
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
368
|
+
block_dims),
|
|
369
|
+
[=](sycl::nd_item<3> item_ct1)
|
|
370
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
|
371
|
+
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
372
|
+
nullptr, WARP_SIZE);
|
|
373
|
+
});
|
|
374
|
+
});
|
|
375
|
+
}
|
|
376
|
+
else {
|
|
377
|
+
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
378
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
379
|
+
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
380
|
+
/*
|
|
381
|
+
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
|
382
|
+
the limit. To get the device limit, query
|
|
383
|
+
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
384
|
+
*/
|
|
385
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
386
|
+
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
|
387
|
+
cgh);
|
|
388
|
+
cgh.parallel_for(
|
|
389
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
|
390
|
+
block_dims),
|
|
391
|
+
[=](sycl::nd_item<3> item_ct1)
|
|
392
|
+
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
|
393
|
+
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
394
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
395
|
+
});
|
|
396
|
+
});
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
|
|
314
400
|
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
|
315
401
|
ggml_tensor* dst, const float* src0_dd,
|
|
316
402
|
const float* src1_dd, float* dst_dd,
|
|
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
|
|
376
462
|
(void)dst;
|
|
377
463
|
(void)src1_dd;
|
|
378
464
|
}
|
|
465
|
+
|
|
466
|
+
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
467
|
+
const ggml_tensor* src1, ggml_tensor* dst,
|
|
468
|
+
const float* src0_dd, const float* src1_dd,
|
|
469
|
+
float* dst_dd,
|
|
470
|
+
const queue_ptr& main_stream) {
|
|
471
|
+
|
|
472
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
473
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
474
|
+
|
|
475
|
+
const int64_t ne00 = src0->ne[0];
|
|
476
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
477
|
+
|
|
478
|
+
float eps;
|
|
479
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
480
|
+
|
|
481
|
+
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
|
482
|
+
|
|
483
|
+
(void)src1;
|
|
484
|
+
(void)dst;
|
|
485
|
+
(void)src1_dd;
|
|
486
|
+
}
|
|
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
|
|
32
32
|
float* dst_dd,
|
|
33
33
|
const queue_ptr& main_stream);
|
|
34
34
|
|
|
35
|
+
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
36
|
+
const ggml_tensor* src1, ggml_tensor* dst,
|
|
37
|
+
const float* src0_dd, const float* src1_dd,
|
|
38
|
+
float* dst_dd,
|
|
39
|
+
const queue_ptr& main_stream);
|
|
40
|
+
|
|
35
41
|
#endif // GGML_SYCL_NORM_HPP
|
|
@@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
|
|
132
132
|
|
|
133
133
|
cgh.parallel_for(
|
|
134
134
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
135
|
-
[=](sycl::nd_item<3> item_ct1) [[
|
|
135
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
136
136
|
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
|
137
137
|
nrows_y, scale, max_bias, m0,
|
|
138
138
|
m1, n_head_log2, item_ct1,
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
#include "wkv.hpp"
|
|
3
|
+
|
|
4
|
+
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
|
5
|
+
|
|
6
|
+
// Helper function for the main kernel
|
|
7
|
+
template <int block_size>
|
|
8
|
+
static void rwkv_wkv6_f32_kernel(
|
|
9
|
+
const int B, const int T, const int C, const int H,
|
|
10
|
+
const float* k, const float* v, const float* r,
|
|
11
|
+
const float* tf, const float* td, const float* s,
|
|
12
|
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
13
|
+
|
|
14
|
+
const int tid = item_ct1.get_local_id(2);
|
|
15
|
+
const int bid = item_ct1.get_group(2);
|
|
16
|
+
|
|
17
|
+
const int head_size = block_size;
|
|
18
|
+
const int batch_i = bid / H;
|
|
19
|
+
const int head_i = bid % H;
|
|
20
|
+
const int state_size = C * head_size;
|
|
21
|
+
const int n_seq_tokens = T / B;
|
|
22
|
+
|
|
23
|
+
// Set up shared memory pointers
|
|
24
|
+
float* _k = shared_mem;
|
|
25
|
+
float* _r = _k + head_size;
|
|
26
|
+
float* _tf = _r + head_size;
|
|
27
|
+
float* _td = _tf + head_size;
|
|
28
|
+
|
|
29
|
+
// Local state array
|
|
30
|
+
float state[block_size];
|
|
31
|
+
|
|
32
|
+
// Load initial state
|
|
33
|
+
#pragma unroll
|
|
34
|
+
for (int i = 0; i < head_size; i++) {
|
|
35
|
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
// Sync threads before shared memory operations
|
|
39
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
40
|
+
|
|
41
|
+
// Load time-mixing parameters
|
|
42
|
+
_tf[tid] = tf[head_i * head_size + tid];
|
|
43
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
44
|
+
|
|
45
|
+
// Main sequence processing loop
|
|
46
|
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
47
|
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
48
|
+
t += C) {
|
|
49
|
+
|
|
50
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
51
|
+
|
|
52
|
+
// Load current timestep data to shared memory
|
|
53
|
+
_k[tid] = k[t];
|
|
54
|
+
_r[tid] = r[t];
|
|
55
|
+
_td[tid] = td[t];
|
|
56
|
+
|
|
57
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
58
|
+
|
|
59
|
+
const float _v = v[t];
|
|
60
|
+
float y = 0;
|
|
61
|
+
|
|
62
|
+
// Process in chunks of 4 for better vectorization
|
|
63
|
+
sycl::float4 k4, r4, tf4, td4, s4;
|
|
64
|
+
#pragma unroll
|
|
65
|
+
for (int j = 0; j < head_size; j += 4) {
|
|
66
|
+
// Load data in vec4 chunks
|
|
67
|
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
68
|
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
69
|
+
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
70
|
+
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
71
|
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
72
|
+
|
|
73
|
+
// Compute key-value product
|
|
74
|
+
sycl::float4 kv4 = k4 * _v;
|
|
75
|
+
|
|
76
|
+
// Accumulate weighted sum
|
|
77
|
+
y += sycl::dot(r4, tf4 * kv4 + s4);
|
|
78
|
+
|
|
79
|
+
// Update state
|
|
80
|
+
s4 = s4 * td4 + kv4;
|
|
81
|
+
|
|
82
|
+
// Store updated state
|
|
83
|
+
state[j] = s4.x();
|
|
84
|
+
state[j+1] = s4.y();
|
|
85
|
+
state[j+2] = s4.z();
|
|
86
|
+
state[j+3] = s4.w();
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
dst[t] = y;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Save final state
|
|
93
|
+
#pragma unroll
|
|
94
|
+
for (int i = 0; i < head_size; i++) {
|
|
95
|
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
template <int block_size>
|
|
100
|
+
static void rwkv_wkv7_f32_kernel(
|
|
101
|
+
const int B, const int T, const int C, const int H,
|
|
102
|
+
const float* r, const float* w, const float* k, const float* v,
|
|
103
|
+
const float* a, const float* b, const float* s,
|
|
104
|
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
105
|
+
|
|
106
|
+
const int tid = item_ct1.get_local_id(2);
|
|
107
|
+
const int bid = item_ct1.get_group(2);
|
|
108
|
+
|
|
109
|
+
const int head_size = block_size;
|
|
110
|
+
const int batch_i = bid / H;
|
|
111
|
+
const int head_i = bid % H;
|
|
112
|
+
const int state_size = C * head_size;
|
|
113
|
+
const int n_seq_tokens = T / B;
|
|
114
|
+
|
|
115
|
+
float* _r = shared_mem;
|
|
116
|
+
float* _w = _r + head_size;
|
|
117
|
+
float* _k = _w + head_size;
|
|
118
|
+
float* _a = _k + head_size;
|
|
119
|
+
float* _b = _a + head_size;
|
|
120
|
+
|
|
121
|
+
float state[block_size];
|
|
122
|
+
|
|
123
|
+
#pragma unroll
|
|
124
|
+
for (int i = 0; i < head_size; i++) {
|
|
125
|
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
129
|
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
130
|
+
t += C) {
|
|
131
|
+
|
|
132
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
133
|
+
|
|
134
|
+
_r[tid] = r[t];
|
|
135
|
+
_w[tid] = w[t];
|
|
136
|
+
_k[tid] = k[t];
|
|
137
|
+
_a[tid] = a[t];
|
|
138
|
+
_b[tid] = b[t];
|
|
139
|
+
|
|
140
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
141
|
+
|
|
142
|
+
const float _v = v[t];
|
|
143
|
+
float y = 0, sa = 0;
|
|
144
|
+
sycl::float4 a4, s4;
|
|
145
|
+
|
|
146
|
+
#pragma unroll
|
|
147
|
+
for (int j = 0; j < head_size; j += 4) {
|
|
148
|
+
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
|
149
|
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
150
|
+
sa += sycl::dot(a4, s4);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
sycl::float4 r4, w4, k4, b4;
|
|
154
|
+
#pragma unroll
|
|
155
|
+
for (int j = 0; j < head_size; j += 4) {
|
|
156
|
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
157
|
+
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
|
158
|
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
159
|
+
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
|
160
|
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
161
|
+
|
|
162
|
+
sycl::float4 kv4 = k4 * _v;
|
|
163
|
+
|
|
164
|
+
s4 = s4 * w4 + kv4 + sa * b4;
|
|
165
|
+
y += sycl::dot(r4, s4);
|
|
166
|
+
|
|
167
|
+
state[j] = s4.x();
|
|
168
|
+
state[j+1] = s4.y();
|
|
169
|
+
state[j+2] = s4.z();
|
|
170
|
+
state[j+3] = s4.w();
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
dst[t] = y;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
#pragma unroll
|
|
177
|
+
for (int i = 0; i < head_size; i++) {
|
|
178
|
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
183
|
+
|
|
184
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
185
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
186
|
+
|
|
187
|
+
const float* k_d = (const float*)dst->src[0]->data;
|
|
188
|
+
const float* v_d = (const float*)dst->src[1]->data;
|
|
189
|
+
const float* r_d = (const float*)dst->src[2]->data;
|
|
190
|
+
const float* tf_d = (const float*)dst->src[3]->data;
|
|
191
|
+
const float* td_d = (const float*)dst->src[4]->data;
|
|
192
|
+
const float* s_d = (const float*)dst->src[5]->data;
|
|
193
|
+
float* dst_d = (float*)dst->data;
|
|
194
|
+
|
|
195
|
+
const int64_t B = dst->src[5]->ne[1];
|
|
196
|
+
const int64_t T = dst->src[0]->ne[2];
|
|
197
|
+
const int64_t C = dst->ne[0];
|
|
198
|
+
const int64_t H = dst->src[0]->ne[1];
|
|
199
|
+
|
|
200
|
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
201
|
+
GGML_ASSERT(C % H == 0);
|
|
202
|
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
203
|
+
|
|
204
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
205
|
+
|
|
206
|
+
// Calculate execution configuration
|
|
207
|
+
const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
|
|
208
|
+
sycl::range<3> block_dims(1, 1, C / H);
|
|
209
|
+
sycl::range<3> grid_dims(1, 1, B * H);
|
|
210
|
+
|
|
211
|
+
// Submit kernel
|
|
212
|
+
if (C / H == WKV_BLOCK_SIZE) {
|
|
213
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
214
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
215
|
+
|
|
216
|
+
cgh.parallel_for(
|
|
217
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
218
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
219
|
+
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
|
220
|
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
221
|
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
222
|
+
);
|
|
223
|
+
});
|
|
224
|
+
});
|
|
225
|
+
} else {
|
|
226
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
227
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
228
|
+
|
|
229
|
+
cgh.parallel_for(
|
|
230
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
231
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
232
|
+
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
|
233
|
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
234
|
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
235
|
+
);
|
|
236
|
+
});
|
|
237
|
+
});
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
GGML_UNUSED(src0);
|
|
241
|
+
GGML_UNUSED(src1);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
245
|
+
|
|
246
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
247
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
248
|
+
|
|
249
|
+
const float* r_d = (const float*)dst->src[0]->data;
|
|
250
|
+
const float* w_d = (const float*)dst->src[1]->data;
|
|
251
|
+
const float* k_d = (const float*)dst->src[2]->data;
|
|
252
|
+
const float* v_d = (const float*)dst->src[3]->data;
|
|
253
|
+
const float* a_d = (const float*)dst->src[4]->data;
|
|
254
|
+
const float* b_d = (const float*)dst->src[5]->data;
|
|
255
|
+
const float* s_d = (const float*)dst->src[6]->data;
|
|
256
|
+
float* dst_d = (float*)dst->data;
|
|
257
|
+
|
|
258
|
+
const int64_t B = dst->src[6]->ne[1];
|
|
259
|
+
const int64_t T = dst->src[0]->ne[2];
|
|
260
|
+
const int64_t C = dst->ne[0];
|
|
261
|
+
const int64_t H = dst->src[0]->ne[1];
|
|
262
|
+
|
|
263
|
+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
|
264
|
+
GGML_ASSERT(C % H == 0);
|
|
265
|
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
|
|
266
|
+
|
|
267
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
268
|
+
|
|
269
|
+
// Calculate execution configuration
|
|
270
|
+
const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
|
|
271
|
+
sycl::range<3> block_dims(1, 1, C / H);
|
|
272
|
+
sycl::range<3> grid_dims(1, 1, B * H);
|
|
273
|
+
|
|
274
|
+
// Submit kernel
|
|
275
|
+
if (C / H == WKV_BLOCK_SIZE) {
|
|
276
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
277
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
278
|
+
|
|
279
|
+
cgh.parallel_for(
|
|
280
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
281
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
282
|
+
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
|
283
|
+
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
|
284
|
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
285
|
+
);
|
|
286
|
+
});
|
|
287
|
+
});
|
|
288
|
+
} else {
|
|
289
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
290
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
291
|
+
|
|
292
|
+
cgh.parallel_for(
|
|
293
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
294
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
295
|
+
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
|
296
|
+
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
|
297
|
+
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
298
|
+
);
|
|
299
|
+
});
|
|
300
|
+
});
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
GGML_UNUSED(src0);
|
|
304
|
+
GGML_UNUSED(src1);
|
|
305
|
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
#ifndef GGML_SYCL_WKV_HPP
|
|
2
|
+
#define GGML_SYCL_WKV_HPP
|
|
3
|
+
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
|
|
6
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
7
|
+
|
|
8
|
+
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
9
|
+
|
|
10
|
+
#endif // GGML_SYCL_WKV_HPP
|