whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -29,24 +29,23 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
|
29
29
|
static_assert(blocks_per_subgroup > 0);
|
30
30
|
static_assert(block_elements_per_subgroup > 0);
|
31
31
|
|
32
|
-
const block_q8_1 * y = (const block_q8_1 *) vy;
|
33
|
-
|
34
32
|
float partial_sum = 0.0f;
|
35
33
|
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
36
|
-
const int ibx
|
37
|
-
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
38
|
-
const int bx_offset = block_type::get_block_offset(ibx);
|
39
|
-
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
34
|
+
const int ibx = row * blocks_per_row + i; // x block index
|
40
35
|
|
36
|
+
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
|
37
|
+
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
41
38
|
// Y block index that aligns with ibx
|
42
39
|
const int iby = i * block_type::block_to_q8_1_ratio();
|
40
|
+
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
|
41
|
+
const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
|
43
42
|
|
44
43
|
#pragma unroll
|
45
44
|
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
46
45
|
// x block quant index when casting the quants to int
|
47
46
|
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
48
47
|
|
49
|
-
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset,
|
48
|
+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
|
50
49
|
}
|
51
50
|
}
|
52
51
|
|
@@ -545,12 +544,12 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
|
|
545
544
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
546
545
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
547
546
|
|
548
|
-
stream
|
549
|
-
cgh
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
547
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
548
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
549
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
550
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
551
|
+
nd_item);
|
552
|
+
});
|
554
553
|
});
|
555
554
|
}
|
556
555
|
|
@@ -562,12 +561,12 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
|
|
562
561
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
563
562
|
|
564
563
|
{
|
565
|
-
stream
|
566
|
-
cgh
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
564
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
565
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
566
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
567
|
+
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
568
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
569
|
+
});
|
571
570
|
});
|
572
571
|
}
|
573
572
|
}
|
@@ -581,17 +580,12 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|
581
580
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
582
581
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
583
582
|
{
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
591
|
-
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
592
|
-
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
593
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
594
|
-
});
|
583
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
584
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
585
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
586
|
+
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
587
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
588
|
+
});
|
595
589
|
});
|
596
590
|
}
|
597
591
|
}
|
@@ -605,17 +599,12 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|
605
599
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
606
600
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
607
601
|
{
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
615
|
-
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
616
|
-
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
617
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
618
|
-
});
|
602
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
603
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
604
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
605
|
+
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
606
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
607
|
+
});
|
619
608
|
});
|
620
609
|
}
|
621
610
|
}
|
@@ -629,17 +618,12 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|
629
618
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
630
619
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
631
620
|
{
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
639
|
-
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
640
|
-
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
641
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
642
|
-
});
|
621
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
622
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
623
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
624
|
+
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
625
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
626
|
+
});
|
643
627
|
});
|
644
628
|
}
|
645
629
|
}
|
@@ -653,17 +637,12 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|
653
637
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
654
638
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
655
639
|
{
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
663
|
-
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
664
|
-
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
665
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
666
|
-
});
|
640
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
641
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
642
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
643
|
+
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
644
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
645
|
+
});
|
667
646
|
});
|
668
647
|
}
|
669
648
|
}
|
@@ -677,17 +656,12 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|
677
656
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
678
657
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
679
658
|
{
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
687
|
-
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
688
|
-
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
689
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
690
|
-
});
|
659
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
660
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
661
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
662
|
+
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
663
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
664
|
+
});
|
691
665
|
});
|
692
666
|
}
|
693
667
|
}
|
@@ -701,17 +675,12 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|
701
675
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
702
676
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
703
677
|
{
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
711
|
-
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
712
|
-
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
713
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
714
|
-
});
|
678
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
679
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
680
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
681
|
+
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
682
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
683
|
+
});
|
715
684
|
});
|
716
685
|
}
|
717
686
|
}
|
@@ -725,17 +694,12 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|
725
694
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
726
695
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
727
696
|
{
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
735
|
-
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
736
|
-
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
737
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
738
|
-
});
|
697
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
698
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
699
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
700
|
+
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
701
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
702
|
+
});
|
739
703
|
});
|
740
704
|
}
|
741
705
|
}
|
@@ -751,12 +715,12 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
|
|
751
715
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
752
716
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
753
717
|
|
754
|
-
stream
|
755
|
-
cgh
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
718
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
719
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
720
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
721
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
|
722
|
+
nd_item);
|
723
|
+
});
|
760
724
|
});
|
761
725
|
}
|
762
726
|
|
@@ -770,21 +734,34 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|
770
734
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
771
735
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
772
736
|
{
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
780
|
-
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
781
|
-
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
782
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
783
|
-
});
|
737
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
738
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
739
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
740
|
+
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
741
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
742
|
+
});
|
784
743
|
});
|
785
744
|
}
|
786
745
|
}
|
787
746
|
|
747
|
+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
748
|
+
const int nrows, dpct::queue_ptr stream) {
|
749
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
750
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
751
|
+
constexpr size_t num_subgroups = 16;
|
752
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
753
|
+
|
754
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
755
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
756
|
+
|
757
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
758
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
759
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
760
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
761
|
+
nd_item);
|
762
|
+
});
|
763
|
+
});
|
764
|
+
}
|
788
765
|
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
789
766
|
float *dst, const int ncols,
|
790
767
|
const int nrows,
|
@@ -794,17 +771,12 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|
794
771
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
795
772
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
796
773
|
{
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
804
|
-
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
805
|
-
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
806
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
807
|
-
});
|
774
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
775
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
776
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
777
|
+
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
778
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
779
|
+
});
|
808
780
|
});
|
809
781
|
}
|
810
782
|
}
|
@@ -819,14 +791,12 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|
819
791
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
820
792
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
821
793
|
{
|
822
|
-
stream
|
823
|
-
cgh
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
829
|
-
});
|
794
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
795
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
796
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
797
|
+
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
|
798
|
+
nrows, item_ct1);
|
799
|
+
});
|
830
800
|
});
|
831
801
|
}
|
832
802
|
}
|
@@ -840,14 +810,12 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
|
840
810
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
841
811
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
842
812
|
{
|
843
|
-
stream
|
844
|
-
cgh
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
850
|
-
});
|
813
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
814
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
815
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
816
|
+
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
|
817
|
+
nrows, item_ct1);
|
818
|
+
});
|
851
819
|
});
|
852
820
|
}
|
853
821
|
}
|
@@ -861,15 +829,12 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
|
861
829
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
862
830
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
863
831
|
{
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
871
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
872
|
-
});
|
832
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
833
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
834
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
835
|
+
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
|
836
|
+
item_ct1);
|
837
|
+
});
|
873
838
|
});
|
874
839
|
}
|
875
840
|
}
|
@@ -883,15 +848,12 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|
883
848
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
884
849
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
885
850
|
{
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
893
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
894
|
-
});
|
851
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
852
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
853
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
854
|
+
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
|
855
|
+
nrows, item_ct1);
|
856
|
+
});
|
895
857
|
});
|
896
858
|
}
|
897
859
|
}
|
@@ -905,15 +867,12 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
|
905
867
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
906
868
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
907
869
|
{
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
915
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
916
|
-
});
|
870
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
871
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
872
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
873
|
+
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
|
874
|
+
item_ct1);
|
875
|
+
});
|
917
876
|
});
|
918
877
|
}
|
919
878
|
}
|
@@ -927,15 +886,12 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|
927
886
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
928
887
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
929
888
|
{
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
937
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
938
|
-
});
|
889
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
890
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
891
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
892
|
+
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
|
893
|
+
item_ct1);
|
894
|
+
});
|
939
895
|
});
|
940
896
|
}
|
941
897
|
}
|
@@ -949,14 +905,12 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
|
949
905
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
950
906
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
951
907
|
{
|
952
|
-
stream
|
953
|
-
cgh
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
959
|
-
});
|
908
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
909
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
910
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
911
|
+
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
|
912
|
+
item_ct1);
|
913
|
+
});
|
960
914
|
});
|
961
915
|
}
|
962
916
|
}
|
@@ -970,15 +924,12 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
|
970
924
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
971
925
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
972
926
|
{
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
980
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
981
|
-
});
|
927
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
928
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
929
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
930
|
+
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
|
931
|
+
item_ct1);
|
932
|
+
});
|
982
933
|
});
|
983
934
|
}
|
984
935
|
}
|
@@ -992,15 +943,12 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
|
992
943
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
993
944
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
994
945
|
{
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
1002
|
-
vx, vy, dst, ncols, nrows, item_ct1);
|
1003
|
-
});
|
946
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
947
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
948
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
949
|
+
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
|
950
|
+
nrows, item_ct1);
|
951
|
+
});
|
1004
952
|
});
|
1005
953
|
}
|
1006
954
|
}
|
@@ -1070,7 +1018,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
1070
1018
|
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
1071
1019
|
break;
|
1072
1020
|
case GGML_TYPE_Q6_K:
|
1073
|
-
|
1021
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
1022
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
1023
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
|
1024
|
+
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
1025
|
+
} else {
|
1026
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
|
1027
|
+
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
1028
|
+
}
|
1074
1029
|
break;
|
1075
1030
|
case GGML_TYPE_IQ1_S:
|
1076
1031
|
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|