whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -254,14 +254,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
|
254
254
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
255
255
|
if (ncols < 1024) {
|
256
256
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
257
|
-
stream
|
258
|
-
cgh
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
});
|
257
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
258
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
259
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
260
|
+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
261
|
+
nullptr, WARP_SIZE);
|
262
|
+
});
|
263
|
+
});
|
265
264
|
}
|
266
265
|
else {
|
267
266
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
@@ -272,16 +271,15 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
|
272
271
|
the limit. To get the device limit, query
|
273
272
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
274
273
|
*/
|
275
|
-
stream
|
274
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
276
275
|
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
277
276
|
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
278
|
-
cgh
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
});
|
277
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
278
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
279
|
+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
280
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
281
|
+
});
|
282
|
+
});
|
285
283
|
}
|
286
284
|
}
|
287
285
|
|
@@ -290,18 +288,14 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
290
288
|
const int ne_elements, queue_ptr stream, int device) {
|
291
289
|
if (group_size < 1024) {
|
292
290
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
293
|
-
stream
|
291
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
294
292
|
const float eps_ct4 = eps;
|
295
|
-
cgh
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
302
|
-
nullptr, WARP_SIZE);
|
303
|
-
});
|
304
|
-
});
|
293
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
294
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
295
|
+
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
|
296
|
+
WARP_SIZE);
|
297
|
+
});
|
298
|
+
});
|
305
299
|
}
|
306
300
|
else {
|
307
301
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
@@ -313,22 +307,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
313
307
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
314
308
|
*/
|
315
309
|
|
316
|
-
stream
|
310
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
317
311
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
318
312
|
cgh);
|
319
313
|
|
320
314
|
const float eps_ct4 = eps;
|
321
315
|
|
322
|
-
cgh
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
eps_ct4, item_ct1,
|
329
|
-
get_pointer(s_sum_acc_ct1), work_group_size);
|
330
|
-
});
|
331
|
-
});
|
316
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
317
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
318
|
+
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
319
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
320
|
+
});
|
321
|
+
});
|
332
322
|
}
|
333
323
|
}
|
334
324
|
|
@@ -340,14 +330,13 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
|
340
330
|
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
341
331
|
if (ncols < 1024) {
|
342
332
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
343
|
-
stream
|
344
|
-
cgh
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
});
|
333
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
334
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
335
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
336
|
+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
337
|
+
nullptr, WARP_SIZE);
|
338
|
+
});
|
339
|
+
});
|
351
340
|
}
|
352
341
|
else {
|
353
342
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
@@ -358,16 +347,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
|
358
347
|
the limit. To get the device limit, query
|
359
348
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
360
349
|
*/
|
361
|
-
stream
|
350
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
362
351
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
363
352
|
cgh);
|
364
|
-
cgh
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
});
|
353
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
354
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
355
|
+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
356
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
357
|
+
});
|
358
|
+
});
|
371
359
|
}
|
372
360
|
}
|
373
361
|
|
@@ -378,16 +366,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
378
366
|
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
379
367
|
if (ncols < 1024) {
|
380
368
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
381
|
-
stream
|
382
|
-
cgh
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
388
|
-
nullptr, WARP_SIZE);
|
389
|
-
});
|
390
|
-
});
|
369
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
370
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
371
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
372
|
+
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
|
373
|
+
});
|
374
|
+
});
|
391
375
|
}
|
392
376
|
else {
|
393
377
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
@@ -398,18 +382,15 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
398
382
|
the limit. To get the device limit, query
|
399
383
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
400
384
|
*/
|
401
|
-
stream
|
385
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
402
386
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
403
387
|
cgh);
|
404
|
-
cgh
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
get_pointer(s_sum_acc_ct1), work_group_size);
|
411
|
-
});
|
412
|
-
});
|
388
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
389
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
390
|
+
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
|
391
|
+
work_group_size);
|
392
|
+
});
|
393
|
+
});
|
413
394
|
}
|
414
395
|
}
|
415
396
|
|
@@ -14,12 +14,13 @@
|
|
14
14
|
#ifndef GGML_SYCL_QUANTS_HPP
|
15
15
|
#define GGML_SYCL_QUANTS_HPP
|
16
16
|
|
17
|
+
#include <utility>
|
18
|
+
|
17
19
|
#include "ggml-common.h"
|
18
20
|
#include "ggml.h"
|
19
21
|
|
20
22
|
namespace ggml_sycl_reordered {
|
21
23
|
|
22
|
-
|
23
24
|
// The reordered block moves quants (qs) and scales(d) to two
|
24
25
|
// uniform regions of memory that is contiguous in the same tensor.
|
25
26
|
// What this means is that instead of having:
|
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
|
|
32
33
|
|
33
34
|
template <ggml_type type> struct block_q_t;
|
34
35
|
|
35
|
-
|
36
36
|
// qk number of weights / quants in a block
|
37
37
|
// qr number of weights in a byte (described as 'before dequantization')
|
38
38
|
// for quantization types that has low and high bits split, qr is calculated with
|
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
|
47
47
|
static constexpr uint32_t vdr_mmvq = 2;
|
48
48
|
};
|
49
49
|
|
50
|
-
static constexpr int get_block_offset(const int block_index
|
50
|
+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
51
|
+
return { block_index * (traits::qk / traits::qr), 0 };
|
52
|
+
}
|
51
53
|
|
52
|
-
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
53
|
-
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
54
|
+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
55
|
+
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
|
54
56
|
}
|
55
57
|
|
56
58
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
@@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
|
64
66
|
static constexpr uint32_t vdr_mmvq = 2;
|
65
67
|
};
|
66
68
|
|
67
|
-
static constexpr int get_block_offset(const int block_index
|
69
|
+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
70
|
+
return { block_index * (traits::qk / traits::qr), 0 };
|
71
|
+
}
|
68
72
|
|
69
|
-
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
73
|
+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
70
74
|
auto nblocks = (nrows * (ncols / traits::qk));
|
71
|
-
return
|
75
|
+
return { nblocks * (QK_K / 2),
|
76
|
+
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
|
72
77
|
}
|
73
78
|
|
74
79
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
75
80
|
|
76
81
|
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
77
|
-
|
78
|
-
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
79
82
|
};
|
80
83
|
|
84
|
+
template <> struct block_q_t<GGML_TYPE_Q6_K> {
|
85
|
+
struct traits {
|
86
|
+
static constexpr uint32_t qk = QK_K;
|
87
|
+
static constexpr uint32_t qi = QI6_K;
|
88
|
+
static constexpr uint32_t qr = QR6_K;
|
89
|
+
static constexpr uint32_t vdr_mmvq = 1;
|
90
|
+
};
|
91
|
+
|
92
|
+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
|
93
|
+
auto low_bits_index = block_index * (traits::qk / traits::qr);
|
94
|
+
// the index of high bits it's after all low bits
|
95
|
+
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
|
96
|
+
return { low_bits_index, high_bits_index };
|
97
|
+
}
|
98
|
+
|
99
|
+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
100
|
+
auto nblocks = (nrows * (ncols / traits::qk));
|
101
|
+
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
|
102
|
+
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
|
103
|
+
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
|
104
|
+
return { block_scales, sb_scale };
|
105
|
+
}
|
106
|
+
|
107
|
+
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
108
|
+
};
|
81
109
|
} // namespace ggml_sycl_reordered
|
82
110
|
|
83
111
|
#endif // GGML_SYCL_QUANTS_HPP
|
@@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|
49
49
|
|
50
50
|
if (i0 >= n_dims) {
|
51
51
|
const int i = row * ne0 + i0;
|
52
|
-
|
53
|
-
dst[i + 0] = x[i + 0];
|
54
|
-
dst[i + 1] = x[i + 1];
|
55
|
-
|
52
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
56
53
|
return;
|
57
54
|
}
|
58
55
|
|
@@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
93
90
|
|
94
91
|
if (i0 >= n_dims) {
|
95
92
|
const int i = row * ne0 + i0;
|
96
|
-
|
97
|
-
dst[i + 0] = x[i + 0];
|
98
|
-
dst[i + 1] = x[i + 1];
|
99
|
-
|
93
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
100
94
|
return;
|
101
95
|
}
|
102
96
|
|
@@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
122
116
|
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
123
117
|
}
|
124
118
|
|
119
|
+
template <typename T, bool has_ff>
|
120
|
+
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
121
|
+
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
122
|
+
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
123
|
+
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
124
|
+
const sycl::nd_item<3> & item_ct1) {
|
125
|
+
// get index pos
|
126
|
+
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
127
|
+
if (i0 >= ne0) {
|
128
|
+
return;
|
129
|
+
}
|
130
|
+
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
131
|
+
|
132
|
+
if (i0 >= n_dims) {
|
133
|
+
const int i = row_dst*ne0 + i0;
|
134
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
135
|
+
return;
|
136
|
+
}
|
137
|
+
|
138
|
+
const int row_x = row_dst % ne1;
|
139
|
+
const int channel_x = row_dst / ne1;
|
140
|
+
const int idst = (row_dst * ne0) + (i0 / 2);
|
141
|
+
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
142
|
+
|
143
|
+
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
144
|
+
const int sec_w = sections.v[1] + sections.v[0];
|
145
|
+
const int sector = (i0 / 2) % sect_dims;
|
146
|
+
|
147
|
+
|
148
|
+
float theta_base = 0.0;
|
149
|
+
if (sector < sections.v[0]) {
|
150
|
+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
151
|
+
}
|
152
|
+
else if (sector >= sections.v[0] && sector < sec_w) {
|
153
|
+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
154
|
+
}
|
155
|
+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
156
|
+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
157
|
+
}
|
158
|
+
else if (sector >= sec_w + sections.v[2]) {
|
159
|
+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
160
|
+
}
|
161
|
+
|
162
|
+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
163
|
+
float cos_theta;
|
164
|
+
float sin_theta;
|
165
|
+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
166
|
+
const float x0 = x[ix + 0];
|
167
|
+
const float x1 = x[ix + n_dims/2];
|
168
|
+
|
169
|
+
// store results in dst
|
170
|
+
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
171
|
+
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
172
|
+
}
|
173
|
+
|
174
|
+
|
175
|
+
|
125
176
|
template <typename T, bool has_ff>
|
126
177
|
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
127
178
|
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
@@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
171
222
|
const float * freq_factors, queue_ptr stream) {
|
172
223
|
GGML_ASSERT(ne0 % 2 == 0);
|
173
224
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
174
|
-
const int num_blocks_x = (ne0
|
225
|
+
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
175
226
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
176
227
|
|
177
228
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
@@ -184,20 +235,22 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
184
235
|
the limit. To get the device limit, query
|
185
236
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
186
237
|
*/
|
187
|
-
stream
|
188
|
-
|
189
|
-
|
190
|
-
|
238
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
239
|
+
[=](sycl::nd_item<3> item_ct1) {
|
240
|
+
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
241
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
242
|
+
});
|
191
243
|
} else {
|
192
244
|
/*
|
193
245
|
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
194
246
|
the limit. To get the device limit, query
|
195
247
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
196
248
|
*/
|
197
|
-
stream
|
198
|
-
|
199
|
-
|
200
|
-
|
249
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
250
|
+
[=](sycl::nd_item<3> item_ct1) {
|
251
|
+
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
252
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
253
|
+
});
|
201
254
|
}
|
202
255
|
}
|
203
256
|
|
@@ -208,7 +261,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
208
261
|
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
209
262
|
GGML_ASSERT(ne0 % 2 == 0);
|
210
263
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
211
|
-
const int num_blocks_x = (ne0
|
264
|
+
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
212
265
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
213
266
|
|
214
267
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
@@ -216,18 +269,54 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
216
269
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
217
270
|
|
218
271
|
if (freq_factors == nullptr) {
|
219
|
-
stream
|
220
|
-
|
221
|
-
|
272
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
273
|
+
[=](sycl::nd_item<3> item_ct1) {
|
274
|
+
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
275
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
276
|
+
});
|
277
|
+
} else {
|
278
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
279
|
+
[=](sycl::nd_item<3> item_ct1) {
|
280
|
+
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
281
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
282
|
+
});
|
283
|
+
}
|
284
|
+
}
|
285
|
+
|
286
|
+
template <typename T>
|
287
|
+
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
288
|
+
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
289
|
+
const float freq_scale, const float freq_base, const float ext_factor,
|
290
|
+
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
291
|
+
const mrope_sections sections, queue_ptr stream) {
|
292
|
+
GGML_ASSERT(ne0 % 2 == 0);
|
293
|
+
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
294
|
+
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
295
|
+
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
296
|
+
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
297
|
+
|
298
|
+
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
299
|
+
// Add FP16 capability check if T could be sycl::half
|
300
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
301
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
302
|
+
}
|
303
|
+
// launch kernel
|
304
|
+
if (freq_factors == nullptr) {
|
305
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
306
|
+
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
307
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
222
308
|
});
|
223
309
|
} else {
|
224
|
-
stream
|
225
|
-
|
226
|
-
|
310
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
311
|
+
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
312
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
227
313
|
});
|
228
314
|
}
|
229
315
|
}
|
230
316
|
|
317
|
+
|
318
|
+
|
319
|
+
|
231
320
|
// rope vision
|
232
321
|
template <typename T>
|
233
322
|
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
@@ -237,7 +326,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|
237
326
|
const mrope_sections sections, queue_ptr stream) {
|
238
327
|
GGML_ASSERT(ne0 % 2 == 0);
|
239
328
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
240
|
-
const int n_blocks_y = (ne0
|
329
|
+
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
241
330
|
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
242
331
|
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
243
332
|
|
@@ -248,12 +337,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|
248
337
|
}
|
249
338
|
// launch kernel
|
250
339
|
if (freq_factors == nullptr) {
|
251
|
-
stream
|
340
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
252
341
|
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
253
342
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
254
343
|
});
|
255
344
|
} else {
|
256
|
-
stream
|
345
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
257
346
|
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
258
347
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
259
348
|
});
|
@@ -298,8 +387,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
298
387
|
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
299
388
|
|
300
389
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
390
|
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
301
391
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
302
392
|
|
393
|
+
if (is_mrope) {
|
394
|
+
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
395
|
+
}
|
396
|
+
|
397
|
+
if (is_vision) {
|
398
|
+
GGML_ASSERT(n_dims == ne00/2);
|
399
|
+
}
|
400
|
+
|
303
401
|
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
304
402
|
|
305
403
|
const float * freq_factors = nullptr;
|
@@ -326,6 +424,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
326
424
|
} else {
|
327
425
|
GGML_ABORT("fatal error");
|
328
426
|
}
|
427
|
+
} else if (is_mrope && !is_vision) {
|
428
|
+
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
429
|
+
if (dst->src[0]->type == GGML_TYPE_F16) {
|
430
|
+
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
431
|
+
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
432
|
+
freq_factors, sections, main_stream);
|
433
|
+
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
434
|
+
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
435
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
436
|
+
main_stream);
|
437
|
+
} else {
|
438
|
+
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
439
|
+
}
|
329
440
|
} else if (is_vision) {
|
330
441
|
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
331
442
|
if (dst->src[0]->type == GGML_TYPE_F16) {
|
@@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
|
127
127
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
128
128
|
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
129
129
|
const size_t n_local_scratch, queue_ptr stream) {
|
130
|
-
stream
|
130
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
131
131
|
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
132
132
|
|
133
|
-
|
134
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
133
|
+
sycl_parallel_for(
|
134
|
+
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
135
135
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
136
136
|
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
137
137
|
nrows_y, scale, max_bias, m0,
|
@@ -1,6 +1,7 @@
|
|
1
1
|
#include "sycl_hw.hpp"
|
2
2
|
|
3
|
-
|
3
|
+
// TODO: currently not used
|
4
|
+
/*
|
4
5
|
sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
|
5
6
|
sycl_hw_info res;
|
6
7
|
int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
|
@@ -11,3 +12,4 @@ sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
|
|
11
12
|
|
12
13
|
return res;
|
13
14
|
}
|
15
|
+
*/
|
@@ -10,6 +10,8 @@
|
|
10
10
|
|
11
11
|
namespace syclex = sycl::ext::oneapi::experimental;
|
12
12
|
|
13
|
+
// TODO: currently not used
|
14
|
+
/*
|
13
15
|
struct sycl_hw_info {
|
14
16
|
syclex::architecture arch;
|
15
17
|
int32_t device_id;
|
@@ -18,6 +20,7 @@ struct sycl_hw_info {
|
|
18
20
|
bool is_in_vector(std::vector<int> &vec, int item);
|
19
21
|
|
20
22
|
sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
|
23
|
+
*/
|
21
24
|
|
22
25
|
|
23
26
|
#endif // SYCL_HW_HPP
|
@@ -45,14 +45,9 @@ static void timestep_embedding_f32_sycl(
|
|
45
45
|
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
46
46
|
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
47
47
|
sycl::range<3> gridDim(1, ne00, num_blocks);
|
48
|
-
stream
|
49
|
-
|
50
|
-
|
51
|
-
[=](sycl::nd_item<3> item_ct1) {
|
52
|
-
timestep_embedding_f32(
|
53
|
-
x, dst, nb1, dim, max_period, item_ct1
|
54
|
-
);
|
55
|
-
});
|
48
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
49
|
+
timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
|
50
|
+
});
|
56
51
|
}
|
57
52
|
|
58
53
|
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|