whispercpp 1.3.1 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +7 -3
- data/README.md +161 -43
- data/Rakefile +45 -13
- data/ext/.gitignore +4 -8
- data/ext/dependencies.rb +73 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +85 -0
- data/ext/ruby_whisper.c +177 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +672 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1303 -0
- data/ext/ruby_whisper_segment.c +220 -0
- data/ext/ruby_whisper_transcribe.cpp +93 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +255 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +133 -0
- data/ext/sources/examples/addon.node/addon.cpp +557 -0
- data/ext/sources/examples/addon.node/index.js +57 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +176 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1295 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +800 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +175 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +469 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +226 -0
- data/ext/sources/examples/server/CMakeLists.txt +15 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1238 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +435 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +43 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1914 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +464 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +843 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +147 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +685 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +59 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2845 -0
- data/ext/sources/examples/talk-llama/llama-context.h +297 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1693 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +710 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +103 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +207 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +44 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +439 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +59 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +116 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1163 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +282 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +15114 -0
- data/ext/sources/examples/talk-llama/llama-model.h +452 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +1049 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3377 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +132 -0
- data/ext/sources/examples/talk-llama/llama.cpp +358 -0
- data/ext/sources/examples/talk-llama/llama.h +1484 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +810 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +854 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +149 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +251 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +435 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +50 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +10 -8
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +11 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +325 -269
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +404 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +92 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +69 -34
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +75 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +140 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +588 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +16 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +597 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +3 -2
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/{ggml/src/ggml-cpu/cpu-feats-x86.cpp → sources/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp} +5 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +3285 -0
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +73 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +172 -41
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3551 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +78 -25
- data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.cpp → sources/ggml/src/ggml-cpu/hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3594 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +19 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +9786 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +118 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/{ggml/src/ggml-cpu/ggml-cpu-quants.h → sources/ggml/src/ggml-cpu/quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +1184 -0
- data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.cpp → sources/ggml/src/ggml-cpu/traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +345 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +1027 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +851 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +752 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +31 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1474 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +638 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3647 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +506 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +155 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +26 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +378 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +66 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +147 -158
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +121 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +649 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2504 -1108
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +2102 -1463
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +110 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +6494 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +120 -128
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +494 -84
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +344 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +561 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +56 -70
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +8 -12
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +575 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +839 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +823 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +188 -67
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2987 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1120 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +84 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +102 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +212 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1197 -1295
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +60 -81
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1065 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +482 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +111 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +472 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +38 -28
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +15 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +6 -11
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1307 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +289 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +200 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3822 -1335
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +61 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +203 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +918 -1782
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +1351 -0
- data/ext/{include → sources/include}/whisper.h +70 -2
- data/ext/sources/src/CMakeLists.txt +145 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +36 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +29 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1966 -386
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +39 -5
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +202 -126
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +510 -0
- data/test/helper.rb +24 -0
- data/{tests → test}/test_callback.rb +45 -3
- data/{tests → test}/test_error.rb +2 -2
- data/{tests → test}/test_model.rb +47 -0
- data/test/test_package.rb +51 -0
- data/test/test_params.rb +297 -0
- data/test/test_segment.rb +146 -0
- data/test/test_vad.rb +19 -0
- data/test/test_vad_params.rb +103 -0
- data/{tests → test}/test_whisper.rb +106 -36
- data/whispercpp.gemspec +5 -5
- metadata +837 -134
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -10835
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/convert.cpp +0 -547
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +0 -1015
- data/ext/ggml/src/ggml-sycl/norm.cpp +0 -378
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- data/tests/helper.rb +0 -7
- data/tests/test_package.rb +0 -31
- data/tests/test_params.rb +0 -160
- data/tests/test_segment.rb +0 -83
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.h → sources/ggml/src/ggml-cpu/hbm.h} +0 -0
- /data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.h → sources/ggml/src/ggml-cpu/traits.h} +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
#include "whisper.h"
|
2
|
-
|
3
|
-
#include "ggml-cpu.h"
|
2
|
+
#include "whisper-arch.h"
|
4
3
|
|
5
4
|
#include "ggml.h"
|
5
|
+
#include "ggml-cpp.h"
|
6
6
|
#include "ggml-alloc.h"
|
7
7
|
#include "ggml-backend.h"
|
8
8
|
|
@@ -17,37 +17,36 @@
|
|
17
17
|
#include <atomic>
|
18
18
|
#include <algorithm>
|
19
19
|
#include <cassert>
|
20
|
+
#include <cfloat>
|
20
21
|
#define _USE_MATH_DEFINES
|
21
22
|
#include <cmath>
|
22
|
-
#include <
|
23
|
+
#include <climits>
|
24
|
+
#include <codecvt>
|
23
25
|
#include <cstdarg>
|
26
|
+
#include <cstdio>
|
24
27
|
#include <cstring>
|
25
28
|
#include <fstream>
|
29
|
+
#include <functional>
|
26
30
|
#include <map>
|
31
|
+
#include <mutex>
|
32
|
+
#include <random>
|
33
|
+
#include <regex>
|
27
34
|
#include <set>
|
28
35
|
#include <string>
|
29
36
|
#include <thread>
|
30
37
|
#include <vector>
|
31
|
-
#include <regex>
|
32
|
-
#include <random>
|
33
|
-
#include <functional>
|
34
|
-
#include <codecvt>
|
35
|
-
|
36
|
-
#if defined(_MSC_VER)
|
37
|
-
#pragma warning(disable: 4244 4267) // possible loss of data
|
38
|
-
#endif
|
39
|
-
|
40
|
-
#if defined(GGML_BIG_ENDIAN)
|
41
|
-
#include <bit>
|
42
38
|
|
39
|
+
#if defined(WHISPER_BIG_ENDIAN)
|
43
40
|
template<typename T>
|
44
41
|
static T byteswap(T value) {
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
42
|
+
T value_swapped;
|
43
|
+
char * source = reinterpret_cast<char *>(&value);
|
44
|
+
char * target = reinterpret_cast<char *>(&value_swapped);
|
45
|
+
int size = sizeof(T);
|
46
|
+
for (int i = 0; i < size; i++) {
|
47
|
+
target[size - 1 - i] = source[i];
|
48
|
+
}
|
49
|
+
return value_swapped;
|
51
50
|
}
|
52
51
|
|
53
52
|
template<typename T>
|
@@ -83,14 +82,14 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
83
82
|
}
|
84
83
|
|
85
84
|
#define BYTESWAP_VALUE(d) d = byteswap(d)
|
86
|
-
#define BYTESWAP_FILTERS(f)
|
85
|
+
#define BYTESWAP_FILTERS(f) \
|
87
86
|
do { \
|
88
87
|
for (auto & datum : f.data) { \
|
89
88
|
datum = byteswap(datum); \
|
90
89
|
} \
|
91
90
|
} while (0)
|
92
|
-
#define BYTESWAP_TENSOR(t)
|
93
|
-
do {
|
91
|
+
#define BYTESWAP_TENSOR(t) \
|
92
|
+
do { \
|
94
93
|
byteswap_tensor(t); \
|
95
94
|
} while (0)
|
96
95
|
#else
|
@@ -141,34 +140,52 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
141
140
|
#define WHISPER_MAX_DECODERS 8
|
142
141
|
#define WHISPER_MAX_NODES 4096
|
143
142
|
|
143
|
+
static std::string format(const char * fmt, ...) {
|
144
|
+
va_list ap;
|
145
|
+
va_list ap2;
|
146
|
+
va_start(ap, fmt);
|
147
|
+
va_copy(ap2, ap);
|
148
|
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
149
|
+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
150
|
+
std::vector<char> buf(size + 1);
|
151
|
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
152
|
+
GGML_ASSERT(size2 == size);
|
153
|
+
va_end(ap2);
|
154
|
+
va_end(ap);
|
155
|
+
return std::string(buf.data(), size);
|
156
|
+
}
|
157
|
+
|
144
158
|
//
|
145
159
|
// ggml helpers
|
146
160
|
//
|
147
161
|
|
148
162
|
static bool ggml_graph_compute_helper(
|
149
163
|
struct ggml_cgraph * graph,
|
150
|
-
std::vector<uint8_t> & buf,
|
151
164
|
int n_threads,
|
152
165
|
ggml_abort_callback abort_callback,
|
153
166
|
void * abort_callback_data) {
|
154
|
-
|
167
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
168
|
+
|
169
|
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
155
170
|
|
156
|
-
|
157
|
-
|
171
|
+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
172
|
+
if (set_abort_callback_fn) {
|
173
|
+
set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
|
174
|
+
}
|
158
175
|
|
159
|
-
|
160
|
-
|
161
|
-
|
176
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
177
|
+
if (ggml_backend_set_n_threads_fn) {
|
178
|
+
ggml_backend_set_n_threads_fn(backend.get(), n_threads);
|
162
179
|
}
|
163
180
|
|
164
|
-
return
|
181
|
+
return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
|
165
182
|
}
|
166
183
|
|
167
184
|
static bool ggml_graph_compute_helper(
|
168
185
|
ggml_backend_sched_t sched,
|
169
186
|
struct ggml_cgraph * graph,
|
170
|
-
int n_threads
|
171
|
-
|
187
|
+
int n_threads,
|
188
|
+
bool sched_reset = true) {
|
172
189
|
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
|
173
190
|
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
|
174
191
|
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
@@ -180,11 +197,61 @@ static bool ggml_graph_compute_helper(
|
|
180
197
|
}
|
181
198
|
}
|
182
199
|
|
183
|
-
bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
|
184
|
-
|
200
|
+
const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
|
201
|
+
|
202
|
+
if (!t || sched_reset) {
|
203
|
+
ggml_backend_sched_reset(sched);
|
204
|
+
}
|
205
|
+
|
206
|
+
return t;
|
207
|
+
}
|
208
|
+
|
209
|
+
// TODO: move these functions to ggml-base with support for ggml-backend?
|
210
|
+
|
211
|
+
static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
|
212
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
213
|
+
GGML_ASSERT(ggml_is_contiguous(t));
|
214
|
+
size_t nels = ggml_nelements(t);
|
215
|
+
for (size_t i = 0; i < nels; ++i) {
|
216
|
+
((float *) t->data)[i] = v;
|
217
|
+
}
|
218
|
+
return t;
|
219
|
+
}
|
220
|
+
|
221
|
+
static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
|
222
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
223
|
+
GGML_ASSERT(ggml_is_contiguous(t));
|
224
|
+
size_t nels = ggml_nelements(t);
|
225
|
+
for (size_t i = 0; i < nels; ++i) {
|
226
|
+
((int32_t *) t->data)[i] = v;
|
227
|
+
}
|
185
228
|
return t;
|
186
229
|
}
|
187
230
|
|
231
|
+
static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
232
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
233
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
234
|
+
return *(float *) data;
|
235
|
+
}
|
236
|
+
|
237
|
+
static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
|
238
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
239
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
240
|
+
*(float *) data = v;
|
241
|
+
}
|
242
|
+
|
243
|
+
static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
244
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
245
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
246
|
+
return *(int32_t *) data;
|
247
|
+
}
|
248
|
+
|
249
|
+
static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
|
250
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
251
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
252
|
+
*(int32_t *) data = v;
|
253
|
+
}
|
254
|
+
|
188
255
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
189
256
|
// the idea is to represent the original matrix multiplication:
|
190
257
|
//
|
@@ -428,6 +495,7 @@ struct whisper_segment {
|
|
428
495
|
int64_t t1;
|
429
496
|
|
430
497
|
std::string text;
|
498
|
+
float no_speech_prob;
|
431
499
|
|
432
500
|
std::vector<whisper_token_data> tokens;
|
433
501
|
|
@@ -520,7 +588,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
|
|
520
588
|
auto & sched = allocr.sched;
|
521
589
|
auto & meta = allocr.meta;
|
522
590
|
|
523
|
-
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
591
|
+
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
|
524
592
|
|
525
593
|
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
526
594
|
|
@@ -716,10 +784,10 @@ struct whisper_model {
|
|
716
784
|
std::vector<whisper_layer_decoder> layers_decoder;
|
717
785
|
|
718
786
|
// ggml context that contains all the meta information about the model tensors
|
719
|
-
|
787
|
+
std::vector<ggml_context *> ctxs;
|
720
788
|
|
721
789
|
// the model backend data is read-only and can be shared between processors
|
722
|
-
ggml_backend_buffer_t
|
790
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
723
791
|
|
724
792
|
// tensors
|
725
793
|
int n_loaded;
|
@@ -791,6 +859,11 @@ struct whisper_aheads_masks {
|
|
791
859
|
ggml_backend_buffer_t buffer = nullptr;
|
792
860
|
};
|
793
861
|
|
862
|
+
struct vad_time_mapping {
|
863
|
+
int64_t processed_time; // Time in processed (VAD) audio
|
864
|
+
int64_t original_time; // Corresponding time in original audio
|
865
|
+
};
|
866
|
+
|
794
867
|
struct whisper_state {
|
795
868
|
int64_t t_sample_us = 0;
|
796
869
|
int64_t t_encode_us = 0;
|
@@ -876,6 +949,19 @@ struct whisper_state {
|
|
876
949
|
|
877
950
|
// [EXPERIMENTAL] speed-up techniques
|
878
951
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
952
|
+
|
953
|
+
whisper_vad_context * vad_context = nullptr;
|
954
|
+
|
955
|
+
struct vad_segment_info {
|
956
|
+
int64_t orig_start;
|
957
|
+
int64_t orig_end;
|
958
|
+
int64_t vad_start;
|
959
|
+
int64_t vad_end;
|
960
|
+
};
|
961
|
+
std::vector<vad_segment_info> vad_segments;
|
962
|
+
bool has_vad_segments = false;
|
963
|
+
|
964
|
+
std::vector<vad_time_mapping> vad_mapping_table;
|
879
965
|
};
|
880
966
|
|
881
967
|
struct whisper_context {
|
@@ -1234,21 +1320,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1234
1320
|
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
1235
1321
|
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
1236
1322
|
|
1323
|
+
ggml_backend_dev_t dev = nullptr;
|
1324
|
+
|
1325
|
+
int cnt = 0;
|
1237
1326
|
if (params.use_gpu) {
|
1238
1327
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1239
|
-
ggml_backend_dev_t
|
1240
|
-
if (ggml_backend_dev_type(
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1328
|
+
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
|
1329
|
+
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1330
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
1331
|
+
dev = dev_cur;
|
1332
|
+
}
|
1333
|
+
|
1334
|
+
if (++cnt > params.gpu_device) {
|
1335
|
+
break;
|
1245
1336
|
}
|
1246
|
-
return result;
|
1247
1337
|
}
|
1248
1338
|
}
|
1249
1339
|
}
|
1250
1340
|
|
1251
|
-
|
1341
|
+
if (dev == nullptr) {
|
1342
|
+
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
|
1343
|
+
return nullptr;
|
1344
|
+
}
|
1345
|
+
|
1346
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1347
|
+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
1348
|
+
if (!result) {
|
1349
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1350
|
+
}
|
1351
|
+
|
1352
|
+
return result;
|
1252
1353
|
}
|
1253
1354
|
|
1254
1355
|
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
@@ -1274,28 +1375,118 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
|
|
1274
1375
|
}
|
1275
1376
|
}
|
1276
1377
|
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1378
|
+
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
1379
|
+
if (backend_cpu == nullptr) {
|
1380
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
1381
|
+
}
|
1382
|
+
result.push_back(backend_cpu);
|
1280
1383
|
|
1281
1384
|
return result;
|
1282
1385
|
}
|
1283
1386
|
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1387
|
+
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
|
1388
|
+
|
1389
|
+
static buft_list_t make_buft_list(whisper_context_params & params) {
|
1390
|
+
// Prio order: GPU -> CPU Extra -> CPU
|
1391
|
+
buft_list_t buft_list;
|
1392
|
+
|
1393
|
+
// GPU
|
1394
|
+
if (params.use_gpu) {
|
1395
|
+
int cnt = 0;
|
1396
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1397
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1398
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1399
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
1400
|
+
auto * buft = ggml_backend_dev_buffer_type(dev);
|
1401
|
+
if (buft) {
|
1402
|
+
buft_list.emplace_back(dev, buft);
|
1403
|
+
}
|
1404
|
+
}
|
1405
|
+
|
1406
|
+
if (++cnt > params.gpu_device) {
|
1407
|
+
break;
|
1408
|
+
}
|
1409
|
+
}
|
1410
|
+
}
|
1411
|
+
}
|
1412
|
+
|
1413
|
+
// CPU Extra
|
1414
|
+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
1415
|
+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
1416
|
+
auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
1417
|
+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
1418
|
+
if (get_extra_bufts_fn) {
|
1419
|
+
ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
|
1420
|
+
while (extra_bufts && *extra_bufts) {
|
1421
|
+
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
1422
|
+
++extra_bufts;
|
1423
|
+
}
|
1287
1424
|
}
|
1288
1425
|
|
1289
|
-
//
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1426
|
+
// CPU
|
1427
|
+
buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
|
1428
|
+
|
1429
|
+
return buft_list;
|
1430
|
+
}
|
1431
|
+
|
1432
|
+
static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
1433
|
+
bool op_supported = true;
|
1434
|
+
|
1435
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
1436
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
1437
|
+
// GPU and default CPU backend support all operators
|
1438
|
+
op_supported = true;
|
1439
|
+
} else {
|
1440
|
+
switch (op) {
|
1441
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
1442
|
+
case GGML_OP_MUL_MAT: {
|
1443
|
+
ggml_init_params params = {
|
1444
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
1445
|
+
/*.mem_buffer =*/ nullptr,
|
1446
|
+
/*.no_alloc =*/ true,
|
1447
|
+
};
|
1448
|
+
|
1449
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1450
|
+
if (!ctx_ptr) {
|
1451
|
+
throw std::runtime_error("failed to create ggml context");
|
1452
|
+
}
|
1453
|
+
ggml_context * ctx = ctx_ptr.get();
|
1454
|
+
|
1455
|
+
ggml_tensor * op_tensor = nullptr;
|
1456
|
+
|
1457
|
+
int64_t n_ctx = hparams.n_audio_ctx;
|
1458
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
1459
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
1460
|
+
|
1461
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
1462
|
+
GGML_ASSERT(w->buffer == nullptr);
|
1463
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
1464
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
1465
|
+
ggml_backend_buffer_free(w->buffer);
|
1466
|
+
w->buffer = nullptr;
|
1467
|
+
break;
|
1468
|
+
}
|
1469
|
+
default: {
|
1470
|
+
op_supported = false;
|
1471
|
+
break;
|
1472
|
+
}
|
1473
|
+
};
|
1474
|
+
}
|
1475
|
+
|
1476
|
+
return op_supported;
|
1477
|
+
}
|
1478
|
+
|
1479
|
+
static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
1480
|
+
GGML_ASSERT(!buft_list.empty());
|
1481
|
+
for (const auto & p : buft_list) {
|
1482
|
+
ggml_backend_dev_t dev = p.first;
|
1483
|
+
ggml_backend_buffer_type_t buft = p.second;
|
1484
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
1485
|
+
return buft;
|
1295
1486
|
}
|
1296
1487
|
}
|
1297
1488
|
|
1298
|
-
return
|
1489
|
+
return nullptr;
|
1299
1490
|
}
|
1300
1491
|
|
1301
1492
|
// load the model from a ggml file
|
@@ -1504,31 +1695,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1504
1695
|
const ggml_type wtype = wctx.wtype;
|
1505
1696
|
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
1506
1697
|
|
1507
|
-
|
1508
|
-
{
|
1509
|
-
const auto & hparams = model.hparams;
|
1698
|
+
const auto & hparams = model.hparams;
|
1510
1699
|
|
1511
|
-
|
1512
|
-
|
1700
|
+
const int n_audio_layer = hparams.n_audio_layer;
|
1701
|
+
const int n_text_layer = hparams.n_text_layer;
|
1513
1702
|
|
1514
|
-
|
1703
|
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
1515
1704
|
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1705
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
1706
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
1707
|
+
auto it = ctx_map.find(buft);
|
1708
|
+
if (it == ctx_map.end()) {
|
1709
|
+
ggml_init_params params = {
|
1710
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
1711
|
+
/*.mem_buffer =*/ nullptr,
|
1712
|
+
/*.no_alloc =*/ true,
|
1713
|
+
};
|
1521
1714
|
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1715
|
+
ggml_context * ctx = ggml_init(params);
|
1716
|
+
if (!ctx) {
|
1717
|
+
throw std::runtime_error("failed to create ggml context");
|
1718
|
+
}
|
1719
|
+
|
1720
|
+
ctx_map[buft] = ctx;
|
1721
|
+
model.ctxs.emplace_back(ctx);
|
1722
|
+
|
1723
|
+
return ctx;
|
1526
1724
|
}
|
1527
|
-
|
1725
|
+
|
1726
|
+
return it->second;
|
1727
|
+
};
|
1728
|
+
|
1729
|
+
// Create a list of available bufts, in priority order
|
1730
|
+
buft_list_t buft_list = make_buft_list(wctx.params);
|
1731
|
+
|
1732
|
+
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
|
1733
|
+
ggml_op op = ASR_TENSOR_INFO.at(type);
|
1734
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
1735
|
+
if (!buft) {
|
1736
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
|
1737
|
+
}
|
1738
|
+
|
1739
|
+
ggml_context * ctx = get_ctx(buft);
|
1740
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
1741
|
+
|
1742
|
+
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
|
1743
|
+
|
1744
|
+
return tensor;
|
1745
|
+
};
|
1746
|
+
|
1528
1747
|
|
1529
1748
|
// prepare tensors for the weights
|
1530
1749
|
{
|
1531
|
-
|
1750
|
+
ggml_init_params params = {
|
1751
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
1752
|
+
/*.mem_buffer =*/ nullptr,
|
1753
|
+
/*.no_alloc =*/ true,
|
1754
|
+
};
|
1755
|
+
|
1756
|
+
ggml_context * ctx = ggml_init(params);
|
1532
1757
|
|
1533
1758
|
const auto & hparams = model.hparams;
|
1534
1759
|
|
@@ -1548,189 +1773,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1548
1773
|
model.layers_decoder.resize(n_text_layer);
|
1549
1774
|
|
1550
1775
|
// encoder
|
1551
|
-
|
1552
|
-
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
1553
|
-
|
1554
|
-
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
1555
|
-
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1556
|
-
|
1557
|
-
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
1558
|
-
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1559
|
-
|
1560
|
-
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1561
|
-
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1562
|
-
|
1563
|
-
// map by name
|
1564
|
-
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
1565
|
-
|
1566
|
-
model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
|
1567
|
-
model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
|
1568
|
-
|
1569
|
-
model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
|
1570
|
-
model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
|
1571
|
-
|
1572
|
-
model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
|
1573
|
-
model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
|
1574
|
-
|
1575
|
-
for (int i = 0; i < n_audio_layer; ++i) {
|
1576
|
-
auto & layer = model.layers_encoder[i];
|
1577
|
-
|
1578
|
-
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1579
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1580
|
-
|
1581
|
-
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
|
1582
|
-
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
|
1776
|
+
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
|
1583
1777
|
|
1584
|
-
|
1585
|
-
|
1778
|
+
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
|
1779
|
+
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
1586
1780
|
|
1587
|
-
|
1588
|
-
|
1781
|
+
model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
|
1782
|
+
model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
1589
1783
|
|
1590
|
-
|
1591
|
-
|
1784
|
+
model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
1785
|
+
model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
1592
1786
|
|
1593
|
-
|
1787
|
+
for (int i = 0; i < n_audio_layer; ++i) {
|
1788
|
+
auto & layer = model.layers_encoder[i];
|
1594
1789
|
|
1595
|
-
|
1596
|
-
|
1790
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1791
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1597
1792
|
|
1598
|
-
|
1599
|
-
|
1793
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
1794
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
|
1600
1795
|
|
1601
|
-
|
1602
|
-
|
1603
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1796
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
1797
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1604
1798
|
|
1605
|
-
|
1606
|
-
|
1799
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1800
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1607
1801
|
|
1608
|
-
|
1609
|
-
|
1802
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1803
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1610
1804
|
|
1611
|
-
|
1612
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
1805
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1613
1806
|
|
1614
|
-
|
1615
|
-
|
1807
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1808
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1616
1809
|
|
1617
|
-
|
1618
|
-
|
1619
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
1620
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1621
|
-
|
1622
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
1623
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
1624
|
-
}
|
1810
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1811
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1625
1812
|
}
|
1626
1813
|
|
1627
1814
|
// decoder
|
1628
|
-
|
1629
|
-
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
|
1630
|
-
|
1631
|
-
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
1632
|
-
|
1633
|
-
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1634
|
-
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1635
|
-
|
1636
|
-
// map by name
|
1637
|
-
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
1638
|
-
|
1639
|
-
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
1640
|
-
|
1641
|
-
model.tensors["decoder.ln.weight"] = model.d_ln_w;
|
1642
|
-
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
1643
|
-
|
1644
|
-
for (int i = 0; i < n_text_layer; ++i) {
|
1645
|
-
auto & layer = model.layers_decoder[i];
|
1646
|
-
|
1647
|
-
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1648
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1649
|
-
|
1650
|
-
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
|
1651
|
-
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
|
1652
|
-
|
1653
|
-
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
|
1654
|
-
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1655
|
-
|
1656
|
-
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1657
|
-
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1658
|
-
|
1659
|
-
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1660
|
-
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1661
|
-
|
1662
|
-
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1815
|
+
model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
|
1663
1816
|
|
1664
|
-
|
1665
|
-
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1817
|
+
model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
|
1666
1818
|
|
1667
|
-
|
1668
|
-
|
1819
|
+
model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
1820
|
+
model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
1669
1821
|
|
1670
|
-
|
1671
|
-
|
1822
|
+
for (int i = 0; i < n_text_layer; ++i) {
|
1823
|
+
auto & layer = model.layers_decoder[i];
|
1672
1824
|
|
1673
|
-
|
1674
|
-
|
1825
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1826
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1675
1827
|
|
1676
|
-
|
1828
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
|
1829
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
|
1677
1830
|
|
1678
|
-
|
1679
|
-
|
1831
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
|
1832
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1680
1833
|
|
1681
|
-
|
1682
|
-
|
1834
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1835
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1683
1836
|
|
1684
|
-
|
1685
|
-
|
1686
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1837
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1838
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1687
1839
|
|
1688
|
-
|
1689
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
1840
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1690
1841
|
|
1691
|
-
|
1692
|
-
|
1842
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1843
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1693
1844
|
|
1694
|
-
|
1695
|
-
|
1845
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1846
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1696
1847
|
|
1697
|
-
|
1698
|
-
|
1848
|
+
layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1849
|
+
layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1699
1850
|
|
1700
|
-
|
1851
|
+
layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1852
|
+
layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1701
1853
|
|
1702
|
-
|
1703
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1854
|
+
layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1704
1855
|
|
1705
|
-
|
1706
|
-
|
1856
|
+
layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1857
|
+
layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1707
1858
|
|
1708
|
-
|
1709
|
-
|
1710
|
-
|
1711
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
1712
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
1713
|
-
|
1714
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
1715
|
-
|
1716
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
1717
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
1718
|
-
|
1719
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
1720
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
1721
|
-
}
|
1859
|
+
layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1860
|
+
layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1722
1861
|
}
|
1862
|
+
|
1863
|
+
ggml_free(ctx);
|
1723
1864
|
}
|
1724
1865
|
|
1725
1866
|
// allocate tensors in the backend buffers
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1867
|
+
for (auto & p : ctx_map) {
|
1868
|
+
ggml_backend_buffer_type_t buft = p.first;
|
1869
|
+
ggml_context * ctx = p.second;
|
1870
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
1871
|
+
if (buf) {
|
1872
|
+
model.buffers.emplace_back(buf);
|
1731
1873
|
|
1732
|
-
|
1733
|
-
|
1874
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
1875
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
1876
|
+
}
|
1877
|
+
}
|
1734
1878
|
|
1735
1879
|
// load weights
|
1736
1880
|
{
|
@@ -1793,11 +1937,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1793
1937
|
return false;
|
1794
1938
|
}
|
1795
1939
|
|
1796
|
-
|
1797
|
-
|
1798
|
-
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
1799
|
-
|
1800
|
-
if (ggml_backend_buffer_is_host(model.buffer)) {
|
1940
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
1801
1941
|
// for the CPU and Metal backend, we can read directly into the tensor
|
1802
1942
|
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
1803
1943
|
BYTESWAP_TENSOR(tensor);
|
@@ -1810,7 +1950,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1810
1950
|
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
1811
1951
|
}
|
1812
1952
|
|
1813
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
|
1814
1953
|
total_size += ggml_nbytes(tensor);
|
1815
1954
|
model.n_loaded++;
|
1816
1955
|
}
|
@@ -1825,7 +1964,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1825
1964
|
}
|
1826
1965
|
}
|
1827
1966
|
|
1828
|
-
|
1967
|
+
for (auto & buf : model.buffers) {
|
1968
|
+
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
1969
|
+
}
|
1829
1970
|
|
1830
1971
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
1831
1972
|
|
@@ -3710,15 +3851,24 @@ void whisper_free_state(struct whisper_state * state) {
|
|
3710
3851
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3711
3852
|
aheads_masks_free(state->aheads_masks);
|
3712
3853
|
|
3854
|
+
if (state->vad_context != nullptr) {
|
3855
|
+
whisper_vad_free(state->vad_context);
|
3856
|
+
state->vad_context = nullptr;
|
3857
|
+
}
|
3858
|
+
|
3713
3859
|
delete state;
|
3714
3860
|
}
|
3715
3861
|
}
|
3716
3862
|
|
3717
3863
|
void whisper_free(struct whisper_context * ctx) {
|
3718
3864
|
if (ctx) {
|
3719
|
-
|
3865
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
3866
|
+
ggml_free(context);
|
3867
|
+
}
|
3720
3868
|
|
3721
|
-
|
3869
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
3870
|
+
ggml_backend_buffer_free(buf);
|
3871
|
+
}
|
3722
3872
|
|
3723
3873
|
whisper_free_state(ctx->state);
|
3724
3874
|
|
@@ -4136,11 +4286,11 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
4136
4286
|
|
4137
4287
|
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
4138
4288
|
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
4139
|
-
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
4140
|
-
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
4141
|
-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
4142
|
-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
4143
|
-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
4289
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
4290
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
4291
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
4292
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
4293
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
4144
4294
|
}
|
4145
4295
|
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
4146
4296
|
}
|
@@ -4182,113 +4332,1238 @@ const char * whisper_print_system_info(void) {
|
|
4182
4332
|
static std::string s;
|
4183
4333
|
|
4184
4334
|
s = "";
|
4185
|
-
s += "
|
4186
|
-
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
4187
|
-
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
4188
|
-
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
4189
|
-
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
4190
|
-
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
4191
|
-
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
4192
|
-
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
4193
|
-
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
4194
|
-
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
4195
|
-
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
4196
|
-
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
4335
|
+
s += "WHISPER : ";
|
4197
4336
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
4198
4337
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
4199
4338
|
|
4339
|
+
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
4340
|
+
auto * reg = ggml_backend_reg_get(i);
|
4341
|
+
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
|
4342
|
+
if (get_features_fn) {
|
4343
|
+
ggml_backend_feature * features = get_features_fn(reg);
|
4344
|
+
s += ggml_backend_reg_name(reg);
|
4345
|
+
s += " : ";
|
4346
|
+
for (; features->name; features++) {
|
4347
|
+
s += features->name;
|
4348
|
+
s += " = ";
|
4349
|
+
s += features->value;
|
4350
|
+
s += " | ";
|
4351
|
+
}
|
4352
|
+
}
|
4353
|
+
}
|
4200
4354
|
return s.c_str();
|
4201
4355
|
}
|
4202
4356
|
|
4203
4357
|
//////////////////////////////////
|
4204
|
-
//
|
4358
|
+
// Voice Activity Detection (VAD)
|
4205
4359
|
//////////////////////////////////
|
4206
4360
|
|
4207
|
-
|
4208
|
-
|
4209
|
-
|
4210
|
-
|
4211
|
-
|
4212
|
-
|
4213
|
-
|
4214
|
-
|
4215
|
-
|
4216
|
-
|
4361
|
+
struct whisper_vad_hparams {
|
4362
|
+
int32_t n_encoder_layers;
|
4363
|
+
int32_t * encoder_in_channels;
|
4364
|
+
int32_t * encoder_out_channels;
|
4365
|
+
int32_t * kernel_sizes;
|
4366
|
+
int32_t lstm_input_size;
|
4367
|
+
int32_t lstm_hidden_size;
|
4368
|
+
int32_t final_conv_in;
|
4369
|
+
int32_t final_conv_out;
|
4370
|
+
};
|
4217
4371
|
|
4218
|
-
|
4219
|
-
|
4220
|
-
|
4221
|
-
|
4222
|
-
// invalid sequence, abort
|
4223
|
-
code_points.push_back(0);
|
4224
|
-
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
4225
|
-
}
|
4226
|
-
value = (value << 6) + (next_byte & 0x3F);
|
4227
|
-
++pos;
|
4228
|
-
--n_remain;
|
4229
|
-
}
|
4372
|
+
struct whisper_vad_model {
|
4373
|
+
std::string type;
|
4374
|
+
std::string version;
|
4375
|
+
whisper_vad_hparams hparams;
|
4230
4376
|
|
4231
|
-
|
4232
|
-
code_points.push_back(value);
|
4233
|
-
}
|
4377
|
+
struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
|
4234
4378
|
|
4235
|
-
//
|
4236
|
-
|
4237
|
-
|
4238
|
-
uint8_t highbits = first_byte >> 4;
|
4239
|
-
n_remain = lookup[highbits] - 1;
|
4379
|
+
// Encoder tensors - 4 convolutional layers
|
4380
|
+
struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
|
4381
|
+
struct ggml_tensor * encoder_0_bias; // [128]
|
4240
4382
|
|
4241
|
-
|
4242
|
-
|
4243
|
-
|
4244
|
-
code_points.push_back(0);
|
4245
|
-
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
|
4246
|
-
}
|
4383
|
+
// Second encoder layer
|
4384
|
+
struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
|
4385
|
+
struct ggml_tensor * encoder_1_bias; // [64]
|
4247
4386
|
|
4248
|
-
|
4249
|
-
|
4250
|
-
|
4251
|
-
while (*pos != 0 && n_remain > 0) {
|
4252
|
-
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
4253
|
-
++pos;
|
4254
|
-
--n_remain;
|
4255
|
-
}
|
4256
|
-
if (n_remain == 0) {
|
4257
|
-
code_points.push_back(value);
|
4258
|
-
}
|
4259
|
-
}
|
4260
|
-
code_points.push_back(0);
|
4387
|
+
// Third encoder layer
|
4388
|
+
struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
|
4389
|
+
struct ggml_tensor * encoder_2_bias; // [64]
|
4261
4390
|
|
4262
|
-
|
4263
|
-
|
4391
|
+
// Fourth encoder layer
|
4392
|
+
struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
|
4393
|
+
struct ggml_tensor * encoder_3_bias; // [128]
|
4264
4394
|
|
4265
|
-
//
|
4266
|
-
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4270
|
-
default: return false;
|
4271
|
-
}
|
4272
|
-
}
|
4395
|
+
// LSTM decoder tensors
|
4396
|
+
struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
|
4397
|
+
struct ggml_tensor * lstm_ih_bias; // [512]
|
4398
|
+
struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
|
4399
|
+
struct ggml_tensor * lstm_hh_bias; // [512]
|
4273
4400
|
|
4274
|
-
//
|
4275
|
-
|
4276
|
-
|
4277
|
-
const whisper_grammar_element * pos,
|
4278
|
-
const uint32_t chr) {
|
4401
|
+
// Final conv layer
|
4402
|
+
struct ggml_tensor * final_conv_weight; // [128]
|
4403
|
+
struct ggml_tensor * final_conv_bias; // [1]
|
4279
4404
|
|
4280
|
-
|
4281
|
-
|
4405
|
+
// ggml contexts
|
4406
|
+
std::vector<ggml_context *> ctxs;
|
4282
4407
|
|
4283
|
-
|
4408
|
+
// buffer for the model tensors
|
4409
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
4284
4410
|
|
4285
|
-
|
4286
|
-
|
4287
|
-
|
4288
|
-
|
4289
|
-
|
4290
|
-
|
4291
|
-
|
4411
|
+
// tensors
|
4412
|
+
int n_loaded;
|
4413
|
+
std::map<std::string, struct ggml_tensor *> tensors;
|
4414
|
+
};
|
4415
|
+
|
4416
|
+
struct whisper_vad_segment {
|
4417
|
+
int64_t start;
|
4418
|
+
int64_t end;
|
4419
|
+
};
|
4420
|
+
|
4421
|
+
struct whisper_vad_segments {
|
4422
|
+
std::vector<whisper_vad_segment> data;
|
4423
|
+
};
|
4424
|
+
|
4425
|
+
struct whisper_vad_context {
|
4426
|
+
int64_t t_vad_us = 0;
|
4427
|
+
|
4428
|
+
int n_window;
|
4429
|
+
int n_context;
|
4430
|
+
int n_threads;
|
4431
|
+
|
4432
|
+
std::vector<ggml_backend_t> backends;
|
4433
|
+
ggml_backend_buffer_t buffer = nullptr;
|
4434
|
+
whisper_context_params params;
|
4435
|
+
std::vector<uint8_t> ctx_buf;
|
4436
|
+
whisper_sched sched;
|
4437
|
+
|
4438
|
+
whisper_vad_model model;
|
4439
|
+
std::string path_model;
|
4440
|
+
struct ggml_tensor * h_state;
|
4441
|
+
struct ggml_tensor * c_state;
|
4442
|
+
std::vector<float> probs;
|
4443
|
+
};
|
4444
|
+
|
4445
|
+
struct whisper_vad_context_params whisper_vad_default_context_params(void) {
|
4446
|
+
whisper_vad_context_params result = {
|
4447
|
+
/*.n_thread = */ 4,
|
4448
|
+
/*.use_gpu = */ false,
|
4449
|
+
/*.gpu_device = */ 0,
|
4450
|
+
};
|
4451
|
+
return result;
|
4452
|
+
}
|
4453
|
+
|
4454
|
+
struct whisper_vad_params whisper_vad_default_params(void) {
|
4455
|
+
whisper_vad_params result = {
|
4456
|
+
/* threshold = */ 0.5f,
|
4457
|
+
/* min_speech_duration_ms = */ 250,
|
4458
|
+
/* min_silence_duration_ms = */ 100,
|
4459
|
+
/* max_speech_duration_s = */ FLT_MAX,
|
4460
|
+
/* speech_pad_ms = */ 30,
|
4461
|
+
/* samples_overlap = */ 0.1,
|
4462
|
+
};
|
4463
|
+
return result;
|
4464
|
+
}
|
4465
|
+
|
4466
|
+
// Time conversion utility functions for whisper VAD
|
4467
|
+
static int cs_to_samples(int64_t cs) {
|
4468
|
+
return (int)((cs / 100.0) * WHISPER_SAMPLE_RATE + 0.5);
|
4469
|
+
}
|
4470
|
+
|
4471
|
+
static int64_t samples_to_cs(int samples) {
|
4472
|
+
return (int64_t)((samples / (double)WHISPER_SAMPLE_RATE) * 100.0 + 0.5);
|
4473
|
+
}
|
4474
|
+
|
4475
|
+
static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
4476
|
+
bool op_supported = true;
|
4477
|
+
|
4478
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
4479
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
4480
|
+
// GPU and default CPU backend support all operators
|
4481
|
+
op_supported = true;
|
4482
|
+
} else {
|
4483
|
+
switch (op) {
|
4484
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
4485
|
+
case GGML_OP_MUL_MAT: {
|
4486
|
+
ggml_init_params params = {
|
4487
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
4488
|
+
/*.mem_buffer =*/ nullptr,
|
4489
|
+
/*.no_alloc =*/ true,
|
4490
|
+
};
|
4491
|
+
|
4492
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
4493
|
+
if (!ctx_ptr) {
|
4494
|
+
throw std::runtime_error("failed to create ggml context");
|
4495
|
+
}
|
4496
|
+
ggml_context * ctx = ctx_ptr.get();
|
4497
|
+
|
4498
|
+
ggml_tensor * op_tensor = nullptr;
|
4499
|
+
|
4500
|
+
int64_t n_ctx = hparams.lstm_hidden_size;
|
4501
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
4502
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
4503
|
+
|
4504
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
4505
|
+
GGML_ASSERT(w->buffer == nullptr);
|
4506
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
4507
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
4508
|
+
ggml_backend_buffer_free(w->buffer);
|
4509
|
+
w->buffer = nullptr;
|
4510
|
+
break;
|
4511
|
+
}
|
4512
|
+
default: {
|
4513
|
+
op_supported = false;
|
4514
|
+
break;
|
4515
|
+
}
|
4516
|
+
};
|
4517
|
+
}
|
4518
|
+
return op_supported;
|
4519
|
+
}
|
4520
|
+
|
4521
|
+
static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
4522
|
+
GGML_ASSERT(!buft_list.empty());
|
4523
|
+
for (const auto & p : buft_list) {
|
4524
|
+
ggml_backend_dev_t dev = p.first;
|
4525
|
+
ggml_backend_buffer_type_t buft = p.second;
|
4526
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
4527
|
+
return buft;
|
4528
|
+
}
|
4529
|
+
}
|
4530
|
+
|
4531
|
+
return nullptr;
|
4532
|
+
}
|
4533
|
+
|
4534
|
+
static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
|
4535
|
+
const whisper_vad_model & model, ggml_tensor * cur) {
|
4536
|
+
// Apply reflective padding to the input tensor
|
4537
|
+
ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
|
4538
|
+
|
4539
|
+
struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
|
4540
|
+
|
4541
|
+
// Calculate cutoff for real/imaginary parts
|
4542
|
+
int cutoff = model.stft_forward_basis->ne[2] / 2;
|
4543
|
+
|
4544
|
+
// Extract real part (first half of the STFT output).
|
4545
|
+
struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
|
4546
|
+
// Extract imaginary part (second half of the STFT output).
|
4547
|
+
struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
|
4548
|
+
|
4549
|
+
// Calculate magnitude: sqrt(real^2 + imag^2)
|
4550
|
+
struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
|
4551
|
+
struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
|
4552
|
+
struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
|
4553
|
+
struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
|
4554
|
+
return magnitude;
|
4555
|
+
}
|
4556
|
+
|
4557
|
+
static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
|
4558
|
+
const whisper_vad_model & model, ggml_tensor * cur) {
|
4559
|
+
// First Conv1D: expands to 128 channels.
|
4560
|
+
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
|
4561
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
|
4562
|
+
cur = ggml_relu(ctx0, cur);
|
4563
|
+
|
4564
|
+
// Second Conv1D: reduces to 64 channels.
|
4565
|
+
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
|
4566
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
|
4567
|
+
cur = ggml_relu(ctx0, cur);
|
4568
|
+
|
4569
|
+
// Third Conv1D: maintains 64 channels
|
4570
|
+
cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
|
4571
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
|
4572
|
+
cur = ggml_relu(ctx0, cur);
|
4573
|
+
|
4574
|
+
// Fourth Conv1D: expands to 128 channels
|
4575
|
+
cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
|
4576
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
|
4577
|
+
cur = ggml_relu(ctx0, cur);
|
4578
|
+
|
4579
|
+
return cur;
|
4580
|
+
}
|
4581
|
+
|
4582
|
+
static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
|
4583
|
+
const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
|
4584
|
+
const whisper_vad_model & model = vctx.model;
|
4585
|
+
const int hdim = model.hparams.lstm_hidden_size;
|
4586
|
+
|
4587
|
+
struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
|
4588
|
+
|
4589
|
+
// Create operations using the input-to-hidden weights.
|
4590
|
+
struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
|
4591
|
+
inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
|
4592
|
+
|
4593
|
+
// Create operations using the hidden-to-hidden weights.
|
4594
|
+
struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
|
4595
|
+
hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
|
4596
|
+
|
4597
|
+
// Create add operation to get preactivations for all gates.
|
4598
|
+
struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
|
4599
|
+
|
4600
|
+
const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
|
4601
|
+
|
4602
|
+
// Create sigmoid for input gate (using the first 128 bytes from the preactivations).
|
4603
|
+
struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
|
4604
|
+
|
4605
|
+
// Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
|
4606
|
+
struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
|
4607
|
+
|
4608
|
+
// Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
|
4609
|
+
struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
|
4610
|
+
|
4611
|
+
// Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
|
4612
|
+
struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
|
4613
|
+
|
4614
|
+
// Update cell state
|
4615
|
+
struct ggml_tensor * c_out = ggml_add(ctx0,
|
4616
|
+
ggml_mul(ctx0, f_t, vctx.c_state),
|
4617
|
+
ggml_mul(ctx0, i_t, g_t));
|
4618
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
|
4619
|
+
|
4620
|
+
// Update hidden state
|
4621
|
+
struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
|
4622
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
|
4623
|
+
|
4624
|
+
return out;
|
4625
|
+
}
|
4626
|
+
|
4627
|
+
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
|
4628
|
+
const auto & model = vctx.model;
|
4629
|
+
|
4630
|
+
struct ggml_init_params params = {
|
4631
|
+
/*.mem_size =*/ vctx.sched.meta.size(),
|
4632
|
+
/*.mem_buffer =*/ vctx.sched.meta.data(),
|
4633
|
+
/*.no_alloc =*/ true,
|
4634
|
+
};
|
4635
|
+
|
4636
|
+
struct ggml_context * ctx0 = ggml_init(params);
|
4637
|
+
|
4638
|
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
4639
|
+
|
4640
|
+
struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
|
4641
|
+
ggml_set_name(frame, "frame");
|
4642
|
+
ggml_set_input(frame);
|
4643
|
+
|
4644
|
+
struct ggml_tensor * cur = nullptr;
|
4645
|
+
{
|
4646
|
+
cur = whisper_vad_build_stft_layer(ctx0, model, frame);
|
4647
|
+
|
4648
|
+
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
|
4649
|
+
|
4650
|
+
// Extract the first element of the first dimension
|
4651
|
+
// (equivalent to pytorch's [:, :, 0])
|
4652
|
+
cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
|
4653
|
+
|
4654
|
+
cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
|
4655
|
+
cur = ggml_relu(ctx0, cur);
|
4656
|
+
cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
|
4657
|
+
cur = ggml_add(ctx0, cur, model.final_conv_bias);
|
4658
|
+
cur = ggml_sigmoid(ctx0, cur);
|
4659
|
+
ggml_set_name(cur, "prob");
|
4660
|
+
ggml_set_output(cur);
|
4661
|
+
}
|
4662
|
+
|
4663
|
+
ggml_build_forward_expand(gf, cur);
|
4664
|
+
|
4665
|
+
ggml_free(ctx0);
|
4666
|
+
|
4667
|
+
return gf;
|
4668
|
+
}
|
4669
|
+
|
4670
|
+
static bool whisper_vad_init_context(whisper_vad_context * vctx) {
|
4671
|
+
|
4672
|
+
auto whisper_context_params = whisper_context_default_params();
|
4673
|
+
// TODO: GPU VAD is forced disabled until the performance is improved
|
4674
|
+
//whisper_context_params.use_gpu = vctx->params.use_gpu;
|
4675
|
+
whisper_context_params.use_gpu = false;
|
4676
|
+
whisper_context_params.gpu_device = vctx->params.gpu_device;
|
4677
|
+
|
4678
|
+
vctx->backends = whisper_backend_init(whisper_context_params);
|
4679
|
+
if (vctx->backends.empty()) {
|
4680
|
+
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
4681
|
+
return false;
|
4682
|
+
}
|
4683
|
+
|
4684
|
+
const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
|
4685
|
+
|
4686
|
+
vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
|
4687
|
+
|
4688
|
+
struct ggml_init_params params = {
|
4689
|
+
/*.mem_size =*/ vctx->ctx_buf.size(),
|
4690
|
+
/*.mem_buffer =*/ vctx->ctx_buf.data(),
|
4691
|
+
/*.no_alloc =*/ true,
|
4692
|
+
};
|
4693
|
+
|
4694
|
+
ggml_context * ctx = ggml_init(params);
|
4695
|
+
if (!ctx) {
|
4696
|
+
WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
|
4697
|
+
return false;
|
4698
|
+
}
|
4699
|
+
|
4700
|
+
// LSTM Hidden state
|
4701
|
+
vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
|
4702
|
+
ggml_set_name(vctx->h_state, "h_state");
|
4703
|
+
|
4704
|
+
// LSTM Cell state
|
4705
|
+
vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
|
4706
|
+
ggml_set_name(vctx->c_state, "c_state");
|
4707
|
+
|
4708
|
+
vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
|
4709
|
+
if (!vctx->buffer) {
|
4710
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
|
4711
|
+
return false;
|
4712
|
+
}
|
4713
|
+
|
4714
|
+
{
|
4715
|
+
bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
|
4716
|
+
[&]() {
|
4717
|
+
return whisper_vad_build_graph(*vctx);
|
4718
|
+
});
|
4719
|
+
|
4720
|
+
if (!ok) {
|
4721
|
+
WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
|
4722
|
+
return false;
|
4723
|
+
}
|
4724
|
+
|
4725
|
+
WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
|
4726
|
+
}
|
4727
|
+
|
4728
|
+
return true;
|
4729
|
+
}
|
4730
|
+
|
4731
|
+
struct whisper_vad_context * whisper_vad_init_from_file_with_params(
|
4732
|
+
const char * path_model,
|
4733
|
+
struct whisper_vad_context_params params) {
|
4734
|
+
WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
|
4735
|
+
#ifdef _MSC_VER
|
4736
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
4737
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
4738
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
4739
|
+
#else
|
4740
|
+
auto fin = std::ifstream(path_model, std::ios::binary);
|
4741
|
+
#endif
|
4742
|
+
if (!fin) {
|
4743
|
+
WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
|
4744
|
+
return nullptr;
|
4745
|
+
}
|
4746
|
+
|
4747
|
+
whisper_model_loader loader = {};
|
4748
|
+
loader.context = &fin;
|
4749
|
+
|
4750
|
+
loader.read = [](void * ctx, void * output, size_t read_size) {
|
4751
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4752
|
+
fin->read((char *)output, read_size);
|
4753
|
+
return read_size;
|
4754
|
+
};
|
4755
|
+
|
4756
|
+
loader.eof = [](void * ctx) {
|
4757
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4758
|
+
return fin->eof();
|
4759
|
+
};
|
4760
|
+
|
4761
|
+
loader.close = [](void * ctx) {
|
4762
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4763
|
+
fin->close();
|
4764
|
+
};
|
4765
|
+
|
4766
|
+
auto ctx = whisper_vad_init_with_params(&loader, params);
|
4767
|
+
if (!ctx) {
|
4768
|
+
whisper_vad_free(ctx);
|
4769
|
+
return nullptr;
|
4770
|
+
}
|
4771
|
+
ctx->path_model = path_model;
|
4772
|
+
return ctx;
|
4773
|
+
}
|
4774
|
+
|
4775
|
+
struct whisper_vad_context * whisper_vad_init_with_params(
|
4776
|
+
struct whisper_model_loader * loader,
|
4777
|
+
struct whisper_vad_context_params params) {
|
4778
|
+
// Read the VAD model
|
4779
|
+
{
|
4780
|
+
uint32_t magic;
|
4781
|
+
read_safe(loader, magic);
|
4782
|
+
if (magic != GGML_FILE_MAGIC) {
|
4783
|
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
4784
|
+
return nullptr;
|
4785
|
+
}
|
4786
|
+
}
|
4787
|
+
|
4788
|
+
whisper_vad_context * vctx = new whisper_vad_context;
|
4789
|
+
vctx->n_threads = params.n_threads;
|
4790
|
+
vctx->params.use_gpu = params.use_gpu;
|
4791
|
+
vctx->params.gpu_device = params.gpu_device;
|
4792
|
+
|
4793
|
+
auto & model = vctx->model;
|
4794
|
+
auto & hparams = model.hparams;
|
4795
|
+
|
4796
|
+
// load model context params.
|
4797
|
+
{
|
4798
|
+
int32_t str_len;
|
4799
|
+
read_safe(loader, str_len);
|
4800
|
+
std::vector<char> buffer(str_len + 1, 0);
|
4801
|
+
loader->read(loader->context, buffer.data(), str_len);
|
4802
|
+
std::string model_type(buffer.data(), str_len);
|
4803
|
+
model.type = model_type;
|
4804
|
+
WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
|
4805
|
+
|
4806
|
+
int32_t major, minor, patch;
|
4807
|
+
read_safe(loader, major);
|
4808
|
+
read_safe(loader, minor);
|
4809
|
+
read_safe(loader, patch);
|
4810
|
+
std::string version_str = std::to_string(major) + "." +
|
4811
|
+
std::to_string(minor) + "." +
|
4812
|
+
std::to_string(patch);
|
4813
|
+
model.version = version_str;
|
4814
|
+
WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
|
4815
|
+
|
4816
|
+
read_safe(loader, vctx->n_window);
|
4817
|
+
read_safe(loader, vctx->n_context);
|
4818
|
+
}
|
4819
|
+
|
4820
|
+
// load model hyper params (hparams).
|
4821
|
+
{
|
4822
|
+
read_safe(loader, hparams.n_encoder_layers);
|
4823
|
+
|
4824
|
+
hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
|
4825
|
+
hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
|
4826
|
+
hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
|
4827
|
+
|
4828
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4829
|
+
read_safe(loader, hparams.encoder_in_channels[i]);
|
4830
|
+
read_safe(loader, hparams.encoder_out_channels[i]);
|
4831
|
+
read_safe(loader, hparams.kernel_sizes[i]);
|
4832
|
+
}
|
4833
|
+
|
4834
|
+
read_safe(loader, hparams.lstm_input_size);
|
4835
|
+
read_safe(loader, hparams.lstm_hidden_size);
|
4836
|
+
read_safe(loader, hparams.final_conv_in);
|
4837
|
+
read_safe(loader, hparams.final_conv_out);
|
4838
|
+
|
4839
|
+
WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
|
4840
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4841
|
+
WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
|
4842
|
+
}
|
4843
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4844
|
+
WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
|
4845
|
+
}
|
4846
|
+
WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
|
4847
|
+
WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
|
4848
|
+
WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
|
4849
|
+
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
|
4850
|
+
}
|
4851
|
+
|
4852
|
+
// 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
|
4853
|
+
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
|
4854
|
+
|
4855
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
4856
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
4857
|
+
auto it = ctx_map.find(buft);
|
4858
|
+
if (it == ctx_map.end()) {
|
4859
|
+
ggml_init_params params = {
|
4860
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
4861
|
+
/*.mem_buffer =*/ nullptr,
|
4862
|
+
/*.no_alloc =*/ true,
|
4863
|
+
};
|
4864
|
+
|
4865
|
+
ggml_context * ctx = ggml_init(params);
|
4866
|
+
if (!ctx) {
|
4867
|
+
throw std::runtime_error("failed to create ggml context");
|
4868
|
+
}
|
4869
|
+
|
4870
|
+
ctx_map[buft] = ctx;
|
4871
|
+
model.ctxs.emplace_back(ctx);
|
4872
|
+
|
4873
|
+
return ctx;
|
4874
|
+
}
|
4875
|
+
|
4876
|
+
return it->second;
|
4877
|
+
};
|
4878
|
+
|
4879
|
+
whisper_context_params wparams = whisper_context_default_params();
|
4880
|
+
wparams.use_gpu = params.use_gpu;
|
4881
|
+
wparams.gpu_device = params.gpu_device;
|
4882
|
+
buft_list_t buft_list = make_buft_list(wparams);
|
4883
|
+
|
4884
|
+
auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
|
4885
|
+
ggml_op op = VAD_TENSOR_OPS.at(type);
|
4886
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
4887
|
+
if (!buft) {
|
4888
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
|
4889
|
+
}
|
4890
|
+
ggml_context * ctx = get_ctx(buft);
|
4891
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
4892
|
+
model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
|
4893
|
+
|
4894
|
+
return tensor;
|
4895
|
+
};
|
4896
|
+
|
4897
|
+
// create tensors
|
4898
|
+
{
|
4899
|
+
ggml_init_params params = {
|
4900
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
4901
|
+
/*.mem_buffer =*/ nullptr,
|
4902
|
+
/*.no_alloc =*/ true,
|
4903
|
+
};
|
4904
|
+
|
4905
|
+
ggml_context * ctx = ggml_init(params);
|
4906
|
+
const auto & hparams = model.hparams;
|
4907
|
+
|
4908
|
+
// SFTF precomputed basis matrix
|
4909
|
+
model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
|
4910
|
+
ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
|
4911
|
+
|
4912
|
+
model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
|
4913
|
+
ggml_new_tensor_3d(
|
4914
|
+
ctx,
|
4915
|
+
GGML_TYPE_F16,
|
4916
|
+
hparams.kernel_sizes[0],
|
4917
|
+
hparams.encoder_in_channels[0],
|
4918
|
+
hparams.encoder_out_channels[0]
|
4919
|
+
));
|
4920
|
+
model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
|
4921
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
|
4922
|
+
|
4923
|
+
model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
|
4924
|
+
ggml_new_tensor_3d(
|
4925
|
+
ctx,
|
4926
|
+
GGML_TYPE_F16,
|
4927
|
+
hparams.kernel_sizes[1],
|
4928
|
+
hparams.encoder_in_channels[1],
|
4929
|
+
hparams.encoder_out_channels[1]
|
4930
|
+
));
|
4931
|
+
model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
|
4932
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
|
4933
|
+
|
4934
|
+
model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
|
4935
|
+
ggml_new_tensor_3d(
|
4936
|
+
ctx,
|
4937
|
+
GGML_TYPE_F16,
|
4938
|
+
hparams.kernel_sizes[2],
|
4939
|
+
hparams.encoder_in_channels[2],
|
4940
|
+
hparams.encoder_out_channels[2]
|
4941
|
+
));
|
4942
|
+
model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
|
4943
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
|
4944
|
+
|
4945
|
+
model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
|
4946
|
+
ggml_new_tensor_3d(
|
4947
|
+
ctx,
|
4948
|
+
GGML_TYPE_F16,
|
4949
|
+
hparams.kernel_sizes[3],
|
4950
|
+
hparams.encoder_in_channels[3],
|
4951
|
+
hparams.encoder_out_channels[3]
|
4952
|
+
));
|
4953
|
+
model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
|
4954
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
|
4955
|
+
|
4956
|
+
// Hidden State dimension (input gate, forget gate, cell gate, output gate)
|
4957
|
+
const int hstate_dim = hparams.lstm_hidden_size * 4;
|
4958
|
+
|
4959
|
+
// LSTM weights - input to hidden
|
4960
|
+
model.lstm_ih_weight = create_tensor(
|
4961
|
+
VAD_TENSOR_LSTM_WEIGHT_IH,
|
4962
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
4963
|
+
);
|
4964
|
+
model.lstm_ih_bias = create_tensor(
|
4965
|
+
VAD_TENSOR_LSTM_BIAS_IH,
|
4966
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
|
4967
|
+
);
|
4968
|
+
|
4969
|
+
// LSTM weights - hidden to hidden
|
4970
|
+
model.lstm_hh_weight = create_tensor(
|
4971
|
+
VAD_TENSOR_LSTM_WEIGHT_HH,
|
4972
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
4973
|
+
);
|
4974
|
+
model.lstm_hh_bias = create_tensor(
|
4975
|
+
VAD_TENSOR_LSTM_BIAS_HH,
|
4976
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
|
4977
|
+
);
|
4978
|
+
|
4979
|
+
// Final conv layer weight
|
4980
|
+
model.final_conv_weight = create_tensor(
|
4981
|
+
VAD_TENSOR_FINAL_CONV_WEIGHT,
|
4982
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
|
4983
|
+
);
|
4984
|
+
model.final_conv_bias = create_tensor(
|
4985
|
+
VAD_TENSOR_FINAL_CONV_BIAS,
|
4986
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
|
4987
|
+
);
|
4988
|
+
|
4989
|
+
ggml_free(ctx);
|
4990
|
+
}
|
4991
|
+
|
4992
|
+
// allocate tensors in the backend buffers
|
4993
|
+
for (auto & p : ctx_map) {
|
4994
|
+
ggml_backend_buffer_type_t buft = p.first;
|
4995
|
+
ggml_context * ctx = p.second;
|
4996
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
4997
|
+
if (buf) {
|
4998
|
+
model.buffers.emplace_back(buf);
|
4999
|
+
|
5000
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
5001
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
5002
|
+
}
|
5003
|
+
}
|
5004
|
+
|
5005
|
+
// load weights
|
5006
|
+
{
|
5007
|
+
size_t total_size = 0;
|
5008
|
+
model.n_loaded = 0;
|
5009
|
+
std::vector<char> read_buf;
|
5010
|
+
|
5011
|
+
while (true) {
|
5012
|
+
int32_t n_dims;
|
5013
|
+
int32_t length;
|
5014
|
+
int32_t ttype;
|
5015
|
+
|
5016
|
+
read_safe(loader, n_dims);
|
5017
|
+
read_safe(loader, length);
|
5018
|
+
read_safe(loader, ttype);
|
5019
|
+
|
5020
|
+
if (loader->eof(loader->context)) {
|
5021
|
+
break;
|
5022
|
+
}
|
5023
|
+
|
5024
|
+
int32_t nelements = 1;
|
5025
|
+
int32_t ne[4] = { 1, 1, 1, 1 };
|
5026
|
+
for (int i = 0; i < n_dims; ++i) {
|
5027
|
+
read_safe(loader, ne[i]);
|
5028
|
+
nelements *= ne[i];
|
5029
|
+
}
|
5030
|
+
|
5031
|
+
std::string name;
|
5032
|
+
std::vector<char> tmp(length);
|
5033
|
+
loader->read(loader->context, &tmp[0], tmp.size());
|
5034
|
+
name.assign(&tmp[0], tmp.size());
|
5035
|
+
|
5036
|
+
if (model.tensors.find(name) == model.tensors.end()) {
|
5037
|
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
5038
|
+
return nullptr;
|
5039
|
+
}
|
5040
|
+
|
5041
|
+
auto tensor = model.tensors[name.data()];
|
5042
|
+
|
5043
|
+
if (ggml_nelements(tensor) != nelements) {
|
5044
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
5045
|
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
5046
|
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
5047
|
+
return nullptr;
|
5048
|
+
}
|
5049
|
+
|
5050
|
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
5051
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
5052
|
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
5053
|
+
return nullptr;
|
5054
|
+
}
|
5055
|
+
|
5056
|
+
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
5057
|
+
|
5058
|
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
5059
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
5060
|
+
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
5061
|
+
return nullptr;
|
5062
|
+
}
|
5063
|
+
|
5064
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
5065
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
5066
|
+
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
5067
|
+
BYTESWAP_TENSOR(tensor);
|
5068
|
+
} else {
|
5069
|
+
// read into a temporary buffer first, then copy to device memory
|
5070
|
+
read_buf.resize(ggml_nbytes(tensor));
|
5071
|
+
|
5072
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
5073
|
+
|
5074
|
+
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
5075
|
+
}
|
5076
|
+
|
5077
|
+
total_size += ggml_nbytes(tensor);
|
5078
|
+
model.n_loaded++;
|
5079
|
+
}
|
5080
|
+
|
5081
|
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
5082
|
+
|
5083
|
+
if (model.n_loaded == 0) {
|
5084
|
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
5085
|
+
} else if (model.n_loaded != (int) model.tensors.size()) {
|
5086
|
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
5087
|
+
return nullptr;
|
5088
|
+
}
|
5089
|
+
|
5090
|
+
}
|
5091
|
+
|
5092
|
+
if (!whisper_vad_init_context(vctx)) {
|
5093
|
+
whisper_vad_free(vctx);
|
5094
|
+
return nullptr;
|
5095
|
+
}
|
5096
|
+
|
5097
|
+
return vctx;
|
5098
|
+
}
|
5099
|
+
|
5100
|
+
bool whisper_vad_detect_speech(
|
5101
|
+
struct whisper_vad_context * vctx,
|
5102
|
+
const float * samples,
|
5103
|
+
int n_samples) {
|
5104
|
+
int n_chunks = n_samples / vctx->n_window;
|
5105
|
+
if (n_samples % vctx->n_window != 0) {
|
5106
|
+
n_chunks += 1; // Add one more chunk for remaining samples.
|
5107
|
+
}
|
5108
|
+
|
5109
|
+
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
|
5110
|
+
WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
|
5111
|
+
|
5112
|
+
// Reset LSTM hidden/cell states
|
5113
|
+
ggml_backend_buffer_clear(vctx->buffer, 0);
|
5114
|
+
|
5115
|
+
vctx->probs.resize(n_chunks);
|
5116
|
+
WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
|
5117
|
+
|
5118
|
+
std::vector<float> window(vctx->n_window, 0.0f);
|
5119
|
+
|
5120
|
+
auto & sched = vctx->sched.sched;
|
5121
|
+
|
5122
|
+
ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
|
5123
|
+
|
5124
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
5125
|
+
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
5126
|
+
return false;
|
5127
|
+
}
|
5128
|
+
|
5129
|
+
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
|
5130
|
+
struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
|
5131
|
+
|
5132
|
+
// we are going to reuse the graph multiple times for each chunk
|
5133
|
+
const int64_t t_start_vad_us = ggml_time_us();
|
5134
|
+
|
5135
|
+
for (int i = 0; i < n_chunks; i++) {
|
5136
|
+
const int idx_start = i * vctx->n_window;
|
5137
|
+
const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
|
5138
|
+
|
5139
|
+
const int chunk_len = idx_end - idx_start;
|
5140
|
+
|
5141
|
+
if (chunk_len < vctx->n_window) {
|
5142
|
+
WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
|
5143
|
+
std::vector<float> partial_chunk(vctx->n_window, 0.0f);
|
5144
|
+
std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
|
5145
|
+
|
5146
|
+
// Copy the zero-padded chunk to the window.
|
5147
|
+
const int samples_to_copy_max = vctx->n_window;
|
5148
|
+
const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
|
5149
|
+
std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
|
5150
|
+
if (samples_to_copy_cur < samples_to_copy_max) {
|
5151
|
+
std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
|
5152
|
+
}
|
5153
|
+
} else {
|
5154
|
+
// Copy current frame samples to the window.
|
5155
|
+
const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
|
5156
|
+
std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
|
5157
|
+
}
|
5158
|
+
|
5159
|
+
// Set the frame tensor data with the samples.
|
5160
|
+
ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
|
5161
|
+
|
5162
|
+
// do not reset the scheduler - we will reuse the graph in the next chunk
|
5163
|
+
if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
|
5164
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
|
5165
|
+
break;
|
5166
|
+
}
|
5167
|
+
|
5168
|
+
// Get the probability for this chunk.
|
5169
|
+
ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
|
5170
|
+
|
5171
|
+
//WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
|
5172
|
+
}
|
5173
|
+
|
5174
|
+
vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
|
5175
|
+
WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
|
5176
|
+
|
5177
|
+
ggml_backend_sched_reset(sched);
|
5178
|
+
|
5179
|
+
return true;
|
5180
|
+
}
|
5181
|
+
|
5182
|
+
int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
|
5183
|
+
return segments->data.size();
|
5184
|
+
}
|
5185
|
+
|
5186
|
+
float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
|
5187
|
+
return segments->data[i_segment].start;
|
5188
|
+
}
|
5189
|
+
|
5190
|
+
float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
|
5191
|
+
return segments->data[i_segment].end;
|
5192
|
+
}
|
5193
|
+
|
5194
|
+
int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
|
5195
|
+
return vctx->probs.size();
|
5196
|
+
}
|
5197
|
+
|
5198
|
+
float * whisper_vad_probs(struct whisper_vad_context * vctx) {
|
5199
|
+
return vctx->probs.data();
|
5200
|
+
}
|
5201
|
+
|
5202
|
+
struct whisper_vad_segments * whisper_vad_segments_from_probs(
|
5203
|
+
struct whisper_vad_context * vctx,
|
5204
|
+
whisper_vad_params params) {
|
5205
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
|
5206
|
+
|
5207
|
+
int n_probs = whisper_vad_n_probs(vctx);
|
5208
|
+
float * probs = whisper_vad_probs(vctx);
|
5209
|
+
float threshold = params.threshold;
|
5210
|
+
int min_speech_duration_ms = params.min_speech_duration_ms;
|
5211
|
+
int min_silence_duration_ms = params.min_silence_duration_ms;
|
5212
|
+
float max_speech_duration_s = params.max_speech_duration_s;
|
5213
|
+
int speech_pad_ms = params.speech_pad_ms;
|
5214
|
+
int n_window = vctx->n_window;
|
5215
|
+
int sample_rate = WHISPER_SAMPLE_RATE;
|
5216
|
+
int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
5217
|
+
int audio_length_samples = n_probs * n_window;
|
5218
|
+
|
5219
|
+
// Min number of samples to be considered valid speech.
|
5220
|
+
int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
5221
|
+
int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
5222
|
+
|
5223
|
+
// Max number of samples that a speech segment can contain before it is
|
5224
|
+
// split into multiple segments.
|
5225
|
+
int max_speech_samples;
|
5226
|
+
if (max_speech_duration_s > 100000.0f) {
|
5227
|
+
max_speech_samples = INT_MAX / 2;
|
5228
|
+
} else {
|
5229
|
+
int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
|
5230
|
+
max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
|
5231
|
+
if (max_speech_samples < 0) {
|
5232
|
+
max_speech_samples = INT_MAX / 2;
|
5233
|
+
}
|
5234
|
+
}
|
5235
|
+
// Detect silence period that exceeds this value, then that location (sample)
|
5236
|
+
// is marked as a potential place where the segment could be split if
|
5237
|
+
// max_speech_samples is reached. The value 98 was taken from the original
|
5238
|
+
// silaro-vad python implementation:
|
5239
|
+
//https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
|
5240
|
+
int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
|
5241
|
+
|
5242
|
+
// Calculate lower threshold for detecting end of speech segments.
|
5243
|
+
float neg_threshold = threshold - 0.15f;
|
5244
|
+
if (neg_threshold < 0.01f) {
|
5245
|
+
neg_threshold = 0.01f;
|
5246
|
+
}
|
5247
|
+
|
5248
|
+
struct speech_segment_t {
|
5249
|
+
int start;
|
5250
|
+
int end;
|
5251
|
+
};
|
5252
|
+
|
5253
|
+
std::vector<speech_segment_t> speeches;
|
5254
|
+
speeches.reserve(256);
|
5255
|
+
|
5256
|
+
bool is_speech_segment = false;
|
5257
|
+
int temp_end = 0;
|
5258
|
+
int prev_end = 0;
|
5259
|
+
int next_start = 0;
|
5260
|
+
int curr_speech_start = 0;
|
5261
|
+
bool has_curr_speech = false;
|
5262
|
+
|
5263
|
+
for (int i = 0; i < n_probs; i++) {
|
5264
|
+
float curr_prob = probs[i];
|
5265
|
+
int curr_sample = n_window * i;
|
5266
|
+
|
5267
|
+
// Reset temp_end when we get back to speech
|
5268
|
+
if ((curr_prob >= threshold) && temp_end) {
|
5269
|
+
temp_end = 0;
|
5270
|
+
if (next_start < prev_end) {
|
5271
|
+
next_start = curr_sample;
|
5272
|
+
}
|
5273
|
+
}
|
5274
|
+
|
5275
|
+
// Start a new speech segment when probability exceeds threshold and not already in speech
|
5276
|
+
if ((curr_prob >= threshold) && !is_speech_segment) {
|
5277
|
+
is_speech_segment = true;
|
5278
|
+
curr_speech_start = curr_sample;
|
5279
|
+
has_curr_speech = true;
|
5280
|
+
continue;
|
5281
|
+
}
|
5282
|
+
|
5283
|
+
// Handle maximum speech duration
|
5284
|
+
if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
|
5285
|
+
if (prev_end) {
|
5286
|
+
speeches.push_back({ curr_speech_start, prev_end });
|
5287
|
+
has_curr_speech = true;
|
5288
|
+
|
5289
|
+
if (next_start < prev_end) { // Previously reached silence and is still not speech
|
5290
|
+
is_speech_segment = false;
|
5291
|
+
has_curr_speech = false;
|
5292
|
+
} else {
|
5293
|
+
curr_speech_start = next_start;
|
5294
|
+
}
|
5295
|
+
prev_end = next_start = temp_end = 0;
|
5296
|
+
} else {
|
5297
|
+
speeches.push_back({ curr_speech_start, curr_sample });
|
5298
|
+
|
5299
|
+
prev_end = next_start = temp_end = 0;
|
5300
|
+
is_speech_segment = false;
|
5301
|
+
has_curr_speech = false;
|
5302
|
+
continue;
|
5303
|
+
}
|
5304
|
+
}
|
5305
|
+
|
5306
|
+
// Handle silence after speech
|
5307
|
+
if ((curr_prob < neg_threshold) && is_speech_segment) {
|
5308
|
+
if (!temp_end) {
|
5309
|
+
temp_end = curr_sample;
|
5310
|
+
}
|
5311
|
+
|
5312
|
+
// Track potential segment ends for max_speech handling
|
5313
|
+
if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
|
5314
|
+
prev_end = temp_end;
|
5315
|
+
}
|
5316
|
+
|
5317
|
+
// Check if silence is long enough to end the segment
|
5318
|
+
if ((curr_sample - temp_end) < min_silence_samples) {
|
5319
|
+
continue;
|
5320
|
+
} else {
|
5321
|
+
// End the segment if it's long enough
|
5322
|
+
if ((temp_end - curr_speech_start) > min_speech_samples) {
|
5323
|
+
speeches.push_back({ curr_speech_start, temp_end });
|
5324
|
+
}
|
5325
|
+
|
5326
|
+
prev_end = next_start = temp_end = 0;
|
5327
|
+
is_speech_segment = false;
|
5328
|
+
has_curr_speech = false;
|
5329
|
+
continue;
|
5330
|
+
}
|
5331
|
+
}
|
5332
|
+
}
|
5333
|
+
|
5334
|
+
// Handle the case if we're still in a speech segment at the end
|
5335
|
+
if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
|
5336
|
+
speeches.push_back({ curr_speech_start, audio_length_samples });
|
5337
|
+
}
|
5338
|
+
|
5339
|
+
// Merge adjacent segments with small gaps in between (post-processing)
|
5340
|
+
if (speeches.size() > 1) {
|
5341
|
+
int merged_count = 0;
|
5342
|
+
for (int i = 0; i < (int) speeches.size() - 1; i++) {
|
5343
|
+
// Define maximum gap allowed for merging (e.g., 200ms converted to samples)
|
5344
|
+
int max_merge_gap_samples = sample_rate * 200 / 1000;
|
5345
|
+
|
5346
|
+
// If the gap between this segment and the next is small enough
|
5347
|
+
if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
|
5348
|
+
// Merge by extending current segment to the end of next segment
|
5349
|
+
speeches[i].end = speeches[i+1].end;
|
5350
|
+
speeches.erase(speeches.begin() + i + 1);
|
5351
|
+
|
5352
|
+
i--;
|
5353
|
+
merged_count++;
|
5354
|
+
}
|
5355
|
+
}
|
5356
|
+
WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
|
5357
|
+
__func__, merged_count, (int) speeches.size());
|
5358
|
+
}
|
5359
|
+
|
5360
|
+
// Double-check for minimum speech duration
|
5361
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
5362
|
+
if (speeches[i].end - speeches[i].start < min_speech_samples) {
|
5363
|
+
WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
|
5364
|
+
__func__, i, speeches[i].end - speeches[i].start);
|
5365
|
+
|
5366
|
+
speeches.erase(speeches.begin() + i);
|
5367
|
+
i--;
|
5368
|
+
}
|
5369
|
+
}
|
5370
|
+
|
5371
|
+
WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
|
5372
|
+
|
5373
|
+
// Allocate final segments
|
5374
|
+
std::vector<whisper_vad_segment> segments;
|
5375
|
+
if (speeches.size() > 0) {
|
5376
|
+
try {
|
5377
|
+
segments.resize(speeches.size());
|
5378
|
+
} catch (const std::bad_alloc &) {
|
5379
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
|
5380
|
+
return nullptr;
|
5381
|
+
}
|
5382
|
+
}
|
5383
|
+
|
5384
|
+
// Apply padding to segments and copy to final segments
|
5385
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
5386
|
+
// Apply padding to the start of the first segment
|
5387
|
+
if (i == 0) {
|
5388
|
+
speeches[i].start =
|
5389
|
+
(speeches[i].start > speech_pad_samples) ?
|
5390
|
+
(speeches[i].start - speech_pad_samples) : 0;
|
5391
|
+
}
|
5392
|
+
|
5393
|
+
// Handle spacing between segments
|
5394
|
+
if (i < (int) speeches.size() - 1) {
|
5395
|
+
int silence_duration = speeches[i+1].start - speeches[i].end;
|
5396
|
+
|
5397
|
+
if (silence_duration < 2 * speech_pad_samples) {
|
5398
|
+
// If segments are close, split the difference
|
5399
|
+
speeches[i].end += silence_duration / 2;
|
5400
|
+
speeches[i+1].start =
|
5401
|
+
(speeches[i+1].start > silence_duration / 2) ?
|
5402
|
+
(speeches[i+1].start - silence_duration / 2) : 0;
|
5403
|
+
} else {
|
5404
|
+
// Otherwise, apply full padding to both
|
5405
|
+
speeches[i].end =
|
5406
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
5407
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
5408
|
+
speeches[i+1].start =
|
5409
|
+
(speeches[i+1].start > speech_pad_samples) ?
|
5410
|
+
(speeches[i+1].start - speech_pad_samples) : 0;
|
5411
|
+
}
|
5412
|
+
} else {
|
5413
|
+
// Apply padding to the end of the last segment
|
5414
|
+
speeches[i].end =
|
5415
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
5416
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
5417
|
+
}
|
5418
|
+
|
5419
|
+
// Convert from samples to centiseconds
|
5420
|
+
segments[i].start = samples_to_cs(speeches[i].start);
|
5421
|
+
segments[i].end = samples_to_cs(speeches[i].end);
|
5422
|
+
|
5423
|
+
WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
|
5424
|
+
__func__, i, segments[i].start/100.0, segments[i].end/100.0, (segments[i].end - segments[i].start)/100.0);
|
5425
|
+
}
|
5426
|
+
|
5427
|
+
whisper_vad_segments * vad_segments = new whisper_vad_segments;
|
5428
|
+
if (vad_segments == NULL) {
|
5429
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
|
5430
|
+
return nullptr;
|
5431
|
+
}
|
5432
|
+
|
5433
|
+
vad_segments->data = std::move(segments);
|
5434
|
+
|
5435
|
+
return vad_segments;
|
5436
|
+
}
|
5437
|
+
|
5438
|
+
struct whisper_vad_segments * whisper_vad_segments_from_samples(
|
5439
|
+
whisper_vad_context * vctx,
|
5440
|
+
whisper_vad_params params,
|
5441
|
+
const float * samples,
|
5442
|
+
int n_samples) {
|
5443
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
|
5444
|
+
if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
|
5445
|
+
WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
|
5446
|
+
return nullptr;
|
5447
|
+
}
|
5448
|
+
return whisper_vad_segments_from_probs(vctx, params);
|
5449
|
+
}
|
5450
|
+
|
5451
|
+
void whisper_vad_free(whisper_vad_context * ctx) {
|
5452
|
+
if (ctx) {
|
5453
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
5454
|
+
ggml_free(context);
|
5455
|
+
}
|
5456
|
+
|
5457
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
5458
|
+
ggml_backend_buffer_free(buf);
|
5459
|
+
}
|
5460
|
+
|
5461
|
+
ggml_backend_sched_free(ctx->sched.sched);
|
5462
|
+
|
5463
|
+
for (auto & backend : ctx->backends) {
|
5464
|
+
ggml_backend_free(backend);
|
5465
|
+
}
|
5466
|
+
|
5467
|
+
|
5468
|
+
delete ctx;
|
5469
|
+
}
|
5470
|
+
}
|
5471
|
+
|
5472
|
+
void whisper_vad_free_segments(whisper_vad_segments * segments) {
|
5473
|
+
if (segments) {
|
5474
|
+
delete segments;
|
5475
|
+
}
|
5476
|
+
}
|
5477
|
+
|
5478
|
+
//////////////////////////////////
|
5479
|
+
// Grammar - ported from llama.cpp
|
5480
|
+
//////////////////////////////////
|
5481
|
+
|
5482
|
+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
5483
|
+
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
5484
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
5485
|
+
const char * src,
|
5486
|
+
whisper_partial_utf8 partial_start) {
|
5487
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
5488
|
+
const char * pos = src;
|
5489
|
+
std::vector<uint32_t> code_points;
|
5490
|
+
uint32_t value = partial_start.value;
|
5491
|
+
int n_remain = partial_start.n_remain;
|
5492
|
+
|
5493
|
+
// continue previous decode, if applicable
|
5494
|
+
while (*pos != 0 && n_remain > 0) {
|
5495
|
+
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
5496
|
+
if ((next_byte >> 6) != 2) {
|
5497
|
+
// invalid sequence, abort
|
5498
|
+
code_points.push_back(0);
|
5499
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
5500
|
+
}
|
5501
|
+
value = (value << 6) + (next_byte & 0x3F);
|
5502
|
+
++pos;
|
5503
|
+
--n_remain;
|
5504
|
+
}
|
5505
|
+
|
5506
|
+
if (partial_start.n_remain > 0 && n_remain == 0) {
|
5507
|
+
code_points.push_back(value);
|
5508
|
+
}
|
5509
|
+
|
5510
|
+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
5511
|
+
while (*pos != 0) {
|
5512
|
+
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
5513
|
+
uint8_t highbits = first_byte >> 4;
|
5514
|
+
n_remain = lookup[highbits] - 1;
|
5515
|
+
|
5516
|
+
if (n_remain < 0) {
|
5517
|
+
// invalid sequence, abort
|
5518
|
+
code_points.clear();
|
5519
|
+
code_points.push_back(0);
|
5520
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
|
5521
|
+
}
|
5522
|
+
|
5523
|
+
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
5524
|
+
value = first_byte & mask;
|
5525
|
+
++pos;
|
5526
|
+
while (*pos != 0 && n_remain > 0) {
|
5527
|
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
5528
|
+
++pos;
|
5529
|
+
--n_remain;
|
5530
|
+
}
|
5531
|
+
if (n_remain == 0) {
|
5532
|
+
code_points.push_back(value);
|
5533
|
+
}
|
5534
|
+
}
|
5535
|
+
code_points.push_back(0);
|
5536
|
+
|
5537
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
|
5538
|
+
}
|
5539
|
+
|
5540
|
+
// returns true iff pos points to the end of one of the definitions of a rule
|
5541
|
+
static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
|
5542
|
+
switch (pos->type) {
|
5543
|
+
case WHISPER_GRETYPE_END: return true; // NOLINT
|
5544
|
+
case WHISPER_GRETYPE_ALT: return true; // NOLINT
|
5545
|
+
default: return false;
|
5546
|
+
}
|
5547
|
+
}
|
5548
|
+
|
5549
|
+
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
5550
|
+
// asserts that pos is pointing to a char range element
|
5551
|
+
static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
|
5552
|
+
const whisper_grammar_element * pos,
|
5553
|
+
const uint32_t chr) {
|
5554
|
+
|
5555
|
+
bool found = false;
|
5556
|
+
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
|
5557
|
+
|
5558
|
+
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
|
5559
|
+
|
5560
|
+
do {
|
5561
|
+
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
|
5562
|
+
// inclusive range, e.g. [a-z]
|
5563
|
+
found = found || (pos->value <= chr && chr <= pos[1].value);
|
5564
|
+
pos += 2;
|
5565
|
+
} else {
|
5566
|
+
// exact char match, e.g. [a] or "a"
|
4292
5567
|
found = found || pos->value == chr;
|
4293
5568
|
pos += 1;
|
4294
5569
|
}
|
@@ -4355,7 +5630,7 @@ static void whisper_grammar_advance_stack(
|
|
4355
5630
|
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
|
4356
5631
|
|
4357
5632
|
if (stack.empty()) {
|
4358
|
-
new_stacks.
|
5633
|
+
new_stacks.emplace_back();
|
4359
5634
|
return;
|
4360
5635
|
}
|
4361
5636
|
|
@@ -4676,7 +5951,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4676
5951
|
/*.detect_language =*/ false,
|
4677
5952
|
|
4678
5953
|
/*.suppress_blank =*/ true,
|
4679
|
-
/*.
|
5954
|
+
/*.suppress_nst =*/ false,
|
4680
5955
|
|
4681
5956
|
/*.temperature =*/ 0.0f,
|
4682
5957
|
/*.max_initial_ts =*/ 1.0f,
|
@@ -4716,6 +5991,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4716
5991
|
/*.n_grammar_rules =*/ 0,
|
4717
5992
|
/*.i_start_rule =*/ 0,
|
4718
5993
|
/*.grammar_penalty =*/ 100.0f,
|
5994
|
+
|
5995
|
+
/*.vad =*/ false,
|
5996
|
+
/*.vad_model_path =*/ nullptr,
|
5997
|
+
|
5998
|
+
/* vad_params =*/ whisper_vad_default_params(),
|
4719
5999
|
};
|
4720
6000
|
|
4721
6001
|
switch (strategy) {
|
@@ -4960,7 +6240,7 @@ static void whisper_process_logits(
|
|
4960
6240
|
|
4961
6241
|
// suppress non-speech tokens
|
4962
6242
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
4963
|
-
if (params.
|
6243
|
+
if (params.suppress_nst) {
|
4964
6244
|
for (const std::string & token : non_speech_tokens) {
|
4965
6245
|
const std::string suppress_tokens[] = {token, " " + token};
|
4966
6246
|
for (const std::string & suppress_token : suppress_tokens) {
|
@@ -5332,6 +6612,186 @@ static void whisper_sequence_score(
|
|
5332
6612
|
}
|
5333
6613
|
}
|
5334
6614
|
|
6615
|
+
static bool whisper_vad(
|
6616
|
+
struct whisper_context * ctx,
|
6617
|
+
struct whisper_state * state,
|
6618
|
+
struct whisper_full_params params,
|
6619
|
+
const float * samples,
|
6620
|
+
int n_samples,
|
6621
|
+
std::vector<float> & filtered_samples) {
|
6622
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
6623
|
+
int filtered_n_samples = 0;
|
6624
|
+
|
6625
|
+
// Clear any existing mapping table
|
6626
|
+
state->vad_mapping_table.clear();
|
6627
|
+
state->has_vad_segments = false;
|
6628
|
+
|
6629
|
+
if (state->vad_context == nullptr) {
|
6630
|
+
struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
|
6631
|
+
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
|
6632
|
+
if (vctx == nullptr) {
|
6633
|
+
WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
|
6634
|
+
return false;
|
6635
|
+
}
|
6636
|
+
state->vad_context = vctx;
|
6637
|
+
}
|
6638
|
+
auto vctx = state->vad_context;
|
6639
|
+
|
6640
|
+
const whisper_vad_params & vad_params = params.vad_params;
|
6641
|
+
|
6642
|
+
whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
|
6643
|
+
|
6644
|
+
if (vad_segments->data.size() > 0) {
|
6645
|
+
state->has_vad_segments = true;
|
6646
|
+
ctx->state->vad_segments.clear();
|
6647
|
+
ctx->state->vad_segments.reserve(vad_segments->data.size());
|
6648
|
+
|
6649
|
+
// Initialize the time mapping table
|
6650
|
+
state->vad_mapping_table.clear();
|
6651
|
+
state->vad_mapping_table.reserve(vad_segments->data.size() * 4);
|
6652
|
+
|
6653
|
+
WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
|
6654
|
+
float overlap_seconds = vad_params.samples_overlap;
|
6655
|
+
int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
|
6656
|
+
|
6657
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
6658
|
+
int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
|
6659
|
+
int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
|
6660
|
+
|
6661
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6662
|
+
segment_end_samples += overlap_samples;
|
6663
|
+
}
|
6664
|
+
segment_end_samples = std::min(segment_end_samples, n_samples - 1);
|
6665
|
+
filtered_n_samples += (segment_end_samples - segment_start_samples);
|
6666
|
+
|
6667
|
+
WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
|
6668
|
+
__func__, i, vad_segments->data[i].start/100.0,
|
6669
|
+
(vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)),
|
6670
|
+
(vad_segments->data[i].end - vad_segments->data[i].start)/100.0 +
|
6671
|
+
(i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
|
6672
|
+
}
|
6673
|
+
|
6674
|
+
int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
|
6675
|
+
int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
|
6676
|
+
int total_samples_needed = filtered_n_samples + total_silence_samples;
|
6677
|
+
|
6678
|
+
WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
|
6679
|
+
__func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
|
6680
|
+
|
6681
|
+
try {
|
6682
|
+
filtered_samples.resize(total_samples_needed);
|
6683
|
+
} catch (const std::bad_alloc & /* e */) {
|
6684
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
|
6685
|
+
whisper_vad_free_segments(vad_segments);
|
6686
|
+
whisper_vad_free(vctx);
|
6687
|
+
return false;
|
6688
|
+
}
|
6689
|
+
|
6690
|
+
int offset = 0;
|
6691
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
6692
|
+
int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
|
6693
|
+
int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
|
6694
|
+
|
6695
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6696
|
+
segment_end_samples += overlap_samples;
|
6697
|
+
}
|
6698
|
+
|
6699
|
+
segment_start_samples = std::min(segment_start_samples, n_samples - 1);
|
6700
|
+
segment_end_samples = std::min(segment_end_samples, n_samples);
|
6701
|
+
int segment_length = segment_end_samples - segment_start_samples;
|
6702
|
+
if (segment_length > 0) {
|
6703
|
+
whisper_state::vad_segment_info segment;
|
6704
|
+
|
6705
|
+
segment.orig_start = vad_segments->data[i].start;
|
6706
|
+
segment.orig_end = vad_segments->data[i].end;
|
6707
|
+
|
6708
|
+
segment.vad_start = samples_to_cs(offset);
|
6709
|
+
segment.vad_end = samples_to_cs(offset + segment_length);
|
6710
|
+
|
6711
|
+
// Add segment boundaries to mapping table
|
6712
|
+
vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start};
|
6713
|
+
vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end};
|
6714
|
+
|
6715
|
+
state->vad_mapping_table.push_back(start_mapping);
|
6716
|
+
state->vad_mapping_table.push_back(end_mapping);
|
6717
|
+
|
6718
|
+
// Add intermediate points for longer segments to improve interpolation accuracy
|
6719
|
+
const int64_t min_segment_length = 100; // 1 second
|
6720
|
+
const int64_t point_interval = 20; // Add a point every 200ms
|
6721
|
+
|
6722
|
+
if (segment.vad_end - segment.vad_start > min_segment_length) {
|
6723
|
+
int64_t segment_duration = segment.vad_end - segment.vad_start;
|
6724
|
+
int num_points = (int)(segment_duration / point_interval) - 1;
|
6725
|
+
|
6726
|
+
for (int j = 1; j <= num_points; j++) {
|
6727
|
+
int64_t vad_time = segment.vad_start + j * point_interval;
|
6728
|
+
|
6729
|
+
if (vad_time >= segment.vad_end) continue;
|
6730
|
+
|
6731
|
+
int64_t vad_elapsed = vad_time - segment.vad_start;
|
6732
|
+
int64_t vad_total = segment.vad_end - segment.vad_start;
|
6733
|
+
int64_t orig_total = segment.orig_end - segment.orig_start;
|
6734
|
+
int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total;
|
6735
|
+
|
6736
|
+
vad_time_mapping intermediate_mapping = {vad_time, orig_time};
|
6737
|
+
state->vad_mapping_table.push_back(intermediate_mapping);
|
6738
|
+
}
|
6739
|
+
}
|
6740
|
+
|
6741
|
+
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
|
6742
|
+
__func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0);
|
6743
|
+
ctx->state->vad_segments.push_back(segment);
|
6744
|
+
|
6745
|
+
// Copy this speech segment
|
6746
|
+
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
|
6747
|
+
offset += segment_length;
|
6748
|
+
|
6749
|
+
// Add silence after this segment (except after the last segment)
|
6750
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6751
|
+
// Calculate the start and end time of the silence gap in processed audio
|
6752
|
+
int64_t silence_start_vad = samples_to_cs(offset);
|
6753
|
+
int64_t silence_end_vad = samples_to_cs(offset + silence_samples);
|
6754
|
+
// Calculate the corresponding original times
|
6755
|
+
int64_t orig_silence_start = segment.orig_end;
|
6756
|
+
int64_t orig_silence_end = vad_segments->data[i+1].start;
|
6757
|
+
|
6758
|
+
// Add mapping points for silence boundaries
|
6759
|
+
state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start});
|
6760
|
+
state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
|
6761
|
+
|
6762
|
+
// Fill with zeros (silence)
|
6763
|
+
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
|
6764
|
+
offset += silence_samples;
|
6765
|
+
}
|
6766
|
+
}
|
6767
|
+
}
|
6768
|
+
|
6769
|
+
// Sort the mapping table by processed time
|
6770
|
+
std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
|
6771
|
+
[](const vad_time_mapping& a, const vad_time_mapping& b) {
|
6772
|
+
return a.processed_time < b.processed_time;
|
6773
|
+
});
|
6774
|
+
|
6775
|
+
// Remove any duplicate processed times to ensure monotonicity which is
|
6776
|
+
// needed for binary search and interpolation later.
|
6777
|
+
if (!state->vad_mapping_table.empty()) {
|
6778
|
+
auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
|
6779
|
+
[](const vad_time_mapping& a, const vad_time_mapping& b) {
|
6780
|
+
return a.processed_time == b.processed_time;
|
6781
|
+
});
|
6782
|
+
state->vad_mapping_table.erase(last, state->vad_mapping_table.end());
|
6783
|
+
}
|
6784
|
+
|
6785
|
+
WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size());
|
6786
|
+
|
6787
|
+
filtered_n_samples = offset;
|
6788
|
+
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
|
6789
|
+
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
|
6790
|
+
}
|
6791
|
+
|
6792
|
+
return true;
|
6793
|
+
}
|
6794
|
+
|
5335
6795
|
int whisper_full_with_state(
|
5336
6796
|
struct whisper_context * ctx,
|
5337
6797
|
struct whisper_state * state,
|
@@ -5381,11 +6841,13 @@ int whisper_full_with_state(
|
|
5381
6841
|
const int seek_start = params.offset_ms/10;
|
5382
6842
|
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
5383
6843
|
|
5384
|
-
// if length of spectrogram is less than
|
5385
|
-
// basically don't process anything that is less than
|
5386
|
-
//
|
5387
|
-
|
5388
|
-
|
6844
|
+
// if length of spectrogram is less than 100ms (10 frames), then return
|
6845
|
+
// basically don't process anything that is less than 100ms
|
6846
|
+
// ref: https://github.com/ggml-org/whisper.cpp/issues/2065
|
6847
|
+
const int delta_min = 10;
|
6848
|
+
|
6849
|
+
if (seek_end < seek_start + delta_min) {
|
6850
|
+
WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
5389
6851
|
return 0;
|
5390
6852
|
}
|
5391
6853
|
|
@@ -5432,7 +6894,7 @@ int whisper_full_with_state(
|
|
5432
6894
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
5433
6895
|
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
5434
6896
|
|
5435
|
-
decoder.rng = std::mt19937(
|
6897
|
+
decoder.rng = std::mt19937(j);
|
5436
6898
|
}
|
5437
6899
|
|
5438
6900
|
// the accumulated text context so far
|
@@ -5529,8 +6991,8 @@ int whisper_full_with_state(
|
|
5529
6991
|
ctx, state, progress_cur, params.progress_callback_user_data);
|
5530
6992
|
}
|
5531
6993
|
|
5532
|
-
// if only
|
5533
|
-
if (seek +
|
6994
|
+
// if only 100ms left, then stop
|
6995
|
+
if (seek + delta_min >= seek_end) {
|
5534
6996
|
break;
|
5535
6997
|
}
|
5536
6998
|
|
@@ -5877,10 +7339,10 @@ int whisper_full_with_state(
|
|
5877
7339
|
// end of segment
|
5878
7340
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
5879
7341
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
5880
|
-
(has_ts && seek + seek_delta +
|
7342
|
+
(has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
|
5881
7343
|
) {
|
5882
7344
|
if (result_len == 0 && !params.no_timestamps) {
|
5883
|
-
if (seek + seek_delta +
|
7345
|
+
if (seek + seek_delta + delta_min >= seek_end) {
|
5884
7346
|
result_len = i + 1;
|
5885
7347
|
} else {
|
5886
7348
|
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
@@ -6147,7 +7609,7 @@ int whisper_full_with_state(
|
|
6147
7609
|
|
6148
7610
|
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
|
6149
7611
|
|
6150
|
-
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
|
7612
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
6151
7613
|
for (int j = i0; j <= i; j++) {
|
6152
7614
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
6153
7615
|
}
|
@@ -6192,7 +7654,7 @@ int whisper_full_with_state(
|
|
6192
7654
|
}
|
6193
7655
|
}
|
6194
7656
|
|
6195
|
-
result_all.push_back({ tt0, tt1, text, {}
|
7657
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
6196
7658
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
6197
7659
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
6198
7660
|
}
|
@@ -6229,7 +7691,7 @@ int whisper_full_with_state(
|
|
6229
7691
|
}
|
6230
7692
|
}
|
6231
7693
|
|
6232
|
-
// ref: https://github.com/
|
7694
|
+
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
|
6233
7695
|
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
6234
7696
|
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
6235
7697
|
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
@@ -6253,6 +7715,21 @@ int whisper_full(
|
|
6253
7715
|
struct whisper_full_params params,
|
6254
7716
|
const float * samples,
|
6255
7717
|
int n_samples) {
|
7718
|
+
|
7719
|
+
std::vector<float> vad_samples;
|
7720
|
+
if (params.vad) {
|
7721
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
7722
|
+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
|
7723
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
|
7724
|
+
return -1;
|
7725
|
+
}
|
7726
|
+
if (vad_samples.empty()) {
|
7727
|
+
ctx->state->result_all.clear();
|
7728
|
+
return 0;
|
7729
|
+
}
|
7730
|
+
samples = vad_samples.data();
|
7731
|
+
n_samples = vad_samples.size();
|
7732
|
+
}
|
6256
7733
|
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
|
6257
7734
|
}
|
6258
7735
|
|
@@ -6262,9 +7739,24 @@ int whisper_full_parallel(
|
|
6262
7739
|
const float * samples,
|
6263
7740
|
int n_samples,
|
6264
7741
|
int n_processors) {
|
7742
|
+
|
6265
7743
|
if (n_processors == 1) {
|
6266
7744
|
return whisper_full(ctx, params, samples, n_samples);
|
6267
7745
|
}
|
7746
|
+
|
7747
|
+
std::vector<float> vad_samples;
|
7748
|
+
if (params.vad) {
|
7749
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
7750
|
+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
|
7751
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
|
7752
|
+
return -1;
|
7753
|
+
}
|
7754
|
+
if (vad_samples.empty()) {
|
7755
|
+
return 0;
|
7756
|
+
}
|
7757
|
+
samples = vad_samples.data();
|
7758
|
+
n_samples = vad_samples.size();
|
7759
|
+
}
|
6268
7760
|
int ret = 0;
|
6269
7761
|
|
6270
7762
|
// prepare separate states for each thread
|
@@ -6387,20 +7879,93 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
|
|
6387
7879
|
return ctx->state->lang_id;
|
6388
7880
|
}
|
6389
7881
|
|
6390
|
-
int64_t
|
6391
|
-
|
7882
|
+
static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector<vad_time_mapping> & mapping_table) {
|
7883
|
+
if (mapping_table.empty()) {
|
7884
|
+
return processed_time;
|
7885
|
+
}
|
7886
|
+
|
7887
|
+
if (processed_time <= mapping_table.front().processed_time) {
|
7888
|
+
return mapping_table.front().original_time; // Before first mapping point
|
7889
|
+
}
|
7890
|
+
|
7891
|
+
if (processed_time >= mapping_table.back().processed_time) {
|
7892
|
+
return mapping_table.back().original_time; // After last mapping point
|
7893
|
+
}
|
7894
|
+
|
7895
|
+
// Binary search over the time map that finds the first entry that has a
|
7896
|
+
// processed time greater than or equal to the current processed time.
|
7897
|
+
auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time,
|
7898
|
+
[](const vad_time_mapping & entry, int64_t time) {
|
7899
|
+
return entry.processed_time < time;
|
7900
|
+
}
|
7901
|
+
);
|
7902
|
+
|
7903
|
+
// If exact match found
|
7904
|
+
if (upper->processed_time == processed_time) {
|
7905
|
+
return upper->original_time;
|
7906
|
+
}
|
7907
|
+
|
7908
|
+
// Need to interpolate between two points
|
7909
|
+
auto lower = upper - 1;
|
7910
|
+
|
7911
|
+
int64_t processed_diff = upper->processed_time - lower->processed_time;
|
7912
|
+
int64_t original_diff = upper->original_time - lower->original_time;
|
7913
|
+
int64_t offset = processed_time - lower->processed_time;
|
7914
|
+
|
7915
|
+
if (processed_diff == 0) {
|
7916
|
+
return lower->original_time;
|
7917
|
+
}
|
7918
|
+
|
7919
|
+
// Perform linear interpolation
|
7920
|
+
return lower->original_time + (offset * original_diff) / processed_diff;
|
6392
7921
|
}
|
6393
7922
|
|
6394
|
-
|
6395
|
-
|
7923
|
+
// Function to get the starting timestamp of a segment
|
7924
|
+
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
|
7925
|
+
// If VAD wasn't used, return the original timestamp
|
7926
|
+
if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
|
7927
|
+
return state->result_all[i_segment].t0;
|
7928
|
+
}
|
7929
|
+
|
7930
|
+
// Get the processed timestamp
|
7931
|
+
int64_t t0 = state->result_all[i_segment].t0;
|
7932
|
+
|
7933
|
+
// Map to original time using the mapping table
|
7934
|
+
return map_processed_to_original_time(t0, state->vad_mapping_table);
|
6396
7935
|
}
|
6397
7936
|
|
7937
|
+
// Function to get the ending timestamp of a segment
|
6398
7938
|
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
|
6399
|
-
return
|
7939
|
+
// If VAD wasn't used, return the original timestamp
|
7940
|
+
if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
|
7941
|
+
return state->result_all[i_segment].t1;
|
7942
|
+
}
|
7943
|
+
|
7944
|
+
// Get the processed timestamp
|
7945
|
+
int64_t t1 = state->result_all[i_segment].t1;
|
7946
|
+
|
7947
|
+
// Map to original time using the mapping table
|
7948
|
+
int64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
|
7949
|
+
|
7950
|
+
// Get the corresponding t0 for this segment
|
7951
|
+
int64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment);
|
7952
|
+
|
7953
|
+
// Ensure minimum duration to prevent zero-length segments
|
7954
|
+
const int64_t min_duration = 10; // 10ms minimum
|
7955
|
+
if (orig_t1 - orig_t0 < min_duration) {
|
7956
|
+
orig_t1 = orig_t0 + min_duration;
|
7957
|
+
}
|
7958
|
+
|
7959
|
+
return orig_t1;
|
7960
|
+
}
|
7961
|
+
|
7962
|
+
|
7963
|
+
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
7964
|
+
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
|
6400
7965
|
}
|
6401
7966
|
|
6402
7967
|
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
6403
|
-
return ctx->state
|
7968
|
+
return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
|
6404
7969
|
}
|
6405
7970
|
|
6406
7971
|
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
@@ -6459,6 +8024,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
|
|
6459
8024
|
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
6460
8025
|
}
|
6461
8026
|
|
8027
|
+
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
|
8028
|
+
return ctx->state->result_all[i_segment].no_speech_prob;
|
8029
|
+
}
|
8030
|
+
|
8031
|
+
float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
|
8032
|
+
return state->result_all[i_segment].no_speech_prob;
|
8033
|
+
}
|
8034
|
+
|
6462
8035
|
// =================================================================================================
|
6463
8036
|
|
6464
8037
|
//
|
@@ -6639,7 +8212,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6639
8212
|
// c: N*N*sizeof(float)
|
6640
8213
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
6641
8214
|
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
|
6642
|
-
std::vector<uint8_t> work;
|
6643
8215
|
|
6644
8216
|
// put a bunch of random data in the buffer
|
6645
8217
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
@@ -6696,12 +8268,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6696
8268
|
double tsum = 0.0;
|
6697
8269
|
|
6698
8270
|
// heat-up
|
6699
|
-
ggml_graph_compute_helper(gf,
|
8271
|
+
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
6700
8272
|
|
6701
8273
|
for (int i = 0; i < n_max; ++i) {
|
6702
8274
|
const int64_t t0 = ggml_time_us();
|
6703
8275
|
|
6704
|
-
ggml_graph_compute_helper(gf,
|
8276
|
+
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
6705
8277
|
|
6706
8278
|
const int64_t t1 = ggml_time_us();
|
6707
8279
|
|
@@ -6754,10 +8326,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6754
8326
|
// token-level timestamps
|
6755
8327
|
//
|
6756
8328
|
|
6757
|
-
static int timestamp_to_sample(int64_t t, int n_samples) {
|
6758
|
-
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
6759
|
-
}
|
6760
|
-
|
6761
8329
|
static int64_t sample_to_timestamp(int i_sample) {
|
6762
8330
|
return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
|
6763
8331
|
}
|
@@ -6807,6 +8375,18 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
|
|
6807
8375
|
return result;
|
6808
8376
|
}
|
6809
8377
|
|
8378
|
+
static int timestamp_to_sample(int64_t t, int64_t segment_t0, int n_samples) {
|
8379
|
+
// Convert absolute timestamp to segment-relative timestamp
|
8380
|
+
int64_t relative_t = t - segment_t0;
|
8381
|
+
int sample = (int)((relative_t * WHISPER_SAMPLE_RATE) / 100);
|
8382
|
+
return std::max(0, std::min(n_samples - 1, sample));
|
8383
|
+
}
|
8384
|
+
|
8385
|
+
static int64_t sample_to_timestamp(int i_sample, int64_t segment_t0) {
|
8386
|
+
int64_t relative_timestamp = (100ll * i_sample) / WHISPER_SAMPLE_RATE;
|
8387
|
+
return relative_timestamp + segment_t0;
|
8388
|
+
}
|
8389
|
+
|
6810
8390
|
static void whisper_exp_compute_token_level_timestamps(
|
6811
8391
|
struct whisper_context & ctx,
|
6812
8392
|
struct whisper_state & state,
|
@@ -6862,12 +8442,6 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6862
8442
|
|
6863
8443
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
|
6864
8444
|
|
6865
|
-
tokens[j].id = token.id;
|
6866
|
-
tokens[j].tid = token.tid;
|
6867
|
-
tokens[j].p = token.p;
|
6868
|
-
tokens[j].pt = token.pt;
|
6869
|
-
tokens[j].ptsum = token.ptsum;
|
6870
|
-
|
6871
8445
|
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
|
6872
8446
|
|
6873
8447
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
@@ -6953,8 +8527,8 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6953
8527
|
continue;
|
6954
8528
|
}
|
6955
8529
|
|
6956
|
-
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
|
6957
|
-
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
|
8530
|
+
int s0 = timestamp_to_sample(tokens[j].t0, segment.t0, n_samples);
|
8531
|
+
int s1 = timestamp_to_sample(tokens[j].t1, segment.t0, n_samples);
|
6958
8532
|
|
6959
8533
|
const int ss0 = std::max(s0 - hw, 0);
|
6960
8534
|
const int ss1 = std::min(s1 + hw, n_samples);
|
@@ -6975,7 +8549,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6975
8549
|
while (k > 0 && state.energy[k] > thold) {
|
6976
8550
|
k--;
|
6977
8551
|
}
|
6978
|
-
tokens[j].t0 = sample_to_timestamp(k);
|
8552
|
+
tokens[j].t0 = sample_to_timestamp(k, segment.t0);
|
6979
8553
|
if (tokens[j].t0 < tokens[j - 1].t1) {
|
6980
8554
|
tokens[j].t0 = tokens[j - 1].t1;
|
6981
8555
|
} else {
|
@@ -6986,7 +8560,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6986
8560
|
k++;
|
6987
8561
|
}
|
6988
8562
|
s0 = k;
|
6989
|
-
tokens[j].t0 = sample_to_timestamp(k);
|
8563
|
+
tokens[j].t0 = sample_to_timestamp(k, segment.t0);
|
6990
8564
|
}
|
6991
8565
|
}
|
6992
8566
|
|
@@ -6996,7 +8570,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6996
8570
|
while (k < n_samples - 1 && state.energy[k] > thold) {
|
6997
8571
|
k++;
|
6998
8572
|
}
|
6999
|
-
tokens[j].t1 = sample_to_timestamp(k);
|
8573
|
+
tokens[j].t1 = sample_to_timestamp(k, segment.t0);
|
7000
8574
|
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
7001
8575
|
tokens[j].t1 = tokens[j + 1].t0;
|
7002
8576
|
} else {
|
@@ -7007,7 +8581,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
7007
8581
|
k--;
|
7008
8582
|
}
|
7009
8583
|
s1 = k;
|
7010
|
-
tokens[j].t1 = sample_to_timestamp(k);
|
8584
|
+
tokens[j].t1 = sample_to_timestamp(k, segment.t0);
|
7011
8585
|
}
|
7012
8586
|
}
|
7013
8587
|
}
|
@@ -7078,18 +8652,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7078
8652
|
struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
|
7079
8653
|
struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
|
7080
8654
|
|
7081
|
-
cost =
|
7082
|
-
trace =
|
7083
|
-
|
8655
|
+
cost = whisper_set_f32(cost, INFINITY);
|
8656
|
+
trace = whisper_set_i32(trace, -1);
|
8657
|
+
whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
7084
8658
|
|
7085
8659
|
// dtw
|
7086
8660
|
// supposedly can be optmized by computing diagonals in parallel ?
|
7087
8661
|
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
7088
8662
|
for (int64_t j = 1; j < M + 1; ++j) {
|
7089
8663
|
for (int64_t i = 1; i < N + 1; ++i) {
|
7090
|
-
float c0 =
|
7091
|
-
float c1 =
|
7092
|
-
float c2 =
|
8664
|
+
float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
8665
|
+
float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
|
8666
|
+
float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
|
7093
8667
|
|
7094
8668
|
float c;
|
7095
8669
|
int32_t t;
|
@@ -7104,9 +8678,9 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7104
8678
|
t = 2;
|
7105
8679
|
}
|
7106
8680
|
|
7107
|
-
c =
|
7108
|
-
|
7109
|
-
|
8681
|
+
c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
8682
|
+
whisper_set_f32_nd(cost, i, j, 0, 0, c);
|
8683
|
+
whisper_set_i32_nd(trace, i, j, 0, 0, t);
|
7110
8684
|
}
|
7111
8685
|
}
|
7112
8686
|
|
@@ -7115,19 +8689,19 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7115
8689
|
struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
7116
8690
|
// trace[0, :] = 2;
|
7117
8691
|
for (int64_t i = 0; i < M + 1; ++i)
|
7118
|
-
|
8692
|
+
whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
|
7119
8693
|
//trace[:, 0] = 1;
|
7120
8694
|
for (int64_t i = 0; i < N + 1; ++i)
|
7121
|
-
|
8695
|
+
whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
|
7122
8696
|
int bt_row_idx = BT_MAX_ROWS - 1;
|
7123
8697
|
int64_t i = N;
|
7124
8698
|
int64_t j = M;
|
7125
8699
|
while (i > 0 || j > 0) {
|
7126
|
-
|
7127
|
-
|
8700
|
+
whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
8701
|
+
whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
7128
8702
|
--bt_row_idx;
|
7129
8703
|
|
7130
|
-
int32_t t =
|
8704
|
+
int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
|
7131
8705
|
if (t == 0) {
|
7132
8706
|
--i;
|
7133
8707
|
--j;
|
@@ -7148,8 +8722,8 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7148
8722
|
ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
|
7149
8723
|
for (int64_t i = 0; i < 2; ++i) {
|
7150
8724
|
for (int64_t j = 0; j < result_n_cols; ++j) {
|
7151
|
-
int32_t v =
|
7152
|
-
|
8725
|
+
int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
8726
|
+
whisper_set_i32_nd(r, i, j, 0, 0, v);
|
7153
8727
|
}
|
7154
8728
|
}
|
7155
8729
|
|
@@ -7184,11 +8758,11 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
|
|
7184
8758
|
idx = 2*(a->ne[2] - 1) - idx;
|
7185
8759
|
}
|
7186
8760
|
|
7187
|
-
filter.push_back(
|
8761
|
+
filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
|
7188
8762
|
}
|
7189
8763
|
std::sort(filter.begin(), filter.end());
|
7190
8764
|
const float v = filter[filter.size()/2];
|
7191
|
-
|
8765
|
+
whisper_set_f32_nd(dst, i, j, k, 0, v);
|
7192
8766
|
filter.clear();
|
7193
8767
|
}
|
7194
8768
|
}
|
@@ -7310,7 +8884,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7310
8884
|
// Compute
|
7311
8885
|
struct ggml_cgraph * gf = ggml_new_graph(gctx);
|
7312
8886
|
ggml_build_forward_expand(gf, w);
|
7313
|
-
|
8887
|
+
|
8888
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
8889
|
+
ggml_backend_graph_compute(backend.get(), gf);
|
7314
8890
|
|
7315
8891
|
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
7316
8892
|
|
@@ -7319,9 +8895,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7319
8895
|
auto seg_i = state->result_all.begin() + i_segment;
|
7320
8896
|
auto tok_i = seg_i->tokens.begin();
|
7321
8897
|
for (int i = 0; i < alignment->ne[1]; ++i) {
|
7322
|
-
int32_t v =
|
8898
|
+
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
|
7323
8899
|
if (v != last_v) {
|
7324
|
-
int32_t time_index =
|
8900
|
+
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
|
7325
8901
|
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
7326
8902
|
last_v = v;
|
7327
8903
|
|
@@ -7362,6 +8938,10 @@ void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
|
7362
8938
|
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
7363
8939
|
}
|
7364
8940
|
|
8941
|
+
const char * whisper_version(void) {
|
8942
|
+
return WHISPER_VERSION;
|
8943
|
+
}
|
8944
|
+
|
7365
8945
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
7366
8946
|
static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
|
7367
8947
|
va_list args;
|