whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -33,14 +33,11 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
|
|
33
33
|
{
|
34
34
|
dpct::has_capability_or_fail(stream->get_device(),
|
35
35
|
{sycl::aspect::fp16});
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
[=](sycl::nd_item<3> item_ct1) {
|
42
|
-
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
|
43
|
-
});
|
36
|
+
sycl_parallel_for(
|
37
|
+
stream,
|
38
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
39
|
+
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
40
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
|
44
41
|
}
|
45
42
|
}
|
46
43
|
|
@@ -53,24 +50,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
53
50
|
dpct::has_capability_or_fail(stream->get_device(),
|
54
51
|
{sycl::aspect::fp16});
|
55
52
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
[=](sycl::nd_item<3> item_ct1) {
|
60
|
-
dequantize_block_q2_K(vx, y, item_ct1);
|
61
|
-
});
|
53
|
+
sycl_parallel_for(
|
54
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
55
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
|
62
56
|
}
|
63
57
|
#else
|
64
58
|
{
|
65
59
|
dpct::has_capability_or_fail(stream->get_device(),
|
66
60
|
{sycl::aspect::fp16});
|
67
61
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
[=](sycl::nd_item<3> item_ct1) {
|
72
|
-
dequantize_block_q2_K(vx, y, item_ct1);
|
73
|
-
});
|
62
|
+
sycl_parallel_for(
|
63
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
64
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
|
74
65
|
}
|
75
66
|
|
76
67
|
#endif
|
@@ -85,24 +76,18 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
85
76
|
dpct::has_capability_or_fail(stream->get_device(),
|
86
77
|
{sycl::aspect::fp16});
|
87
78
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
[=](sycl::nd_item<3> item_ct1) {
|
92
|
-
dequantize_block_q3_K(vx, y, item_ct1);
|
93
|
-
});
|
79
|
+
sycl_parallel_for(
|
80
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
81
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
|
94
82
|
}
|
95
83
|
#else
|
96
84
|
{
|
97
85
|
dpct::has_capability_or_fail(stream->get_device(),
|
98
86
|
{sycl::aspect::fp16});
|
99
87
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
[=](sycl::nd_item<3> item_ct1) {
|
104
|
-
dequantize_block_q3_K(vx, y, item_ct1);
|
105
|
-
});
|
88
|
+
sycl_parallel_for(
|
89
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
90
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
|
106
91
|
}
|
107
92
|
#endif
|
108
93
|
}
|
@@ -116,12 +101,9 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
116
101
|
dpct::has_capability_or_fail(stream->get_device(),
|
117
102
|
{sycl::aspect::fp16});
|
118
103
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
[=](sycl::nd_item<3> item_ct1) {
|
123
|
-
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
124
|
-
});
|
104
|
+
sycl_parallel_for(
|
105
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
106
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
|
125
107
|
}
|
126
108
|
}
|
127
109
|
|
@@ -135,13 +117,12 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
|
135
117
|
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
136
118
|
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
137
119
|
GGML_ASSERT(k % 2 == 0);
|
138
|
-
stream
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
120
|
+
sycl_parallel_for(stream,
|
121
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
|
122
|
+
sycl::range<3>(1, 1, WARP_SIZE)),
|
123
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
124
|
+
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
125
|
+
});
|
145
126
|
}
|
146
127
|
|
147
128
|
template <typename dst_t>
|
@@ -153,12 +134,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
153
134
|
dpct::has_capability_or_fail(stream->get_device(),
|
154
135
|
{sycl::aspect::fp16});
|
155
136
|
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
[=](sycl::nd_item<3> item_ct1) {
|
160
|
-
dequantize_block_q4_1(vx, y, nb32, item_ct1);
|
161
|
-
});
|
137
|
+
sycl_parallel_for(
|
138
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
139
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
|
162
140
|
}
|
163
141
|
}
|
164
142
|
|
@@ -171,14 +149,13 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
171
149
|
dpct::has_capability_or_fail(stream->get_device(),
|
172
150
|
{sycl::aspect::fp16});
|
173
151
|
|
174
|
-
stream
|
152
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
175
153
|
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
});
|
154
|
+
sycl_parallel_for(
|
155
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
156
|
+
[=](sycl::nd_item<3> item_ct1) {
|
157
|
+
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
158
|
+
});
|
182
159
|
});
|
183
160
|
}
|
184
161
|
}
|
@@ -191,13 +168,13 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
|
|
191
168
|
|
192
169
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
193
170
|
|
194
|
-
stream
|
171
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
195
172
|
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
196
173
|
|
197
|
-
cgh
|
198
|
-
|
199
|
-
|
200
|
-
|
174
|
+
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
175
|
+
[=](sycl::nd_item<1> item_ct1) {
|
176
|
+
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
177
|
+
});
|
201
178
|
});
|
202
179
|
}
|
203
180
|
|
@@ -210,24 +187,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
210
187
|
dpct::has_capability_or_fail(stream->get_device(),
|
211
188
|
{sycl::aspect::fp16});
|
212
189
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
[=](sycl::nd_item<3> item_ct1) {
|
217
|
-
dequantize_block_q5_K(vx, y, item_ct1);
|
218
|
-
});
|
190
|
+
sycl_parallel_for(
|
191
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
192
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
|
219
193
|
}
|
220
194
|
#else
|
221
195
|
{
|
222
196
|
dpct::has_capability_or_fail(stream->get_device(),
|
223
197
|
{sycl::aspect::fp16});
|
224
198
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
[=](sycl::nd_item<3> item_ct1) {
|
229
|
-
dequantize_block_q5_K(vx, y, item_ct1);
|
230
|
-
});
|
199
|
+
sycl_parallel_for(
|
200
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
201
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
|
231
202
|
}
|
232
203
|
|
233
204
|
#endif
|
@@ -242,29 +213,34 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
242
213
|
dpct::has_capability_or_fail(stream->get_device(),
|
243
214
|
{sycl::aspect::fp16});
|
244
215
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
[=](sycl::nd_item<3> item_ct1) {
|
249
|
-
dequantize_block_q6_K(vx, y, item_ct1);
|
250
|
-
});
|
216
|
+
sycl_parallel_for(
|
217
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
218
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
|
251
219
|
}
|
252
220
|
#else
|
253
221
|
{
|
254
222
|
dpct::has_capability_or_fail(stream->get_device(),
|
255
223
|
{sycl::aspect::fp16});
|
256
224
|
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
[=](sycl::nd_item<3> item_ct1) {
|
261
|
-
dequantize_block_q6_K(vx, y, item_ct1);
|
262
|
-
});
|
225
|
+
sycl_parallel_for(
|
226
|
+
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
227
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
|
263
228
|
}
|
264
229
|
|
265
230
|
#endif
|
266
231
|
}
|
267
232
|
|
233
|
+
template <typename dst_t>
|
234
|
+
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
235
|
+
const int64_t nb = k / QK_K;
|
236
|
+
|
237
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
238
|
+
|
239
|
+
sycl_parallel_for(stream,
|
240
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
241
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
242
|
+
}
|
243
|
+
|
268
244
|
template <typename dst_t>
|
269
245
|
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
270
246
|
dpct::queue_ptr stream) {
|
@@ -273,15 +249,10 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
273
249
|
dpct::has_capability_or_fail(stream->get_device(),
|
274
250
|
{sycl::aspect::fp16});
|
275
251
|
|
276
|
-
stream
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
[=](sycl::nd_item<3> item_ct1) {
|
281
|
-
dequantize_block_iq1_s(
|
282
|
-
vx, y, item_ct1, iq1s_grid_gpu
|
283
|
-
);
|
284
|
-
});
|
252
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
253
|
+
sycl_parallel_for(
|
254
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
255
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
|
285
256
|
});
|
286
257
|
}
|
287
258
|
}
|
@@ -294,15 +265,10 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
294
265
|
dpct::has_capability_or_fail(stream->get_device(),
|
295
266
|
{sycl::aspect::fp16});
|
296
267
|
|
297
|
-
stream
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
[=](sycl::nd_item<3> item_ct1) {
|
302
|
-
dequantize_block_iq1_m(
|
303
|
-
vx, y, item_ct1, iq1s_grid_gpu
|
304
|
-
);
|
305
|
-
});
|
268
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
269
|
+
sycl_parallel_for(
|
270
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
271
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
|
306
272
|
});
|
307
273
|
}
|
308
274
|
}
|
@@ -315,15 +281,12 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
|
|
315
281
|
dpct::has_capability_or_fail(stream->get_device(),
|
316
282
|
{sycl::aspect::fp16});
|
317
283
|
|
318
|
-
stream
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
vx, y, item_ct1, iq2xxs_grid,
|
325
|
-
ksigns_iq2xs, kmask_iq2xs);
|
326
|
-
});
|
284
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
285
|
+
sycl_parallel_for(
|
286
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
287
|
+
[=](sycl::nd_item<3> item_ct1) {
|
288
|
+
dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
289
|
+
});
|
327
290
|
});
|
328
291
|
}
|
329
292
|
}
|
@@ -336,15 +299,12 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
|
|
336
299
|
dpct::has_capability_or_fail(stream->get_device(),
|
337
300
|
{sycl::aspect::fp16});
|
338
301
|
|
339
|
-
stream
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
vx, y, item_ct1, iq2xs_grid,
|
346
|
-
ksigns_iq2xs, kmask_iq2xs);
|
347
|
-
});
|
302
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
303
|
+
sycl_parallel_for(
|
304
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
305
|
+
[=](sycl::nd_item<3> item_ct1) {
|
306
|
+
dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
|
307
|
+
});
|
348
308
|
});
|
349
309
|
}
|
350
310
|
}
|
@@ -357,13 +317,10 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
357
317
|
dpct::has_capability_or_fail(stream->get_device(),
|
358
318
|
{sycl::aspect::fp16});
|
359
319
|
|
360
|
-
stream
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
[=](sycl::nd_item<3> item_ct1) {
|
365
|
-
dequantize_block_iq2_s(vx, y, item_ct1);
|
366
|
-
});
|
320
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
321
|
+
sycl_parallel_for(
|
322
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
323
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
|
367
324
|
});
|
368
325
|
}
|
369
326
|
}
|
@@ -377,15 +334,12 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
|
|
377
334
|
dpct::has_capability_or_fail(stream->get_device(),
|
378
335
|
{sycl::aspect::fp16});
|
379
336
|
|
380
|
-
stream
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
vx, y, item_ct1, iq3xxs_grid,
|
387
|
-
ksigns_iq2xs, kmask_iq2xs);
|
388
|
-
});
|
337
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
338
|
+
sycl_parallel_for(
|
339
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
340
|
+
[=](sycl::nd_item<3> item_ct1) {
|
341
|
+
dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
342
|
+
});
|
389
343
|
});
|
390
344
|
}
|
391
345
|
}
|
@@ -398,14 +352,10 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
|
398
352
|
dpct::has_capability_or_fail(stream->get_device(),
|
399
353
|
{sycl::aspect::fp16});
|
400
354
|
|
401
|
-
stream
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
[=](sycl::nd_item<3> item_ct1) {
|
406
|
-
dequantize_block_iq3_s(
|
407
|
-
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
|
408
|
-
});
|
355
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
356
|
+
sycl_parallel_for(
|
357
|
+
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
358
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
|
409
359
|
});
|
410
360
|
}
|
411
361
|
}
|
@@ -421,14 +371,11 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
|
|
421
371
|
dpct::has_capability_or_fail(stream->get_device(),
|
422
372
|
{sycl::aspect::fp16});
|
423
373
|
|
424
|
-
stream
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
[=](sycl::nd_item<3> item_ct1) {
|
430
|
-
dequantize_block_iq4_xs(vx, y, item_ct1);
|
431
|
-
});
|
374
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
375
|
+
sycl_parallel_for(
|
376
|
+
cgh,
|
377
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
378
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
|
432
379
|
});
|
433
380
|
}
|
434
381
|
#endif
|
@@ -442,14 +389,11 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
|
|
442
389
|
dpct::has_capability_or_fail(stream->get_device(),
|
443
390
|
{sycl::aspect::fp16});
|
444
391
|
|
445
|
-
stream
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
[=](sycl::nd_item<3> item_ct1) {
|
451
|
-
dequantize_block_iq4_nl(vx, y, item_ct1);
|
452
|
-
});
|
392
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
393
|
+
sycl_parallel_for(
|
394
|
+
cgh,
|
395
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
396
|
+
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
|
453
397
|
});
|
454
398
|
}
|
455
399
|
}
|
@@ -530,7 +474,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
|
530
474
|
case GGML_TYPE_Q5_K:
|
531
475
|
return dequantize_row_q5_K_sycl;
|
532
476
|
case GGML_TYPE_Q6_K:
|
533
|
-
|
477
|
+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
478
|
+
return dequantize_row_q6_K_sycl_reorder;
|
479
|
+
} else {
|
480
|
+
return dequantize_row_q6_K_sycl;
|
481
|
+
}
|
534
482
|
case GGML_TYPE_IQ1_S:
|
535
483
|
return dequantize_row_iq1_s_sycl;
|
536
484
|
case GGML_TYPE_IQ1_M:
|
@@ -587,7 +535,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|
587
535
|
case GGML_TYPE_Q5_K:
|
588
536
|
return dequantize_row_q5_K_sycl;
|
589
537
|
case GGML_TYPE_Q6_K:
|
590
|
-
|
538
|
+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
539
|
+
return dequantize_row_q6_K_sycl_reorder;
|
540
|
+
} else {
|
541
|
+
return dequantize_row_q6_K_sycl;
|
542
|
+
}
|
591
543
|
case GGML_TYPE_IQ1_S:
|
592
544
|
return dequantize_row_iq1_s_sycl;
|
593
545
|
case GGML_TYPE_IQ1_M:
|