whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -208,12 +208,10 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
|
208
208
|
dpct::has_capability_or_fail(stream->get_device(),
|
209
209
|
{sycl::aspect::fp16});
|
210
210
|
|
211
|
-
stream
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
nrows, item_ct1);
|
216
|
-
});
|
211
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
212
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
213
|
+
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
|
214
|
+
});
|
217
215
|
}
|
218
216
|
}
|
219
217
|
|
@@ -877,12 +875,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
|
877
875
|
dpct::has_capability_or_fail(stream->get_device(),
|
878
876
|
{sycl::aspect::fp16});
|
879
877
|
|
880
|
-
stream
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
});
|
878
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
879
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
880
|
+
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
|
881
|
+
nrows, item_ct1);
|
882
|
+
});
|
886
883
|
}
|
887
884
|
}
|
888
885
|
|
@@ -900,12 +897,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
|
900
897
|
dpct::has_capability_or_fail(stream->get_device(),
|
901
898
|
{sycl::aspect::fp16});
|
902
899
|
|
903
|
-
stream
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
vx, y, dst, ncols, nrows, item_ct1);
|
908
|
-
});
|
900
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
901
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
902
|
+
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
|
903
|
+
});
|
909
904
|
}
|
910
905
|
}
|
911
906
|
|
@@ -921,12 +916,10 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
|
921
916
|
dpct::has_capability_or_fail(stream->get_device(),
|
922
917
|
{sycl::aspect::fp16});
|
923
918
|
|
924
|
-
stream
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
vx, y, dst, ncols, nrows, item_ct1);
|
929
|
-
});
|
919
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
920
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
921
|
+
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
|
922
|
+
});
|
930
923
|
}
|
931
924
|
}
|
932
925
|
|
@@ -942,12 +935,10 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
|
942
935
|
dpct::has_capability_or_fail(stream->get_device(),
|
943
936
|
{sycl::aspect::fp16});
|
944
937
|
|
945
|
-
stream
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
vx, y, dst, ncols, nrows, item_ct1);
|
950
|
-
});
|
938
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
939
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
940
|
+
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
|
941
|
+
});
|
951
942
|
}
|
952
943
|
}
|
953
944
|
|
@@ -963,12 +954,10 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
|
963
954
|
dpct::has_capability_or_fail(stream->get_device(),
|
964
955
|
{sycl::aspect::fp16});
|
965
956
|
|
966
|
-
stream
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
vx, y, dst, ncols, nrows, item_ct1);
|
971
|
-
});
|
957
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
958
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
959
|
+
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
|
960
|
+
});
|
972
961
|
}
|
973
962
|
}
|
974
963
|
|
@@ -984,12 +973,10 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
|
984
973
|
dpct::has_capability_or_fail(stream->get_device(),
|
985
974
|
{sycl::aspect::fp16});
|
986
975
|
|
987
|
-
stream
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
vx, y, dst, ncols, nrows, item_ct1);
|
992
|
-
});
|
976
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
977
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
978
|
+
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
|
979
|
+
});
|
993
980
|
}
|
994
981
|
}
|
995
982
|
|
@@ -1002,11 +989,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
|
1002
989
|
const int block_num_y = (nrows + ny - 1) / ny;
|
1003
990
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
1004
991
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
1005
|
-
stream
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
});
|
992
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
993
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
994
|
+
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
995
|
+
});
|
1010
996
|
}
|
1011
997
|
|
1012
998
|
static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
@@ -1018,11 +1004,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
|
1018
1004
|
const int block_num_y = (nrows + ny - 1) / ny;
|
1019
1005
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
1020
1006
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
1021
|
-
stream
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
});
|
1007
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1008
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
1009
|
+
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
1010
|
+
});
|
1026
1011
|
}
|
1027
1012
|
|
1028
1013
|
static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
@@ -1034,11 +1019,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
|
1034
1019
|
const int block_num_y = (nrows + ny - 1) / ny;
|
1035
1020
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
1036
1021
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
1037
|
-
stream
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
});
|
1022
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1023
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
1024
|
+
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
1025
|
+
});
|
1042
1026
|
}
|
1043
1027
|
|
1044
1028
|
static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
@@ -1047,11 +1031,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
|
1047
1031
|
dpct::queue_ptr stream) {
|
1048
1032
|
GGML_ASSERT(ncols % QK_K == 0);
|
1049
1033
|
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
1050
|
-
stream
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
});
|
1034
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
1035
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
1036
|
+
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
1037
|
+
});
|
1055
1038
|
}
|
1056
1039
|
|
1057
1040
|
static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
@@ -1063,11 +1046,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
|
1063
1046
|
const int block_num_y = (nrows + ny - 1) / ny;
|
1064
1047
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
1065
1048
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
1066
|
-
stream
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
});
|
1049
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
1050
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
1051
|
+
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
1052
|
+
});
|
1071
1053
|
}
|
1072
1054
|
|
1073
1055
|
void ggml_sycl_op_dequantize_mul_mat_vec(
|
@@ -13,10 +13,10 @@
|
|
13
13
|
#ifndef GGML_SYCL_DPCT_HELPER_HPP
|
14
14
|
#define GGML_SYCL_DPCT_HELPER_HPP
|
15
15
|
|
16
|
+
#include <map>
|
16
17
|
#include <sycl/sycl.hpp>
|
17
18
|
#include <sycl/half_type.hpp>
|
18
19
|
#include <syclcompat/math.hpp>
|
19
|
-
#include <map>
|
20
20
|
|
21
21
|
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
22
22
|
#include <oneapi/mkl.hpp>
|
@@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
|
|
118
118
|
#endif
|
119
119
|
}
|
120
120
|
|
121
|
+
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
122
|
+
namespace syclex = sycl::ext::oneapi::experimental;
|
123
|
+
#endif
|
124
|
+
|
125
|
+
template <int NR, typename Func>
|
126
|
+
__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
|
127
|
+
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
128
|
+
syclex::nd_launch(cgh, nd_range, func);
|
129
|
+
#else
|
130
|
+
cgh.parallel_for(nd_range, func);
|
131
|
+
#endif
|
132
|
+
}
|
133
|
+
|
134
|
+
template <int NR, typename Func>
|
135
|
+
__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
|
136
|
+
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
137
|
+
syclex::nd_launch(*q, nd_range, func);
|
138
|
+
#else
|
139
|
+
q->parallel_for(nd_range, func);
|
140
|
+
#endif
|
141
|
+
}
|
142
|
+
|
143
|
+
template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
|
144
|
+
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
145
|
+
syclex::submit(*stream, func);
|
146
|
+
#else
|
147
|
+
stream->submit(func);
|
148
|
+
#endif
|
149
|
+
}
|
150
|
+
|
121
151
|
namespace dpct
|
122
152
|
{
|
123
153
|
typedef sycl::queue *queue_ptr;
|