whispercpp 1.3.1 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +4 -3
- data/README.md +92 -31
- data/Rakefile +26 -7
- data/ext/.gitignore +5 -7
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +21 -198
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +17 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
- data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
- data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
- data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
- data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
- data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
- data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
- data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
- data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
- data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
- data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
- data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
- data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
- data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
- data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
- data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
- data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{include → sources/include}/whisper.h +68 -2
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
- data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
- data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{src → sources/src}/whisper.cpp +1905 -374
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +33 -5
- data/lib/whisper/model/uri.rb +149 -128
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +28 -0
- data/tests/test_callback.rb +45 -3
- data/tests/test_error.rb +2 -2
- data/tests/test_model.rb +38 -0
- data/tests/test_package.rb +18 -3
- data/tests/test_params.rb +145 -8
- data/tests/test_segment.rb +10 -19
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +37 -37
- data/whispercpp.gemspec +5 -4
- metadata +766 -111
- data/ext/cpu.mk +0 -9
- data/ext/examples/dr_wav.h +0 -8815
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
- data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
- data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
- data/ext/metal-embed.mk +0 -17
- data/ext/metal.mk +0 -6
- data/ext/ruby_whisper.cpp +0 -1909
- data/ext/scripts/get-flags.mk +0 -38
- data/lib/whisper.rb +0 -2
- /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
- /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
- /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
- /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
- /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
#include "whisper.h"
|
2
|
-
|
3
|
-
#include "ggml-cpu.h"
|
2
|
+
#include "whisper-arch.h"
|
4
3
|
|
5
4
|
#include "ggml.h"
|
5
|
+
#include "ggml-cpp.h"
|
6
6
|
#include "ggml-alloc.h"
|
7
7
|
#include "ggml-backend.h"
|
8
8
|
|
@@ -17,37 +17,36 @@
|
|
17
17
|
#include <atomic>
|
18
18
|
#include <algorithm>
|
19
19
|
#include <cassert>
|
20
|
+
#include <cfloat>
|
20
21
|
#define _USE_MATH_DEFINES
|
21
22
|
#include <cmath>
|
22
|
-
#include <
|
23
|
+
#include <climits>
|
24
|
+
#include <codecvt>
|
23
25
|
#include <cstdarg>
|
26
|
+
#include <cstdio>
|
24
27
|
#include <cstring>
|
25
28
|
#include <fstream>
|
29
|
+
#include <functional>
|
26
30
|
#include <map>
|
31
|
+
#include <mutex>
|
32
|
+
#include <random>
|
33
|
+
#include <regex>
|
27
34
|
#include <set>
|
28
35
|
#include <string>
|
29
36
|
#include <thread>
|
30
37
|
#include <vector>
|
31
|
-
#include <regex>
|
32
|
-
#include <random>
|
33
|
-
#include <functional>
|
34
|
-
#include <codecvt>
|
35
|
-
|
36
|
-
#if defined(_MSC_VER)
|
37
|
-
#pragma warning(disable: 4244 4267) // possible loss of data
|
38
|
-
#endif
|
39
|
-
|
40
|
-
#if defined(GGML_BIG_ENDIAN)
|
41
|
-
#include <bit>
|
42
38
|
|
39
|
+
#if defined(WHISPER_BIG_ENDIAN)
|
43
40
|
template<typename T>
|
44
41
|
static T byteswap(T value) {
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
42
|
+
T value_swapped;
|
43
|
+
char * source = reinterpret_cast<char *>(&value);
|
44
|
+
char * target = reinterpret_cast<char *>(&value_swapped);
|
45
|
+
int size = sizeof(T);
|
46
|
+
for (int i = 0; i < size; i++) {
|
47
|
+
target[size - 1 - i] = source[i];
|
48
|
+
}
|
49
|
+
return value_swapped;
|
51
50
|
}
|
52
51
|
|
53
52
|
template<typename T>
|
@@ -83,14 +82,14 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
83
82
|
}
|
84
83
|
|
85
84
|
#define BYTESWAP_VALUE(d) d = byteswap(d)
|
86
|
-
#define BYTESWAP_FILTERS(f)
|
85
|
+
#define BYTESWAP_FILTERS(f) \
|
87
86
|
do { \
|
88
87
|
for (auto & datum : f.data) { \
|
89
88
|
datum = byteswap(datum); \
|
90
89
|
} \
|
91
90
|
} while (0)
|
92
|
-
#define BYTESWAP_TENSOR(t)
|
93
|
-
do {
|
91
|
+
#define BYTESWAP_TENSOR(t) \
|
92
|
+
do { \
|
94
93
|
byteswap_tensor(t); \
|
95
94
|
} while (0)
|
96
95
|
#else
|
@@ -141,34 +140,52 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
141
140
|
#define WHISPER_MAX_DECODERS 8
|
142
141
|
#define WHISPER_MAX_NODES 4096
|
143
142
|
|
143
|
+
static std::string format(const char * fmt, ...) {
|
144
|
+
va_list ap;
|
145
|
+
va_list ap2;
|
146
|
+
va_start(ap, fmt);
|
147
|
+
va_copy(ap2, ap);
|
148
|
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
149
|
+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
150
|
+
std::vector<char> buf(size + 1);
|
151
|
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
152
|
+
GGML_ASSERT(size2 == size);
|
153
|
+
va_end(ap2);
|
154
|
+
va_end(ap);
|
155
|
+
return std::string(buf.data(), size);
|
156
|
+
}
|
157
|
+
|
144
158
|
//
|
145
159
|
// ggml helpers
|
146
160
|
//
|
147
161
|
|
148
162
|
static bool ggml_graph_compute_helper(
|
149
163
|
struct ggml_cgraph * graph,
|
150
|
-
std::vector<uint8_t> & buf,
|
151
164
|
int n_threads,
|
152
165
|
ggml_abort_callback abort_callback,
|
153
166
|
void * abort_callback_data) {
|
154
|
-
|
167
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
168
|
+
|
169
|
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
155
170
|
|
156
|
-
|
157
|
-
|
171
|
+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
172
|
+
if (set_abort_callback_fn) {
|
173
|
+
set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
|
174
|
+
}
|
158
175
|
|
159
|
-
|
160
|
-
|
161
|
-
|
176
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
177
|
+
if (ggml_backend_set_n_threads_fn) {
|
178
|
+
ggml_backend_set_n_threads_fn(backend.get(), n_threads);
|
162
179
|
}
|
163
180
|
|
164
|
-
return
|
181
|
+
return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
|
165
182
|
}
|
166
183
|
|
167
184
|
static bool ggml_graph_compute_helper(
|
168
185
|
ggml_backend_sched_t sched,
|
169
186
|
struct ggml_cgraph * graph,
|
170
|
-
int n_threads
|
171
|
-
|
187
|
+
int n_threads,
|
188
|
+
bool sched_reset = true) {
|
172
189
|
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
|
173
190
|
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
|
174
191
|
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
@@ -180,11 +197,70 @@ static bool ggml_graph_compute_helper(
|
|
180
197
|
}
|
181
198
|
}
|
182
199
|
|
183
|
-
bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
|
184
|
-
|
200
|
+
const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
|
201
|
+
|
202
|
+
if (!t || sched_reset) {
|
203
|
+
ggml_backend_sched_reset(sched);
|
204
|
+
}
|
205
|
+
|
206
|
+
return t;
|
207
|
+
}
|
208
|
+
|
209
|
+
static void whisper_load_backends() {
|
210
|
+
#ifdef GGML_BACKEND_DL
|
211
|
+
static std::once_flag flag;
|
212
|
+
std::call_once(flag, []() {
|
213
|
+
ggml_backend_load_all();
|
214
|
+
});
|
215
|
+
#endif
|
216
|
+
}
|
217
|
+
|
218
|
+
// TODO: move these functions to ggml-base with support for ggml-backend?
|
219
|
+
|
220
|
+
static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
|
221
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
222
|
+
GGML_ASSERT(ggml_is_contiguous(t));
|
223
|
+
size_t nels = ggml_nelements(t);
|
224
|
+
for (size_t i = 0; i < nels; ++i) {
|
225
|
+
((float *) t->data)[i] = v;
|
226
|
+
}
|
227
|
+
return t;
|
228
|
+
}
|
229
|
+
|
230
|
+
static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
|
231
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
232
|
+
GGML_ASSERT(ggml_is_contiguous(t));
|
233
|
+
size_t nels = ggml_nelements(t);
|
234
|
+
for (size_t i = 0; i < nels; ++i) {
|
235
|
+
((int32_t *) t->data)[i] = v;
|
236
|
+
}
|
185
237
|
return t;
|
186
238
|
}
|
187
239
|
|
240
|
+
static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
241
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
242
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
243
|
+
return *(float *) data;
|
244
|
+
}
|
245
|
+
|
246
|
+
static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
|
247
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
248
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
249
|
+
*(float *) data = v;
|
250
|
+
}
|
251
|
+
|
252
|
+
static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
253
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
254
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
255
|
+
return *(int32_t *) data;
|
256
|
+
}
|
257
|
+
|
258
|
+
static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
|
259
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
260
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
261
|
+
*(int32_t *) data = v;
|
262
|
+
}
|
263
|
+
|
188
264
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
189
265
|
// the idea is to represent the original matrix multiplication:
|
190
266
|
//
|
@@ -428,6 +504,7 @@ struct whisper_segment {
|
|
428
504
|
int64_t t1;
|
429
505
|
|
430
506
|
std::string text;
|
507
|
+
float no_speech_prob;
|
431
508
|
|
432
509
|
std::vector<whisper_token_data> tokens;
|
433
510
|
|
@@ -520,7 +597,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
|
|
520
597
|
auto & sched = allocr.sched;
|
521
598
|
auto & meta = allocr.meta;
|
522
599
|
|
523
|
-
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
|
600
|
+
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
|
524
601
|
|
525
602
|
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
526
603
|
|
@@ -716,10 +793,10 @@ struct whisper_model {
|
|
716
793
|
std::vector<whisper_layer_decoder> layers_decoder;
|
717
794
|
|
718
795
|
// ggml context that contains all the meta information about the model tensors
|
719
|
-
|
796
|
+
std::vector<ggml_context *> ctxs;
|
720
797
|
|
721
798
|
// the model backend data is read-only and can be shared between processors
|
722
|
-
ggml_backend_buffer_t
|
799
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
723
800
|
|
724
801
|
// tensors
|
725
802
|
int n_loaded;
|
@@ -876,6 +953,17 @@ struct whisper_state {
|
|
876
953
|
|
877
954
|
// [EXPERIMENTAL] speed-up techniques
|
878
955
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
956
|
+
|
957
|
+
whisper_vad_context * vad_context = nullptr;
|
958
|
+
|
959
|
+
struct vad_segment_info {
|
960
|
+
float orig_start;
|
961
|
+
float orig_end;
|
962
|
+
float vad_start;
|
963
|
+
float vad_end;
|
964
|
+
};
|
965
|
+
std::vector<vad_segment_info> vad_segments;
|
966
|
+
bool has_vad_segments = false;
|
879
967
|
};
|
880
968
|
|
881
969
|
struct whisper_context {
|
@@ -1234,21 +1322,38 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1234
1322
|
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
1235
1323
|
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
1236
1324
|
|
1325
|
+
whisper_load_backends();
|
1326
|
+
|
1327
|
+
ggml_backend_dev_t dev = nullptr;
|
1328
|
+
|
1329
|
+
int cnt = 0;
|
1237
1330
|
if (params.use_gpu) {
|
1238
1331
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1239
|
-
ggml_backend_dev_t
|
1240
|
-
if (ggml_backend_dev_type(
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1332
|
+
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
|
1333
|
+
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1334
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
1335
|
+
dev = dev_cur;
|
1336
|
+
}
|
1337
|
+
|
1338
|
+
if (++cnt > params.gpu_device) {
|
1339
|
+
break;
|
1245
1340
|
}
|
1246
|
-
return result;
|
1247
1341
|
}
|
1248
1342
|
}
|
1249
1343
|
}
|
1250
1344
|
|
1251
|
-
|
1345
|
+
if (dev == nullptr) {
|
1346
|
+
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
|
1347
|
+
return nullptr;
|
1348
|
+
}
|
1349
|
+
|
1350
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1351
|
+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
1352
|
+
if (!result) {
|
1353
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1354
|
+
}
|
1355
|
+
|
1356
|
+
return result;
|
1252
1357
|
}
|
1253
1358
|
|
1254
1359
|
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
@@ -1274,28 +1379,118 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
|
|
1274
1379
|
}
|
1275
1380
|
}
|
1276
1381
|
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1382
|
+
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
1383
|
+
if (backend_cpu == nullptr) {
|
1384
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
1385
|
+
}
|
1386
|
+
result.push_back(backend_cpu);
|
1280
1387
|
|
1281
1388
|
return result;
|
1282
1389
|
}
|
1283
1390
|
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1391
|
+
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
|
1392
|
+
|
1393
|
+
static buft_list_t make_buft_list(whisper_context_params & params) {
|
1394
|
+
// Prio order: GPU -> CPU Extra -> CPU
|
1395
|
+
buft_list_t buft_list;
|
1396
|
+
|
1397
|
+
// GPU
|
1398
|
+
if (params.use_gpu) {
|
1399
|
+
int cnt = 0;
|
1400
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1401
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1402
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1403
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
1404
|
+
auto * buft = ggml_backend_dev_buffer_type(dev);
|
1405
|
+
if (buft) {
|
1406
|
+
buft_list.emplace_back(dev, buft);
|
1407
|
+
}
|
1408
|
+
}
|
1409
|
+
|
1410
|
+
if (++cnt > params.gpu_device) {
|
1411
|
+
break;
|
1412
|
+
}
|
1413
|
+
}
|
1414
|
+
}
|
1415
|
+
}
|
1416
|
+
|
1417
|
+
// CPU Extra
|
1418
|
+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
1419
|
+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
1420
|
+
auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
1421
|
+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
1422
|
+
if (get_extra_bufts_fn) {
|
1423
|
+
ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
|
1424
|
+
while (extra_bufts && *extra_bufts) {
|
1425
|
+
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
1426
|
+
++extra_bufts;
|
1427
|
+
}
|
1287
1428
|
}
|
1288
1429
|
|
1289
|
-
//
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1430
|
+
// CPU
|
1431
|
+
buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
|
1432
|
+
|
1433
|
+
return buft_list;
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
1437
|
+
bool op_supported = true;
|
1438
|
+
|
1439
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
1440
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
1441
|
+
// GPU and default CPU backend support all operators
|
1442
|
+
op_supported = true;
|
1443
|
+
} else {
|
1444
|
+
switch (op) {
|
1445
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
1446
|
+
case GGML_OP_MUL_MAT: {
|
1447
|
+
ggml_init_params params = {
|
1448
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
1449
|
+
/*.mem_buffer =*/ nullptr,
|
1450
|
+
/*.no_alloc =*/ true,
|
1451
|
+
};
|
1452
|
+
|
1453
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1454
|
+
if (!ctx_ptr) {
|
1455
|
+
throw std::runtime_error("failed to create ggml context");
|
1456
|
+
}
|
1457
|
+
ggml_context * ctx = ctx_ptr.get();
|
1458
|
+
|
1459
|
+
ggml_tensor * op_tensor = nullptr;
|
1460
|
+
|
1461
|
+
int64_t n_ctx = hparams.n_audio_ctx;
|
1462
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
1463
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
1464
|
+
|
1465
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
1466
|
+
GGML_ASSERT(w->buffer == nullptr);
|
1467
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
1468
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
1469
|
+
ggml_backend_buffer_free(w->buffer);
|
1470
|
+
w->buffer = nullptr;
|
1471
|
+
break;
|
1472
|
+
}
|
1473
|
+
default: {
|
1474
|
+
op_supported = false;
|
1475
|
+
break;
|
1476
|
+
}
|
1477
|
+
};
|
1478
|
+
}
|
1479
|
+
|
1480
|
+
return op_supported;
|
1481
|
+
}
|
1482
|
+
|
1483
|
+
static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
1484
|
+
GGML_ASSERT(!buft_list.empty());
|
1485
|
+
for (const auto & p : buft_list) {
|
1486
|
+
ggml_backend_dev_t dev = p.first;
|
1487
|
+
ggml_backend_buffer_type_t buft = p.second;
|
1488
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
1489
|
+
return buft;
|
1295
1490
|
}
|
1296
1491
|
}
|
1297
1492
|
|
1298
|
-
return
|
1493
|
+
return nullptr;
|
1299
1494
|
}
|
1300
1495
|
|
1301
1496
|
// load the model from a ggml file
|
@@ -1504,31 +1699,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1504
1699
|
const ggml_type wtype = wctx.wtype;
|
1505
1700
|
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
1506
1701
|
|
1507
|
-
|
1508
|
-
{
|
1509
|
-
const auto & hparams = model.hparams;
|
1702
|
+
const auto & hparams = model.hparams;
|
1510
1703
|
|
1511
|
-
|
1512
|
-
|
1704
|
+
const int n_audio_layer = hparams.n_audio_layer;
|
1705
|
+
const int n_text_layer = hparams.n_text_layer;
|
1513
1706
|
|
1514
|
-
|
1707
|
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
1515
1708
|
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1709
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
1710
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
1711
|
+
auto it = ctx_map.find(buft);
|
1712
|
+
if (it == ctx_map.end()) {
|
1713
|
+
ggml_init_params params = {
|
1714
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
1715
|
+
/*.mem_buffer =*/ nullptr,
|
1716
|
+
/*.no_alloc =*/ true,
|
1717
|
+
};
|
1521
1718
|
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1719
|
+
ggml_context * ctx = ggml_init(params);
|
1720
|
+
if (!ctx) {
|
1721
|
+
throw std::runtime_error("failed to create ggml context");
|
1722
|
+
}
|
1723
|
+
|
1724
|
+
ctx_map[buft] = ctx;
|
1725
|
+
model.ctxs.emplace_back(ctx);
|
1726
|
+
|
1727
|
+
return ctx;
|
1526
1728
|
}
|
1527
|
-
|
1729
|
+
|
1730
|
+
return it->second;
|
1731
|
+
};
|
1732
|
+
|
1733
|
+
// Create a list of available bufts, in priority order
|
1734
|
+
buft_list_t buft_list = make_buft_list(wctx.params);
|
1735
|
+
|
1736
|
+
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
|
1737
|
+
ggml_op op = ASR_TENSOR_INFO.at(type);
|
1738
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
1739
|
+
if (!buft) {
|
1740
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
|
1741
|
+
}
|
1742
|
+
|
1743
|
+
ggml_context * ctx = get_ctx(buft);
|
1744
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
1745
|
+
|
1746
|
+
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
|
1747
|
+
|
1748
|
+
return tensor;
|
1749
|
+
};
|
1750
|
+
|
1528
1751
|
|
1529
1752
|
// prepare tensors for the weights
|
1530
1753
|
{
|
1531
|
-
|
1754
|
+
ggml_init_params params = {
|
1755
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
1756
|
+
/*.mem_buffer =*/ nullptr,
|
1757
|
+
/*.no_alloc =*/ true,
|
1758
|
+
};
|
1759
|
+
|
1760
|
+
ggml_context * ctx = ggml_init(params);
|
1532
1761
|
|
1533
1762
|
const auto & hparams = model.hparams;
|
1534
1763
|
|
@@ -1548,189 +1777,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1548
1777
|
model.layers_decoder.resize(n_text_layer);
|
1549
1778
|
|
1550
1779
|
// encoder
|
1551
|
-
|
1552
|
-
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
1553
|
-
|
1554
|
-
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
1555
|
-
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1556
|
-
|
1557
|
-
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
1558
|
-
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1559
|
-
|
1560
|
-
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1561
|
-
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1562
|
-
|
1563
|
-
// map by name
|
1564
|
-
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
1565
|
-
|
1566
|
-
model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
|
1567
|
-
model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
|
1568
|
-
|
1569
|
-
model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
|
1570
|
-
model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
|
1571
|
-
|
1572
|
-
model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
|
1573
|
-
model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
|
1574
|
-
|
1575
|
-
for (int i = 0; i < n_audio_layer; ++i) {
|
1576
|
-
auto & layer = model.layers_encoder[i];
|
1577
|
-
|
1578
|
-
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1579
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1780
|
+
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
|
1580
1781
|
|
1581
|
-
|
1582
|
-
|
1782
|
+
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
|
1783
|
+
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
1583
1784
|
|
1584
|
-
|
1585
|
-
|
1785
|
+
model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
|
1786
|
+
model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
1586
1787
|
|
1587
|
-
|
1588
|
-
|
1788
|
+
model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
1789
|
+
model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
1589
1790
|
|
1590
|
-
|
1591
|
-
|
1791
|
+
for (int i = 0; i < n_audio_layer; ++i) {
|
1792
|
+
auto & layer = model.layers_encoder[i];
|
1592
1793
|
|
1593
|
-
|
1794
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1795
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1594
1796
|
|
1595
|
-
|
1596
|
-
|
1797
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
1798
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
|
1597
1799
|
|
1598
|
-
|
1599
|
-
|
1800
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
1801
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1600
1802
|
|
1601
|
-
|
1602
|
-
|
1603
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1803
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1804
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1604
1805
|
|
1605
|
-
|
1606
|
-
|
1806
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1807
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1607
1808
|
|
1608
|
-
|
1609
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
1809
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1610
1810
|
|
1611
|
-
|
1612
|
-
|
1811
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1812
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1613
1813
|
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
1618
|
-
|
1619
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
1620
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1621
|
-
|
1622
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
1623
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
1624
|
-
}
|
1814
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1815
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1625
1816
|
}
|
1626
1817
|
|
1627
1818
|
// decoder
|
1628
|
-
|
1629
|
-
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
|
1630
|
-
|
1631
|
-
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
1632
|
-
|
1633
|
-
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1634
|
-
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1635
|
-
|
1636
|
-
// map by name
|
1637
|
-
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
1638
|
-
|
1639
|
-
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
1640
|
-
|
1641
|
-
model.tensors["decoder.ln.weight"] = model.d_ln_w;
|
1642
|
-
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
1643
|
-
|
1644
|
-
for (int i = 0; i < n_text_layer; ++i) {
|
1645
|
-
auto & layer = model.layers_decoder[i];
|
1646
|
-
|
1647
|
-
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1648
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1649
|
-
|
1650
|
-
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
|
1651
|
-
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
|
1652
|
-
|
1653
|
-
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
|
1654
|
-
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1655
|
-
|
1656
|
-
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1657
|
-
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1658
|
-
|
1659
|
-
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
1660
|
-
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1819
|
+
model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
|
1661
1820
|
|
1662
|
-
|
1821
|
+
model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
|
1663
1822
|
|
1664
|
-
|
1665
|
-
|
1823
|
+
model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
1824
|
+
model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
1666
1825
|
|
1667
|
-
|
1668
|
-
|
1826
|
+
for (int i = 0; i < n_text_layer; ++i) {
|
1827
|
+
auto & layer = model.layers_decoder[i];
|
1669
1828
|
|
1670
|
-
|
1671
|
-
|
1829
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1830
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1672
1831
|
|
1673
|
-
|
1674
|
-
|
1832
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
|
1833
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
|
1675
1834
|
|
1676
|
-
|
1835
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
|
1836
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1677
1837
|
|
1678
|
-
|
1679
|
-
|
1838
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1839
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1680
1840
|
|
1681
|
-
|
1682
|
-
|
1841
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1842
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1683
1843
|
|
1684
|
-
|
1685
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
1686
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1844
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1687
1845
|
|
1688
|
-
|
1689
|
-
|
1846
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1847
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1690
1848
|
|
1691
|
-
|
1692
|
-
|
1849
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1850
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1693
1851
|
|
1694
|
-
|
1695
|
-
|
1852
|
+
layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1853
|
+
layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1696
1854
|
|
1697
|
-
|
1698
|
-
|
1855
|
+
layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1856
|
+
layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1699
1857
|
|
1700
|
-
|
1858
|
+
layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1701
1859
|
|
1702
|
-
|
1703
|
-
|
1860
|
+
layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1861
|
+
layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1704
1862
|
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1708
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
|
1709
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
|
1710
|
-
|
1711
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
1712
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
1713
|
-
|
1714
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
1715
|
-
|
1716
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
1717
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
1718
|
-
|
1719
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
1720
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
1721
|
-
}
|
1863
|
+
layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1864
|
+
layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1722
1865
|
}
|
1866
|
+
|
1867
|
+
ggml_free(ctx);
|
1723
1868
|
}
|
1724
1869
|
|
1725
1870
|
// allocate tensors in the backend buffers
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1871
|
+
for (auto & p : ctx_map) {
|
1872
|
+
ggml_backend_buffer_type_t buft = p.first;
|
1873
|
+
ggml_context * ctx = p.second;
|
1874
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
1875
|
+
if (buf) {
|
1876
|
+
model.buffers.emplace_back(buf);
|
1731
1877
|
|
1732
|
-
|
1733
|
-
|
1878
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
1879
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
1880
|
+
}
|
1881
|
+
}
|
1734
1882
|
|
1735
1883
|
// load weights
|
1736
1884
|
{
|
@@ -1793,11 +1941,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1793
1941
|
return false;
|
1794
1942
|
}
|
1795
1943
|
|
1796
|
-
|
1797
|
-
|
1798
|
-
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
1799
|
-
|
1800
|
-
if (ggml_backend_buffer_is_host(model.buffer)) {
|
1944
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
1801
1945
|
// for the CPU and Metal backend, we can read directly into the tensor
|
1802
1946
|
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
1803
1947
|
BYTESWAP_TENSOR(tensor);
|
@@ -1810,7 +1954,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1810
1954
|
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
1811
1955
|
}
|
1812
1956
|
|
1813
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
|
1814
1957
|
total_size += ggml_nbytes(tensor);
|
1815
1958
|
model.n_loaded++;
|
1816
1959
|
}
|
@@ -1825,7 +1968,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1825
1968
|
}
|
1826
1969
|
}
|
1827
1970
|
|
1828
|
-
|
1971
|
+
for (auto & buf : model.buffers) {
|
1972
|
+
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
1973
|
+
}
|
1829
1974
|
|
1830
1975
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
1831
1976
|
|
@@ -3710,15 +3855,24 @@ void whisper_free_state(struct whisper_state * state) {
|
|
3710
3855
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3711
3856
|
aheads_masks_free(state->aheads_masks);
|
3712
3857
|
|
3858
|
+
if (state->vad_context != nullptr) {
|
3859
|
+
whisper_vad_free(state->vad_context);
|
3860
|
+
state->vad_context = nullptr;
|
3861
|
+
}
|
3862
|
+
|
3713
3863
|
delete state;
|
3714
3864
|
}
|
3715
3865
|
}
|
3716
3866
|
|
3717
3867
|
void whisper_free(struct whisper_context * ctx) {
|
3718
3868
|
if (ctx) {
|
3719
|
-
|
3869
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
3870
|
+
ggml_free(context);
|
3871
|
+
}
|
3720
3872
|
|
3721
|
-
|
3873
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
3874
|
+
ggml_backend_buffer_free(buf);
|
3875
|
+
}
|
3722
3876
|
|
3723
3877
|
whisper_free_state(ctx->state);
|
3724
3878
|
|
@@ -4136,11 +4290,11 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
4136
4290
|
|
4137
4291
|
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
4138
4292
|
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
4139
|
-
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
4140
|
-
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
4141
|
-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
4142
|
-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
4143
|
-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
4293
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
4294
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
4295
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
4296
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
4297
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
4144
4298
|
}
|
4145
4299
|
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
4146
4300
|
}
|
@@ -4181,112 +4335,1230 @@ static int whisper_has_openvino(void) {
|
|
4181
4335
|
const char * whisper_print_system_info(void) {
|
4182
4336
|
static std::string s;
|
4183
4337
|
|
4338
|
+
whisper_load_backends();
|
4339
|
+
|
4184
4340
|
s = "";
|
4185
|
-
s += "
|
4186
|
-
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
4187
|
-
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
4188
|
-
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
4189
|
-
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
4190
|
-
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
4191
|
-
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
4192
|
-
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
4193
|
-
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
4194
|
-
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
4195
|
-
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
4196
|
-
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
4341
|
+
s += "WHISPER : ";
|
4197
4342
|
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
4198
4343
|
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
4199
4344
|
|
4345
|
+
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
4346
|
+
auto * reg = ggml_backend_reg_get(i);
|
4347
|
+
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
|
4348
|
+
if (get_features_fn) {
|
4349
|
+
ggml_backend_feature * features = get_features_fn(reg);
|
4350
|
+
s += ggml_backend_reg_name(reg);
|
4351
|
+
s += " : ";
|
4352
|
+
for (; features->name; features++) {
|
4353
|
+
s += features->name;
|
4354
|
+
s += " = ";
|
4355
|
+
s += features->value;
|
4356
|
+
s += " | ";
|
4357
|
+
}
|
4358
|
+
}
|
4359
|
+
}
|
4200
4360
|
return s.c_str();
|
4201
4361
|
}
|
4202
4362
|
|
4203
4363
|
//////////////////////////////////
|
4204
|
-
//
|
4364
|
+
// Voice Activity Detection (VAD)
|
4205
4365
|
//////////////////////////////////
|
4206
4366
|
|
4207
|
-
|
4208
|
-
|
4209
|
-
|
4210
|
-
|
4211
|
-
|
4212
|
-
|
4213
|
-
|
4214
|
-
|
4215
|
-
|
4216
|
-
|
4367
|
+
struct whisper_vad_hparams {
|
4368
|
+
int32_t n_encoder_layers;
|
4369
|
+
int32_t * encoder_in_channels;
|
4370
|
+
int32_t * encoder_out_channels;
|
4371
|
+
int32_t * kernel_sizes;
|
4372
|
+
int32_t lstm_input_size;
|
4373
|
+
int32_t lstm_hidden_size;
|
4374
|
+
int32_t final_conv_in;
|
4375
|
+
int32_t final_conv_out;
|
4376
|
+
};
|
4217
4377
|
|
4218
|
-
|
4219
|
-
|
4220
|
-
|
4221
|
-
|
4222
|
-
// invalid sequence, abort
|
4223
|
-
code_points.push_back(0);
|
4224
|
-
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
4225
|
-
}
|
4226
|
-
value = (value << 6) + (next_byte & 0x3F);
|
4227
|
-
++pos;
|
4228
|
-
--n_remain;
|
4229
|
-
}
|
4378
|
+
struct whisper_vad_model {
|
4379
|
+
std::string type;
|
4380
|
+
std::string version;
|
4381
|
+
whisper_vad_hparams hparams;
|
4230
4382
|
|
4231
|
-
|
4232
|
-
code_points.push_back(value);
|
4233
|
-
}
|
4383
|
+
struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
|
4234
4384
|
|
4235
|
-
//
|
4236
|
-
|
4237
|
-
|
4238
|
-
uint8_t highbits = first_byte >> 4;
|
4239
|
-
n_remain = lookup[highbits] - 1;
|
4385
|
+
// Encoder tensors - 4 convolutional layers
|
4386
|
+
struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
|
4387
|
+
struct ggml_tensor * encoder_0_bias; // [128]
|
4240
4388
|
|
4241
|
-
|
4242
|
-
|
4243
|
-
|
4244
|
-
code_points.push_back(0);
|
4245
|
-
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
|
4246
|
-
}
|
4389
|
+
// Second encoder layer
|
4390
|
+
struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
|
4391
|
+
struct ggml_tensor * encoder_1_bias; // [64]
|
4247
4392
|
|
4248
|
-
|
4249
|
-
|
4250
|
-
|
4251
|
-
while (*pos != 0 && n_remain > 0) {
|
4252
|
-
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
4253
|
-
++pos;
|
4254
|
-
--n_remain;
|
4255
|
-
}
|
4256
|
-
if (n_remain == 0) {
|
4257
|
-
code_points.push_back(value);
|
4258
|
-
}
|
4259
|
-
}
|
4260
|
-
code_points.push_back(0);
|
4393
|
+
// Third encoder layer
|
4394
|
+
struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
|
4395
|
+
struct ggml_tensor * encoder_2_bias; // [64]
|
4261
4396
|
|
4262
|
-
|
4263
|
-
|
4397
|
+
// Fourth encoder layer
|
4398
|
+
struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
|
4399
|
+
struct ggml_tensor * encoder_3_bias; // [128]
|
4264
4400
|
|
4265
|
-
//
|
4266
|
-
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4270
|
-
default: return false;
|
4271
|
-
}
|
4272
|
-
}
|
4401
|
+
// LSTM decoder tensors
|
4402
|
+
struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
|
4403
|
+
struct ggml_tensor * lstm_ih_bias; // [512]
|
4404
|
+
struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
|
4405
|
+
struct ggml_tensor * lstm_hh_bias; // [512]
|
4273
4406
|
|
4274
|
-
//
|
4275
|
-
|
4276
|
-
|
4277
|
-
const whisper_grammar_element * pos,
|
4278
|
-
const uint32_t chr) {
|
4407
|
+
// Final conv layer
|
4408
|
+
struct ggml_tensor * final_conv_weight; // [128]
|
4409
|
+
struct ggml_tensor * final_conv_bias; // [1]
|
4279
4410
|
|
4280
|
-
|
4281
|
-
|
4411
|
+
// ggml contexts
|
4412
|
+
std::vector<ggml_context *> ctxs;
|
4282
4413
|
|
4283
|
-
|
4414
|
+
// buffer for the model tensors
|
4415
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
4284
4416
|
|
4285
|
-
|
4286
|
-
|
4287
|
-
|
4288
|
-
|
4289
|
-
|
4417
|
+
// tensors
|
4418
|
+
int n_loaded;
|
4419
|
+
std::map<std::string, struct ggml_tensor *> tensors;
|
4420
|
+
};
|
4421
|
+
|
4422
|
+
struct whisper_vad_segment {
|
4423
|
+
float start; // Start time in seconds
|
4424
|
+
float end; // End time in seconds
|
4425
|
+
};
|
4426
|
+
|
4427
|
+
struct whisper_vad_segments {
|
4428
|
+
std::vector<whisper_vad_segment> data;
|
4429
|
+
};
|
4430
|
+
|
4431
|
+
struct whisper_vad_context {
|
4432
|
+
int64_t t_vad_us = 0;
|
4433
|
+
|
4434
|
+
int n_window;
|
4435
|
+
int n_context;
|
4436
|
+
int n_threads;
|
4437
|
+
|
4438
|
+
std::vector<ggml_backend_t> backends;
|
4439
|
+
ggml_backend_buffer_t buffer = nullptr;
|
4440
|
+
whisper_context_params params;
|
4441
|
+
std::vector<uint8_t> ctx_buf;
|
4442
|
+
whisper_sched sched;
|
4443
|
+
|
4444
|
+
whisper_vad_model model;
|
4445
|
+
std::string path_model;
|
4446
|
+
struct ggml_tensor * h_state;
|
4447
|
+
struct ggml_tensor * c_state;
|
4448
|
+
std::vector<float> probs;
|
4449
|
+
};
|
4450
|
+
|
4451
|
+
struct whisper_vad_context_params whisper_vad_default_context_params(void) {
|
4452
|
+
whisper_vad_context_params result = {
|
4453
|
+
/*.n_thread = */ 4,
|
4454
|
+
/*.use_gpu = */ false,
|
4455
|
+
/*.gpu_device = */ 0,
|
4456
|
+
};
|
4457
|
+
return result;
|
4458
|
+
}
|
4459
|
+
|
4460
|
+
struct whisper_vad_params whisper_vad_default_params(void) {
|
4461
|
+
whisper_vad_params result = {
|
4462
|
+
/* threshold = */ 0.5f,
|
4463
|
+
/* min_speech_duration_ms = */ 250,
|
4464
|
+
/* min_silence_duration_ms = */ 100,
|
4465
|
+
/* max_speech_duration_s = */ FLT_MAX,
|
4466
|
+
/* speech_pad_ms = */ 30,
|
4467
|
+
/* samples_overlap = */ 0.1,
|
4468
|
+
};
|
4469
|
+
return result;
|
4470
|
+
}
|
4471
|
+
|
4472
|
+
static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
4473
|
+
bool op_supported = true;
|
4474
|
+
|
4475
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
4476
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
4477
|
+
// GPU and default CPU backend support all operators
|
4478
|
+
op_supported = true;
|
4479
|
+
} else {
|
4480
|
+
switch (op) {
|
4481
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
4482
|
+
case GGML_OP_MUL_MAT: {
|
4483
|
+
ggml_init_params params = {
|
4484
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
4485
|
+
/*.mem_buffer =*/ nullptr,
|
4486
|
+
/*.no_alloc =*/ true,
|
4487
|
+
};
|
4488
|
+
|
4489
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
4490
|
+
if (!ctx_ptr) {
|
4491
|
+
throw std::runtime_error("failed to create ggml context");
|
4492
|
+
}
|
4493
|
+
ggml_context * ctx = ctx_ptr.get();
|
4494
|
+
|
4495
|
+
ggml_tensor * op_tensor = nullptr;
|
4496
|
+
|
4497
|
+
int64_t n_ctx = hparams.lstm_hidden_size;
|
4498
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
4499
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
4500
|
+
|
4501
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
4502
|
+
GGML_ASSERT(w->buffer == nullptr);
|
4503
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
4504
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
4505
|
+
ggml_backend_buffer_free(w->buffer);
|
4506
|
+
w->buffer = nullptr;
|
4507
|
+
break;
|
4508
|
+
}
|
4509
|
+
default: {
|
4510
|
+
op_supported = false;
|
4511
|
+
break;
|
4512
|
+
}
|
4513
|
+
};
|
4514
|
+
}
|
4515
|
+
return op_supported;
|
4516
|
+
}
|
4517
|
+
|
4518
|
+
static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
4519
|
+
GGML_ASSERT(!buft_list.empty());
|
4520
|
+
for (const auto & p : buft_list) {
|
4521
|
+
ggml_backend_dev_t dev = p.first;
|
4522
|
+
ggml_backend_buffer_type_t buft = p.second;
|
4523
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
4524
|
+
return buft;
|
4525
|
+
}
|
4526
|
+
}
|
4527
|
+
|
4528
|
+
return nullptr;
|
4529
|
+
}
|
4530
|
+
|
4531
|
+
static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
|
4532
|
+
const whisper_vad_model & model, ggml_tensor * cur) {
|
4533
|
+
// Apply reflective padding to the input tensor
|
4534
|
+
ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
|
4535
|
+
|
4536
|
+
struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
|
4537
|
+
|
4538
|
+
// Calculate cutoff for real/imaginary parts
|
4539
|
+
int cutoff = model.stft_forward_basis->ne[2] / 2;
|
4540
|
+
|
4541
|
+
// Extract real part (first half of the STFT output).
|
4542
|
+
struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
|
4543
|
+
// Extract imaginary part (second half of the STFT output).
|
4544
|
+
struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
|
4545
|
+
|
4546
|
+
// Calculate magnitude: sqrt(real^2 + imag^2)
|
4547
|
+
struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
|
4548
|
+
struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
|
4549
|
+
struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
|
4550
|
+
struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
|
4551
|
+
return magnitude;
|
4552
|
+
}
|
4553
|
+
|
4554
|
+
static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
|
4555
|
+
const whisper_vad_model & model, ggml_tensor * cur) {
|
4556
|
+
// First Conv1D: expands to 128 channels.
|
4557
|
+
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
|
4558
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
|
4559
|
+
cur = ggml_relu(ctx0, cur);
|
4560
|
+
|
4561
|
+
// Second Conv1D: reduces to 64 channels.
|
4562
|
+
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
|
4563
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
|
4564
|
+
cur = ggml_relu(ctx0, cur);
|
4565
|
+
|
4566
|
+
// Third Conv1D: maintains 64 channels
|
4567
|
+
cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
|
4568
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
|
4569
|
+
cur = ggml_relu(ctx0, cur);
|
4570
|
+
|
4571
|
+
// Fourth Conv1D: expands to 128 channels
|
4572
|
+
cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
|
4573
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
|
4574
|
+
cur = ggml_relu(ctx0, cur);
|
4575
|
+
|
4576
|
+
return cur;
|
4577
|
+
}
|
4578
|
+
|
4579
|
+
static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
|
4580
|
+
const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
|
4581
|
+
const whisper_vad_model & model = vctx.model;
|
4582
|
+
const int hdim = model.hparams.lstm_hidden_size;
|
4583
|
+
|
4584
|
+
struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
|
4585
|
+
|
4586
|
+
// Create operations using the input-to-hidden weights.
|
4587
|
+
struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
|
4588
|
+
inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
|
4589
|
+
|
4590
|
+
// Create operations using the hidden-to-hidden weights.
|
4591
|
+
struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
|
4592
|
+
hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
|
4593
|
+
|
4594
|
+
// Create add operation to get preactivations for all gates.
|
4595
|
+
struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
|
4596
|
+
|
4597
|
+
const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
|
4598
|
+
|
4599
|
+
// Create sigmoid for input gate (using the first 128 bytes from the preactivations).
|
4600
|
+
struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
|
4601
|
+
|
4602
|
+
// Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
|
4603
|
+
struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
|
4604
|
+
|
4605
|
+
// Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
|
4606
|
+
struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
|
4607
|
+
|
4608
|
+
// Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
|
4609
|
+
struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
|
4610
|
+
|
4611
|
+
// Update cell state
|
4612
|
+
struct ggml_tensor * c_out = ggml_add(ctx0,
|
4613
|
+
ggml_mul(ctx0, f_t, vctx.c_state),
|
4614
|
+
ggml_mul(ctx0, i_t, g_t));
|
4615
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
|
4616
|
+
|
4617
|
+
// Update hidden state
|
4618
|
+
struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
|
4619
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
|
4620
|
+
|
4621
|
+
return out;
|
4622
|
+
}
|
4623
|
+
|
4624
|
+
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
|
4625
|
+
const auto & model = vctx.model;
|
4626
|
+
|
4627
|
+
struct ggml_init_params params = {
|
4628
|
+
/*.mem_size =*/ vctx.sched.meta.size(),
|
4629
|
+
/*.mem_buffer =*/ vctx.sched.meta.data(),
|
4630
|
+
/*.no_alloc =*/ true,
|
4631
|
+
};
|
4632
|
+
|
4633
|
+
struct ggml_context * ctx0 = ggml_init(params);
|
4634
|
+
|
4635
|
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
4636
|
+
|
4637
|
+
struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
|
4638
|
+
ggml_set_name(frame, "frame");
|
4639
|
+
ggml_set_input(frame);
|
4640
|
+
|
4641
|
+
struct ggml_tensor * cur = nullptr;
|
4642
|
+
{
|
4643
|
+
cur = whisper_vad_build_stft_layer(ctx0, model, frame);
|
4644
|
+
|
4645
|
+
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
|
4646
|
+
|
4647
|
+
// Extract the first element of the first dimension
|
4648
|
+
// (equivalent to pytorch's [:, :, 0])
|
4649
|
+
cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
|
4650
|
+
|
4651
|
+
cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
|
4652
|
+
cur = ggml_relu(ctx0, cur);
|
4653
|
+
cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
|
4654
|
+
cur = ggml_add(ctx0, cur, model.final_conv_bias);
|
4655
|
+
cur = ggml_sigmoid(ctx0, cur);
|
4656
|
+
ggml_set_name(cur, "prob");
|
4657
|
+
ggml_set_output(cur);
|
4658
|
+
}
|
4659
|
+
|
4660
|
+
ggml_build_forward_expand(gf, cur);
|
4661
|
+
|
4662
|
+
ggml_free(ctx0);
|
4663
|
+
|
4664
|
+
return gf;
|
4665
|
+
}
|
4666
|
+
|
4667
|
+
static bool whisper_vad_init_context(whisper_vad_context * vctx) {
|
4668
|
+
|
4669
|
+
auto whisper_context_params = whisper_context_default_params();
|
4670
|
+
// TODO: GPU VAD is forced disabled until the performance is improved
|
4671
|
+
//whisper_context_params.use_gpu = vctx->params.use_gpu;
|
4672
|
+
whisper_context_params.use_gpu = false;
|
4673
|
+
whisper_context_params.gpu_device = vctx->params.gpu_device;
|
4674
|
+
|
4675
|
+
vctx->backends = whisper_backend_init(whisper_context_params);
|
4676
|
+
if (vctx->backends.empty()) {
|
4677
|
+
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
4678
|
+
return false;
|
4679
|
+
}
|
4680
|
+
|
4681
|
+
const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
|
4682
|
+
|
4683
|
+
vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
|
4684
|
+
|
4685
|
+
struct ggml_init_params params = {
|
4686
|
+
/*.mem_size =*/ vctx->ctx_buf.size(),
|
4687
|
+
/*.mem_buffer =*/ vctx->ctx_buf.data(),
|
4688
|
+
/*.no_alloc =*/ true,
|
4689
|
+
};
|
4690
|
+
|
4691
|
+
ggml_context * ctx = ggml_init(params);
|
4692
|
+
if (!ctx) {
|
4693
|
+
WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
|
4694
|
+
return false;
|
4695
|
+
}
|
4696
|
+
|
4697
|
+
// LSTM Hidden state
|
4698
|
+
vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
|
4699
|
+
ggml_set_name(vctx->h_state, "h_state");
|
4700
|
+
|
4701
|
+
// LSTM Cell state
|
4702
|
+
vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
|
4703
|
+
ggml_set_name(vctx->c_state, "c_state");
|
4704
|
+
|
4705
|
+
vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
|
4706
|
+
if (!vctx->buffer) {
|
4707
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
|
4708
|
+
return false;
|
4709
|
+
}
|
4710
|
+
|
4711
|
+
{
|
4712
|
+
bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
|
4713
|
+
[&]() {
|
4714
|
+
return whisper_vad_build_graph(*vctx);
|
4715
|
+
});
|
4716
|
+
|
4717
|
+
if (!ok) {
|
4718
|
+
WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
|
4719
|
+
return false;
|
4720
|
+
}
|
4721
|
+
|
4722
|
+
WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
|
4723
|
+
}
|
4724
|
+
|
4725
|
+
return true;
|
4726
|
+
}
|
4727
|
+
|
4728
|
+
struct whisper_vad_context * whisper_vad_init_from_file_with_params(
|
4729
|
+
const char * path_model,
|
4730
|
+
struct whisper_vad_context_params params) {
|
4731
|
+
WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
|
4732
|
+
#ifdef _MSC_VER
|
4733
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
4734
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
4735
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
4736
|
+
#else
|
4737
|
+
auto fin = std::ifstream(path_model, std::ios::binary);
|
4738
|
+
#endif
|
4739
|
+
if (!fin) {
|
4740
|
+
WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
|
4741
|
+
return nullptr;
|
4742
|
+
}
|
4743
|
+
|
4744
|
+
whisper_model_loader loader = {};
|
4745
|
+
loader.context = &fin;
|
4746
|
+
|
4747
|
+
loader.read = [](void * ctx, void * output, size_t read_size) {
|
4748
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4749
|
+
fin->read((char *)output, read_size);
|
4750
|
+
return read_size;
|
4751
|
+
};
|
4752
|
+
|
4753
|
+
loader.eof = [](void * ctx) {
|
4754
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4755
|
+
return fin->eof();
|
4756
|
+
};
|
4757
|
+
|
4758
|
+
loader.close = [](void * ctx) {
|
4759
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4760
|
+
fin->close();
|
4761
|
+
};
|
4762
|
+
|
4763
|
+
auto ctx = whisper_vad_init_with_params(&loader, params);
|
4764
|
+
if (!ctx) {
|
4765
|
+
whisper_vad_free(ctx);
|
4766
|
+
return nullptr;
|
4767
|
+
}
|
4768
|
+
ctx->path_model = path_model;
|
4769
|
+
return ctx;
|
4770
|
+
}
|
4771
|
+
|
4772
|
+
struct whisper_vad_context * whisper_vad_init_with_params(
|
4773
|
+
struct whisper_model_loader * loader,
|
4774
|
+
struct whisper_vad_context_params params) {
|
4775
|
+
// Read the VAD model
|
4776
|
+
{
|
4777
|
+
uint32_t magic;
|
4778
|
+
read_safe(loader, magic);
|
4779
|
+
if (magic != GGML_FILE_MAGIC) {
|
4780
|
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
4781
|
+
return nullptr;
|
4782
|
+
}
|
4783
|
+
}
|
4784
|
+
|
4785
|
+
whisper_vad_context * vctx = new whisper_vad_context;
|
4786
|
+
vctx->n_threads = params.n_threads;
|
4787
|
+
vctx->params.use_gpu = params.use_gpu;
|
4788
|
+
vctx->params.gpu_device = params.gpu_device;
|
4789
|
+
|
4790
|
+
auto & model = vctx->model;
|
4791
|
+
auto & hparams = model.hparams;
|
4792
|
+
|
4793
|
+
// load model context params.
|
4794
|
+
{
|
4795
|
+
int32_t str_len;
|
4796
|
+
read_safe(loader, str_len);
|
4797
|
+
std::vector<char> buffer(str_len + 1, 0);
|
4798
|
+
loader->read(loader->context, buffer.data(), str_len);
|
4799
|
+
std::string model_type(buffer.data(), str_len);
|
4800
|
+
model.type = model_type;
|
4801
|
+
WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
|
4802
|
+
|
4803
|
+
int32_t major, minor, patch;
|
4804
|
+
read_safe(loader, major);
|
4805
|
+
read_safe(loader, minor);
|
4806
|
+
read_safe(loader, patch);
|
4807
|
+
std::string version_str = std::to_string(major) + "." +
|
4808
|
+
std::to_string(minor) + "." +
|
4809
|
+
std::to_string(patch);
|
4810
|
+
model.version = version_str;
|
4811
|
+
WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
|
4812
|
+
|
4813
|
+
read_safe(loader, vctx->n_window);
|
4814
|
+
read_safe(loader, vctx->n_context);
|
4815
|
+
}
|
4816
|
+
|
4817
|
+
// load model hyper params (hparams).
|
4818
|
+
{
|
4819
|
+
read_safe(loader, hparams.n_encoder_layers);
|
4820
|
+
|
4821
|
+
hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
|
4822
|
+
hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
|
4823
|
+
hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
|
4824
|
+
|
4825
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4826
|
+
read_safe(loader, hparams.encoder_in_channels[i]);
|
4827
|
+
read_safe(loader, hparams.encoder_out_channels[i]);
|
4828
|
+
read_safe(loader, hparams.kernel_sizes[i]);
|
4829
|
+
}
|
4830
|
+
|
4831
|
+
read_safe(loader, hparams.lstm_input_size);
|
4832
|
+
read_safe(loader, hparams.lstm_hidden_size);
|
4833
|
+
read_safe(loader, hparams.final_conv_in);
|
4834
|
+
read_safe(loader, hparams.final_conv_out);
|
4835
|
+
|
4836
|
+
WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
|
4837
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4838
|
+
WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
|
4839
|
+
}
|
4840
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4841
|
+
WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
|
4842
|
+
}
|
4843
|
+
WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
|
4844
|
+
WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
|
4845
|
+
WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
|
4846
|
+
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
|
4847
|
+
}
|
4848
|
+
|
4849
|
+
// 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
|
4850
|
+
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
|
4851
|
+
|
4852
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
4853
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
4854
|
+
auto it = ctx_map.find(buft);
|
4855
|
+
if (it == ctx_map.end()) {
|
4856
|
+
ggml_init_params params = {
|
4857
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
4858
|
+
/*.mem_buffer =*/ nullptr,
|
4859
|
+
/*.no_alloc =*/ true,
|
4860
|
+
};
|
4861
|
+
|
4862
|
+
ggml_context * ctx = ggml_init(params);
|
4863
|
+
if (!ctx) {
|
4864
|
+
throw std::runtime_error("failed to create ggml context");
|
4865
|
+
}
|
4866
|
+
|
4867
|
+
ctx_map[buft] = ctx;
|
4868
|
+
model.ctxs.emplace_back(ctx);
|
4869
|
+
|
4870
|
+
return ctx;
|
4871
|
+
}
|
4872
|
+
|
4873
|
+
return it->second;
|
4874
|
+
};
|
4875
|
+
|
4876
|
+
whisper_context_params wparams = whisper_context_default_params();
|
4877
|
+
wparams.use_gpu = params.use_gpu;
|
4878
|
+
wparams.gpu_device = params.gpu_device;
|
4879
|
+
buft_list_t buft_list = make_buft_list(wparams);
|
4880
|
+
|
4881
|
+
auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
|
4882
|
+
ggml_op op = VAD_TENSOR_OPS.at(type);
|
4883
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
4884
|
+
if (!buft) {
|
4885
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
|
4886
|
+
}
|
4887
|
+
ggml_context * ctx = get_ctx(buft);
|
4888
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
4889
|
+
model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
|
4890
|
+
|
4891
|
+
return tensor;
|
4892
|
+
};
|
4893
|
+
|
4894
|
+
// create tensors
|
4895
|
+
{
|
4896
|
+
ggml_init_params params = {
|
4897
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
4898
|
+
/*.mem_buffer =*/ nullptr,
|
4899
|
+
/*.no_alloc =*/ true,
|
4900
|
+
};
|
4901
|
+
|
4902
|
+
ggml_context * ctx = ggml_init(params);
|
4903
|
+
const auto & hparams = model.hparams;
|
4904
|
+
|
4905
|
+
// SFTF precomputed basis matrix
|
4906
|
+
model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
|
4907
|
+
ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
|
4908
|
+
|
4909
|
+
model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
|
4910
|
+
ggml_new_tensor_3d(
|
4911
|
+
ctx,
|
4912
|
+
GGML_TYPE_F16,
|
4913
|
+
hparams.kernel_sizes[0],
|
4914
|
+
hparams.encoder_in_channels[0],
|
4915
|
+
hparams.encoder_out_channels[0]
|
4916
|
+
));
|
4917
|
+
model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
|
4918
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
|
4919
|
+
|
4920
|
+
model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
|
4921
|
+
ggml_new_tensor_3d(
|
4922
|
+
ctx,
|
4923
|
+
GGML_TYPE_F16,
|
4924
|
+
hparams.kernel_sizes[1],
|
4925
|
+
hparams.encoder_in_channels[1],
|
4926
|
+
hparams.encoder_out_channels[1]
|
4927
|
+
));
|
4928
|
+
model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
|
4929
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
|
4930
|
+
|
4931
|
+
model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
|
4932
|
+
ggml_new_tensor_3d(
|
4933
|
+
ctx,
|
4934
|
+
GGML_TYPE_F16,
|
4935
|
+
hparams.kernel_sizes[2],
|
4936
|
+
hparams.encoder_in_channels[2],
|
4937
|
+
hparams.encoder_out_channels[2]
|
4938
|
+
));
|
4939
|
+
model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
|
4940
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
|
4941
|
+
|
4942
|
+
model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
|
4943
|
+
ggml_new_tensor_3d(
|
4944
|
+
ctx,
|
4945
|
+
GGML_TYPE_F16,
|
4946
|
+
hparams.kernel_sizes[3],
|
4947
|
+
hparams.encoder_in_channels[3],
|
4948
|
+
hparams.encoder_out_channels[3]
|
4949
|
+
));
|
4950
|
+
model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
|
4951
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
|
4952
|
+
|
4953
|
+
// Hidden State dimension (input gate, forget gate, cell gate, output gate)
|
4954
|
+
const int hstate_dim = hparams.lstm_hidden_size * 4;
|
4955
|
+
|
4956
|
+
// LSTM weights - input to hidden
|
4957
|
+
model.lstm_ih_weight = create_tensor(
|
4958
|
+
VAD_TENSOR_LSTM_WEIGHT_IH,
|
4959
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
4960
|
+
);
|
4961
|
+
model.lstm_ih_bias = create_tensor(
|
4962
|
+
VAD_TENSOR_LSTM_BIAS_IH,
|
4963
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
|
4964
|
+
);
|
4965
|
+
|
4966
|
+
// LSTM weights - hidden to hidden
|
4967
|
+
model.lstm_hh_weight = create_tensor(
|
4968
|
+
VAD_TENSOR_LSTM_WEIGHT_HH,
|
4969
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
4970
|
+
);
|
4971
|
+
model.lstm_hh_bias = create_tensor(
|
4972
|
+
VAD_TENSOR_LSTM_BIAS_HH,
|
4973
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
|
4974
|
+
);
|
4975
|
+
|
4976
|
+
// Final conv layer weight
|
4977
|
+
model.final_conv_weight = create_tensor(
|
4978
|
+
VAD_TENSOR_FINAL_CONV_WEIGHT,
|
4979
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
|
4980
|
+
);
|
4981
|
+
model.final_conv_bias = create_tensor(
|
4982
|
+
VAD_TENSOR_FINAL_CONV_BIAS,
|
4983
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
|
4984
|
+
);
|
4985
|
+
|
4986
|
+
ggml_free(ctx);
|
4987
|
+
}
|
4988
|
+
|
4989
|
+
// allocate tensors in the backend buffers
|
4990
|
+
for (auto & p : ctx_map) {
|
4991
|
+
ggml_backend_buffer_type_t buft = p.first;
|
4992
|
+
ggml_context * ctx = p.second;
|
4993
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
4994
|
+
if (buf) {
|
4995
|
+
model.buffers.emplace_back(buf);
|
4996
|
+
|
4997
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
4998
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
4999
|
+
}
|
5000
|
+
}
|
5001
|
+
|
5002
|
+
// load weights
|
5003
|
+
{
|
5004
|
+
size_t total_size = 0;
|
5005
|
+
model.n_loaded = 0;
|
5006
|
+
std::vector<char> read_buf;
|
5007
|
+
|
5008
|
+
while (true) {
|
5009
|
+
int32_t n_dims;
|
5010
|
+
int32_t length;
|
5011
|
+
int32_t ttype;
|
5012
|
+
|
5013
|
+
read_safe(loader, n_dims);
|
5014
|
+
read_safe(loader, length);
|
5015
|
+
read_safe(loader, ttype);
|
5016
|
+
|
5017
|
+
if (loader->eof(loader->context)) {
|
5018
|
+
break;
|
5019
|
+
}
|
5020
|
+
|
5021
|
+
int32_t nelements = 1;
|
5022
|
+
int32_t ne[4] = { 1, 1, 1, 1 };
|
5023
|
+
for (int i = 0; i < n_dims; ++i) {
|
5024
|
+
read_safe(loader, ne[i]);
|
5025
|
+
nelements *= ne[i];
|
5026
|
+
}
|
5027
|
+
|
5028
|
+
std::string name;
|
5029
|
+
std::vector<char> tmp(length);
|
5030
|
+
loader->read(loader->context, &tmp[0], tmp.size());
|
5031
|
+
name.assign(&tmp[0], tmp.size());
|
5032
|
+
|
5033
|
+
if (model.tensors.find(name) == model.tensors.end()) {
|
5034
|
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
5035
|
+
return nullptr;
|
5036
|
+
}
|
5037
|
+
|
5038
|
+
auto tensor = model.tensors[name.data()];
|
5039
|
+
|
5040
|
+
if (ggml_nelements(tensor) != nelements) {
|
5041
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
5042
|
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
5043
|
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
5044
|
+
return nullptr;
|
5045
|
+
}
|
5046
|
+
|
5047
|
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
5048
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
5049
|
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
5050
|
+
return nullptr;
|
5051
|
+
}
|
5052
|
+
|
5053
|
+
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
5054
|
+
|
5055
|
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
5056
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
5057
|
+
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
5058
|
+
return nullptr;
|
5059
|
+
}
|
5060
|
+
|
5061
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
5062
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
5063
|
+
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
5064
|
+
BYTESWAP_TENSOR(tensor);
|
5065
|
+
} else {
|
5066
|
+
// read into a temporary buffer first, then copy to device memory
|
5067
|
+
read_buf.resize(ggml_nbytes(tensor));
|
5068
|
+
|
5069
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
5070
|
+
|
5071
|
+
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
5072
|
+
}
|
5073
|
+
|
5074
|
+
total_size += ggml_nbytes(tensor);
|
5075
|
+
model.n_loaded++;
|
5076
|
+
}
|
5077
|
+
|
5078
|
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
5079
|
+
|
5080
|
+
if (model.n_loaded == 0) {
|
5081
|
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
5082
|
+
} else if (model.n_loaded != (int) model.tensors.size()) {
|
5083
|
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
5084
|
+
return nullptr;
|
5085
|
+
}
|
5086
|
+
|
5087
|
+
}
|
5088
|
+
|
5089
|
+
if (!whisper_vad_init_context(vctx)) {
|
5090
|
+
whisper_vad_free(vctx);
|
5091
|
+
return nullptr;
|
5092
|
+
}
|
5093
|
+
|
5094
|
+
return vctx;
|
5095
|
+
}
|
5096
|
+
|
5097
|
+
bool whisper_vad_detect_speech(
|
5098
|
+
struct whisper_vad_context * vctx,
|
5099
|
+
const float * samples,
|
5100
|
+
int n_samples) {
|
5101
|
+
int n_chunks = n_samples / vctx->n_window;
|
5102
|
+
if (n_samples % vctx->n_window != 0) {
|
5103
|
+
n_chunks += 1; // Add one more chunk for remaining samples.
|
5104
|
+
}
|
5105
|
+
|
5106
|
+
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
|
5107
|
+
WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
|
5108
|
+
|
5109
|
+
// Reset LSTM hidden/cell states
|
5110
|
+
ggml_backend_buffer_clear(vctx->buffer, 0);
|
5111
|
+
|
5112
|
+
vctx->probs.resize(n_chunks);
|
5113
|
+
WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
|
5114
|
+
|
5115
|
+
std::vector<float> window(vctx->n_window, 0.0f);
|
5116
|
+
|
5117
|
+
auto & sched = vctx->sched.sched;
|
5118
|
+
|
5119
|
+
ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
|
5120
|
+
|
5121
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
5122
|
+
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
5123
|
+
return false;
|
5124
|
+
}
|
5125
|
+
|
5126
|
+
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
|
5127
|
+
struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
|
5128
|
+
|
5129
|
+
// we are going to reuse the graph multiple times for each chunk
|
5130
|
+
const int64_t t_start_vad_us = ggml_time_us();
|
5131
|
+
|
5132
|
+
for (int i = 0; i < n_chunks; i++) {
|
5133
|
+
const int idx_start = i * vctx->n_window;
|
5134
|
+
const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
|
5135
|
+
|
5136
|
+
const int chunk_len = idx_end - idx_start;
|
5137
|
+
|
5138
|
+
if (chunk_len < vctx->n_window) {
|
5139
|
+
WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
|
5140
|
+
std::vector<float> partial_chunk(vctx->n_window, 0.0f);
|
5141
|
+
std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
|
5142
|
+
|
5143
|
+
// Copy the zero-padded chunk to the window.
|
5144
|
+
const int samples_to_copy_max = vctx->n_window;
|
5145
|
+
const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
|
5146
|
+
std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
|
5147
|
+
if (samples_to_copy_cur < samples_to_copy_max) {
|
5148
|
+
std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
|
5149
|
+
}
|
5150
|
+
} else {
|
5151
|
+
// Copy current frame samples to the window.
|
5152
|
+
const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
|
5153
|
+
std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
|
5154
|
+
}
|
5155
|
+
|
5156
|
+
// Set the frame tensor data with the samples.
|
5157
|
+
ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
|
5158
|
+
|
5159
|
+
// do not reset the scheduler - we will reuse the graph in the next chunk
|
5160
|
+
if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
|
5161
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
|
5162
|
+
break;
|
5163
|
+
}
|
5164
|
+
|
5165
|
+
// Get the probability for this chunk.
|
5166
|
+
ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
|
5167
|
+
|
5168
|
+
//WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
|
5169
|
+
}
|
5170
|
+
|
5171
|
+
vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
|
5172
|
+
WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
|
5173
|
+
|
5174
|
+
ggml_backend_sched_reset(sched);
|
5175
|
+
|
5176
|
+
return true;
|
5177
|
+
}
|
5178
|
+
|
5179
|
+
int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
|
5180
|
+
return segments->data.size();
|
5181
|
+
}
|
5182
|
+
|
5183
|
+
float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
|
5184
|
+
return segments->data[i_segment].start;
|
5185
|
+
}
|
5186
|
+
|
5187
|
+
float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
|
5188
|
+
return segments->data[i_segment].end;
|
5189
|
+
}
|
5190
|
+
|
5191
|
+
int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
|
5192
|
+
return vctx->probs.size();
|
5193
|
+
}
|
5194
|
+
|
5195
|
+
float * whisper_vad_probs(struct whisper_vad_context * vctx) {
|
5196
|
+
return vctx->probs.data();
|
5197
|
+
}
|
5198
|
+
|
5199
|
+
struct whisper_vad_segments * whisper_vad_segments_from_probs(
|
5200
|
+
struct whisper_vad_context * vctx,
|
5201
|
+
whisper_vad_params params) {
|
5202
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
|
5203
|
+
|
5204
|
+
int n_probs = whisper_vad_n_probs(vctx);
|
5205
|
+
float * probs = whisper_vad_probs(vctx);
|
5206
|
+
float threshold = params.threshold;
|
5207
|
+
int min_speech_duration_ms = params.min_speech_duration_ms;
|
5208
|
+
int min_silence_duration_ms = params.min_silence_duration_ms;
|
5209
|
+
float max_speech_duration_s = params.max_speech_duration_s;
|
5210
|
+
int speech_pad_ms = params.speech_pad_ms;
|
5211
|
+
int n_window = vctx->n_window;
|
5212
|
+
int sample_rate = WHISPER_SAMPLE_RATE;
|
5213
|
+
int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
5214
|
+
int audio_length_samples = n_probs * n_window;
|
5215
|
+
|
5216
|
+
// Min number of samples to be considered valid speech.
|
5217
|
+
int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
5218
|
+
int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
5219
|
+
|
5220
|
+
// Max number of samples that a speech segment can contain before it is
|
5221
|
+
// split into multiple segments.
|
5222
|
+
int max_speech_samples;
|
5223
|
+
if (max_speech_duration_s > 100000.0f) {
|
5224
|
+
max_speech_samples = INT_MAX / 2;
|
5225
|
+
} else {
|
5226
|
+
int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
|
5227
|
+
max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
|
5228
|
+
if (max_speech_samples < 0) {
|
5229
|
+
max_speech_samples = INT_MAX / 2;
|
5230
|
+
}
|
5231
|
+
}
|
5232
|
+
// Detect silence period that exceeds this value, then that location (sample)
|
5233
|
+
// is marked as a potential place where the segment could be split if
|
5234
|
+
// max_speech_samples is reached. The value 98 was taken from the original
|
5235
|
+
// silaro-vad python implementation:
|
5236
|
+
//https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
|
5237
|
+
int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
|
5238
|
+
|
5239
|
+
// Calculate lower threshold for detecting end of speech segments.
|
5240
|
+
float neg_threshold = threshold - 0.15f;
|
5241
|
+
if (neg_threshold < 0.01f) {
|
5242
|
+
neg_threshold = 0.01f;
|
5243
|
+
}
|
5244
|
+
|
5245
|
+
struct speech_segment_t {
|
5246
|
+
int start;
|
5247
|
+
int end;
|
5248
|
+
};
|
5249
|
+
|
5250
|
+
std::vector<speech_segment_t> speeches;
|
5251
|
+
speeches.reserve(256);
|
5252
|
+
|
5253
|
+
bool is_speech_segment = false;
|
5254
|
+
int temp_end = 0;
|
5255
|
+
int prev_end = 0;
|
5256
|
+
int next_start = 0;
|
5257
|
+
int curr_speech_start = 0;
|
5258
|
+
bool has_curr_speech = false;
|
5259
|
+
|
5260
|
+
for (int i = 0; i < n_probs; i++) {
|
5261
|
+
float curr_prob = probs[i];
|
5262
|
+
int curr_sample = n_window * i;
|
5263
|
+
|
5264
|
+
// Reset temp_end when we get back to speech
|
5265
|
+
if ((curr_prob >= threshold) && temp_end) {
|
5266
|
+
temp_end = 0;
|
5267
|
+
if (next_start < prev_end) {
|
5268
|
+
next_start = curr_sample;
|
5269
|
+
}
|
5270
|
+
}
|
5271
|
+
|
5272
|
+
// Start a new speech segment when probability exceeds threshold and not already in speech
|
5273
|
+
if ((curr_prob >= threshold) && !is_speech_segment) {
|
5274
|
+
is_speech_segment = true;
|
5275
|
+
curr_speech_start = curr_sample;
|
5276
|
+
has_curr_speech = true;
|
5277
|
+
continue;
|
5278
|
+
}
|
5279
|
+
|
5280
|
+
// Handle maximum speech duration
|
5281
|
+
if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
|
5282
|
+
if (prev_end) {
|
5283
|
+
speeches.push_back({ curr_speech_start, prev_end });
|
5284
|
+
has_curr_speech = true;
|
5285
|
+
|
5286
|
+
if (next_start < prev_end) { // Previously reached silence and is still not speech
|
5287
|
+
is_speech_segment = false;
|
5288
|
+
has_curr_speech = false;
|
5289
|
+
} else {
|
5290
|
+
curr_speech_start = next_start;
|
5291
|
+
}
|
5292
|
+
prev_end = next_start = temp_end = 0;
|
5293
|
+
} else {
|
5294
|
+
speeches.push_back({ curr_speech_start, curr_sample });
|
5295
|
+
|
5296
|
+
prev_end = next_start = temp_end = 0;
|
5297
|
+
is_speech_segment = false;
|
5298
|
+
has_curr_speech = false;
|
5299
|
+
continue;
|
5300
|
+
}
|
5301
|
+
}
|
5302
|
+
|
5303
|
+
// Handle silence after speech
|
5304
|
+
if ((curr_prob < neg_threshold) && is_speech_segment) {
|
5305
|
+
if (!temp_end) {
|
5306
|
+
temp_end = curr_sample;
|
5307
|
+
}
|
5308
|
+
|
5309
|
+
// Track potential segment ends for max_speech handling
|
5310
|
+
if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
|
5311
|
+
prev_end = temp_end;
|
5312
|
+
}
|
5313
|
+
|
5314
|
+
// Check if silence is long enough to end the segment
|
5315
|
+
if ((curr_sample - temp_end) < min_silence_samples) {
|
5316
|
+
continue;
|
5317
|
+
} else {
|
5318
|
+
// End the segment if it's long enough
|
5319
|
+
if ((temp_end - curr_speech_start) > min_speech_samples) {
|
5320
|
+
speeches.push_back({ curr_speech_start, temp_end });
|
5321
|
+
}
|
5322
|
+
|
5323
|
+
prev_end = next_start = temp_end = 0;
|
5324
|
+
is_speech_segment = false;
|
5325
|
+
has_curr_speech = false;
|
5326
|
+
continue;
|
5327
|
+
}
|
5328
|
+
}
|
5329
|
+
}
|
5330
|
+
|
5331
|
+
// Handle the case if we're still in a speech segment at the end
|
5332
|
+
if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
|
5333
|
+
speeches.push_back({ curr_speech_start, audio_length_samples });
|
5334
|
+
}
|
5335
|
+
|
5336
|
+
// Merge adjacent segments with small gaps in between (post-processing)
|
5337
|
+
if (speeches.size() > 1) {
|
5338
|
+
int merged_count = 0;
|
5339
|
+
for (int i = 0; i < (int) speeches.size() - 1; i++) {
|
5340
|
+
// Define maximum gap allowed for merging (e.g., 200ms converted to samples)
|
5341
|
+
int max_merge_gap_samples = sample_rate * 200 / 1000;
|
5342
|
+
|
5343
|
+
// If the gap between this segment and the next is small enough
|
5344
|
+
if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
|
5345
|
+
// Merge by extending current segment to the end of next segment
|
5346
|
+
speeches[i].end = speeches[i+1].end;
|
5347
|
+
speeches.erase(speeches.begin() + i + 1);
|
5348
|
+
|
5349
|
+
i--;
|
5350
|
+
merged_count++;
|
5351
|
+
}
|
5352
|
+
}
|
5353
|
+
WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
|
5354
|
+
__func__, merged_count, (int) speeches.size());
|
5355
|
+
}
|
5356
|
+
|
5357
|
+
// Double-check for minimum speech duration
|
5358
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
5359
|
+
if (speeches[i].end - speeches[i].start < min_speech_samples) {
|
5360
|
+
WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
|
5361
|
+
__func__, i, speeches[i].end - speeches[i].start);
|
5362
|
+
|
5363
|
+
speeches.erase(speeches.begin() + i);
|
5364
|
+
i--;
|
5365
|
+
}
|
5366
|
+
}
|
5367
|
+
|
5368
|
+
WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
|
5369
|
+
|
5370
|
+
// Allocate final segments
|
5371
|
+
std::vector<whisper_vad_segment> segments;
|
5372
|
+
if (speeches.size() > 0) {
|
5373
|
+
try {
|
5374
|
+
segments.resize(speeches.size());
|
5375
|
+
} catch (const std::bad_alloc &) {
|
5376
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
|
5377
|
+
return nullptr;
|
5378
|
+
}
|
5379
|
+
}
|
5380
|
+
|
5381
|
+
// Apply padding to segments and copy to final segments
|
5382
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
5383
|
+
// Apply padding to the start of the first segment
|
5384
|
+
if (i == 0) {
|
5385
|
+
speeches[i].start =
|
5386
|
+
(speeches[i].start > speech_pad_samples) ?
|
5387
|
+
(speeches[i].start - speech_pad_samples) : 0;
|
5388
|
+
}
|
5389
|
+
|
5390
|
+
// Handle spacing between segments
|
5391
|
+
if (i < (int) speeches.size() - 1) {
|
5392
|
+
int silence_duration = speeches[i+1].start - speeches[i].end;
|
5393
|
+
|
5394
|
+
if (silence_duration < 2 * speech_pad_samples) {
|
5395
|
+
// If segments are close, split the difference
|
5396
|
+
speeches[i].end += silence_duration / 2;
|
5397
|
+
speeches[i+1].start =
|
5398
|
+
(speeches[i+1].start > silence_duration / 2) ?
|
5399
|
+
(speeches[i+1].start - silence_duration / 2) : 0;
|
5400
|
+
} else {
|
5401
|
+
// Otherwise, apply full padding to both
|
5402
|
+
speeches[i].end =
|
5403
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
5404
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
5405
|
+
speeches[i+1].start =
|
5406
|
+
(speeches[i+1].start > speech_pad_samples) ?
|
5407
|
+
(speeches[i+1].start - speech_pad_samples) : 0;
|
5408
|
+
}
|
5409
|
+
} else {
|
5410
|
+
// Apply padding to the end of the last segment
|
5411
|
+
speeches[i].end =
|
5412
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
5413
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
5414
|
+
}
|
5415
|
+
|
5416
|
+
// Convert from samples to seconds and copy to final segments
|
5417
|
+
segments[i].start = (float)speeches[i].start / sample_rate;
|
5418
|
+
segments[i].end = (float)speeches[i].end / sample_rate;
|
5419
|
+
|
5420
|
+
WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
|
5421
|
+
__func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
|
5422
|
+
}
|
5423
|
+
|
5424
|
+
whisper_vad_segments * vad_segments = new whisper_vad_segments;
|
5425
|
+
if (vad_segments == NULL) {
|
5426
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
|
5427
|
+
return nullptr;
|
5428
|
+
}
|
5429
|
+
|
5430
|
+
vad_segments->data = std::move(segments);
|
5431
|
+
|
5432
|
+
return vad_segments;
|
5433
|
+
}
|
5434
|
+
|
5435
|
+
struct whisper_vad_segments * whisper_vad_segments_from_samples(
|
5436
|
+
whisper_vad_context * vctx,
|
5437
|
+
whisper_vad_params params,
|
5438
|
+
const float * samples,
|
5439
|
+
int n_samples) {
|
5440
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
|
5441
|
+
if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
|
5442
|
+
WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
|
5443
|
+
return nullptr;
|
5444
|
+
}
|
5445
|
+
return whisper_vad_segments_from_probs(vctx, params);
|
5446
|
+
}
|
5447
|
+
|
5448
|
+
void whisper_vad_free(whisper_vad_context * ctx) {
|
5449
|
+
if (ctx) {
|
5450
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
5451
|
+
ggml_free(context);
|
5452
|
+
}
|
5453
|
+
|
5454
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
5455
|
+
ggml_backend_buffer_free(buf);
|
5456
|
+
}
|
5457
|
+
|
5458
|
+
ggml_backend_sched_free(ctx->sched.sched);
|
5459
|
+
|
5460
|
+
for (auto & backend : ctx->backends) {
|
5461
|
+
ggml_backend_free(backend);
|
5462
|
+
}
|
5463
|
+
|
5464
|
+
|
5465
|
+
delete ctx;
|
5466
|
+
}
|
5467
|
+
}
|
5468
|
+
|
5469
|
+
void whisper_vad_free_segments(whisper_vad_segments * segments) {
|
5470
|
+
if (segments) {
|
5471
|
+
delete segments;
|
5472
|
+
}
|
5473
|
+
}
|
5474
|
+
|
5475
|
+
//////////////////////////////////
|
5476
|
+
// Grammar - ported from llama.cpp
|
5477
|
+
//////////////////////////////////
|
5478
|
+
|
5479
|
+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
5480
|
+
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
5481
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
5482
|
+
const char * src,
|
5483
|
+
whisper_partial_utf8 partial_start) {
|
5484
|
+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
5485
|
+
const char * pos = src;
|
5486
|
+
std::vector<uint32_t> code_points;
|
5487
|
+
uint32_t value = partial_start.value;
|
5488
|
+
int n_remain = partial_start.n_remain;
|
5489
|
+
|
5490
|
+
// continue previous decode, if applicable
|
5491
|
+
while (*pos != 0 && n_remain > 0) {
|
5492
|
+
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
5493
|
+
if ((next_byte >> 6) != 2) {
|
5494
|
+
// invalid sequence, abort
|
5495
|
+
code_points.push_back(0);
|
5496
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
|
5497
|
+
}
|
5498
|
+
value = (value << 6) + (next_byte & 0x3F);
|
5499
|
+
++pos;
|
5500
|
+
--n_remain;
|
5501
|
+
}
|
5502
|
+
|
5503
|
+
if (partial_start.n_remain > 0 && n_remain == 0) {
|
5504
|
+
code_points.push_back(value);
|
5505
|
+
}
|
5506
|
+
|
5507
|
+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
5508
|
+
while (*pos != 0) {
|
5509
|
+
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
5510
|
+
uint8_t highbits = first_byte >> 4;
|
5511
|
+
n_remain = lookup[highbits] - 1;
|
5512
|
+
|
5513
|
+
if (n_remain < 0) {
|
5514
|
+
// invalid sequence, abort
|
5515
|
+
code_points.clear();
|
5516
|
+
code_points.push_back(0);
|
5517
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
|
5518
|
+
}
|
5519
|
+
|
5520
|
+
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
5521
|
+
value = first_byte & mask;
|
5522
|
+
++pos;
|
5523
|
+
while (*pos != 0 && n_remain > 0) {
|
5524
|
+
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
5525
|
+
++pos;
|
5526
|
+
--n_remain;
|
5527
|
+
}
|
5528
|
+
if (n_remain == 0) {
|
5529
|
+
code_points.push_back(value);
|
5530
|
+
}
|
5531
|
+
}
|
5532
|
+
code_points.push_back(0);
|
5533
|
+
|
5534
|
+
return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
|
5535
|
+
}
|
5536
|
+
|
5537
|
+
// returns true iff pos points to the end of one of the definitions of a rule
|
5538
|
+
static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
|
5539
|
+
switch (pos->type) {
|
5540
|
+
case WHISPER_GRETYPE_END: return true; // NOLINT
|
5541
|
+
case WHISPER_GRETYPE_ALT: return true; // NOLINT
|
5542
|
+
default: return false;
|
5543
|
+
}
|
5544
|
+
}
|
5545
|
+
|
5546
|
+
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
5547
|
+
// asserts that pos is pointing to a char range element
|
5548
|
+
static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
|
5549
|
+
const whisper_grammar_element * pos,
|
5550
|
+
const uint32_t chr) {
|
5551
|
+
|
5552
|
+
bool found = false;
|
5553
|
+
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
|
5554
|
+
|
5555
|
+
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
|
5556
|
+
|
5557
|
+
do {
|
5558
|
+
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
|
5559
|
+
// inclusive range, e.g. [a-z]
|
5560
|
+
found = found || (pos->value <= chr && chr <= pos[1].value);
|
5561
|
+
pos += 2;
|
4290
5562
|
} else {
|
4291
5563
|
// exact char match, e.g. [a] or "a"
|
4292
5564
|
found = found || pos->value == chr;
|
@@ -4355,7 +5627,7 @@ static void whisper_grammar_advance_stack(
|
|
4355
5627
|
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
|
4356
5628
|
|
4357
5629
|
if (stack.empty()) {
|
4358
|
-
new_stacks.
|
5630
|
+
new_stacks.emplace_back();
|
4359
5631
|
return;
|
4360
5632
|
}
|
4361
5633
|
|
@@ -4676,7 +5948,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4676
5948
|
/*.detect_language =*/ false,
|
4677
5949
|
|
4678
5950
|
/*.suppress_blank =*/ true,
|
4679
|
-
/*.
|
5951
|
+
/*.suppress_nst =*/ false,
|
4680
5952
|
|
4681
5953
|
/*.temperature =*/ 0.0f,
|
4682
5954
|
/*.max_initial_ts =*/ 1.0f,
|
@@ -4716,6 +5988,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4716
5988
|
/*.n_grammar_rules =*/ 0,
|
4717
5989
|
/*.i_start_rule =*/ 0,
|
4718
5990
|
/*.grammar_penalty =*/ 100.0f,
|
5991
|
+
|
5992
|
+
/*.vad =*/ false,
|
5993
|
+
/*.vad_model_path =*/ nullptr,
|
5994
|
+
|
5995
|
+
/* vad_params =*/ whisper_vad_default_params(),
|
4719
5996
|
};
|
4720
5997
|
|
4721
5998
|
switch (strategy) {
|
@@ -4960,7 +6237,7 @@ static void whisper_process_logits(
|
|
4960
6237
|
|
4961
6238
|
// suppress non-speech tokens
|
4962
6239
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
4963
|
-
if (params.
|
6240
|
+
if (params.suppress_nst) {
|
4964
6241
|
for (const std::string & token : non_speech_tokens) {
|
4965
6242
|
const std::string suppress_tokens[] = {token, " " + token};
|
4966
6243
|
for (const std::string & suppress_token : suppress_tokens) {
|
@@ -5332,6 +6609,121 @@ static void whisper_sequence_score(
|
|
5332
6609
|
}
|
5333
6610
|
}
|
5334
6611
|
|
6612
|
+
static bool whisper_vad(
|
6613
|
+
struct whisper_context * ctx,
|
6614
|
+
struct whisper_state * state,
|
6615
|
+
struct whisper_full_params params,
|
6616
|
+
const float * samples,
|
6617
|
+
int n_samples,
|
6618
|
+
std::vector<float> & filtered_samples,
|
6619
|
+
int & filtered_n_samples) {
|
6620
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
|
6621
|
+
filtered_n_samples = 0;
|
6622
|
+
|
6623
|
+
if (state->vad_context == nullptr) {
|
6624
|
+
struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
|
6625
|
+
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
|
6626
|
+
if (vctx == nullptr) {
|
6627
|
+
WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
|
6628
|
+
return false;
|
6629
|
+
}
|
6630
|
+
state->vad_context = vctx;
|
6631
|
+
}
|
6632
|
+
auto vctx = state->vad_context;
|
6633
|
+
|
6634
|
+
const whisper_vad_params & vad_params = params.vad_params;
|
6635
|
+
|
6636
|
+
whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
|
6637
|
+
|
6638
|
+
if (vad_segments->data.size() > 0) {
|
6639
|
+
state->has_vad_segments = true;
|
6640
|
+
ctx->state->vad_segments.clear();
|
6641
|
+
ctx->state->vad_segments.reserve(vad_segments->data.size());
|
6642
|
+
|
6643
|
+
WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
|
6644
|
+
float overlap_seconds = vad_params.samples_overlap;
|
6645
|
+
int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
|
6646
|
+
|
6647
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
6648
|
+
int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
|
6649
|
+
int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
|
6650
|
+
|
6651
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6652
|
+
segment_end_samples += overlap_samples;
|
6653
|
+
}
|
6654
|
+
segment_end_samples = std::min(segment_end_samples, n_samples - 1);
|
6655
|
+
filtered_n_samples += (segment_end_samples - segment_start_samples);
|
6656
|
+
|
6657
|
+
WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
|
6658
|
+
__func__, i, vad_segments->data[i].start,
|
6659
|
+
vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
|
6660
|
+
(vad_segments->data[i].end - vad_segments->data[i].start) +
|
6661
|
+
(i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
|
6662
|
+
}
|
6663
|
+
|
6664
|
+
int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
|
6665
|
+
int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
|
6666
|
+
int total_samples_needed = filtered_n_samples + total_silence_samples;
|
6667
|
+
|
6668
|
+
WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
|
6669
|
+
__func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
|
6670
|
+
|
6671
|
+
try {
|
6672
|
+
filtered_samples.resize(total_samples_needed);
|
6673
|
+
} catch (const std::bad_alloc & /* e */) {
|
6674
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
|
6675
|
+
whisper_vad_free_segments(vad_segments);
|
6676
|
+
whisper_vad_free(vctx);
|
6677
|
+
return false;
|
6678
|
+
}
|
6679
|
+
|
6680
|
+
int offset = 0;
|
6681
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
6682
|
+
int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
|
6683
|
+
int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
|
6684
|
+
|
6685
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6686
|
+
segment_end_samples += overlap_samples;
|
6687
|
+
}
|
6688
|
+
|
6689
|
+
segment_start_samples = std::min(segment_start_samples, n_samples - 1);
|
6690
|
+
segment_end_samples = std::min(segment_end_samples, n_samples);
|
6691
|
+
int segment_length = segment_end_samples - segment_start_samples;
|
6692
|
+
|
6693
|
+
if (segment_length > 0) {
|
6694
|
+
whisper_state::vad_segment_info segment;
|
6695
|
+
|
6696
|
+
segment.orig_start = vad_segments->data[i].start;
|
6697
|
+
segment.orig_end = vad_segments->data[i].end;
|
6698
|
+
|
6699
|
+
segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
|
6700
|
+
segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
|
6701
|
+
|
6702
|
+
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
|
6703
|
+
__func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
|
6704
|
+
ctx->state->vad_segments.push_back(segment);
|
6705
|
+
|
6706
|
+
// Copy this speech segment
|
6707
|
+
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
|
6708
|
+
offset += segment_length;
|
6709
|
+
|
6710
|
+
// Add silence after this segment (except after the last segment)
|
6711
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6712
|
+
// Fill with zeros (silence)
|
6713
|
+
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
|
6714
|
+
offset += silence_samples;
|
6715
|
+
}
|
6716
|
+
}
|
6717
|
+
}
|
6718
|
+
|
6719
|
+
filtered_n_samples = offset;
|
6720
|
+
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
|
6721
|
+
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
|
6722
|
+
}
|
6723
|
+
|
6724
|
+
return true;
|
6725
|
+
}
|
6726
|
+
|
5335
6727
|
int whisper_full_with_state(
|
5336
6728
|
struct whisper_context * ctx,
|
5337
6729
|
struct whisper_state * state,
|
@@ -5343,9 +6735,27 @@ int whisper_full_with_state(
|
|
5343
6735
|
|
5344
6736
|
result_all.clear();
|
5345
6737
|
|
5346
|
-
|
6738
|
+
const float * process_samples = samples;
|
6739
|
+
int n_process_samples = n_samples;
|
6740
|
+
std::vector<float> vad_samples;
|
6741
|
+
|
6742
|
+
if (params.vad) {
|
6743
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
6744
|
+
int vad_n_samples;
|
6745
|
+
if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
|
6746
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
|
6747
|
+
return -1;
|
6748
|
+
}
|
6749
|
+
if (vad_n_samples == 0) {
|
6750
|
+
return 0;
|
6751
|
+
}
|
6752
|
+
process_samples = vad_samples.data();
|
6753
|
+
n_process_samples = vad_n_samples;
|
6754
|
+
}
|
6755
|
+
|
6756
|
+
if (n_process_samples > 0) {
|
5347
6757
|
// compute log mel spectrogram
|
5348
|
-
if (whisper_pcm_to_mel_with_state(ctx, state,
|
6758
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
|
5349
6759
|
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5350
6760
|
return -2;
|
5351
6761
|
}
|
@@ -5381,11 +6791,13 @@ int whisper_full_with_state(
|
|
5381
6791
|
const int seek_start = params.offset_ms/10;
|
5382
6792
|
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
5383
6793
|
|
5384
|
-
// if length of spectrogram is less than
|
5385
|
-
// basically don't process anything that is less than
|
5386
|
-
//
|
5387
|
-
|
5388
|
-
|
6794
|
+
// if length of spectrogram is less than 100ms (10 frames), then return
|
6795
|
+
// basically don't process anything that is less than 100ms
|
6796
|
+
// ref: https://github.com/ggml-org/whisper.cpp/issues/2065
|
6797
|
+
const int delta_min = 10;
|
6798
|
+
|
6799
|
+
if (seek_end < seek_start + delta_min) {
|
6800
|
+
WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
5389
6801
|
return 0;
|
5390
6802
|
}
|
5391
6803
|
|
@@ -5432,7 +6844,7 @@ int whisper_full_with_state(
|
|
5432
6844
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
5433
6845
|
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
5434
6846
|
|
5435
|
-
decoder.rng = std::mt19937(
|
6847
|
+
decoder.rng = std::mt19937(j);
|
5436
6848
|
}
|
5437
6849
|
|
5438
6850
|
// the accumulated text context so far
|
@@ -5529,8 +6941,8 @@ int whisper_full_with_state(
|
|
5529
6941
|
ctx, state, progress_cur, params.progress_callback_user_data);
|
5530
6942
|
}
|
5531
6943
|
|
5532
|
-
// if only
|
5533
|
-
if (seek +
|
6944
|
+
// if only 100ms left, then stop
|
6945
|
+
if (seek + delta_min >= seek_end) {
|
5534
6946
|
break;
|
5535
6947
|
}
|
5536
6948
|
|
@@ -5877,10 +7289,10 @@ int whisper_full_with_state(
|
|
5877
7289
|
// end of segment
|
5878
7290
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
5879
7291
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
5880
|
-
(has_ts && seek + seek_delta +
|
7292
|
+
(has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
|
5881
7293
|
) {
|
5882
7294
|
if (result_len == 0 && !params.no_timestamps) {
|
5883
|
-
if (seek + seek_delta +
|
7295
|
+
if (seek + seek_delta + delta_min >= seek_end) {
|
5884
7296
|
result_len = i + 1;
|
5885
7297
|
} else {
|
5886
7298
|
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
@@ -6147,7 +7559,7 @@ int whisper_full_with_state(
|
|
6147
7559
|
|
6148
7560
|
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
|
6149
7561
|
|
6150
|
-
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
|
7562
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
6151
7563
|
for (int j = i0; j <= i; j++) {
|
6152
7564
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
6153
7565
|
}
|
@@ -6192,7 +7604,7 @@ int whisper_full_with_state(
|
|
6192
7604
|
}
|
6193
7605
|
}
|
6194
7606
|
|
6195
|
-
result_all.push_back({ tt0, tt1, text, {}
|
7607
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
6196
7608
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
6197
7609
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
6198
7610
|
}
|
@@ -6229,7 +7641,7 @@ int whisper_full_with_state(
|
|
6229
7641
|
}
|
6230
7642
|
}
|
6231
7643
|
|
6232
|
-
// ref: https://github.com/
|
7644
|
+
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
|
6233
7645
|
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
6234
7646
|
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
6235
7647
|
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
@@ -6388,19 +7800,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
|
|
6388
7800
|
}
|
6389
7801
|
|
6390
7802
|
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
|
6391
|
-
return
|
7803
|
+
// If VAD wasn't used, return the original timestamp
|
7804
|
+
if (!state->has_vad_segments || state->vad_segments.empty()) {
|
7805
|
+
return state->result_all[i_segment].t0;
|
7806
|
+
}
|
7807
|
+
|
7808
|
+
// Get the start timestamp produced by whisper_full. whisper_full processes
|
7809
|
+
// only the speech segments in this case so we need to map these timestamps
|
7810
|
+
// back to the original audio.
|
7811
|
+
float t0 = state->result_all[i_segment].t0 / 100.0f;
|
7812
|
+
|
7813
|
+
// Find which VAD segment this timestamp belongs.
|
7814
|
+
// TODO(danbev) This could be optimized by using a binary search if the number
|
7815
|
+
// of segments exceed a certain limit. Also we might be able to assume that
|
7816
|
+
// the access pattern is sequential and optimized for that too.
|
7817
|
+
for (size_t i = 0; i < state->vad_segments.size(); i++) {
|
7818
|
+
const auto & segment = state->vad_segments[i];
|
7819
|
+
|
7820
|
+
// Check if the timestamp falls within this segment.
|
7821
|
+
if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
|
7822
|
+
float proportion = 0.0f;
|
7823
|
+
if (segment.vad_end > segment.vad_start) {
|
7824
|
+
proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
|
7825
|
+
}
|
7826
|
+
float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
|
7827
|
+
return (int64_t)(orig_t0 * 100);
|
7828
|
+
}
|
7829
|
+
}
|
7830
|
+
|
7831
|
+
// Check if the timestamp falls between two segments.
|
7832
|
+
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
|
7833
|
+
const auto & curr = state->vad_segments[i];
|
7834
|
+
const auto & next = state->vad_segments[i + 1];
|
7835
|
+
|
7836
|
+
if (t0 > curr.vad_end && t0 < next.vad_start) {
|
7837
|
+
// Calculate how far we are through the gap as a proportion
|
7838
|
+
float gap_proportion = 0.0f;
|
7839
|
+
if (next.vad_start > curr.vad_end) {
|
7840
|
+
gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
|
7841
|
+
}
|
7842
|
+
float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
|
7843
|
+
return (int64_t)(orig_t0 * 100);
|
7844
|
+
}
|
7845
|
+
}
|
7846
|
+
|
7847
|
+
// Handle the case where the timestamp is after the last segment.
|
7848
|
+
if (t0 > state->vad_segments.back().vad_end) {
|
7849
|
+
// For timestamps after the last segment, add the extra time to the end of the last segment
|
7850
|
+
const auto& last = state->vad_segments.back();
|
7851
|
+
// Calculate how far beyond the last segment
|
7852
|
+
float extra_time = t0 - last.vad_end;
|
7853
|
+
// Add this extra time to the original end time
|
7854
|
+
float orig_t0 = last.orig_end + extra_time;
|
7855
|
+
return (int64_t)(orig_t0 * 100);
|
7856
|
+
}
|
7857
|
+
|
7858
|
+
WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
|
7859
|
+
return t0;
|
6392
7860
|
}
|
6393
7861
|
|
6394
7862
|
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
6395
|
-
return ctx->state
|
7863
|
+
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
|
6396
7864
|
}
|
6397
7865
|
|
6398
7866
|
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
|
6399
|
-
return
|
7867
|
+
// If VAD wasn't used, return the original timestamp
|
7868
|
+
if (!state->has_vad_segments || state->vad_segments.empty()) {
|
7869
|
+
return state->result_all[i_segment].t1;
|
7870
|
+
}
|
7871
|
+
|
7872
|
+
// Get the end timestamp produced by whisper_full. whisper_full processes
|
7873
|
+
// only the speech segments in this case so we need to map these timestamps
|
7874
|
+
// back to the original audio.
|
7875
|
+
float t1 = state->result_all[i_segment].t1 / 100.0f;
|
7876
|
+
|
7877
|
+
// Find which VAD segment this timestamp belongs.
|
7878
|
+
// TODO(danbev) This could be optimized by using a binary search if the number
|
7879
|
+
// of segments exceed a certain limit. Also we might be able to assume that
|
7880
|
+
// the access pattern is sequential and optimized for that too.
|
7881
|
+
for (size_t i = 0; i < state->vad_segments.size(); i++) {
|
7882
|
+
const auto& segment = state->vad_segments[i];
|
7883
|
+
|
7884
|
+
// Check if the timestamp falls within this segment.
|
7885
|
+
if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
|
7886
|
+
// Calculate the proportion through the filtered segment.
|
7887
|
+
float proportion = 0.0f;
|
7888
|
+
if (segment.vad_end > segment.vad_start) {
|
7889
|
+
proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
|
7890
|
+
}
|
7891
|
+
float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
|
7892
|
+
return (int64_t)(orig_t1 * 100);
|
7893
|
+
}
|
7894
|
+
}
|
7895
|
+
|
7896
|
+
// Check if the timestamp falls between two segments.
|
7897
|
+
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
|
7898
|
+
const auto & curr = state->vad_segments[i];
|
7899
|
+
const auto & next = state->vad_segments[i + 1];
|
7900
|
+
|
7901
|
+
if (t1 > curr.vad_end && t1 < next.vad_start) {
|
7902
|
+
// Calculate how far we are through the gap as a proportion
|
7903
|
+
float gap_proportion = 0.0f;
|
7904
|
+
if (next.vad_start > curr.vad_end) {
|
7905
|
+
gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
|
7906
|
+
}
|
7907
|
+
// Map to the corresponding position in the original gap
|
7908
|
+
float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
|
7909
|
+
return (int64_t)(orig_t1 * 100);
|
7910
|
+
}
|
7911
|
+
}
|
7912
|
+
|
7913
|
+
// Handle the case where the timestamp is after the last segment
|
7914
|
+
if (t1 > state->vad_segments.back().vad_end) {
|
7915
|
+
// For the last segment, use the end of the last VAD segment
|
7916
|
+
const auto& last = state->vad_segments.back();
|
7917
|
+
// Calculate how far beyond the last segment
|
7918
|
+
float extra_time = t1 - last.vad_end;
|
7919
|
+
// Add this extra time to the original end time
|
7920
|
+
float orig_t1 = last.orig_end + extra_time;
|
7921
|
+
return (int64_t)(orig_t1 * 100);
|
7922
|
+
}
|
7923
|
+
|
7924
|
+
WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
|
7925
|
+
return t1;
|
6400
7926
|
}
|
6401
7927
|
|
6402
7928
|
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
6403
|
-
return ctx->state
|
7929
|
+
return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
|
6404
7930
|
}
|
6405
7931
|
|
6406
7932
|
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
@@ -6459,6 +7985,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
|
|
6459
7985
|
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
6460
7986
|
}
|
6461
7987
|
|
7988
|
+
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
|
7989
|
+
return ctx->state->result_all[i_segment].no_speech_prob;
|
7990
|
+
}
|
7991
|
+
|
7992
|
+
float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
|
7993
|
+
return state->result_all[i_segment].no_speech_prob;
|
7994
|
+
}
|
7995
|
+
|
6462
7996
|
// =================================================================================================
|
6463
7997
|
|
6464
7998
|
//
|
@@ -6620,6 +8154,8 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
|
|
6620
8154
|
}
|
6621
8155
|
|
6622
8156
|
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
8157
|
+
whisper_load_backends();
|
8158
|
+
|
6623
8159
|
static std::string s;
|
6624
8160
|
s = "";
|
6625
8161
|
char strbuf[256];
|
@@ -6639,7 +8175,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6639
8175
|
// c: N*N*sizeof(float)
|
6640
8176
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
6641
8177
|
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
|
6642
|
-
std::vector<uint8_t> work;
|
6643
8178
|
|
6644
8179
|
// put a bunch of random data in the buffer
|
6645
8180
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
@@ -6696,12 +8231,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6696
8231
|
double tsum = 0.0;
|
6697
8232
|
|
6698
8233
|
// heat-up
|
6699
|
-
ggml_graph_compute_helper(gf,
|
8234
|
+
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
6700
8235
|
|
6701
8236
|
for (int i = 0; i < n_max; ++i) {
|
6702
8237
|
const int64_t t0 = ggml_time_us();
|
6703
8238
|
|
6704
|
-
ggml_graph_compute_helper(gf,
|
8239
|
+
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
6705
8240
|
|
6706
8241
|
const int64_t t1 = ggml_time_us();
|
6707
8242
|
|
@@ -6862,12 +8397,6 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6862
8397
|
|
6863
8398
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
|
6864
8399
|
|
6865
|
-
tokens[j].id = token.id;
|
6866
|
-
tokens[j].tid = token.tid;
|
6867
|
-
tokens[j].p = token.p;
|
6868
|
-
tokens[j].pt = token.pt;
|
6869
|
-
tokens[j].ptsum = token.ptsum;
|
6870
|
-
|
6871
8400
|
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
|
6872
8401
|
|
6873
8402
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
@@ -7078,18 +8607,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7078
8607
|
struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
|
7079
8608
|
struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
|
7080
8609
|
|
7081
|
-
cost =
|
7082
|
-
trace =
|
7083
|
-
|
8610
|
+
cost = whisper_set_f32(cost, INFINITY);
|
8611
|
+
trace = whisper_set_i32(trace, -1);
|
8612
|
+
whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
7084
8613
|
|
7085
8614
|
// dtw
|
7086
8615
|
// supposedly can be optmized by computing diagonals in parallel ?
|
7087
8616
|
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
7088
8617
|
for (int64_t j = 1; j < M + 1; ++j) {
|
7089
8618
|
for (int64_t i = 1; i < N + 1; ++i) {
|
7090
|
-
float c0 =
|
7091
|
-
float c1 =
|
7092
|
-
float c2 =
|
8619
|
+
float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
8620
|
+
float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
|
8621
|
+
float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
|
7093
8622
|
|
7094
8623
|
float c;
|
7095
8624
|
int32_t t;
|
@@ -7104,9 +8633,9 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7104
8633
|
t = 2;
|
7105
8634
|
}
|
7106
8635
|
|
7107
|
-
c =
|
7108
|
-
|
7109
|
-
|
8636
|
+
c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
8637
|
+
whisper_set_f32_nd(cost, i, j, 0, 0, c);
|
8638
|
+
whisper_set_i32_nd(trace, i, j, 0, 0, t);
|
7110
8639
|
}
|
7111
8640
|
}
|
7112
8641
|
|
@@ -7115,19 +8644,19 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7115
8644
|
struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
7116
8645
|
// trace[0, :] = 2;
|
7117
8646
|
for (int64_t i = 0; i < M + 1; ++i)
|
7118
|
-
|
8647
|
+
whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
|
7119
8648
|
//trace[:, 0] = 1;
|
7120
8649
|
for (int64_t i = 0; i < N + 1; ++i)
|
7121
|
-
|
8650
|
+
whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
|
7122
8651
|
int bt_row_idx = BT_MAX_ROWS - 1;
|
7123
8652
|
int64_t i = N;
|
7124
8653
|
int64_t j = M;
|
7125
8654
|
while (i > 0 || j > 0) {
|
7126
|
-
|
7127
|
-
|
8655
|
+
whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
8656
|
+
whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
7128
8657
|
--bt_row_idx;
|
7129
8658
|
|
7130
|
-
int32_t t =
|
8659
|
+
int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
|
7131
8660
|
if (t == 0) {
|
7132
8661
|
--i;
|
7133
8662
|
--j;
|
@@ -7148,8 +8677,8 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
7148
8677
|
ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
|
7149
8678
|
for (int64_t i = 0; i < 2; ++i) {
|
7150
8679
|
for (int64_t j = 0; j < result_n_cols; ++j) {
|
7151
|
-
int32_t v =
|
7152
|
-
|
8680
|
+
int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
8681
|
+
whisper_set_i32_nd(r, i, j, 0, 0, v);
|
7153
8682
|
}
|
7154
8683
|
}
|
7155
8684
|
|
@@ -7184,11 +8713,11 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
|
|
7184
8713
|
idx = 2*(a->ne[2] - 1) - idx;
|
7185
8714
|
}
|
7186
8715
|
|
7187
|
-
filter.push_back(
|
8716
|
+
filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
|
7188
8717
|
}
|
7189
8718
|
std::sort(filter.begin(), filter.end());
|
7190
8719
|
const float v = filter[filter.size()/2];
|
7191
|
-
|
8720
|
+
whisper_set_f32_nd(dst, i, j, k, 0, v);
|
7192
8721
|
filter.clear();
|
7193
8722
|
}
|
7194
8723
|
}
|
@@ -7310,7 +8839,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7310
8839
|
// Compute
|
7311
8840
|
struct ggml_cgraph * gf = ggml_new_graph(gctx);
|
7312
8841
|
ggml_build_forward_expand(gf, w);
|
7313
|
-
|
8842
|
+
|
8843
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
8844
|
+
ggml_backend_graph_compute(backend.get(), gf);
|
7314
8845
|
|
7315
8846
|
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
7316
8847
|
|
@@ -7319,9 +8850,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7319
8850
|
auto seg_i = state->result_all.begin() + i_segment;
|
7320
8851
|
auto tok_i = seg_i->tokens.begin();
|
7321
8852
|
for (int i = 0; i < alignment->ne[1]; ++i) {
|
7322
|
-
int32_t v =
|
8853
|
+
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
|
7323
8854
|
if (v != last_v) {
|
7324
|
-
int32_t time_index =
|
8855
|
+
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
|
7325
8856
|
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
7326
8857
|
last_v = v;
|
7327
8858
|
|