whispercpp 1.3.0 → 1.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -0
- data/LICENSE +1 -1
- data/README.md +216 -424
- data/Rakefile +79 -11
- data/ext/.gitignore +11 -0
- data/ext/dependencies.rb +61 -0
- data/ext/extconf.rb +18 -26
- data/ext/options.rb +221 -0
- data/ext/ruby_whisper.c +159 -0
- data/ext/ruby_whisper.h +27 -2
- data/ext/ruby_whisper_context.c +641 -0
- data/ext/ruby_whisper_error.c +52 -0
- data/ext/ruby_whisper_model.c +232 -0
- data/ext/ruby_whisper_params.c +1301 -0
- data/ext/ruby_whisper_segment.c +143 -0
- data/ext/ruby_whisper_transcribe.cpp +87 -0
- data/ext/ruby_whisper_vad_params.c +288 -0
- data/ext/sources/.dockerignore +3 -0
- data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
- data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
- data/ext/sources/CMakeLists.txt +251 -0
- data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
- data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
- data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
- data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
- data/ext/sources/bindings/javascript/package.json +26 -0
- data/ext/sources/bindings/javascript/whisper.js +19 -0
- data/ext/sources/build-xcframework.sh +547 -0
- data/ext/sources/ci/run.sh +336 -0
- data/ext/sources/close-issue.yml +28 -0
- data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
- data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
- data/ext/sources/cmake/build-info.cmake +60 -0
- data/ext/sources/cmake/git-vars.cmake +22 -0
- data/ext/sources/cmake/whisper-config.cmake.in +65 -0
- data/ext/sources/cmake/whisper.pc.in +10 -0
- data/ext/sources/examples/CMakeLists.txt +124 -0
- data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
- data/ext/sources/examples/addon.node/addon.cpp +438 -0
- data/ext/sources/examples/addon.node/index.js +54 -0
- data/ext/sources/examples/addon.node/package.json +16 -0
- data/ext/sources/examples/bench/CMakeLists.txt +8 -0
- data/ext/sources/examples/bench/bench.cpp +175 -0
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
- data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
- data/ext/sources/examples/cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/cli/cli.cpp +1294 -0
- data/ext/sources/examples/coi-serviceworker.js +146 -0
- data/ext/sources/examples/command/CMakeLists.txt +10 -0
- data/ext/sources/examples/command/command.cpp +776 -0
- data/ext/sources/examples/command/commands.txt +9 -0
- data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
- data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/common-ggml.cpp +238 -0
- data/ext/sources/examples/common-ggml.h +18 -0
- data/ext/sources/examples/common-sdl.cpp +227 -0
- data/ext/sources/examples/common-sdl.h +49 -0
- data/ext/sources/examples/common-whisper.cpp +168 -0
- data/ext/sources/examples/common-whisper.h +24 -0
- data/ext/sources/examples/common.cpp +675 -0
- data/ext/sources/examples/common.h +322 -0
- data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
- data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
- data/ext/sources/examples/generate-karaoke.sh +57 -0
- data/ext/sources/examples/grammar-parser.cpp +423 -0
- data/ext/sources/examples/grammar-parser.h +29 -0
- data/ext/sources/examples/helpers.js +191 -0
- data/ext/sources/examples/json.hpp +24596 -0
- data/ext/sources/examples/livestream.sh +112 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
- data/ext/sources/examples/lsp/lsp.cpp +467 -0
- data/ext/sources/examples/lsp/whisper.vim +362 -0
- data/ext/sources/examples/miniaudio.h +93468 -0
- data/ext/sources/examples/python/test_whisper_processor.py +7 -0
- data/ext/sources/examples/python/whisper_processor.py +54 -0
- data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
- data/ext/sources/examples/quantize/quantize.cpp +223 -0
- data/ext/sources/examples/server/CMakeLists.txt +12 -0
- data/ext/sources/examples/server/bench.js +29 -0
- data/ext/sources/examples/server/httplib.h +10497 -0
- data/ext/sources/examples/server/server.cpp +1091 -0
- data/ext/sources/examples/server.py +115 -0
- data/ext/sources/examples/stb_vorbis.c +5584 -0
- data/ext/sources/examples/stream/CMakeLists.txt +10 -0
- data/ext/sources/examples/stream/stream.cpp +429 -0
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
- data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
- data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
- data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
- data/ext/sources/examples/sycl/build.sh +22 -0
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
- data/ext/sources/examples/sycl/run-whisper.sh +17 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
- data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
- data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
- data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
- data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
- data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
- data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
- data/ext/sources/examples/talk-llama/llama-context.h +276 -0
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
- data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
- data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
- data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
- data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
- data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
- data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
- data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
- data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
- data/ext/sources/examples/talk-llama/llama-io.h +35 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
- data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
- data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
- data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
- data/ext/sources/examples/talk-llama/llama-model.h +425 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
- data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
- data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
- data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
- data/ext/sources/examples/talk-llama/llama.cpp +354 -0
- data/ext/sources/examples/talk-llama/llama.h +1377 -0
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
- data/ext/sources/examples/talk-llama/speak +40 -0
- data/ext/sources/examples/talk-llama/speak.bat +1 -0
- data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
- data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
- data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
- data/ext/sources/examples/talk-llama/unicode.h +66 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
- data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
- data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
- data/ext/sources/ggml/CMakeLists.txt +390 -0
- data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
- data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
- data/ext/sources/ggml/cmake/common.cmake +26 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
- data/ext/sources/ggml/include/ggml-alloc.h +76 -0
- data/ext/sources/ggml/include/ggml-backend.h +354 -0
- data/ext/sources/ggml/include/ggml-blas.h +25 -0
- data/ext/sources/ggml/include/ggml-cann.h +123 -0
- data/ext/sources/ggml/include/ggml-cpp.h +39 -0
- data/ext/sources/ggml/include/ggml-cpu.h +143 -0
- data/ext/sources/ggml/include/ggml-cuda.h +47 -0
- data/ext/sources/ggml/include/ggml-kompute.h +50 -0
- data/ext/sources/ggml/include/ggml-metal.h +66 -0
- data/ext/sources/ggml/include/ggml-opencl.h +26 -0
- data/ext/sources/ggml/include/ggml-opt.h +237 -0
- data/ext/sources/ggml/include/ggml-rpc.h +33 -0
- data/ext/sources/ggml/include/ggml-sycl.h +49 -0
- data/ext/sources/ggml/include/ggml-vulkan.h +29 -0
- data/ext/{ggml.h → sources/ggml/include/ggml.h} +621 -821
- data/ext/sources/ggml/include/gguf.h +202 -0
- data/ext/sources/ggml/src/CMakeLists.txt +346 -0
- data/ext/sources/ggml/src/ggml-alloc.c +1042 -0
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- data/ext/sources/ggml/src/ggml-amx/common.h +94 -0
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/sources/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +255 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +586 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +2011 -0
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +181 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +3193 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +420 -0
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +2606 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +234 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/sources/ggml/src/ggml-common.h +1857 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +221 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
- data/ext/sources/ggml/src/ggml-cpu/cpu-feats-x86.cpp +327 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +508 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +13747 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +671 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
- data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
- data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
- data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
- data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
- data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
- data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
- data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +15 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +243 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +140 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
- data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
- data/ext/sources/ggml/src/ggml-impl.h +601 -0
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +5998 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +7089 -0
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1037 -0
- data/ext/sources/ggml/src/ggml-quants.c +5232 -0
- data/ext/sources/ggml/src/ggml-quants.h +100 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +1813 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +195 -0
- data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +101 -0
- data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +623 -0
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +1162 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +4493 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
- data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +3030 -0
- data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +501 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +47 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +261 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
- data/ext/sources/ggml/src/ggml-threading.cpp +12 -0
- data/ext/sources/ggml/src/ggml-threading.h +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
- data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +10700 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +751 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
- data/ext/sources/ggml/src/ggml.c +6550 -0
- data/ext/sources/ggml/src/gguf.cpp +1330 -0
- data/ext/{whisper.h → sources/include/whisper.h} +91 -24
- data/ext/sources/src/CMakeLists.txt +143 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.h +158 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +226 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.h +154 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +222 -0
- data/ext/sources/src/coreml/whisper-encoder.h +26 -0
- data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
- data/ext/sources/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/sources/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/sources/src/whisper-arch.h +197 -0
- data/ext/{whisper.cpp → sources/src/whisper.cpp} +2535 -835
- data/ext/sources/tests/CMakeLists.txt +105 -0
- data/ext/sources/tests/earnings21/eval.mk +58 -0
- data/ext/sources/tests/earnings21/eval.py +68 -0
- data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
- data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
- data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
- data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
- data/ext/sources/tests/earnings21/requirements.txt +6 -0
- data/ext/sources/tests/en-0-ref.txt +1 -0
- data/ext/sources/tests/en-1-ref.txt +1 -0
- data/ext/sources/tests/en-2-ref.txt +1 -0
- data/ext/sources/tests/es-0-ref.txt +1 -0
- data/ext/sources/tests/librispeech/eval.mk +39 -0
- data/ext/sources/tests/librispeech/eval.py +47 -0
- data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
- data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
- data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
- data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
- data/ext/sources/tests/librispeech/requirements.txt +6 -0
- data/ext/sources/tests/run-tests.sh +130 -0
- data/ext/sources/tests/test-c.c +3 -0
- data/ext/sources/tests/test-vad-full.cpp +54 -0
- data/ext/sources/tests/test-vad.cpp +83 -0
- data/ext/sources/tests/test-whisper.js +58 -0
- data/extsources.rb +34 -0
- data/lib/whisper/model/uri.rb +178 -0
- data/sig/whisper.rbs +480 -0
- data/tests/helper.rb +35 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +202 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +109 -0
- data/tests/test_package.rb +46 -0
- data/tests/test_params.rb +297 -0
- data/tests/test_segment.rb +74 -0
- data/tests/test_vad.rb +19 -0
- data/tests/test_vad_params.rb +103 -0
- data/tests/test_whisper.rb +212 -124
- data/whispercpp.gemspec +37 -0
- metadata +794 -13
- data/ext/dr_wav.h +0 -6434
- data/ext/ggml.c +0 -21755
- data/ext/ruby_whisper.cpp +0 -426
@@ -1,62 +1,52 @@
|
|
1
1
|
#include "whisper.h"
|
2
|
+
#include "whisper-arch.h"
|
3
|
+
|
4
|
+
#include "ggml.h"
|
5
|
+
#include "ggml-cpp.h"
|
6
|
+
#include "ggml-alloc.h"
|
7
|
+
#include "ggml-backend.h"
|
2
8
|
|
3
9
|
#ifdef WHISPER_USE_COREML
|
4
10
|
#include "coreml/whisper-encoder.h"
|
5
11
|
#endif
|
6
12
|
|
7
|
-
#ifdef GGML_USE_METAL
|
8
|
-
#include "ggml-metal.h"
|
9
|
-
#endif
|
10
|
-
|
11
|
-
#ifdef GGML_USE_CUDA
|
12
|
-
#include "ggml-cuda.h"
|
13
|
-
#endif
|
14
|
-
|
15
|
-
#ifdef GGML_USE_SYCL
|
16
|
-
#include "ggml-sycl.h"
|
17
|
-
#endif
|
18
|
-
|
19
13
|
#ifdef WHISPER_USE_OPENVINO
|
20
14
|
#include "openvino/whisper-openvino-encoder.h"
|
21
15
|
#endif
|
22
16
|
|
23
|
-
#include "ggml.h"
|
24
|
-
#include "ggml-alloc.h"
|
25
|
-
#include "ggml-backend.h"
|
26
|
-
|
27
17
|
#include <atomic>
|
28
18
|
#include <algorithm>
|
29
19
|
#include <cassert>
|
20
|
+
#include <cfloat>
|
30
21
|
#define _USE_MATH_DEFINES
|
31
22
|
#include <cmath>
|
32
|
-
#include <
|
23
|
+
#include <climits>
|
24
|
+
#include <codecvt>
|
33
25
|
#include <cstdarg>
|
26
|
+
#include <cstdio>
|
34
27
|
#include <cstring>
|
35
28
|
#include <fstream>
|
29
|
+
#include <functional>
|
36
30
|
#include <map>
|
31
|
+
#include <mutex>
|
32
|
+
#include <random>
|
33
|
+
#include <regex>
|
37
34
|
#include <set>
|
38
35
|
#include <string>
|
39
36
|
#include <thread>
|
40
37
|
#include <vector>
|
41
|
-
#include <regex>
|
42
|
-
#include <random>
|
43
|
-
#include <functional>
|
44
|
-
|
45
|
-
#if defined(_MSC_VER)
|
46
|
-
#pragma warning(disable: 4244 4267) // possible loss of data
|
47
|
-
#endif
|
48
|
-
|
49
|
-
#if defined(GGML_BIG_ENDIAN)
|
50
|
-
#include <bit>
|
51
38
|
|
39
|
+
#if defined(WHISPER_BIG_ENDIAN)
|
52
40
|
template<typename T>
|
53
41
|
static T byteswap(T value) {
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
42
|
+
T value_swapped;
|
43
|
+
char * source = reinterpret_cast<char *>(&value);
|
44
|
+
char * target = reinterpret_cast<char *>(&value_swapped);
|
45
|
+
int size = sizeof(T);
|
46
|
+
for (int i = 0; i < size; i++) {
|
47
|
+
target[size - 1 - i] = source[i];
|
48
|
+
}
|
49
|
+
return value_swapped;
|
60
50
|
}
|
61
51
|
|
62
52
|
template<typename T>
|
@@ -92,14 +82,14 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
92
82
|
}
|
93
83
|
|
94
84
|
#define BYTESWAP_VALUE(d) d = byteswap(d)
|
95
|
-
#define BYTESWAP_FILTERS(f)
|
85
|
+
#define BYTESWAP_FILTERS(f) \
|
96
86
|
do { \
|
97
87
|
for (auto & datum : f.data) { \
|
98
88
|
datum = byteswap(datum); \
|
99
89
|
} \
|
100
90
|
} while (0)
|
101
|
-
#define BYTESWAP_TENSOR(t)
|
102
|
-
do {
|
91
|
+
#define BYTESWAP_TENSOR(t) \
|
92
|
+
do { \
|
103
93
|
byteswap_tensor(t); \
|
104
94
|
} while (0)
|
105
95
|
#else
|
@@ -147,47 +137,128 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
147
137
|
} \
|
148
138
|
} while (0)
|
149
139
|
|
150
|
-
//#define WHISPER_USE_FLASH_ATTN
|
151
|
-
//#define WHISPER_USE_FLASH_FF
|
152
140
|
#define WHISPER_MAX_DECODERS 8
|
153
141
|
#define WHISPER_MAX_NODES 4096
|
154
142
|
|
143
|
+
static std::string format(const char * fmt, ...) {
|
144
|
+
va_list ap;
|
145
|
+
va_list ap2;
|
146
|
+
va_start(ap, fmt);
|
147
|
+
va_copy(ap2, ap);
|
148
|
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
149
|
+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
150
|
+
std::vector<char> buf(size + 1);
|
151
|
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
152
|
+
GGML_ASSERT(size2 == size);
|
153
|
+
va_end(ap2);
|
154
|
+
va_end(ap);
|
155
|
+
return std::string(buf.data(), size);
|
156
|
+
}
|
157
|
+
|
155
158
|
//
|
156
159
|
// ggml helpers
|
157
160
|
//
|
158
161
|
|
159
162
|
static bool ggml_graph_compute_helper(
|
160
163
|
struct ggml_cgraph * graph,
|
161
|
-
std::vector<uint8_t> & buf,
|
162
164
|
int n_threads,
|
163
165
|
ggml_abort_callback abort_callback,
|
164
166
|
void * abort_callback_data) {
|
165
|
-
|
167
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
166
168
|
|
167
|
-
|
168
|
-
plan.abort_callback_data = abort_callback_data;
|
169
|
+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
|
169
170
|
|
170
|
-
|
171
|
-
|
172
|
-
|
171
|
+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
|
172
|
+
if (set_abort_callback_fn) {
|
173
|
+
set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
|
173
174
|
}
|
174
175
|
|
175
|
-
|
176
|
+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
177
|
+
if (ggml_backend_set_n_threads_fn) {
|
178
|
+
ggml_backend_set_n_threads_fn(backend.get(), n_threads);
|
179
|
+
}
|
180
|
+
|
181
|
+
return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
|
176
182
|
}
|
177
183
|
|
178
184
|
static bool ggml_graph_compute_helper(
|
179
|
-
|
185
|
+
ggml_backend_sched_t sched,
|
180
186
|
struct ggml_cgraph * graph,
|
181
|
-
int n_threads
|
182
|
-
|
183
|
-
|
187
|
+
int n_threads,
|
188
|
+
bool sched_reset = true) {
|
189
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
|
190
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
|
191
|
+
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
192
|
+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
193
|
+
|
194
|
+
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
195
|
+
if (fn_set_n_threads) {
|
196
|
+
fn_set_n_threads(backend, n_threads);
|
197
|
+
}
|
184
198
|
}
|
185
|
-
|
186
|
-
|
187
|
-
|
199
|
+
|
200
|
+
const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
|
201
|
+
|
202
|
+
if (!t || sched_reset) {
|
203
|
+
ggml_backend_sched_reset(sched);
|
188
204
|
}
|
205
|
+
|
206
|
+
return t;
|
207
|
+
}
|
208
|
+
|
209
|
+
static void whisper_load_backends() {
|
210
|
+
#ifdef GGML_BACKEND_DL
|
211
|
+
static std::once_flag flag;
|
212
|
+
std::call_once(flag, []() {
|
213
|
+
ggml_backend_load_all();
|
214
|
+
});
|
189
215
|
#endif
|
190
|
-
|
216
|
+
}
|
217
|
+
|
218
|
+
// TODO: move these functions to ggml-base with support for ggml-backend?
|
219
|
+
|
220
|
+
static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
|
221
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
222
|
+
GGML_ASSERT(ggml_is_contiguous(t));
|
223
|
+
size_t nels = ggml_nelements(t);
|
224
|
+
for (size_t i = 0; i < nels; ++i) {
|
225
|
+
((float *) t->data)[i] = v;
|
226
|
+
}
|
227
|
+
return t;
|
228
|
+
}
|
229
|
+
|
230
|
+
static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
|
231
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
232
|
+
GGML_ASSERT(ggml_is_contiguous(t));
|
233
|
+
size_t nels = ggml_nelements(t);
|
234
|
+
for (size_t i = 0; i < nels; ++i) {
|
235
|
+
((int32_t *) t->data)[i] = v;
|
236
|
+
}
|
237
|
+
return t;
|
238
|
+
}
|
239
|
+
|
240
|
+
static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
241
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
242
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
243
|
+
return *(float *) data;
|
244
|
+
}
|
245
|
+
|
246
|
+
static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
|
247
|
+
GGML_ASSERT(t->type == GGML_TYPE_F32);
|
248
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
249
|
+
*(float *) data = v;
|
250
|
+
}
|
251
|
+
|
252
|
+
static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
253
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
254
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
255
|
+
return *(int32_t *) data;
|
256
|
+
}
|
257
|
+
|
258
|
+
static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
|
259
|
+
GGML_ASSERT(t->type == GGML_TYPE_I32);
|
260
|
+
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
|
261
|
+
*(int32_t *) data = v;
|
191
262
|
}
|
192
263
|
|
193
264
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
@@ -363,6 +434,7 @@ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15},
|
|
363
434
|
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
|
364
435
|
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
|
365
436
|
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
|
437
|
+
static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
|
366
438
|
|
367
439
|
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
368
440
|
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
|
@@ -376,6 +448,7 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
|
|
376
448
|
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
|
377
449
|
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
|
378
450
|
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
|
451
|
+
{ WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
|
379
452
|
};
|
380
453
|
|
381
454
|
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
|
@@ -431,6 +504,7 @@ struct whisper_segment {
|
|
431
504
|
int64_t t1;
|
432
505
|
|
433
506
|
std::string text;
|
507
|
+
float no_speech_prob;
|
434
508
|
|
435
509
|
std::vector<whisper_token_data> tokens;
|
436
510
|
|
@@ -502,33 +576,41 @@ struct whisper_pair {
|
|
502
576
|
whisper_pair() : first(A()), second(B()) {}
|
503
577
|
};
|
504
578
|
|
505
|
-
//
|
506
|
-
struct
|
507
|
-
|
579
|
+
// ggml_backend_sched wrapper for whisper usage
|
580
|
+
struct whisper_sched {
|
581
|
+
ggml_backend_sched_t sched = nullptr;
|
508
582
|
|
509
583
|
std::vector<uint8_t> meta;
|
510
584
|
};
|
511
585
|
|
512
|
-
static size_t
|
513
|
-
|
586
|
+
static size_t whisper_sched_size(struct whisper_sched & allocr) {
|
587
|
+
size_t size = allocr.meta.size();
|
588
|
+
for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
|
589
|
+
ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
|
590
|
+
size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
|
591
|
+
}
|
592
|
+
return size;
|
514
593
|
}
|
515
594
|
|
516
595
|
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
517
|
-
static bool
|
518
|
-
auto &
|
596
|
+
static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
|
597
|
+
auto & sched = allocr.sched;
|
519
598
|
auto & meta = allocr.meta;
|
520
599
|
|
521
|
-
|
600
|
+
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
|
522
601
|
|
523
602
|
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
524
603
|
|
525
604
|
// since there are dependencies between the different graphs,
|
526
605
|
// we need to allocate them instead of only reserving to get the correct compute buffer size
|
527
|
-
if (!
|
606
|
+
if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
|
528
607
|
// failed to allocate the compute buffer
|
529
608
|
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
530
609
|
return false;
|
531
610
|
}
|
611
|
+
|
612
|
+
ggml_backend_sched_reset(sched);
|
613
|
+
|
532
614
|
return true;
|
533
615
|
}
|
534
616
|
|
@@ -671,9 +753,9 @@ struct whisper_kv_cache {
|
|
671
753
|
struct ggml_tensor * k;
|
672
754
|
struct ggml_tensor * v;
|
673
755
|
|
674
|
-
struct ggml_context * ctx = nullptr;
|
675
|
-
|
676
756
|
ggml_backend_buffer_t buffer = nullptr;
|
757
|
+
|
758
|
+
std::vector<uint8_t> ctx_buf;
|
677
759
|
};
|
678
760
|
|
679
761
|
struct whisper_model {
|
@@ -711,10 +793,10 @@ struct whisper_model {
|
|
711
793
|
std::vector<whisper_layer_decoder> layers_decoder;
|
712
794
|
|
713
795
|
// ggml context that contains all the meta information about the model tensors
|
714
|
-
|
796
|
+
std::vector<ggml_context *> ctxs;
|
715
797
|
|
716
798
|
// the model backend data is read-only and can be shared between processors
|
717
|
-
ggml_backend_buffer_t
|
799
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
718
800
|
|
719
801
|
// tensors
|
720
802
|
int n_loaded;
|
@@ -802,6 +884,9 @@ struct whisper_state {
|
|
802
884
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
803
885
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
804
886
|
|
887
|
+
// number of decoders for which we have constructed the KV cache
|
888
|
+
int32_t kv_self_n_dec = 0;
|
889
|
+
|
805
890
|
// unified self-attention KV cache for all decoders
|
806
891
|
whisper_kv_cache kv_self;
|
807
892
|
|
@@ -809,21 +894,22 @@ struct whisper_state {
|
|
809
894
|
// shared between all decoders
|
810
895
|
whisper_kv_cache kv_cross;
|
811
896
|
|
897
|
+
// padded buffer for flash-attention
|
898
|
+
whisper_kv_cache kv_pad;
|
899
|
+
|
812
900
|
whisper_mel mel;
|
813
901
|
|
814
902
|
whisper_batch batch;
|
815
903
|
|
816
904
|
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
817
905
|
|
818
|
-
ggml_backend_t
|
906
|
+
std::vector<ggml_backend_t> backends;
|
819
907
|
|
820
|
-
// ggml-alloc:
|
821
908
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
whisper_allocr alloc_decode;
|
909
|
+
whisper_sched sched_conv;
|
910
|
+
whisper_sched sched_encode;
|
911
|
+
whisper_sched sched_cross;
|
912
|
+
whisper_sched sched_decode;
|
827
913
|
|
828
914
|
// result of the encoder
|
829
915
|
struct ggml_tensor * embd_conv = nullptr;
|
@@ -858,6 +944,7 @@ struct whisper_state {
|
|
858
944
|
whisper_token tid_last;
|
859
945
|
|
860
946
|
std::vector<float> energy; // PCM signal energy
|
947
|
+
float no_speech_prob = 0.0f;
|
861
948
|
|
862
949
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
863
950
|
whisper_aheads_masks aheads_masks;
|
@@ -866,6 +953,17 @@ struct whisper_state {
|
|
866
953
|
|
867
954
|
// [EXPERIMENTAL] speed-up techniques
|
868
955
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
956
|
+
|
957
|
+
whisper_vad_context * vad_context = nullptr;
|
958
|
+
|
959
|
+
struct vad_segment_info {
|
960
|
+
float orig_start;
|
961
|
+
float orig_end;
|
962
|
+
float vad_start;
|
963
|
+
float vad_end;
|
964
|
+
};
|
965
|
+
std::vector<vad_segment_info> vad_segments;
|
966
|
+
bool has_vad_segments = false;
|
869
967
|
};
|
870
968
|
|
871
969
|
struct whisper_context {
|
@@ -882,8 +980,6 @@ struct whisper_context {
|
|
882
980
|
|
883
981
|
whisper_state * state = nullptr;
|
884
982
|
|
885
|
-
ggml_backend_t backend = nullptr;
|
886
|
-
|
887
983
|
std::string path_model; // populated by whisper_init_from_file_with_params()
|
888
984
|
};
|
889
985
|
|
@@ -901,21 +997,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
901
997
|
BYTESWAP_VALUE(dest);
|
902
998
|
}
|
903
999
|
|
904
|
-
static bool
|
905
|
-
const struct whisper_hparams & hparams,
|
1000
|
+
static bool whisper_kv_cache_init(
|
906
1001
|
struct whisper_kv_cache & cache,
|
907
1002
|
ggml_backend_t backend,
|
908
1003
|
ggml_type wtype,
|
1004
|
+
int64_t n_text_state,
|
1005
|
+
int64_t n_text_layer,
|
909
1006
|
int n_ctx) {
|
910
|
-
const int64_t n_text_state = hparams.n_text_state;
|
911
|
-
const int64_t n_text_layer = hparams.n_text_layer;
|
912
|
-
|
913
1007
|
const int64_t n_mem = n_text_layer*n_ctx;
|
914
1008
|
const int64_t n_elements = n_text_state*n_mem;
|
915
1009
|
|
1010
|
+
cache.ctx_buf.resize(2*ggml_tensor_overhead());
|
1011
|
+
|
916
1012
|
struct ggml_init_params params = {
|
917
|
-
/*.mem_size =*/
|
918
|
-
/*.mem_buffer =*/
|
1013
|
+
/*.mem_size =*/ cache.ctx_buf.size(),
|
1014
|
+
/*.mem_buffer =*/ cache.ctx_buf.data(),
|
919
1015
|
/*.no_alloc =*/ true,
|
920
1016
|
};
|
921
1017
|
|
@@ -925,29 +1021,31 @@ static bool kv_cache_init(
|
|
925
1021
|
cache.cells.clear();
|
926
1022
|
cache.cells.resize(n_ctx);
|
927
1023
|
|
928
|
-
|
1024
|
+
struct ggml_context * ctx = ggml_init(params);
|
929
1025
|
|
930
|
-
if (!
|
1026
|
+
if (!ctx) {
|
931
1027
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
|
932
1028
|
return false;
|
933
1029
|
}
|
934
1030
|
|
935
|
-
cache.k = ggml_new_tensor_1d(
|
936
|
-
cache.v = ggml_new_tensor_1d(
|
1031
|
+
cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
1032
|
+
cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
|
937
1033
|
|
938
|
-
cache.buffer = ggml_backend_alloc_ctx_tensors(
|
1034
|
+
cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
939
1035
|
if (!cache.buffer) {
|
940
1036
|
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
|
941
1037
|
return false;
|
942
1038
|
}
|
943
1039
|
|
1040
|
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
1041
|
+
|
1042
|
+
ggml_free(ctx);
|
1043
|
+
|
944
1044
|
return true;
|
945
1045
|
}
|
946
1046
|
|
947
|
-
static void
|
948
|
-
ggml_free(cache.ctx);
|
1047
|
+
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
|
949
1048
|
ggml_backend_buffer_free(cache.buffer);
|
950
|
-
cache.ctx = nullptr;
|
951
1049
|
}
|
952
1050
|
|
953
1051
|
static bool whisper_kv_cache_find_slot(
|
@@ -1018,6 +1116,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
|
|
1018
1116
|
cache.cells[i].seq_id.clear();
|
1019
1117
|
}
|
1020
1118
|
cache.head = 0;
|
1119
|
+
|
1120
|
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
1021
1121
|
}
|
1022
1122
|
|
1023
1123
|
static void whisper_kv_cache_seq_rm(
|
@@ -1068,6 +1168,26 @@ static void whisper_kv_cache_seq_cp(
|
|
1068
1168
|
}
|
1069
1169
|
}
|
1070
1170
|
|
1171
|
+
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
|
1172
|
+
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
|
1173
|
+
return 1u;
|
1174
|
+
}
|
1175
|
+
|
1176
|
+
#ifdef GGML_USE_METAL
|
1177
|
+
if (wctx.params.use_gpu) {
|
1178
|
+
return 32u;
|
1179
|
+
}
|
1180
|
+
#endif
|
1181
|
+
|
1182
|
+
#ifdef GGML_USE_CUDA
|
1183
|
+
if (wctx.params.use_gpu) {
|
1184
|
+
return 256u;
|
1185
|
+
}
|
1186
|
+
#endif
|
1187
|
+
|
1188
|
+
return 1u;
|
1189
|
+
}
|
1190
|
+
|
1071
1191
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
1072
1192
|
static bool aheads_masks_init(
|
1073
1193
|
const whisper_context_params & cparams,
|
@@ -1199,49 +1319,178 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|
1199
1319
|
return size;
|
1200
1320
|
}
|
1201
1321
|
|
1202
|
-
static ggml_backend_t
|
1203
|
-
|
1322
|
+
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
1323
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
1204
1324
|
|
1205
|
-
|
1206
|
-
|
1325
|
+
whisper_load_backends();
|
1326
|
+
|
1327
|
+
ggml_backend_dev_t dev = nullptr;
|
1328
|
+
|
1329
|
+
int cnt = 0;
|
1207
1330
|
if (params.use_gpu) {
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1331
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1332
|
+
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
|
1333
|
+
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1334
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
1335
|
+
dev = dev_cur;
|
1336
|
+
}
|
1337
|
+
|
1338
|
+
if (++cnt > params.gpu_device) {
|
1339
|
+
break;
|
1340
|
+
}
|
1341
|
+
}
|
1212
1342
|
}
|
1213
1343
|
}
|
1214
|
-
#endif
|
1215
1344
|
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1345
|
+
if (dev == nullptr) {
|
1346
|
+
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
|
1347
|
+
return nullptr;
|
1348
|
+
}
|
1349
|
+
|
1350
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1351
|
+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
1352
|
+
if (!result) {
|
1353
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1354
|
+
}
|
1355
|
+
|
1356
|
+
return result;
|
1357
|
+
}
|
1358
|
+
|
1359
|
+
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
1360
|
+
std::vector<ggml_backend_t> result;
|
1361
|
+
|
1362
|
+
ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
|
1363
|
+
|
1364
|
+
if (backend_gpu) {
|
1365
|
+
result.push_back(backend_gpu);
|
1366
|
+
}
|
1367
|
+
|
1368
|
+
// ACCEL backends
|
1369
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1370
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1371
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
1372
|
+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1373
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
1374
|
+
if (!backend) {
|
1375
|
+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
1376
|
+
continue;
|
1377
|
+
}
|
1378
|
+
result.push_back(backend);
|
1227
1379
|
}
|
1228
1380
|
}
|
1229
|
-
#endif
|
1230
1381
|
|
1231
|
-
|
1382
|
+
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
1383
|
+
if (backend_cpu == nullptr) {
|
1384
|
+
throw std::runtime_error("failed to initialize CPU backend");
|
1385
|
+
}
|
1386
|
+
result.push_back(backend_cpu);
|
1387
|
+
|
1388
|
+
return result;
|
1389
|
+
}
|
1390
|
+
|
1391
|
+
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
|
1392
|
+
|
1393
|
+
static buft_list_t make_buft_list(whisper_context_params & params) {
|
1394
|
+
// Prio order: GPU -> CPU Extra -> CPU
|
1395
|
+
buft_list_t buft_list;
|
1396
|
+
|
1397
|
+
// GPU
|
1232
1398
|
if (params.use_gpu) {
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1399
|
+
int cnt = 0;
|
1400
|
+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
1401
|
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
1402
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
1403
|
+
if (cnt == 0 || cnt == params.gpu_device) {
|
1404
|
+
auto * buft = ggml_backend_dev_buffer_type(dev);
|
1405
|
+
if (buft) {
|
1406
|
+
buft_list.emplace_back(dev, buft);
|
1407
|
+
}
|
1408
|
+
}
|
1409
|
+
|
1410
|
+
if (++cnt > params.gpu_device) {
|
1411
|
+
break;
|
1412
|
+
}
|
1413
|
+
}
|
1237
1414
|
}
|
1238
1415
|
}
|
1239
|
-
#endif
|
1240
1416
|
|
1241
|
-
|
1242
|
-
|
1417
|
+
// CPU Extra
|
1418
|
+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
1419
|
+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
1420
|
+
auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
1421
|
+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
1422
|
+
if (get_extra_bufts_fn) {
|
1423
|
+
ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
|
1424
|
+
while (extra_bufts && *extra_bufts) {
|
1425
|
+
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
1426
|
+
++extra_bufts;
|
1427
|
+
}
|
1428
|
+
}
|
1429
|
+
|
1430
|
+
// CPU
|
1431
|
+
buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
|
1432
|
+
|
1433
|
+
return buft_list;
|
1434
|
+
}
|
1435
|
+
|
1436
|
+
static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
1437
|
+
bool op_supported = true;
|
1438
|
+
|
1439
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
1440
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
1441
|
+
// GPU and default CPU backend support all operators
|
1442
|
+
op_supported = true;
|
1443
|
+
} else {
|
1444
|
+
switch (op) {
|
1445
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
1446
|
+
case GGML_OP_MUL_MAT: {
|
1447
|
+
ggml_init_params params = {
|
1448
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
1449
|
+
/*.mem_buffer =*/ nullptr,
|
1450
|
+
/*.no_alloc =*/ true,
|
1451
|
+
};
|
1452
|
+
|
1453
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
1454
|
+
if (!ctx_ptr) {
|
1455
|
+
throw std::runtime_error("failed to create ggml context");
|
1456
|
+
}
|
1457
|
+
ggml_context * ctx = ctx_ptr.get();
|
1458
|
+
|
1459
|
+
ggml_tensor * op_tensor = nullptr;
|
1460
|
+
|
1461
|
+
int64_t n_ctx = hparams.n_audio_ctx;
|
1462
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
1463
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
1464
|
+
|
1465
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
1466
|
+
GGML_ASSERT(w->buffer == nullptr);
|
1467
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
1468
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
1469
|
+
ggml_backend_buffer_free(w->buffer);
|
1470
|
+
w->buffer = nullptr;
|
1471
|
+
break;
|
1472
|
+
}
|
1473
|
+
default: {
|
1474
|
+
op_supported = false;
|
1475
|
+
break;
|
1476
|
+
}
|
1477
|
+
};
|
1478
|
+
}
|
1479
|
+
|
1480
|
+
return op_supported;
|
1481
|
+
}
|
1482
|
+
|
1483
|
+
static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
1484
|
+
GGML_ASSERT(!buft_list.empty());
|
1485
|
+
for (const auto & p : buft_list) {
|
1486
|
+
ggml_backend_dev_t dev = p.first;
|
1487
|
+
ggml_backend_buffer_type_t buft = p.second;
|
1488
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
1489
|
+
return buft;
|
1490
|
+
}
|
1243
1491
|
}
|
1244
|
-
|
1492
|
+
|
1493
|
+
return nullptr;
|
1245
1494
|
}
|
1246
1495
|
|
1247
1496
|
// load the model from a ggml file
|
@@ -1450,31 +1699,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1450
1699
|
const ggml_type wtype = wctx.wtype;
|
1451
1700
|
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
1452
1701
|
|
1453
|
-
|
1454
|
-
{
|
1455
|
-
const auto & hparams = model.hparams;
|
1702
|
+
const auto & hparams = model.hparams;
|
1456
1703
|
|
1457
|
-
|
1458
|
-
|
1704
|
+
const int n_audio_layer = hparams.n_audio_layer;
|
1705
|
+
const int n_text_layer = hparams.n_text_layer;
|
1459
1706
|
|
1460
|
-
|
1707
|
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
1461
1708
|
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1709
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
1710
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
1711
|
+
auto it = ctx_map.find(buft);
|
1712
|
+
if (it == ctx_map.end()) {
|
1713
|
+
ggml_init_params params = {
|
1714
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
1715
|
+
/*.mem_buffer =*/ nullptr,
|
1716
|
+
/*.no_alloc =*/ true,
|
1717
|
+
};
|
1467
1718
|
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1719
|
+
ggml_context * ctx = ggml_init(params);
|
1720
|
+
if (!ctx) {
|
1721
|
+
throw std::runtime_error("failed to create ggml context");
|
1722
|
+
}
|
1723
|
+
|
1724
|
+
ctx_map[buft] = ctx;
|
1725
|
+
model.ctxs.emplace_back(ctx);
|
1726
|
+
|
1727
|
+
return ctx;
|
1472
1728
|
}
|
1473
|
-
|
1729
|
+
|
1730
|
+
return it->second;
|
1731
|
+
};
|
1732
|
+
|
1733
|
+
// Create a list of available bufts, in priority order
|
1734
|
+
buft_list_t buft_list = make_buft_list(wctx.params);
|
1735
|
+
|
1736
|
+
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
|
1737
|
+
ggml_op op = ASR_TENSOR_INFO.at(type);
|
1738
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
1739
|
+
if (!buft) {
|
1740
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
|
1741
|
+
}
|
1742
|
+
|
1743
|
+
ggml_context * ctx = get_ctx(buft);
|
1744
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
1745
|
+
|
1746
|
+
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
|
1747
|
+
|
1748
|
+
return tensor;
|
1749
|
+
};
|
1750
|
+
|
1474
1751
|
|
1475
1752
|
// prepare tensors for the weights
|
1476
1753
|
{
|
1477
|
-
|
1754
|
+
ggml_init_params params = {
|
1755
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
1756
|
+
/*.mem_buffer =*/ nullptr,
|
1757
|
+
/*.no_alloc =*/ true,
|
1758
|
+
};
|
1759
|
+
|
1760
|
+
ggml_context * ctx = ggml_init(params);
|
1478
1761
|
|
1479
1762
|
const auto & hparams = model.hparams;
|
1480
1763
|
|
@@ -1494,195 +1777,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1494
1777
|
model.layers_decoder.resize(n_text_layer);
|
1495
1778
|
|
1496
1779
|
// encoder
|
1497
|
-
|
1498
|
-
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
1499
|
-
|
1500
|
-
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
1501
|
-
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1502
|
-
|
1503
|
-
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
1504
|
-
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
1505
|
-
|
1506
|
-
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1507
|
-
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
1508
|
-
|
1509
|
-
// map by name
|
1510
|
-
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
1780
|
+
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
|
1511
1781
|
|
1512
|
-
|
1513
|
-
|
1782
|
+
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
|
1783
|
+
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
1514
1784
|
|
1515
|
-
|
1516
|
-
|
1785
|
+
model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
|
1786
|
+
model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
1517
1787
|
|
1518
|
-
|
1519
|
-
|
1788
|
+
model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
1789
|
+
model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
1520
1790
|
|
1521
|
-
|
1522
|
-
|
1791
|
+
for (int i = 0; i < n_audio_layer; ++i) {
|
1792
|
+
auto & layer = model.layers_encoder[i];
|
1523
1793
|
|
1524
|
-
|
1525
|
-
|
1794
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1795
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1526
1796
|
|
1527
|
-
|
1528
|
-
|
1797
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
1798
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
|
1529
1799
|
|
1530
|
-
|
1531
|
-
|
1800
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
1801
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1532
1802
|
|
1533
|
-
|
1534
|
-
|
1803
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1804
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1535
1805
|
|
1536
|
-
|
1537
|
-
|
1806
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1807
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1538
1808
|
|
1539
|
-
|
1809
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1540
1810
|
|
1541
|
-
|
1542
|
-
|
1811
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1812
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1543
1813
|
|
1544
|
-
|
1545
|
-
|
1546
|
-
|
1547
|
-
// map by name
|
1548
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
1549
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1550
|
-
|
1551
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
1552
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
1553
|
-
|
1554
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
1555
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
1556
|
-
|
1557
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
1558
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
1559
|
-
|
1560
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
1561
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
1562
|
-
|
1563
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
1564
|
-
|
1565
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
1566
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1567
|
-
|
1568
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
1569
|
-
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
1570
|
-
}
|
1814
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
1815
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
1571
1816
|
}
|
1572
1817
|
|
1573
1818
|
// decoder
|
1574
|
-
|
1575
|
-
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
|
1576
|
-
|
1577
|
-
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
1578
|
-
|
1579
|
-
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1580
|
-
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1581
|
-
|
1582
|
-
// map by name
|
1583
|
-
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
1584
|
-
|
1585
|
-
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
1586
|
-
|
1587
|
-
model.tensors["decoder.ln.weight"] = model.d_ln_w;
|
1588
|
-
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
1589
|
-
|
1590
|
-
for (int i = 0; i < n_text_layer; ++i) {
|
1591
|
-
auto & layer = model.layers_decoder[i];
|
1819
|
+
model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
|
1592
1820
|
|
1593
|
-
|
1594
|
-
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1821
|
+
model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
|
1595
1822
|
|
1596
|
-
|
1597
|
-
|
1823
|
+
model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
1824
|
+
model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
1598
1825
|
|
1599
|
-
|
1600
|
-
|
1826
|
+
for (int i = 0; i < n_text_layer; ++i) {
|
1827
|
+
auto & layer = model.layers_decoder[i];
|
1601
1828
|
|
1602
|
-
|
1603
|
-
|
1829
|
+
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1830
|
+
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1604
1831
|
|
1605
|
-
|
1606
|
-
|
1832
|
+
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
|
1833
|
+
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
|
1607
1834
|
|
1608
|
-
|
1835
|
+
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
|
1836
|
+
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1609
1837
|
|
1610
|
-
|
1611
|
-
|
1838
|
+
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1839
|
+
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1612
1840
|
|
1613
|
-
|
1614
|
-
|
1841
|
+
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1842
|
+
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1615
1843
|
|
1616
|
-
|
1617
|
-
layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
1844
|
+
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1618
1845
|
|
1619
|
-
|
1620
|
-
|
1846
|
+
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1847
|
+
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1621
1848
|
|
1622
|
-
|
1849
|
+
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1850
|
+
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1623
1851
|
|
1624
|
-
|
1625
|
-
|
1852
|
+
layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1853
|
+
layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1626
1854
|
|
1627
|
-
|
1628
|
-
|
1855
|
+
layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1856
|
+
layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1629
1857
|
|
1630
|
-
|
1631
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
1632
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
1858
|
+
layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1633
1859
|
|
1634
|
-
|
1635
|
-
|
1860
|
+
layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1861
|
+
layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1636
1862
|
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1640
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
1641
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
1642
|
-
|
1643
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
1644
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
1645
|
-
|
1646
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
1647
|
-
|
1648
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
1649
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
1650
|
-
|
1651
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
1652
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
1653
|
-
|
1654
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
|
1655
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
|
1656
|
-
|
1657
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
1658
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
1659
|
-
|
1660
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
1661
|
-
|
1662
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
1663
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
1664
|
-
|
1665
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
1666
|
-
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
1667
|
-
}
|
1863
|
+
layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
1864
|
+
layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
1668
1865
|
}
|
1669
|
-
}
|
1670
1866
|
|
1671
|
-
|
1672
|
-
if (!wctx.backend) {
|
1673
|
-
WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
|
1674
|
-
return false;
|
1867
|
+
ggml_free(ctx);
|
1675
1868
|
}
|
1676
1869
|
|
1677
1870
|
// allocate tensors in the backend buffers
|
1678
|
-
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1871
|
+
for (auto & p : ctx_map) {
|
1872
|
+
ggml_backend_buffer_type_t buft = p.first;
|
1873
|
+
ggml_context * ctx = p.second;
|
1874
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
1875
|
+
if (buf) {
|
1876
|
+
model.buffers.emplace_back(buf);
|
1683
1877
|
|
1684
|
-
|
1685
|
-
|
1878
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
1879
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
1880
|
+
}
|
1881
|
+
}
|
1686
1882
|
|
1687
1883
|
// load weights
|
1688
1884
|
{
|
@@ -1745,11 +1941,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1745
1941
|
return false;
|
1746
1942
|
}
|
1747
1943
|
|
1748
|
-
|
1749
|
-
|
1750
|
-
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
1751
|
-
|
1752
|
-
if (ggml_backend_buffer_is_host(model.buffer)) {
|
1944
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
1753
1945
|
// for the CPU and Metal backend, we can read directly into the tensor
|
1754
1946
|
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
1755
1947
|
BYTESWAP_TENSOR(tensor);
|
@@ -1762,7 +1954,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1762
1954
|
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
1763
1955
|
}
|
1764
1956
|
|
1765
|
-
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
|
1766
1957
|
total_size += ggml_nbytes(tensor);
|
1767
1958
|
model.n_loaded++;
|
1768
1959
|
}
|
@@ -1777,6 +1968,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
1777
1968
|
}
|
1778
1969
|
}
|
1779
1970
|
|
1971
|
+
for (auto & buf : model.buffers) {
|
1972
|
+
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
1973
|
+
}
|
1974
|
+
|
1780
1975
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
1781
1976
|
|
1782
1977
|
return true;
|
@@ -1812,8 +2007,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
1812
2007
|
const int n_mels = hparams.n_mels;
|
1813
2008
|
|
1814
2009
|
struct ggml_init_params params = {
|
1815
|
-
/*.mem_size =*/ wstate.
|
1816
|
-
/*.mem_buffer =*/ wstate.
|
2010
|
+
/*.mem_size =*/ wstate.sched_conv.meta.size(),
|
2011
|
+
/*.mem_buffer =*/ wstate.sched_conv.meta.data(),
|
1817
2012
|
/*.no_alloc =*/ true,
|
1818
2013
|
};
|
1819
2014
|
|
@@ -1847,6 +2042,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
1847
2042
|
ggml_build_forward_expand(gf, mel);
|
1848
2043
|
|
1849
2044
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
2045
|
+
ggml_set_input(cur); // the external encoder will write into this tensor
|
1850
2046
|
|
1851
2047
|
ggml_set_name(cur, "embd_enc");
|
1852
2048
|
wstate.embd_enc = cur;
|
@@ -1872,9 +2068,17 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1872
2068
|
const int n_head = hparams.n_audio_head;
|
1873
2069
|
const int n_layer = hparams.n_audio_layer;
|
1874
2070
|
|
2071
|
+
const int n_state_head = n_state/n_head;
|
2072
|
+
|
2073
|
+
auto & kv_pad = wstate.kv_pad;
|
2074
|
+
|
2075
|
+
WHISPER_ASSERT(!!kv_pad.buffer);
|
2076
|
+
|
2077
|
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
2078
|
+
|
1875
2079
|
struct ggml_init_params params = {
|
1876
|
-
/*.mem_size =*/ wstate.
|
1877
|
-
/*.mem_buffer =*/ wstate.
|
2080
|
+
/*.mem_size =*/ wstate.sched_encode.meta.size(),
|
2081
|
+
/*.mem_buffer =*/ wstate.sched_encode.meta.data(),
|
1878
2082
|
/*.no_alloc =*/ true,
|
1879
2083
|
};
|
1880
2084
|
|
@@ -1884,7 +2088,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1884
2088
|
|
1885
2089
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
1886
2090
|
|
1887
|
-
const float KQscale = 1.0f/sqrtf(float(
|
2091
|
+
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
1888
2092
|
|
1889
2093
|
// ===================================================================
|
1890
2094
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
@@ -1934,14 +2138,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1934
2138
|
|
1935
2139
|
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
1936
2140
|
|
1937
|
-
//Qcur = ggml_scale(ctx0, Qcur, pow(float(
|
2141
|
+
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
1938
2142
|
|
1939
2143
|
// note: no bias for Key
|
1940
2144
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
1941
2145
|
layer.attn_k_w,
|
1942
2146
|
cur);
|
1943
2147
|
|
1944
|
-
//Kcur = ggml_scale(ctx0, Kcur, pow(float(
|
2148
|
+
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
1945
2149
|
|
1946
2150
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
1947
2151
|
layer.attn_v_w,
|
@@ -1951,70 +2155,60 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
1951
2155
|
|
1952
2156
|
// ------
|
1953
2157
|
|
1954
|
-
#ifdef WHISPER_USE_FLASH_ATTN
|
1955
2158
|
struct ggml_tensor * Q =
|
1956
2159
|
ggml_permute(ctx0,
|
1957
|
-
|
1958
|
-
Qcur,
|
1959
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
2160
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
1960
2161
|
0, 2, 1, 3);
|
1961
2162
|
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1965
|
-
Kcur,
|
1966
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
1967
|
-
0, 2, 1, 3);
|
2163
|
+
if (wctx.params.flash_attn) {
|
2164
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
|
2165
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
|
1968
2166
|
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
1, 2, 0, 3),
|
1976
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
2167
|
+
struct ggml_tensor * K =
|
2168
|
+
ggml_view_3d(ctx0, kv_pad.k,
|
2169
|
+
n_state_head, n_ctx_pad, n_head,
|
2170
|
+
ggml_element_size(kv_pad.k)*n_state,
|
2171
|
+
ggml_element_size(kv_pad.k)*n_state_head,
|
2172
|
+
0);
|
1977
2173
|
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
1985
|
-
0, 2, 1, 3);
|
2174
|
+
struct ggml_tensor * V =
|
2175
|
+
ggml_view_3d(ctx0, kv_pad.v,
|
2176
|
+
n_state_head, n_ctx_pad, n_head,
|
2177
|
+
ggml_element_size(kv_pad.v)*n_state,
|
2178
|
+
ggml_element_size(kv_pad.v)*n_state_head,
|
2179
|
+
0);
|
1986
2180
|
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
2181
|
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
|
2182
|
+
|
2183
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
2184
|
+
} else {
|
2185
|
+
struct ggml_tensor * K =
|
2186
|
+
ggml_permute(ctx0,
|
2187
|
+
ggml_cast(ctx0,
|
2188
|
+
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
2189
|
+
wctx.itype),
|
2190
|
+
0, 2, 1, 3);
|
1993
2191
|
|
1994
|
-
|
1995
|
-
|
2192
|
+
// K * Q
|
2193
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
1996
2194
|
|
1997
|
-
|
2195
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
1998
2196
|
|
1999
|
-
|
2197
|
+
struct ggml_tensor * V =
|
2198
|
+
ggml_cast(ctx0,
|
2199
|
+
ggml_permute(ctx0,
|
2200
|
+
ggml_reshape_3d(ctx0,
|
2201
|
+
Vcur,
|
2202
|
+
n_state_head, n_head, n_ctx),
|
2203
|
+
1, 2, 0, 3),
|
2204
|
+
wctx.itype);
|
2000
2205
|
|
2001
|
-
|
2002
|
-
ggml_cpy(ctx0,
|
2003
|
-
ggml_permute(ctx0,
|
2004
|
-
ggml_reshape_3d(ctx0,
|
2005
|
-
Vcur,
|
2006
|
-
n_state/n_head, n_head, n_ctx),
|
2007
|
-
1, 2, 0, 3),
|
2008
|
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
2009
|
-
);
|
2206
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
2010
2207
|
|
2011
|
-
|
2012
|
-
#endif
|
2013
|
-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2208
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2014
2209
|
|
2015
|
-
|
2016
|
-
|
2017
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
2210
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
2211
|
+
}
|
2018
2212
|
}
|
2019
2213
|
|
2020
2214
|
// projection
|
@@ -2043,11 +2237,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
2043
2237
|
layer.mlp_ln_b);
|
2044
2238
|
}
|
2045
2239
|
|
2046
|
-
#ifdef WHISPER_USE_FLASH_FF
|
2047
|
-
cur = ggml_flash_ff(ctx0,
|
2048
|
-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
2049
|
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
2050
|
-
#else
|
2051
2240
|
// fully connected
|
2052
2241
|
cur = ggml_mul_mat(ctx0,
|
2053
2242
|
layer.mlp_0_w,
|
@@ -2064,7 +2253,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
2064
2253
|
cur);
|
2065
2254
|
|
2066
2255
|
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
2067
|
-
#endif
|
2068
2256
|
}
|
2069
2257
|
|
2070
2258
|
inpL = ggml_add(ctx0, cur, inpFF);
|
@@ -2113,9 +2301,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2113
2301
|
const int n_state = hparams.n_audio_state;
|
2114
2302
|
const int n_head = hparams.n_audio_head;
|
2115
2303
|
|
2304
|
+
const int n_state_head = n_state/n_head;
|
2305
|
+
|
2306
|
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
2307
|
+
|
2116
2308
|
struct ggml_init_params params = {
|
2117
|
-
/*.mem_size =*/ wstate.
|
2118
|
-
/*.mem_buffer =*/ wstate.
|
2309
|
+
/*.mem_size =*/ wstate.sched_cross.meta.size(),
|
2310
|
+
/*.mem_buffer =*/ wstate.sched_cross.meta.data(),
|
2119
2311
|
/*.no_alloc =*/ true,
|
2120
2312
|
};
|
2121
2313
|
|
@@ -2125,18 +2317,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2125
2317
|
|
2126
2318
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
2127
2319
|
|
2128
|
-
const float Kscale = pow(float(
|
2320
|
+
const float Kscale = pow(float(n_state_head), -0.25);
|
2129
2321
|
|
2130
2322
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
2131
2323
|
auto & layer = model.layers_decoder[il];
|
2132
2324
|
|
2133
|
-
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
2325
|
+
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
|
2134
2326
|
layer.cross_attn_k_w,
|
2135
2327
|
cur);
|
2136
2328
|
|
2137
2329
|
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
2138
2330
|
|
2139
|
-
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
2331
|
+
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
2140
2332
|
layer.cross_attn_v_w,
|
2141
2333
|
cur);
|
2142
2334
|
|
@@ -2144,15 +2336,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
2144
2336
|
Vcross,
|
2145
2337
|
layer.cross_attn_v_b);
|
2146
2338
|
|
2147
|
-
|
2339
|
+
struct ggml_tensor * k;
|
2340
|
+
struct ggml_tensor * v;
|
2341
|
+
|
2342
|
+
if (wctx.params.flash_attn) {
|
2343
|
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
2344
|
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
2148
2345
|
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2346
|
+
v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
|
2347
|
+
(ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
|
2348
|
+
} else {
|
2349
|
+
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
2350
|
+
|
2351
|
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
2352
|
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
2152
2353
|
|
2153
|
-
|
2154
|
-
|
2155
|
-
|
2354
|
+
v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
2355
|
+
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
2356
|
+
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
2357
|
+
}
|
2156
2358
|
|
2157
2359
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
2158
2360
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
@@ -2186,11 +2388,11 @@ static bool whisper_encode_internal(
|
|
2186
2388
|
|
2187
2389
|
// conv
|
2188
2390
|
{
|
2189
|
-
auto &
|
2391
|
+
auto & sched = wstate.sched_conv.sched;
|
2190
2392
|
|
2191
2393
|
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
|
2192
2394
|
|
2193
|
-
if (!
|
2395
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2194
2396
|
// should never happen as we pre-allocate the memory
|
2195
2397
|
return false;
|
2196
2398
|
}
|
@@ -2223,7 +2425,7 @@ static bool whisper_encode_internal(
|
|
2223
2425
|
}
|
2224
2426
|
|
2225
2427
|
if (!whisper_encode_external(wstate)) {
|
2226
|
-
if (!ggml_graph_compute_helper(
|
2428
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2227
2429
|
return false;
|
2228
2430
|
}
|
2229
2431
|
} else {
|
@@ -2237,32 +2439,32 @@ static bool whisper_encode_internal(
|
|
2237
2439
|
|
2238
2440
|
// encoder
|
2239
2441
|
if (!whisper_encode_external(wstate)) {
|
2240
|
-
auto &
|
2442
|
+
auto & sched = wstate.sched_encode.sched;
|
2241
2443
|
|
2242
2444
|
ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
|
2243
2445
|
|
2244
|
-
if (!
|
2446
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2245
2447
|
// should never happen as we pre-allocate the memory
|
2246
2448
|
return false;
|
2247
2449
|
}
|
2248
2450
|
|
2249
|
-
if (!ggml_graph_compute_helper(
|
2451
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2250
2452
|
return false;
|
2251
2453
|
}
|
2252
2454
|
}
|
2253
2455
|
|
2254
2456
|
// cross
|
2255
2457
|
{
|
2256
|
-
auto &
|
2458
|
+
auto & sched = wstate.sched_cross.sched;
|
2257
2459
|
|
2258
2460
|
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
2259
2461
|
|
2260
|
-
if (!
|
2462
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2261
2463
|
// should never happen as we pre-allocate the memory
|
2262
2464
|
return false;
|
2263
2465
|
}
|
2264
2466
|
|
2265
|
-
if (!ggml_graph_compute_helper(
|
2467
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2266
2468
|
return false;
|
2267
2469
|
}
|
2268
2470
|
}
|
@@ -2284,24 +2486,28 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2284
2486
|
|
2285
2487
|
auto & kv_self = wstate.kv_self;
|
2286
2488
|
|
2287
|
-
WHISPER_ASSERT(!!kv_self.
|
2489
|
+
WHISPER_ASSERT(!!kv_self.buffer);
|
2288
2490
|
|
2289
2491
|
const int n_ctx = kv_self.size;
|
2290
2492
|
const int n_state = hparams.n_text_state;
|
2291
2493
|
const int n_head = hparams.n_text_head;
|
2292
2494
|
const int n_layer = hparams.n_text_layer;
|
2293
2495
|
|
2496
|
+
const int n_state_head = n_state/n_head;
|
2497
|
+
|
2294
2498
|
const int n_tokens = batch.n_tokens;
|
2295
2499
|
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
2296
2500
|
|
2297
|
-
const
|
2298
|
-
|
2501
|
+
const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
|
2502
|
+
|
2503
|
+
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
2504
|
+
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
2299
2505
|
|
2300
2506
|
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
2301
2507
|
|
2302
2508
|
struct ggml_init_params params = {
|
2303
|
-
/*.mem_size =*/ wstate.
|
2304
|
-
/*.mem_buffer =*/ wstate.
|
2509
|
+
/*.mem_size =*/ wstate.sched_decode.meta.size(),
|
2510
|
+
/*.mem_buffer =*/ wstate.sched_decode.meta.data(),
|
2305
2511
|
/*.no_alloc =*/ true,
|
2306
2512
|
};
|
2307
2513
|
|
@@ -2317,12 +2523,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2317
2523
|
ggml_set_name(position, "position");
|
2318
2524
|
ggml_set_input(position);
|
2319
2525
|
|
2320
|
-
const float KQscale = pow(float(
|
2526
|
+
const float KQscale = pow(float(n_state_head), -0.25);
|
2321
2527
|
|
2322
|
-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
2528
|
+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
|
2323
2529
|
ggml_set_name(KQ_mask, "KQ_mask");
|
2324
2530
|
ggml_set_input(KQ_mask);
|
2325
2531
|
|
2532
|
+
struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
|
2533
|
+
|
2326
2534
|
// token encoding + position encoding
|
2327
2535
|
struct ggml_tensor * cur =
|
2328
2536
|
ggml_add(ctx0,
|
@@ -2378,12 +2586,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2378
2586
|
Vcur,
|
2379
2587
|
layer.attn_v_b);
|
2380
2588
|
|
2381
|
-
|
2589
|
+
struct ggml_tensor * k;
|
2590
|
+
struct ggml_tensor * v;
|
2591
|
+
|
2592
|
+
if (wctx.params.flash_attn) {
|
2593
|
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
2594
|
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
2595
|
+
|
2596
|
+
v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
|
2597
|
+
(ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
|
2598
|
+
} else {
|
2599
|
+
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
2600
|
+
|
2601
|
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
2602
|
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
2382
2603
|
|
2383
|
-
|
2384
|
-
|
2385
|
-
|
2386
|
-
|
2604
|
+
v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
2605
|
+
( n_ctx)*ggml_element_size(kv_self.v),
|
2606
|
+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
|
2607
|
+
}
|
2387
2608
|
|
2388
2609
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
2389
2610
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
@@ -2393,40 +2614,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2393
2614
|
|
2394
2615
|
struct ggml_tensor * Q =
|
2395
2616
|
ggml_permute(ctx0,
|
2396
|
-
ggml_reshape_3d(ctx0, Qcur,
|
2617
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
2397
2618
|
0, 2, 1, 3);
|
2398
2619
|
|
2399
2620
|
struct ggml_tensor * K =
|
2400
2621
|
ggml_view_3d(ctx0, kv_self.k,
|
2401
|
-
|
2622
|
+
n_state_head, n_kv, n_head,
|
2402
2623
|
ggml_element_size(kv_self.k)*n_state,
|
2403
|
-
ggml_element_size(kv_self.k)*
|
2624
|
+
ggml_element_size(kv_self.k)*n_state_head,
|
2404
2625
|
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
2405
2626
|
|
2406
|
-
|
2407
|
-
|
2627
|
+
if (wctx.params.flash_attn) {
|
2628
|
+
struct ggml_tensor * V =
|
2629
|
+
ggml_view_3d(ctx0, kv_self.v,
|
2630
|
+
n_state_head, n_kv, n_head,
|
2631
|
+
ggml_element_size(kv_self.v)*n_state,
|
2632
|
+
ggml_element_size(kv_self.v)*n_state_head,
|
2633
|
+
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
2408
2634
|
|
2409
|
-
|
2635
|
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
|
2410
2636
|
|
2411
|
-
|
2412
|
-
|
2637
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
2638
|
+
} else {
|
2639
|
+
// K * Q
|
2640
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
2413
2641
|
|
2414
|
-
|
2642
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
|
2415
2643
|
|
2416
|
-
|
2417
|
-
|
2418
|
-
|
2419
|
-
|
2420
|
-
|
2421
|
-
|
2644
|
+
struct ggml_tensor * V =
|
2645
|
+
ggml_view_3d(ctx0, kv_self.v,
|
2646
|
+
n_kv, n_state_head, n_head,
|
2647
|
+
n_ctx*ggml_element_size(kv_self.v),
|
2648
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
|
2649
|
+
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
2422
2650
|
|
2423
|
-
|
2651
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
2424
2652
|
|
2425
|
-
|
2653
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2426
2654
|
|
2427
|
-
|
2428
|
-
|
2429
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
2655
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
2656
|
+
}
|
2430
2657
|
}
|
2431
2658
|
|
2432
2659
|
// projection
|
@@ -2465,80 +2692,75 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
2465
2692
|
Qcur,
|
2466
2693
|
layer.cross_attn_q_b);
|
2467
2694
|
|
2468
|
-
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
2469
|
-
|
2470
|
-
// Kcross is already scaled
|
2471
|
-
struct ggml_tensor * Kcross =
|
2472
|
-
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2473
|
-
n_state/n_head, n_audio_ctx, n_head,
|
2474
|
-
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2475
|
-
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
2476
|
-
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
2477
|
-
|
2478
|
-
//struct ggml_tensor * Vcross =
|
2479
|
-
// ggml_reshape_3d(ctx0,
|
2480
|
-
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
2481
|
-
// n_state/n_head, n_head, n_audio_ctx);
|
2482
|
-
|
2483
|
-
//struct ggml_tensor * V_trans =
|
2484
|
-
// ggml_cpy(ctx0,
|
2485
|
-
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
2486
|
-
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
2487
|
-
|
2488
|
-
struct ggml_tensor * V =
|
2489
|
-
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2490
|
-
n_audio_ctx, n_state/n_head, n_head,
|
2491
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
2492
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
2493
|
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
2494
|
-
|
2495
|
-
// ------
|
2496
|
-
|
2497
2695
|
struct ggml_tensor * Q =
|
2498
2696
|
ggml_permute(ctx0,
|
2499
|
-
ggml_reshape_3d(ctx0, Qcur,
|
2697
|
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
2500
2698
|
0, 2, 1, 3);
|
2501
2699
|
|
2502
|
-
|
2503
|
-
|
2700
|
+
if (wctx.params.flash_attn) {
|
2701
|
+
struct ggml_tensor * Kcross =
|
2702
|
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2703
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
2704
|
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2705
|
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
2706
|
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
|
2504
2707
|
|
2505
|
-
|
2506
|
-
|
2507
|
-
|
2508
|
-
|
2509
|
-
|
2708
|
+
struct ggml_tensor * Vcross =
|
2709
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2710
|
+
n_state_head, n_audio_ctx_pad, n_head,
|
2711
|
+
ggml_element_size(wstate.kv_cross.v)*n_state,
|
2712
|
+
ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
2713
|
+
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
2510
2714
|
|
2511
|
-
|
2512
|
-
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
2715
|
+
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
|
2513
2716
|
|
2514
|
-
|
2515
|
-
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2717
|
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
2718
|
+
} else {
|
2719
|
+
struct ggml_tensor * Kcross =
|
2720
|
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
2721
|
+
n_state_head, n_audio_ctx, n_head,
|
2722
|
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
2723
|
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
2724
|
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
2725
|
+
|
2726
|
+
struct ggml_tensor * Vcross =
|
2727
|
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
2728
|
+
n_audio_ctx, n_state_head, n_head,
|
2729
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
2730
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
2731
|
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
2732
|
+
|
2733
|
+
// ------
|
2734
|
+
|
2735
|
+
// K * Q
|
2736
|
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
2737
|
+
|
2738
|
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
2739
|
+
|
2740
|
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
2741
|
+
if (wctx.params.dtw_token_timestamps) {
|
2742
|
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
2743
|
+
struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
2744
|
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
2745
|
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
2746
|
+
aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
2747
|
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
2748
|
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
2749
|
+
aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
2750
|
+
if (aheads_cross_QKs == NULL) {
|
2751
|
+
aheads_cross_QKs = aheads_KQs;
|
2752
|
+
} else {
|
2753
|
+
aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
|
2754
|
+
}
|
2530
2755
|
}
|
2531
2756
|
}
|
2532
|
-
}
|
2533
2757
|
|
2534
|
-
|
2758
|
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
|
2535
2759
|
|
2536
|
-
|
2760
|
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
2537
2761
|
|
2538
|
-
|
2539
|
-
|
2540
|
-
KQV_merged,
|
2541
|
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
2762
|
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
2763
|
+
}
|
2542
2764
|
}
|
2543
2765
|
|
2544
2766
|
// projection
|
@@ -2671,18 +2893,20 @@ static bool whisper_decode_internal(
|
|
2671
2893
|
return false;
|
2672
2894
|
}
|
2673
2895
|
|
2674
|
-
|
2896
|
+
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
|
2897
|
+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
|
2898
|
+
|
2675
2899
|
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
2676
2900
|
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
2677
2901
|
}
|
2678
2902
|
|
2679
2903
|
// decoder
|
2680
2904
|
{
|
2681
|
-
auto &
|
2905
|
+
auto & sched = wstate.sched_decode.sched;
|
2682
2906
|
|
2683
2907
|
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
|
2684
2908
|
|
2685
|
-
if (!
|
2909
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
2686
2910
|
// should never happen as we pre-allocate the memory
|
2687
2911
|
return false;
|
2688
2912
|
}
|
@@ -2705,9 +2929,10 @@ static bool whisper_decode_internal(
|
|
2705
2929
|
struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
|
2706
2930
|
|
2707
2931
|
auto & kv_self = wstate.kv_self;
|
2708
|
-
const int32_t n_kv = kv_self.n;
|
2709
2932
|
|
2710
|
-
|
2933
|
+
const int32_t n_kv = kv_self.n;
|
2934
|
+
|
2935
|
+
wstate.inp_mask.resize(ggml_nelements(KQ_mask));
|
2711
2936
|
|
2712
2937
|
float * data = wstate.inp_mask.data();
|
2713
2938
|
memset(data, 0, ggml_nbytes(KQ_mask));
|
@@ -2723,14 +2948,20 @@ static bool whisper_decode_internal(
|
|
2723
2948
|
}
|
2724
2949
|
}
|
2725
2950
|
}
|
2951
|
+
|
2952
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
2953
|
+
for (int j = 0; j < n_kv; ++j) {
|
2954
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
2955
|
+
}
|
2956
|
+
}
|
2726
2957
|
}
|
2727
2958
|
|
2728
2959
|
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
2729
2960
|
}
|
2730
2961
|
|
2731
|
-
logits = gf
|
2962
|
+
logits = ggml_graph_node(gf, -1);
|
2732
2963
|
|
2733
|
-
if (!ggml_graph_compute_helper(
|
2964
|
+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
2734
2965
|
return false;
|
2735
2966
|
}
|
2736
2967
|
}
|
@@ -2784,29 +3015,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
2784
3015
|
}
|
2785
3016
|
|
2786
3017
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
2787
|
-
|
2788
|
-
|
3018
|
+
namespace {
|
3019
|
+
struct whisper_global_cache {
|
3020
|
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
3021
|
+
// We can use precalculated values to speed up the process.
|
3022
|
+
float sin_vals[SIN_COS_N_COUNT];
|
3023
|
+
float cos_vals[SIN_COS_N_COUNT];
|
3024
|
+
|
3025
|
+
// Hann window (Use cosf to eliminate difference)
|
3026
|
+
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
3027
|
+
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
3028
|
+
float hann_window[WHISPER_N_FFT];
|
3029
|
+
|
3030
|
+
whisper_global_cache() {
|
3031
|
+
fill_sin_cos_table();
|
3032
|
+
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
3033
|
+
}
|
3034
|
+
|
3035
|
+
void fill_sin_cos_table() {
|
3036
|
+
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
3037
|
+
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
3038
|
+
sin_vals[i] = sinf(theta);
|
3039
|
+
cos_vals[i] = cosf(theta);
|
3040
|
+
}
|
3041
|
+
}
|
2789
3042
|
|
2790
|
-
|
2791
|
-
|
2792
|
-
|
2793
|
-
|
2794
|
-
|
2795
|
-
|
2796
|
-
|
2797
|
-
|
2798
|
-
cos_vals[i] = cosf(theta);
|
3043
|
+
void fill_hann_window(int length, bool periodic, float * output) {
|
3044
|
+
int offset = -1;
|
3045
|
+
if (periodic) {
|
3046
|
+
offset = 0;
|
3047
|
+
}
|
3048
|
+
for (int i = 0; i < length; i++) {
|
3049
|
+
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
3050
|
+
}
|
2799
3051
|
}
|
2800
|
-
|
3052
|
+
} global_cache;
|
2801
3053
|
}
|
2802
3054
|
|
2803
3055
|
// naive Discrete Fourier Transform
|
2804
3056
|
// input is real-valued
|
2805
3057
|
// output is complex-valued
|
2806
|
-
static void dft(const
|
2807
|
-
int N = in.size();
|
2808
|
-
|
2809
|
-
out.resize(N*2);
|
3058
|
+
static void dft(const float* in, int N, float* out) {
|
2810
3059
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
2811
3060
|
|
2812
3061
|
for (int k = 0; k < N; k++) {
|
@@ -2815,8 +3064,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2815
3064
|
|
2816
3065
|
for (int n = 0; n < N; n++) {
|
2817
3066
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
2818
|
-
re += in[n]*cos_vals[idx]; // cos(t)
|
2819
|
-
im -= in[n]*sin_vals[idx]; // sin(t)
|
3067
|
+
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
3068
|
+
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
2820
3069
|
}
|
2821
3070
|
|
2822
3071
|
out[k*2 + 0] = re;
|
@@ -2828,47 +3077,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2828
3077
|
// poor man's implementation - use something better
|
2829
3078
|
// input is real-valued
|
2830
3079
|
// output is complex-valued
|
2831
|
-
static void fft(
|
2832
|
-
out.resize(in.size()*2);
|
2833
|
-
|
2834
|
-
int N = in.size();
|
2835
|
-
|
3080
|
+
static void fft(float* in, int N, float* out) {
|
2836
3081
|
if (N == 1) {
|
2837
3082
|
out[0] = in[0];
|
2838
3083
|
out[1] = 0;
|
2839
3084
|
return;
|
2840
3085
|
}
|
2841
3086
|
|
2842
|
-
|
2843
|
-
|
3087
|
+
const int half_N = N / 2;
|
3088
|
+
if (N - half_N*2 == 1) {
|
3089
|
+
dft(in, N, out);
|
2844
3090
|
return;
|
2845
3091
|
}
|
2846
3092
|
|
2847
|
-
|
2848
|
-
|
2849
|
-
|
2850
|
-
even.reserve(N/2);
|
2851
|
-
odd.reserve(N/2);
|
2852
|
-
|
2853
|
-
for (int i = 0; i < N; i++) {
|
2854
|
-
if (i % 2 == 0) {
|
2855
|
-
even.push_back(in[i]);
|
2856
|
-
} else {
|
2857
|
-
odd.push_back(in[i]);
|
2858
|
-
}
|
3093
|
+
float* even = in + N;
|
3094
|
+
for (int i = 0; i < half_N; ++i) {
|
3095
|
+
even[i]= in[2*i];
|
2859
3096
|
}
|
3097
|
+
float* even_fft = out + 2 * N;
|
3098
|
+
fft(even, half_N, even_fft);
|
2860
3099
|
|
2861
|
-
|
2862
|
-
|
2863
|
-
|
2864
|
-
|
2865
|
-
|
3100
|
+
float* odd = even;
|
3101
|
+
for (int i = 0; i < half_N; ++i) {
|
3102
|
+
odd[i] = in[2*i + 1];
|
3103
|
+
}
|
3104
|
+
float* odd_fft = even_fft + N;
|
3105
|
+
fft(odd, half_N, odd_fft);
|
2866
3106
|
|
2867
3107
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
2868
|
-
for (int k = 0; k <
|
3108
|
+
for (int k = 0; k < half_N; k++) {
|
2869
3109
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
2870
|
-
float re = cos_vals[idx]; // cos(t)
|
2871
|
-
float im = -sin_vals[idx]; // sin(t)
|
3110
|
+
float re = global_cache.cos_vals[idx]; // cos(t)
|
3111
|
+
float im = -global_cache.sin_vals[idx]; // sin(t)
|
2872
3112
|
|
2873
3113
|
float re_odd = odd_fft[2*k + 0];
|
2874
3114
|
float im_odd = odd_fft[2*k + 1];
|
@@ -2876,52 +3116,39 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
2876
3116
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
2877
3117
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
2878
3118
|
|
2879
|
-
out[2*(k +
|
2880
|
-
out[2*(k +
|
2881
|
-
}
|
2882
|
-
}
|
2883
|
-
|
2884
|
-
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
|
2885
|
-
if (output.size() < static_cast<size_t>(length)) {
|
2886
|
-
output.resize(length);
|
3119
|
+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
3120
|
+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
2887
3121
|
}
|
2888
|
-
int offset = -1;
|
2889
|
-
if (periodic) {
|
2890
|
-
offset = 0;
|
2891
|
-
}
|
2892
|
-
for (int i = 0; i < length; i++) {
|
2893
|
-
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
|
2894
|
-
}
|
2895
|
-
|
2896
|
-
return true;
|
2897
3122
|
}
|
2898
3123
|
|
2899
|
-
static void log_mel_spectrogram_worker_thread(int ith, const
|
3124
|
+
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
2900
3125
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
2901
3126
|
const whisper_filters & filters, whisper_mel & mel) {
|
2902
|
-
std::vector<float> fft_in(frame_size, 0.0);
|
2903
|
-
std::vector<float> fft_out(2 *
|
3127
|
+
std::vector<float> fft_in(frame_size * 2, 0.0);
|
3128
|
+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
3129
|
+
|
2904
3130
|
int n_fft = filters.n_fft;
|
2905
3131
|
int i = ith;
|
2906
3132
|
|
2907
3133
|
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
2908
|
-
assert(
|
2909
|
-
|
3134
|
+
assert(n_fft == 1 + (frame_size / 2));
|
3135
|
+
|
2910
3136
|
// calculate FFT only when fft_in are not all zero
|
2911
3137
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
2912
3138
|
const int offset = i * frame_step;
|
2913
3139
|
|
2914
|
-
// apply
|
3140
|
+
// apply Hann window (~10% faster)
|
2915
3141
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
2916
3142
|
fft_in[j] = hann[j] * samples[offset + j];
|
2917
3143
|
}
|
3144
|
+
|
2918
3145
|
// fill the rest with zeros
|
2919
3146
|
if (n_samples - offset < frame_size) {
|
2920
3147
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
2921
3148
|
}
|
2922
3149
|
|
2923
3150
|
// FFT
|
2924
|
-
fft(fft_in, fft_out);
|
3151
|
+
fft(fft_in.data(), frame_size, fft_out.data());
|
2925
3152
|
|
2926
3153
|
// Calculate modulus^2 of complex numbers
|
2927
3154
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
@@ -2932,7 +3159,6 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
2932
3159
|
// mel spectrogram
|
2933
3160
|
for (int j = 0; j < mel.n_mel; j++) {
|
2934
3161
|
double sum = 0.0;
|
2935
|
-
|
2936
3162
|
// unroll loop (suggested by GH user @lunixbochs)
|
2937
3163
|
int k = 0;
|
2938
3164
|
for (k = 0; k < n_fft - 3; k += 4) {
|
@@ -2942,14 +3168,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|
2942
3168
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
2943
3169
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
2944
3170
|
}
|
2945
|
-
|
2946
3171
|
// handle n_fft remainder
|
2947
3172
|
for (; k < n_fft; k++) {
|
2948
3173
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
2949
3174
|
}
|
2950
|
-
|
2951
3175
|
sum = log10(std::max(sum, 1e-10));
|
2952
|
-
|
2953
3176
|
mel.data[j * mel.n_len + i] = sum;
|
2954
3177
|
}
|
2955
3178
|
}
|
@@ -2978,12 +3201,9 @@ static bool log_mel_spectrogram(
|
|
2978
3201
|
whisper_mel & mel) {
|
2979
3202
|
const int64_t t_start_us = ggml_time_us();
|
2980
3203
|
|
2981
|
-
//
|
2982
|
-
|
2983
|
-
|
2984
|
-
std::vector<float> hann;
|
2985
|
-
hann_window(frame_size, true, hann);
|
2986
|
-
|
3204
|
+
// Hann window
|
3205
|
+
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
3206
|
+
const float * hann = global_cache.hann_window;
|
2987
3207
|
|
2988
3208
|
// Calculate the length of padding
|
2989
3209
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
@@ -3008,12 +3228,11 @@ static bool log_mel_spectrogram(
|
|
3008
3228
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
3009
3229
|
mel.data.resize(mel.n_mel * mel.n_len);
|
3010
3230
|
|
3011
|
-
|
3012
3231
|
{
|
3013
3232
|
std::vector<std::thread> workers(n_threads - 1);
|
3014
3233
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
3015
3234
|
workers[iw] = std::thread(
|
3016
|
-
log_mel_spectrogram_worker_thread, iw + 1, std::cref(
|
3235
|
+
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
3017
3236
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
3018
3237
|
std::cref(filters), std::ref(mel));
|
3019
3238
|
}
|
@@ -3173,23 +3392,23 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
3173
3392
|
#endif
|
3174
3393
|
|
3175
3394
|
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
3176
|
-
fill_sin_cos_table();
|
3177
|
-
|
3178
3395
|
whisper_state * state = new whisper_state;
|
3179
3396
|
|
3180
|
-
state->
|
3181
|
-
if (
|
3397
|
+
state->backends = whisper_backend_init(ctx->params);
|
3398
|
+
if (state->backends.empty()) {
|
3182
3399
|
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
3183
3400
|
whisper_free_state(state);
|
3184
3401
|
return nullptr;
|
3185
3402
|
}
|
3186
3403
|
|
3187
|
-
// at this point, we don't know yet how many decoders will be used
|
3188
|
-
//
|
3189
|
-
|
3190
|
-
|
3191
|
-
|
3192
|
-
|
3404
|
+
// at this point, we don't know yet how many decoders will be used
|
3405
|
+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
3406
|
+
state->kv_self_n_dec = 1;
|
3407
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
3408
|
+
ctx->model.hparams.n_text_state,
|
3409
|
+
ctx->model.hparams.n_text_layer,
|
3410
|
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
3411
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
3193
3412
|
whisper_free_state(state);
|
3194
3413
|
return nullptr;
|
3195
3414
|
}
|
@@ -3199,8 +3418,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3199
3418
|
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3200
3419
|
}
|
3201
3420
|
|
3202
|
-
if (!
|
3203
|
-
|
3421
|
+
if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
|
3422
|
+
ctx->model.hparams.n_text_state,
|
3423
|
+
ctx->model.hparams.n_text_layer,
|
3424
|
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
3425
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
|
3204
3426
|
whisper_free_state(state);
|
3205
3427
|
return nullptr;
|
3206
3428
|
}
|
@@ -3210,9 +3432,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3210
3432
|
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3211
3433
|
}
|
3212
3434
|
|
3435
|
+
if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
|
3436
|
+
ctx->model.hparams.n_audio_state,
|
3437
|
+
1,
|
3438
|
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
3439
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
3440
|
+
whisper_free_state(state);
|
3441
|
+
return nullptr;
|
3442
|
+
}
|
3443
|
+
|
3444
|
+
{
|
3445
|
+
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
|
3446
|
+
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
3447
|
+
}
|
3448
|
+
|
3213
3449
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3214
3450
|
if (ctx->params.dtw_token_timestamps) {
|
3215
|
-
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks,
|
3451
|
+
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
|
3216
3452
|
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
|
3217
3453
|
whisper_free_state(state);
|
3218
3454
|
return nullptr;
|
@@ -3255,7 +3491,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3255
3491
|
|
3256
3492
|
// conv allocator
|
3257
3493
|
{
|
3258
|
-
bool ok =
|
3494
|
+
bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
|
3259
3495
|
[&]() {
|
3260
3496
|
return whisper_build_graph_conv(*ctx, *state);
|
3261
3497
|
});
|
@@ -3266,12 +3502,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3266
3502
|
return nullptr;
|
3267
3503
|
}
|
3268
3504
|
|
3269
|
-
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__,
|
3505
|
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
|
3270
3506
|
}
|
3271
3507
|
|
3272
3508
|
// encoder allocator
|
3273
3509
|
if (!whisper_encode_external(*state)) {
|
3274
|
-
bool ok =
|
3510
|
+
bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
|
3275
3511
|
[&]() {
|
3276
3512
|
return whisper_build_graph_encoder(*ctx, *state);
|
3277
3513
|
});
|
@@ -3282,12 +3518,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3282
3518
|
return nullptr;
|
3283
3519
|
}
|
3284
3520
|
|
3285
|
-
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__,
|
3521
|
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
|
3286
3522
|
}
|
3287
3523
|
|
3288
3524
|
// cross allocator
|
3289
3525
|
{
|
3290
|
-
bool ok =
|
3526
|
+
bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
|
3291
3527
|
[&]() {
|
3292
3528
|
return whisper_build_graph_cross(*ctx, *state);
|
3293
3529
|
});
|
@@ -3298,12 +3534,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3298
3534
|
return nullptr;
|
3299
3535
|
}
|
3300
3536
|
|
3301
|
-
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__,
|
3537
|
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
|
3302
3538
|
}
|
3303
3539
|
|
3304
3540
|
// decoder allocator
|
3305
3541
|
{
|
3306
|
-
bool ok =
|
3542
|
+
bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
|
3307
3543
|
[&]() {
|
3308
3544
|
const auto & hparams = ctx->model.hparams;
|
3309
3545
|
|
@@ -3322,19 +3558,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
3322
3558
|
return nullptr;
|
3323
3559
|
}
|
3324
3560
|
|
3325
|
-
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__,
|
3561
|
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
|
3326
3562
|
}
|
3327
3563
|
|
3328
3564
|
return state;
|
3329
3565
|
}
|
3330
3566
|
|
3331
|
-
int
|
3567
|
+
int whisper_ctx_init_openvino_encoder_with_state(
|
3332
3568
|
struct whisper_context * ctx,
|
3569
|
+
struct whisper_state * state,
|
3333
3570
|
const char * model_path,
|
3334
3571
|
const char * device,
|
3335
3572
|
const char * cache_dir) {
|
3336
3573
|
#ifndef WHISPER_USE_OPENVINO
|
3337
3574
|
(void)(ctx);
|
3575
|
+
(void)(state);
|
3338
3576
|
(void)(model_path);
|
3339
3577
|
(void)(device);
|
3340
3578
|
(void)(cache_dir);
|
@@ -3365,8 +3603,8 @@ int whisper_ctx_init_openvino_encoder(
|
|
3365
3603
|
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
3366
3604
|
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
3367
3605
|
|
3368
|
-
|
3369
|
-
if (!
|
3606
|
+
state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
3607
|
+
if (!state->ctx_openvino) {
|
3370
3608
|
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
3371
3609
|
return 1;
|
3372
3610
|
} else {
|
@@ -3377,9 +3615,18 @@ int whisper_ctx_init_openvino_encoder(
|
|
3377
3615
|
#endif
|
3378
3616
|
}
|
3379
3617
|
|
3618
|
+
int whisper_ctx_init_openvino_encoder(
|
3619
|
+
struct whisper_context * ctx,
|
3620
|
+
const char * model_path,
|
3621
|
+
const char * device,
|
3622
|
+
const char * cache_dir) {
|
3623
|
+
return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
|
3624
|
+
}
|
3625
|
+
|
3380
3626
|
struct whisper_context_params whisper_context_default_params() {
|
3381
3627
|
struct whisper_context_params result = {
|
3382
3628
|
/*.use_gpu =*/ true,
|
3629
|
+
/*.flash_attn =*/ false,
|
3383
3630
|
/*.gpu_device =*/ 0,
|
3384
3631
|
|
3385
3632
|
/*.dtw_token_timestamps =*/ false,
|
@@ -3396,8 +3643,14 @@ struct whisper_context_params whisper_context_default_params() {
|
|
3396
3643
|
|
3397
3644
|
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
3398
3645
|
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
3399
|
-
|
3646
|
+
#ifdef _MSC_VER
|
3647
|
+
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
|
3648
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
3649
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
3650
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
3651
|
+
#else
|
3400
3652
|
auto fin = std::ifstream(path_model, std::ios::binary);
|
3653
|
+
#endif
|
3401
3654
|
if (!fin) {
|
3402
3655
|
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
3403
3656
|
return nullptr;
|
@@ -3472,6 +3725,18 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
3472
3725
|
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
3473
3726
|
ggml_time_init();
|
3474
3727
|
|
3728
|
+
if (params.flash_attn && params.dtw_token_timestamps) {
|
3729
|
+
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
3730
|
+
params.dtw_token_timestamps = false;
|
3731
|
+
}
|
3732
|
+
|
3733
|
+
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
3734
|
+
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
3735
|
+
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
3736
|
+
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
3737
|
+
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
|
3738
|
+
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
|
3739
|
+
|
3475
3740
|
whisper_context * ctx = new whisper_context;
|
3476
3741
|
ctx->params = params;
|
3477
3742
|
|
@@ -3558,8 +3823,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
|
|
3558
3823
|
|
3559
3824
|
void whisper_free_state(struct whisper_state * state) {
|
3560
3825
|
if (state) {
|
3561
|
-
|
3562
|
-
|
3826
|
+
whisper_kv_cache_free(state->kv_self);
|
3827
|
+
whisper_kv_cache_free(state->kv_cross);
|
3828
|
+
whisper_kv_cache_free(state->kv_pad);
|
3563
3829
|
|
3564
3830
|
#ifdef WHISPER_USE_COREML
|
3565
3831
|
if (state->ctx_coreml != nullptr) {
|
@@ -3577,30 +3843,39 @@ void whisper_free_state(struct whisper_state * state) {
|
|
3577
3843
|
|
3578
3844
|
whisper_batch_free(state->batch);
|
3579
3845
|
|
3580
|
-
|
3581
|
-
|
3582
|
-
|
3583
|
-
|
3846
|
+
ggml_backend_sched_free(state->sched_conv.sched);
|
3847
|
+
ggml_backend_sched_free(state->sched_encode.sched);
|
3848
|
+
ggml_backend_sched_free(state->sched_cross.sched);
|
3849
|
+
ggml_backend_sched_free(state->sched_decode.sched);
|
3584
3850
|
|
3585
|
-
|
3851
|
+
for (auto & backend : state->backends) {
|
3852
|
+
ggml_backend_free(backend);
|
3853
|
+
}
|
3586
3854
|
|
3587
3855
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
3588
3856
|
aheads_masks_free(state->aheads_masks);
|
3589
3857
|
|
3858
|
+
if (state->vad_context != nullptr) {
|
3859
|
+
whisper_vad_free(state->vad_context);
|
3860
|
+
state->vad_context = nullptr;
|
3861
|
+
}
|
3862
|
+
|
3590
3863
|
delete state;
|
3591
3864
|
}
|
3592
3865
|
}
|
3593
3866
|
|
3594
3867
|
void whisper_free(struct whisper_context * ctx) {
|
3595
3868
|
if (ctx) {
|
3596
|
-
|
3869
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
3870
|
+
ggml_free(context);
|
3871
|
+
}
|
3597
3872
|
|
3598
|
-
|
3873
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
3874
|
+
ggml_backend_buffer_free(buf);
|
3875
|
+
}
|
3599
3876
|
|
3600
3877
|
whisper_free_state(ctx->state);
|
3601
3878
|
|
3602
|
-
ggml_backend_free(ctx->backend);
|
3603
|
-
|
3604
3879
|
delete ctx;
|
3605
3880
|
}
|
3606
3881
|
}
|
@@ -3630,30 +3905,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
3630
3905
|
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
3631
3906
|
}
|
3632
3907
|
|
3633
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
3634
|
-
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
3635
|
-
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
3636
|
-
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
3637
|
-
return -1;
|
3638
|
-
}
|
3639
|
-
|
3640
|
-
return 0;
|
3641
|
-
}
|
3642
|
-
|
3643
|
-
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
3644
|
-
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
3645
|
-
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
3646
|
-
}
|
3647
|
-
|
3648
|
-
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
|
3649
|
-
// TODO
|
3650
|
-
|
3651
|
-
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
|
3652
|
-
// TODO
|
3653
|
-
|
3654
|
-
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
|
3655
|
-
// TODO
|
3656
|
-
|
3657
3908
|
int whisper_set_mel_with_state(
|
3658
3909
|
struct whisper_context * ctx,
|
3659
3910
|
struct whisper_state * state,
|
@@ -3742,7 +3993,7 @@ int whisper_token_count(struct whisper_context * ctx, const char * text) {
|
|
3742
3993
|
return -whisper_tokenize(ctx, text, NULL, 0);
|
3743
3994
|
}
|
3744
3995
|
|
3745
|
-
int whisper_lang_max_id() {
|
3996
|
+
int whisper_lang_max_id(void) {
|
3746
3997
|
auto max_id = 0;
|
3747
3998
|
for (const auto & kv : g_lang) {
|
3748
3999
|
max_id = std::max(max_id, kv.second.first);
|
@@ -3963,134 +4214,1262 @@ float * whisper_get_logits(struct whisper_context * ctx) {
|
|
3963
4214
|
return ctx->state->logits.data();
|
3964
4215
|
}
|
3965
4216
|
|
3966
|
-
float * whisper_get_logits_from_state(struct whisper_state * state) {
|
3967
|
-
return state->logits.data();
|
3968
|
-
}
|
4217
|
+
float * whisper_get_logits_from_state(struct whisper_state * state) {
|
4218
|
+
return state->logits.data();
|
4219
|
+
}
|
4220
|
+
|
4221
|
+
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
4222
|
+
return ctx->vocab.id_to_token.at(token).c_str();
|
4223
|
+
}
|
4224
|
+
|
4225
|
+
whisper_token whisper_token_eot(struct whisper_context * ctx) {
|
4226
|
+
return ctx->vocab.token_eot;
|
4227
|
+
}
|
4228
|
+
|
4229
|
+
whisper_token whisper_token_sot(struct whisper_context * ctx) {
|
4230
|
+
return ctx->vocab.token_sot;
|
4231
|
+
}
|
4232
|
+
|
4233
|
+
whisper_token whisper_token_solm(struct whisper_context * ctx) {
|
4234
|
+
return ctx->vocab.token_solm;
|
4235
|
+
}
|
4236
|
+
|
4237
|
+
whisper_token whisper_token_prev(struct whisper_context * ctx) {
|
4238
|
+
return ctx->vocab.token_prev;
|
4239
|
+
}
|
4240
|
+
|
4241
|
+
whisper_token whisper_token_nosp(struct whisper_context * ctx) {
|
4242
|
+
return ctx->vocab.token_nosp;
|
4243
|
+
}
|
4244
|
+
|
4245
|
+
whisper_token whisper_token_not(struct whisper_context * ctx) {
|
4246
|
+
return ctx->vocab.token_not;
|
4247
|
+
}
|
4248
|
+
|
4249
|
+
whisper_token whisper_token_beg(struct whisper_context * ctx) {
|
4250
|
+
return ctx->vocab.token_beg;
|
4251
|
+
}
|
4252
|
+
|
4253
|
+
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
|
4254
|
+
return whisper_token_sot(ctx) + 1 + lang_id;
|
4255
|
+
}
|
4256
|
+
|
4257
|
+
whisper_token whisper_token_translate(struct whisper_context * ctx) {
|
4258
|
+
return ctx->vocab.token_translate;
|
4259
|
+
}
|
4260
|
+
|
4261
|
+
whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
4262
|
+
return ctx->vocab.token_transcribe;
|
4263
|
+
}
|
4264
|
+
|
4265
|
+
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
|
4266
|
+
if (ctx->state == nullptr) {
|
4267
|
+
return nullptr;
|
4268
|
+
}
|
4269
|
+
whisper_timings * timings = new whisper_timings;
|
4270
|
+
timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
|
4271
|
+
timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
|
4272
|
+
timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
|
4273
|
+
timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
|
4274
|
+
timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
|
4275
|
+
return timings;
|
4276
|
+
}
|
4277
|
+
|
4278
|
+
void whisper_print_timings(struct whisper_context * ctx) {
|
4279
|
+
const int64_t t_end_us = ggml_time_us();
|
4280
|
+
|
4281
|
+
WHISPER_LOG_INFO("\n");
|
4282
|
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
4283
|
+
if (ctx->state != nullptr) {
|
4284
|
+
|
4285
|
+
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
4286
|
+
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
4287
|
+
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
4288
|
+
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
4289
|
+
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
4290
|
+
|
4291
|
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
4292
|
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
4293
|
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
4294
|
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
4295
|
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
4296
|
+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
4297
|
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
4298
|
+
}
|
4299
|
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
4300
|
+
}
|
4301
|
+
|
4302
|
+
void whisper_reset_timings(struct whisper_context * ctx) {
|
4303
|
+
ctx->t_start_us = ggml_time_us();
|
4304
|
+
if (ctx->state != nullptr) {
|
4305
|
+
ctx->state->t_mel_us = 0;
|
4306
|
+
ctx->state->t_sample_us = 0;
|
4307
|
+
ctx->state->t_encode_us = 0;
|
4308
|
+
ctx->state->t_decode_us = 0;
|
4309
|
+
ctx->state->t_batchd_us = 0;
|
4310
|
+
ctx->state->t_prompt_us = 0;
|
4311
|
+
ctx->state->n_sample = 0;
|
4312
|
+
ctx->state->n_encode = 0;
|
4313
|
+
ctx->state->n_decode = 0;
|
4314
|
+
ctx->state->n_batchd = 0;
|
4315
|
+
ctx->state->n_prompt = 0;
|
4316
|
+
}
|
4317
|
+
}
|
4318
|
+
|
4319
|
+
static int whisper_has_coreml(void) {
|
4320
|
+
#ifdef WHISPER_USE_COREML
|
4321
|
+
return 1;
|
4322
|
+
#else
|
4323
|
+
return 0;
|
4324
|
+
#endif
|
4325
|
+
}
|
4326
|
+
|
4327
|
+
static int whisper_has_openvino(void) {
|
4328
|
+
#ifdef WHISPER_USE_OPENVINO
|
4329
|
+
return 1;
|
4330
|
+
#else
|
4331
|
+
return 0;
|
4332
|
+
#endif
|
4333
|
+
}
|
4334
|
+
|
4335
|
+
const char * whisper_print_system_info(void) {
|
4336
|
+
static std::string s;
|
4337
|
+
|
4338
|
+
whisper_load_backends();
|
4339
|
+
|
4340
|
+
s = "";
|
4341
|
+
s += "WHISPER : ";
|
4342
|
+
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
4343
|
+
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
4344
|
+
|
4345
|
+
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
4346
|
+
auto * reg = ggml_backend_reg_get(i);
|
4347
|
+
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
|
4348
|
+
if (get_features_fn) {
|
4349
|
+
ggml_backend_feature * features = get_features_fn(reg);
|
4350
|
+
s += ggml_backend_reg_name(reg);
|
4351
|
+
s += " : ";
|
4352
|
+
for (; features->name; features++) {
|
4353
|
+
s += features->name;
|
4354
|
+
s += " = ";
|
4355
|
+
s += features->value;
|
4356
|
+
s += " | ";
|
4357
|
+
}
|
4358
|
+
}
|
4359
|
+
}
|
4360
|
+
return s.c_str();
|
4361
|
+
}
|
4362
|
+
|
4363
|
+
//////////////////////////////////
|
4364
|
+
// Voice Activity Detection (VAD)
|
4365
|
+
//////////////////////////////////
|
4366
|
+
|
4367
|
+
struct whisper_vad_hparams {
|
4368
|
+
int32_t n_encoder_layers;
|
4369
|
+
int32_t * encoder_in_channels;
|
4370
|
+
int32_t * encoder_out_channels;
|
4371
|
+
int32_t * kernel_sizes;
|
4372
|
+
int32_t lstm_input_size;
|
4373
|
+
int32_t lstm_hidden_size;
|
4374
|
+
int32_t final_conv_in;
|
4375
|
+
int32_t final_conv_out;
|
4376
|
+
};
|
4377
|
+
|
4378
|
+
struct whisper_vad_model {
|
4379
|
+
std::string type;
|
4380
|
+
std::string version;
|
4381
|
+
whisper_vad_hparams hparams;
|
4382
|
+
|
4383
|
+
struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
|
4384
|
+
|
4385
|
+
// Encoder tensors - 4 convolutional layers
|
4386
|
+
struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
|
4387
|
+
struct ggml_tensor * encoder_0_bias; // [128]
|
4388
|
+
|
4389
|
+
// Second encoder layer
|
4390
|
+
struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
|
4391
|
+
struct ggml_tensor * encoder_1_bias; // [64]
|
4392
|
+
|
4393
|
+
// Third encoder layer
|
4394
|
+
struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
|
4395
|
+
struct ggml_tensor * encoder_2_bias; // [64]
|
4396
|
+
|
4397
|
+
// Fourth encoder layer
|
4398
|
+
struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
|
4399
|
+
struct ggml_tensor * encoder_3_bias; // [128]
|
4400
|
+
|
4401
|
+
// LSTM decoder tensors
|
4402
|
+
struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
|
4403
|
+
struct ggml_tensor * lstm_ih_bias; // [512]
|
4404
|
+
struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
|
4405
|
+
struct ggml_tensor * lstm_hh_bias; // [512]
|
4406
|
+
|
4407
|
+
// Final conv layer
|
4408
|
+
struct ggml_tensor * final_conv_weight; // [128]
|
4409
|
+
struct ggml_tensor * final_conv_bias; // [1]
|
4410
|
+
|
4411
|
+
// ggml contexts
|
4412
|
+
std::vector<ggml_context *> ctxs;
|
4413
|
+
|
4414
|
+
// buffer for the model tensors
|
4415
|
+
std::vector<ggml_backend_buffer_t> buffers;
|
4416
|
+
|
4417
|
+
// tensors
|
4418
|
+
int n_loaded;
|
4419
|
+
std::map<std::string, struct ggml_tensor *> tensors;
|
4420
|
+
};
|
4421
|
+
|
4422
|
+
struct whisper_vad_segment {
|
4423
|
+
float start; // Start time in seconds
|
4424
|
+
float end; // End time in seconds
|
4425
|
+
};
|
4426
|
+
|
4427
|
+
struct whisper_vad_segments {
|
4428
|
+
std::vector<whisper_vad_segment> data;
|
4429
|
+
};
|
4430
|
+
|
4431
|
+
struct whisper_vad_context {
|
4432
|
+
int64_t t_vad_us = 0;
|
4433
|
+
|
4434
|
+
int n_window;
|
4435
|
+
int n_context;
|
4436
|
+
int n_threads;
|
4437
|
+
|
4438
|
+
std::vector<ggml_backend_t> backends;
|
4439
|
+
ggml_backend_buffer_t buffer = nullptr;
|
4440
|
+
whisper_context_params params;
|
4441
|
+
std::vector<uint8_t> ctx_buf;
|
4442
|
+
whisper_sched sched;
|
4443
|
+
|
4444
|
+
whisper_vad_model model;
|
4445
|
+
std::string path_model;
|
4446
|
+
struct ggml_tensor * h_state;
|
4447
|
+
struct ggml_tensor * c_state;
|
4448
|
+
std::vector<float> probs;
|
4449
|
+
};
|
4450
|
+
|
4451
|
+
struct whisper_vad_context_params whisper_vad_default_context_params(void) {
|
4452
|
+
whisper_vad_context_params result = {
|
4453
|
+
/*.n_thread = */ 4,
|
4454
|
+
/*.use_gpu = */ false,
|
4455
|
+
/*.gpu_device = */ 0,
|
4456
|
+
};
|
4457
|
+
return result;
|
4458
|
+
}
|
4459
|
+
|
4460
|
+
struct whisper_vad_params whisper_vad_default_params(void) {
|
4461
|
+
whisper_vad_params result = {
|
4462
|
+
/* threshold = */ 0.5f,
|
4463
|
+
/* min_speech_duration_ms = */ 250,
|
4464
|
+
/* min_silence_duration_ms = */ 100,
|
4465
|
+
/* max_speech_duration_s = */ FLT_MAX,
|
4466
|
+
/* speech_pad_ms = */ 30,
|
4467
|
+
/* samples_overlap = */ 0.1,
|
4468
|
+
};
|
4469
|
+
return result;
|
4470
|
+
}
|
4471
|
+
|
4472
|
+
static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
4473
|
+
bool op_supported = true;
|
4474
|
+
|
4475
|
+
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
4476
|
+
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
4477
|
+
// GPU and default CPU backend support all operators
|
4478
|
+
op_supported = true;
|
4479
|
+
} else {
|
4480
|
+
switch (op) {
|
4481
|
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
4482
|
+
case GGML_OP_MUL_MAT: {
|
4483
|
+
ggml_init_params params = {
|
4484
|
+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
4485
|
+
/*.mem_buffer =*/ nullptr,
|
4486
|
+
/*.no_alloc =*/ true,
|
4487
|
+
};
|
4488
|
+
|
4489
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
4490
|
+
if (!ctx_ptr) {
|
4491
|
+
throw std::runtime_error("failed to create ggml context");
|
4492
|
+
}
|
4493
|
+
ggml_context * ctx = ctx_ptr.get();
|
4494
|
+
|
4495
|
+
ggml_tensor * op_tensor = nullptr;
|
4496
|
+
|
4497
|
+
int64_t n_ctx = hparams.lstm_hidden_size;
|
4498
|
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
4499
|
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
4500
|
+
|
4501
|
+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
4502
|
+
GGML_ASSERT(w->buffer == nullptr);
|
4503
|
+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
4504
|
+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
4505
|
+
ggml_backend_buffer_free(w->buffer);
|
4506
|
+
w->buffer = nullptr;
|
4507
|
+
break;
|
4508
|
+
}
|
4509
|
+
default: {
|
4510
|
+
op_supported = false;
|
4511
|
+
break;
|
4512
|
+
}
|
4513
|
+
};
|
4514
|
+
}
|
4515
|
+
return op_supported;
|
4516
|
+
}
|
4517
|
+
|
4518
|
+
static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
4519
|
+
GGML_ASSERT(!buft_list.empty());
|
4520
|
+
for (const auto & p : buft_list) {
|
4521
|
+
ggml_backend_dev_t dev = p.first;
|
4522
|
+
ggml_backend_buffer_type_t buft = p.second;
|
4523
|
+
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
4524
|
+
return buft;
|
4525
|
+
}
|
4526
|
+
}
|
4527
|
+
|
4528
|
+
return nullptr;
|
4529
|
+
}
|
4530
|
+
|
4531
|
+
static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
|
4532
|
+
const whisper_vad_model & model, ggml_tensor * cur) {
|
4533
|
+
// Apply reflective padding to the input tensor
|
4534
|
+
ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
|
4535
|
+
|
4536
|
+
struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
|
4537
|
+
|
4538
|
+
// Calculate cutoff for real/imaginary parts
|
4539
|
+
int cutoff = model.stft_forward_basis->ne[2] / 2;
|
4540
|
+
|
4541
|
+
// Extract real part (first half of the STFT output).
|
4542
|
+
struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
|
4543
|
+
// Extract imaginary part (second half of the STFT output).
|
4544
|
+
struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
|
4545
|
+
|
4546
|
+
// Calculate magnitude: sqrt(real^2 + imag^2)
|
4547
|
+
struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
|
4548
|
+
struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
|
4549
|
+
struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
|
4550
|
+
struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
|
4551
|
+
return magnitude;
|
4552
|
+
}
|
4553
|
+
|
4554
|
+
static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
|
4555
|
+
const whisper_vad_model & model, ggml_tensor * cur) {
|
4556
|
+
// First Conv1D: expands to 128 channels.
|
4557
|
+
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
|
4558
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
|
4559
|
+
cur = ggml_relu(ctx0, cur);
|
4560
|
+
|
4561
|
+
// Second Conv1D: reduces to 64 channels.
|
4562
|
+
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
|
4563
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
|
4564
|
+
cur = ggml_relu(ctx0, cur);
|
4565
|
+
|
4566
|
+
// Third Conv1D: maintains 64 channels
|
4567
|
+
cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
|
4568
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
|
4569
|
+
cur = ggml_relu(ctx0, cur);
|
4570
|
+
|
4571
|
+
// Fourth Conv1D: expands to 128 channels
|
4572
|
+
cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
|
4573
|
+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
|
4574
|
+
cur = ggml_relu(ctx0, cur);
|
4575
|
+
|
4576
|
+
return cur;
|
4577
|
+
}
|
4578
|
+
|
4579
|
+
static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
|
4580
|
+
const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
|
4581
|
+
const whisper_vad_model & model = vctx.model;
|
4582
|
+
const int hdim = model.hparams.lstm_hidden_size;
|
4583
|
+
|
4584
|
+
struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
|
4585
|
+
|
4586
|
+
// Create operations using the input-to-hidden weights.
|
4587
|
+
struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
|
4588
|
+
inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
|
4589
|
+
|
4590
|
+
// Create operations using the hidden-to-hidden weights.
|
4591
|
+
struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
|
4592
|
+
hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
|
4593
|
+
|
4594
|
+
// Create add operation to get preactivations for all gates.
|
4595
|
+
struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
|
4596
|
+
|
4597
|
+
const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
|
4598
|
+
|
4599
|
+
// Create sigmoid for input gate (using the first 128 bytes from the preactivations).
|
4600
|
+
struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
|
4601
|
+
|
4602
|
+
// Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
|
4603
|
+
struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
|
4604
|
+
|
4605
|
+
// Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
|
4606
|
+
struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
|
4607
|
+
|
4608
|
+
// Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
|
4609
|
+
struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
|
4610
|
+
|
4611
|
+
// Update cell state
|
4612
|
+
struct ggml_tensor * c_out = ggml_add(ctx0,
|
4613
|
+
ggml_mul(ctx0, f_t, vctx.c_state),
|
4614
|
+
ggml_mul(ctx0, i_t, g_t));
|
4615
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
|
4616
|
+
|
4617
|
+
// Update hidden state
|
4618
|
+
struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
|
4619
|
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
|
4620
|
+
|
4621
|
+
return out;
|
4622
|
+
}
|
4623
|
+
|
4624
|
+
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
|
4625
|
+
const auto & model = vctx.model;
|
4626
|
+
|
4627
|
+
struct ggml_init_params params = {
|
4628
|
+
/*.mem_size =*/ vctx.sched.meta.size(),
|
4629
|
+
/*.mem_buffer =*/ vctx.sched.meta.data(),
|
4630
|
+
/*.no_alloc =*/ true,
|
4631
|
+
};
|
4632
|
+
|
4633
|
+
struct ggml_context * ctx0 = ggml_init(params);
|
4634
|
+
|
4635
|
+
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
4636
|
+
|
4637
|
+
struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
|
4638
|
+
ggml_set_name(frame, "frame");
|
4639
|
+
ggml_set_input(frame);
|
4640
|
+
|
4641
|
+
struct ggml_tensor * cur = nullptr;
|
4642
|
+
{
|
4643
|
+
cur = whisper_vad_build_stft_layer(ctx0, model, frame);
|
4644
|
+
|
4645
|
+
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
|
4646
|
+
|
4647
|
+
// Extract the first element of the first dimension
|
4648
|
+
// (equivalent to pytorch's [:, :, 0])
|
4649
|
+
cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
|
4650
|
+
|
4651
|
+
cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
|
4652
|
+
cur = ggml_relu(ctx0, cur);
|
4653
|
+
cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
|
4654
|
+
cur = ggml_add(ctx0, cur, model.final_conv_bias);
|
4655
|
+
cur = ggml_sigmoid(ctx0, cur);
|
4656
|
+
ggml_set_name(cur, "prob");
|
4657
|
+
ggml_set_output(cur);
|
4658
|
+
}
|
4659
|
+
|
4660
|
+
ggml_build_forward_expand(gf, cur);
|
4661
|
+
|
4662
|
+
ggml_free(ctx0);
|
4663
|
+
|
4664
|
+
return gf;
|
4665
|
+
}
|
4666
|
+
|
4667
|
+
static bool whisper_vad_init_context(whisper_vad_context * vctx) {
|
4668
|
+
|
4669
|
+
auto whisper_context_params = whisper_context_default_params();
|
4670
|
+
// TODO: GPU VAD is forced disabled until the performance is improved
|
4671
|
+
//whisper_context_params.use_gpu = vctx->params.use_gpu;
|
4672
|
+
whisper_context_params.use_gpu = false;
|
4673
|
+
whisper_context_params.gpu_device = vctx->params.gpu_device;
|
4674
|
+
|
4675
|
+
vctx->backends = whisper_backend_init(whisper_context_params);
|
4676
|
+
if (vctx->backends.empty()) {
|
4677
|
+
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
4678
|
+
return false;
|
4679
|
+
}
|
4680
|
+
|
4681
|
+
const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
|
4682
|
+
|
4683
|
+
vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
|
4684
|
+
|
4685
|
+
struct ggml_init_params params = {
|
4686
|
+
/*.mem_size =*/ vctx->ctx_buf.size(),
|
4687
|
+
/*.mem_buffer =*/ vctx->ctx_buf.data(),
|
4688
|
+
/*.no_alloc =*/ true,
|
4689
|
+
};
|
4690
|
+
|
4691
|
+
ggml_context * ctx = ggml_init(params);
|
4692
|
+
if (!ctx) {
|
4693
|
+
WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
|
4694
|
+
return false;
|
4695
|
+
}
|
4696
|
+
|
4697
|
+
// LSTM Hidden state
|
4698
|
+
vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
|
4699
|
+
ggml_set_name(vctx->h_state, "h_state");
|
4700
|
+
|
4701
|
+
// LSTM Cell state
|
4702
|
+
vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
|
4703
|
+
ggml_set_name(vctx->c_state, "c_state");
|
4704
|
+
|
4705
|
+
vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
|
4706
|
+
if (!vctx->buffer) {
|
4707
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
|
4708
|
+
return false;
|
4709
|
+
}
|
4710
|
+
|
4711
|
+
{
|
4712
|
+
bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
|
4713
|
+
[&]() {
|
4714
|
+
return whisper_vad_build_graph(*vctx);
|
4715
|
+
});
|
4716
|
+
|
4717
|
+
if (!ok) {
|
4718
|
+
WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
|
4719
|
+
return false;
|
4720
|
+
}
|
4721
|
+
|
4722
|
+
WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
|
4723
|
+
}
|
4724
|
+
|
4725
|
+
return true;
|
4726
|
+
}
|
4727
|
+
|
4728
|
+
struct whisper_vad_context * whisper_vad_init_from_file_with_params(
|
4729
|
+
const char * path_model,
|
4730
|
+
struct whisper_vad_context_params params) {
|
4731
|
+
WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
|
4732
|
+
#ifdef _MSC_VER
|
4733
|
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
4734
|
+
std::wstring path_model_wide = converter.from_bytes(path_model);
|
4735
|
+
auto fin = std::ifstream(path_model_wide, std::ios::binary);
|
4736
|
+
#else
|
4737
|
+
auto fin = std::ifstream(path_model, std::ios::binary);
|
4738
|
+
#endif
|
4739
|
+
if (!fin) {
|
4740
|
+
WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
|
4741
|
+
return nullptr;
|
4742
|
+
}
|
4743
|
+
|
4744
|
+
whisper_model_loader loader = {};
|
4745
|
+
loader.context = &fin;
|
4746
|
+
|
4747
|
+
loader.read = [](void * ctx, void * output, size_t read_size) {
|
4748
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4749
|
+
fin->read((char *)output, read_size);
|
4750
|
+
return read_size;
|
4751
|
+
};
|
4752
|
+
|
4753
|
+
loader.eof = [](void * ctx) {
|
4754
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4755
|
+
return fin->eof();
|
4756
|
+
};
|
4757
|
+
|
4758
|
+
loader.close = [](void * ctx) {
|
4759
|
+
std::ifstream * fin = (std::ifstream*)ctx;
|
4760
|
+
fin->close();
|
4761
|
+
};
|
4762
|
+
|
4763
|
+
auto ctx = whisper_vad_init_with_params(&loader, params);
|
4764
|
+
if (!ctx) {
|
4765
|
+
whisper_vad_free(ctx);
|
4766
|
+
return nullptr;
|
4767
|
+
}
|
4768
|
+
ctx->path_model = path_model;
|
4769
|
+
return ctx;
|
4770
|
+
}
|
4771
|
+
|
4772
|
+
struct whisper_vad_context * whisper_vad_init_with_params(
|
4773
|
+
struct whisper_model_loader * loader,
|
4774
|
+
struct whisper_vad_context_params params) {
|
4775
|
+
// Read the VAD model
|
4776
|
+
{
|
4777
|
+
uint32_t magic;
|
4778
|
+
read_safe(loader, magic);
|
4779
|
+
if (magic != GGML_FILE_MAGIC) {
|
4780
|
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
4781
|
+
return nullptr;
|
4782
|
+
}
|
4783
|
+
}
|
4784
|
+
|
4785
|
+
whisper_vad_context * vctx = new whisper_vad_context;
|
4786
|
+
vctx->n_threads = params.n_threads;
|
4787
|
+
vctx->params.use_gpu = params.use_gpu;
|
4788
|
+
vctx->params.gpu_device = params.gpu_device;
|
4789
|
+
|
4790
|
+
auto & model = vctx->model;
|
4791
|
+
auto & hparams = model.hparams;
|
4792
|
+
|
4793
|
+
// load model context params.
|
4794
|
+
{
|
4795
|
+
int32_t str_len;
|
4796
|
+
read_safe(loader, str_len);
|
4797
|
+
std::vector<char> buffer(str_len + 1, 0);
|
4798
|
+
loader->read(loader->context, buffer.data(), str_len);
|
4799
|
+
std::string model_type(buffer.data(), str_len);
|
4800
|
+
model.type = model_type;
|
4801
|
+
WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
|
4802
|
+
|
4803
|
+
int32_t major, minor, patch;
|
4804
|
+
read_safe(loader, major);
|
4805
|
+
read_safe(loader, minor);
|
4806
|
+
read_safe(loader, patch);
|
4807
|
+
std::string version_str = std::to_string(major) + "." +
|
4808
|
+
std::to_string(minor) + "." +
|
4809
|
+
std::to_string(patch);
|
4810
|
+
model.version = version_str;
|
4811
|
+
WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
|
4812
|
+
|
4813
|
+
read_safe(loader, vctx->n_window);
|
4814
|
+
read_safe(loader, vctx->n_context);
|
4815
|
+
}
|
4816
|
+
|
4817
|
+
// load model hyper params (hparams).
|
4818
|
+
{
|
4819
|
+
read_safe(loader, hparams.n_encoder_layers);
|
4820
|
+
|
4821
|
+
hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
|
4822
|
+
hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
|
4823
|
+
hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
|
4824
|
+
|
4825
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4826
|
+
read_safe(loader, hparams.encoder_in_channels[i]);
|
4827
|
+
read_safe(loader, hparams.encoder_out_channels[i]);
|
4828
|
+
read_safe(loader, hparams.kernel_sizes[i]);
|
4829
|
+
}
|
4830
|
+
|
4831
|
+
read_safe(loader, hparams.lstm_input_size);
|
4832
|
+
read_safe(loader, hparams.lstm_hidden_size);
|
4833
|
+
read_safe(loader, hparams.final_conv_in);
|
4834
|
+
read_safe(loader, hparams.final_conv_out);
|
4835
|
+
|
4836
|
+
WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
|
4837
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4838
|
+
WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
|
4839
|
+
}
|
4840
|
+
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
|
4841
|
+
WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
|
4842
|
+
}
|
4843
|
+
WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
|
4844
|
+
WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
|
4845
|
+
WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
|
4846
|
+
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
|
4847
|
+
}
|
4848
|
+
|
4849
|
+
// 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
|
4850
|
+
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
|
4851
|
+
|
4852
|
+
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
4853
|
+
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
4854
|
+
auto it = ctx_map.find(buft);
|
4855
|
+
if (it == ctx_map.end()) {
|
4856
|
+
ggml_init_params params = {
|
4857
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
4858
|
+
/*.mem_buffer =*/ nullptr,
|
4859
|
+
/*.no_alloc =*/ true,
|
4860
|
+
};
|
4861
|
+
|
4862
|
+
ggml_context * ctx = ggml_init(params);
|
4863
|
+
if (!ctx) {
|
4864
|
+
throw std::runtime_error("failed to create ggml context");
|
4865
|
+
}
|
4866
|
+
|
4867
|
+
ctx_map[buft] = ctx;
|
4868
|
+
model.ctxs.emplace_back(ctx);
|
4869
|
+
|
4870
|
+
return ctx;
|
4871
|
+
}
|
4872
|
+
|
4873
|
+
return it->second;
|
4874
|
+
};
|
4875
|
+
|
4876
|
+
whisper_context_params wparams = whisper_context_default_params();
|
4877
|
+
wparams.use_gpu = params.use_gpu;
|
4878
|
+
wparams.gpu_device = params.gpu_device;
|
4879
|
+
buft_list_t buft_list = make_buft_list(wparams);
|
4880
|
+
|
4881
|
+
auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
|
4882
|
+
ggml_op op = VAD_TENSOR_OPS.at(type);
|
4883
|
+
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
4884
|
+
if (!buft) {
|
4885
|
+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
|
4886
|
+
}
|
4887
|
+
ggml_context * ctx = get_ctx(buft);
|
4888
|
+
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
4889
|
+
model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
|
4890
|
+
|
4891
|
+
return tensor;
|
4892
|
+
};
|
4893
|
+
|
4894
|
+
// create tensors
|
4895
|
+
{
|
4896
|
+
ggml_init_params params = {
|
4897
|
+
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
4898
|
+
/*.mem_buffer =*/ nullptr,
|
4899
|
+
/*.no_alloc =*/ true,
|
4900
|
+
};
|
4901
|
+
|
4902
|
+
ggml_context * ctx = ggml_init(params);
|
4903
|
+
const auto & hparams = model.hparams;
|
4904
|
+
|
4905
|
+
// SFTF precomputed basis matrix
|
4906
|
+
model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
|
4907
|
+
ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
|
4908
|
+
|
4909
|
+
model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
|
4910
|
+
ggml_new_tensor_3d(
|
4911
|
+
ctx,
|
4912
|
+
GGML_TYPE_F16,
|
4913
|
+
hparams.kernel_sizes[0],
|
4914
|
+
hparams.encoder_in_channels[0],
|
4915
|
+
hparams.encoder_out_channels[0]
|
4916
|
+
));
|
4917
|
+
model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
|
4918
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
|
4919
|
+
|
4920
|
+
model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
|
4921
|
+
ggml_new_tensor_3d(
|
4922
|
+
ctx,
|
4923
|
+
GGML_TYPE_F16,
|
4924
|
+
hparams.kernel_sizes[1],
|
4925
|
+
hparams.encoder_in_channels[1],
|
4926
|
+
hparams.encoder_out_channels[1]
|
4927
|
+
));
|
4928
|
+
model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
|
4929
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
|
4930
|
+
|
4931
|
+
model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
|
4932
|
+
ggml_new_tensor_3d(
|
4933
|
+
ctx,
|
4934
|
+
GGML_TYPE_F16,
|
4935
|
+
hparams.kernel_sizes[2],
|
4936
|
+
hparams.encoder_in_channels[2],
|
4937
|
+
hparams.encoder_out_channels[2]
|
4938
|
+
));
|
4939
|
+
model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
|
4940
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
|
4941
|
+
|
4942
|
+
model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
|
4943
|
+
ggml_new_tensor_3d(
|
4944
|
+
ctx,
|
4945
|
+
GGML_TYPE_F16,
|
4946
|
+
hparams.kernel_sizes[3],
|
4947
|
+
hparams.encoder_in_channels[3],
|
4948
|
+
hparams.encoder_out_channels[3]
|
4949
|
+
));
|
4950
|
+
model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
|
4951
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
|
4952
|
+
|
4953
|
+
// Hidden State dimension (input gate, forget gate, cell gate, output gate)
|
4954
|
+
const int hstate_dim = hparams.lstm_hidden_size * 4;
|
4955
|
+
|
4956
|
+
// LSTM weights - input to hidden
|
4957
|
+
model.lstm_ih_weight = create_tensor(
|
4958
|
+
VAD_TENSOR_LSTM_WEIGHT_IH,
|
4959
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
4960
|
+
);
|
4961
|
+
model.lstm_ih_bias = create_tensor(
|
4962
|
+
VAD_TENSOR_LSTM_BIAS_IH,
|
4963
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
|
4964
|
+
);
|
4965
|
+
|
4966
|
+
// LSTM weights - hidden to hidden
|
4967
|
+
model.lstm_hh_weight = create_tensor(
|
4968
|
+
VAD_TENSOR_LSTM_WEIGHT_HH,
|
4969
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
|
4970
|
+
);
|
4971
|
+
model.lstm_hh_bias = create_tensor(
|
4972
|
+
VAD_TENSOR_LSTM_BIAS_HH,
|
4973
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
|
4974
|
+
);
|
4975
|
+
|
4976
|
+
// Final conv layer weight
|
4977
|
+
model.final_conv_weight = create_tensor(
|
4978
|
+
VAD_TENSOR_FINAL_CONV_WEIGHT,
|
4979
|
+
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
|
4980
|
+
);
|
4981
|
+
model.final_conv_bias = create_tensor(
|
4982
|
+
VAD_TENSOR_FINAL_CONV_BIAS,
|
4983
|
+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
|
4984
|
+
);
|
4985
|
+
|
4986
|
+
ggml_free(ctx);
|
4987
|
+
}
|
4988
|
+
|
4989
|
+
// allocate tensors in the backend buffers
|
4990
|
+
for (auto & p : ctx_map) {
|
4991
|
+
ggml_backend_buffer_type_t buft = p.first;
|
4992
|
+
ggml_context * ctx = p.second;
|
4993
|
+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
4994
|
+
if (buf) {
|
4995
|
+
model.buffers.emplace_back(buf);
|
4996
|
+
|
4997
|
+
size_t size_main = ggml_backend_buffer_get_size(buf);
|
4998
|
+
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
4999
|
+
}
|
5000
|
+
}
|
5001
|
+
|
5002
|
+
// load weights
|
5003
|
+
{
|
5004
|
+
size_t total_size = 0;
|
5005
|
+
model.n_loaded = 0;
|
5006
|
+
std::vector<char> read_buf;
|
5007
|
+
|
5008
|
+
while (true) {
|
5009
|
+
int32_t n_dims;
|
5010
|
+
int32_t length;
|
5011
|
+
int32_t ttype;
|
5012
|
+
|
5013
|
+
read_safe(loader, n_dims);
|
5014
|
+
read_safe(loader, length);
|
5015
|
+
read_safe(loader, ttype);
|
5016
|
+
|
5017
|
+
if (loader->eof(loader->context)) {
|
5018
|
+
break;
|
5019
|
+
}
|
5020
|
+
|
5021
|
+
int32_t nelements = 1;
|
5022
|
+
int32_t ne[4] = { 1, 1, 1, 1 };
|
5023
|
+
for (int i = 0; i < n_dims; ++i) {
|
5024
|
+
read_safe(loader, ne[i]);
|
5025
|
+
nelements *= ne[i];
|
5026
|
+
}
|
5027
|
+
|
5028
|
+
std::string name;
|
5029
|
+
std::vector<char> tmp(length);
|
5030
|
+
loader->read(loader->context, &tmp[0], tmp.size());
|
5031
|
+
name.assign(&tmp[0], tmp.size());
|
5032
|
+
|
5033
|
+
if (model.tensors.find(name) == model.tensors.end()) {
|
5034
|
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
5035
|
+
return nullptr;
|
5036
|
+
}
|
5037
|
+
|
5038
|
+
auto tensor = model.tensors[name.data()];
|
5039
|
+
|
5040
|
+
if (ggml_nelements(tensor) != nelements) {
|
5041
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
5042
|
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
5043
|
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
5044
|
+
return nullptr;
|
5045
|
+
}
|
5046
|
+
|
5047
|
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
5048
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
5049
|
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
5050
|
+
return nullptr;
|
5051
|
+
}
|
5052
|
+
|
5053
|
+
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
5054
|
+
|
5055
|
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
5056
|
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
5057
|
+
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
5058
|
+
return nullptr;
|
5059
|
+
}
|
5060
|
+
|
5061
|
+
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
5062
|
+
// for the CPU and Metal backend, we can read directly into the tensor
|
5063
|
+
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
5064
|
+
BYTESWAP_TENSOR(tensor);
|
5065
|
+
} else {
|
5066
|
+
// read into a temporary buffer first, then copy to device memory
|
5067
|
+
read_buf.resize(ggml_nbytes(tensor));
|
5068
|
+
|
5069
|
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
5070
|
+
|
5071
|
+
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
5072
|
+
}
|
5073
|
+
|
5074
|
+
total_size += ggml_nbytes(tensor);
|
5075
|
+
model.n_loaded++;
|
5076
|
+
}
|
5077
|
+
|
5078
|
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
|
5079
|
+
|
5080
|
+
if (model.n_loaded == 0) {
|
5081
|
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
5082
|
+
} else if (model.n_loaded != (int) model.tensors.size()) {
|
5083
|
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
5084
|
+
return nullptr;
|
5085
|
+
}
|
5086
|
+
|
5087
|
+
}
|
5088
|
+
|
5089
|
+
if (!whisper_vad_init_context(vctx)) {
|
5090
|
+
whisper_vad_free(vctx);
|
5091
|
+
return nullptr;
|
5092
|
+
}
|
5093
|
+
|
5094
|
+
return vctx;
|
5095
|
+
}
|
5096
|
+
|
5097
|
+
bool whisper_vad_detect_speech(
|
5098
|
+
struct whisper_vad_context * vctx,
|
5099
|
+
const float * samples,
|
5100
|
+
int n_samples) {
|
5101
|
+
int n_chunks = n_samples / vctx->n_window;
|
5102
|
+
if (n_samples % vctx->n_window != 0) {
|
5103
|
+
n_chunks += 1; // Add one more chunk for remaining samples.
|
5104
|
+
}
|
5105
|
+
|
5106
|
+
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
|
5107
|
+
WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
|
5108
|
+
|
5109
|
+
// Reset LSTM hidden/cell states
|
5110
|
+
ggml_backend_buffer_clear(vctx->buffer, 0);
|
5111
|
+
|
5112
|
+
vctx->probs.resize(n_chunks);
|
5113
|
+
WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
|
5114
|
+
|
5115
|
+
std::vector<float> window(vctx->n_window, 0.0f);
|
5116
|
+
|
5117
|
+
auto & sched = vctx->sched.sched;
|
5118
|
+
|
5119
|
+
ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
|
5120
|
+
|
5121
|
+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
5122
|
+
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
|
5123
|
+
return false;
|
5124
|
+
}
|
5125
|
+
|
5126
|
+
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
|
5127
|
+
struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
|
5128
|
+
|
5129
|
+
// we are going to reuse the graph multiple times for each chunk
|
5130
|
+
const int64_t t_start_vad_us = ggml_time_us();
|
5131
|
+
|
5132
|
+
for (int i = 0; i < n_chunks; i++) {
|
5133
|
+
const int idx_start = i * vctx->n_window;
|
5134
|
+
const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
|
5135
|
+
|
5136
|
+
const int chunk_len = idx_end - idx_start;
|
5137
|
+
|
5138
|
+
if (chunk_len < vctx->n_window) {
|
5139
|
+
WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
|
5140
|
+
std::vector<float> partial_chunk(vctx->n_window, 0.0f);
|
5141
|
+
std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
|
5142
|
+
|
5143
|
+
// Copy the zero-padded chunk to the window.
|
5144
|
+
const int samples_to_copy_max = vctx->n_window;
|
5145
|
+
const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
|
5146
|
+
std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
|
5147
|
+
if (samples_to_copy_cur < samples_to_copy_max) {
|
5148
|
+
std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
|
5149
|
+
}
|
5150
|
+
} else {
|
5151
|
+
// Copy current frame samples to the window.
|
5152
|
+
const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
|
5153
|
+
std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
|
5154
|
+
}
|
5155
|
+
|
5156
|
+
// Set the frame tensor data with the samples.
|
5157
|
+
ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
|
5158
|
+
|
5159
|
+
// do not reset the scheduler - we will reuse the graph in the next chunk
|
5160
|
+
if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
|
5161
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
|
5162
|
+
break;
|
5163
|
+
}
|
5164
|
+
|
5165
|
+
// Get the probability for this chunk.
|
5166
|
+
ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
|
5167
|
+
|
5168
|
+
//WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
|
5169
|
+
}
|
5170
|
+
|
5171
|
+
vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
|
5172
|
+
WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
|
5173
|
+
|
5174
|
+
ggml_backend_sched_reset(sched);
|
5175
|
+
|
5176
|
+
return true;
|
5177
|
+
}
|
5178
|
+
|
5179
|
+
int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
|
5180
|
+
return segments->data.size();
|
5181
|
+
}
|
5182
|
+
|
5183
|
+
float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
|
5184
|
+
return segments->data[i_segment].start;
|
5185
|
+
}
|
5186
|
+
|
5187
|
+
float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
|
5188
|
+
return segments->data[i_segment].end;
|
5189
|
+
}
|
5190
|
+
|
5191
|
+
int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
|
5192
|
+
return vctx->probs.size();
|
5193
|
+
}
|
5194
|
+
|
5195
|
+
float * whisper_vad_probs(struct whisper_vad_context * vctx) {
|
5196
|
+
return vctx->probs.data();
|
5197
|
+
}
|
5198
|
+
|
5199
|
+
struct whisper_vad_segments * whisper_vad_segments_from_probs(
|
5200
|
+
struct whisper_vad_context * vctx,
|
5201
|
+
whisper_vad_params params) {
|
5202
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
|
5203
|
+
|
5204
|
+
int n_probs = whisper_vad_n_probs(vctx);
|
5205
|
+
float * probs = whisper_vad_probs(vctx);
|
5206
|
+
float threshold = params.threshold;
|
5207
|
+
int min_speech_duration_ms = params.min_speech_duration_ms;
|
5208
|
+
int min_silence_duration_ms = params.min_silence_duration_ms;
|
5209
|
+
float max_speech_duration_s = params.max_speech_duration_s;
|
5210
|
+
int speech_pad_ms = params.speech_pad_ms;
|
5211
|
+
int n_window = vctx->n_window;
|
5212
|
+
int sample_rate = WHISPER_SAMPLE_RATE;
|
5213
|
+
int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
|
5214
|
+
int audio_length_samples = n_probs * n_window;
|
5215
|
+
|
5216
|
+
// Min number of samples to be considered valid speech.
|
5217
|
+
int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
|
5218
|
+
int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
|
5219
|
+
|
5220
|
+
// Max number of samples that a speech segment can contain before it is
|
5221
|
+
// split into multiple segments.
|
5222
|
+
int max_speech_samples;
|
5223
|
+
if (max_speech_duration_s > 100000.0f) {
|
5224
|
+
max_speech_samples = INT_MAX / 2;
|
5225
|
+
} else {
|
5226
|
+
int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
|
5227
|
+
max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
|
5228
|
+
if (max_speech_samples < 0) {
|
5229
|
+
max_speech_samples = INT_MAX / 2;
|
5230
|
+
}
|
5231
|
+
}
|
5232
|
+
// Detect silence period that exceeds this value, then that location (sample)
|
5233
|
+
// is marked as a potential place where the segment could be split if
|
5234
|
+
// max_speech_samples is reached. The value 98 was taken from the original
|
5235
|
+
// silaro-vad python implementation:
|
5236
|
+
//https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
|
5237
|
+
int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
|
5238
|
+
|
5239
|
+
// Calculate lower threshold for detecting end of speech segments.
|
5240
|
+
float neg_threshold = threshold - 0.15f;
|
5241
|
+
if (neg_threshold < 0.01f) {
|
5242
|
+
neg_threshold = 0.01f;
|
5243
|
+
}
|
5244
|
+
|
5245
|
+
struct speech_segment_t {
|
5246
|
+
int start;
|
5247
|
+
int end;
|
5248
|
+
};
|
5249
|
+
|
5250
|
+
std::vector<speech_segment_t> speeches;
|
5251
|
+
speeches.reserve(256);
|
5252
|
+
|
5253
|
+
bool is_speech_segment = false;
|
5254
|
+
int temp_end = 0;
|
5255
|
+
int prev_end = 0;
|
5256
|
+
int next_start = 0;
|
5257
|
+
int curr_speech_start = 0;
|
5258
|
+
bool has_curr_speech = false;
|
5259
|
+
|
5260
|
+
for (int i = 0; i < n_probs; i++) {
|
5261
|
+
float curr_prob = probs[i];
|
5262
|
+
int curr_sample = n_window * i;
|
5263
|
+
|
5264
|
+
// Reset temp_end when we get back to speech
|
5265
|
+
if ((curr_prob >= threshold) && temp_end) {
|
5266
|
+
temp_end = 0;
|
5267
|
+
if (next_start < prev_end) {
|
5268
|
+
next_start = curr_sample;
|
5269
|
+
}
|
5270
|
+
}
|
5271
|
+
|
5272
|
+
// Start a new speech segment when probability exceeds threshold and not already in speech
|
5273
|
+
if ((curr_prob >= threshold) && !is_speech_segment) {
|
5274
|
+
is_speech_segment = true;
|
5275
|
+
curr_speech_start = curr_sample;
|
5276
|
+
has_curr_speech = true;
|
5277
|
+
continue;
|
5278
|
+
}
|
5279
|
+
|
5280
|
+
// Handle maximum speech duration
|
5281
|
+
if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
|
5282
|
+
if (prev_end) {
|
5283
|
+
speeches.push_back({ curr_speech_start, prev_end });
|
5284
|
+
has_curr_speech = true;
|
5285
|
+
|
5286
|
+
if (next_start < prev_end) { // Previously reached silence and is still not speech
|
5287
|
+
is_speech_segment = false;
|
5288
|
+
has_curr_speech = false;
|
5289
|
+
} else {
|
5290
|
+
curr_speech_start = next_start;
|
5291
|
+
}
|
5292
|
+
prev_end = next_start = temp_end = 0;
|
5293
|
+
} else {
|
5294
|
+
speeches.push_back({ curr_speech_start, curr_sample });
|
5295
|
+
|
5296
|
+
prev_end = next_start = temp_end = 0;
|
5297
|
+
is_speech_segment = false;
|
5298
|
+
has_curr_speech = false;
|
5299
|
+
continue;
|
5300
|
+
}
|
5301
|
+
}
|
5302
|
+
|
5303
|
+
// Handle silence after speech
|
5304
|
+
if ((curr_prob < neg_threshold) && is_speech_segment) {
|
5305
|
+
if (!temp_end) {
|
5306
|
+
temp_end = curr_sample;
|
5307
|
+
}
|
5308
|
+
|
5309
|
+
// Track potential segment ends for max_speech handling
|
5310
|
+
if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
|
5311
|
+
prev_end = temp_end;
|
5312
|
+
}
|
5313
|
+
|
5314
|
+
// Check if silence is long enough to end the segment
|
5315
|
+
if ((curr_sample - temp_end) < min_silence_samples) {
|
5316
|
+
continue;
|
5317
|
+
} else {
|
5318
|
+
// End the segment if it's long enough
|
5319
|
+
if ((temp_end - curr_speech_start) > min_speech_samples) {
|
5320
|
+
speeches.push_back({ curr_speech_start, temp_end });
|
5321
|
+
}
|
3969
5322
|
|
3970
|
-
|
3971
|
-
|
3972
|
-
|
5323
|
+
prev_end = next_start = temp_end = 0;
|
5324
|
+
is_speech_segment = false;
|
5325
|
+
has_curr_speech = false;
|
5326
|
+
continue;
|
5327
|
+
}
|
5328
|
+
}
|
5329
|
+
}
|
3973
5330
|
|
3974
|
-
|
3975
|
-
|
3976
|
-
}
|
5331
|
+
// Handle the case if we're still in a speech segment at the end
|
5332
|
+
if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
|
5333
|
+
speeches.push_back({ curr_speech_start, audio_length_samples });
|
5334
|
+
}
|
3977
5335
|
|
3978
|
-
|
3979
|
-
|
3980
|
-
|
5336
|
+
// Merge adjacent segments with small gaps in between (post-processing)
|
5337
|
+
if (speeches.size() > 1) {
|
5338
|
+
int merged_count = 0;
|
5339
|
+
for (int i = 0; i < (int) speeches.size() - 1; i++) {
|
5340
|
+
// Define maximum gap allowed for merging (e.g., 200ms converted to samples)
|
5341
|
+
int max_merge_gap_samples = sample_rate * 200 / 1000;
|
3981
5342
|
|
3982
|
-
|
3983
|
-
|
3984
|
-
|
5343
|
+
// If the gap between this segment and the next is small enough
|
5344
|
+
if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
|
5345
|
+
// Merge by extending current segment to the end of next segment
|
5346
|
+
speeches[i].end = speeches[i+1].end;
|
5347
|
+
speeches.erase(speeches.begin() + i + 1);
|
3985
5348
|
|
3986
|
-
|
3987
|
-
|
3988
|
-
}
|
5349
|
+
i--;
|
5350
|
+
merged_count++;
|
5351
|
+
}
|
5352
|
+
}
|
5353
|
+
WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
|
5354
|
+
__func__, merged_count, (int) speeches.size());
|
5355
|
+
}
|
3989
5356
|
|
3990
|
-
|
3991
|
-
|
3992
|
-
|
5357
|
+
// Double-check for minimum speech duration
|
5358
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
5359
|
+
if (speeches[i].end - speeches[i].start < min_speech_samples) {
|
5360
|
+
WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
|
5361
|
+
__func__, i, speeches[i].end - speeches[i].start);
|
3993
5362
|
|
3994
|
-
|
3995
|
-
|
3996
|
-
}
|
5363
|
+
speeches.erase(speeches.begin() + i);
|
5364
|
+
i--;
|
5365
|
+
}
|
5366
|
+
}
|
3997
5367
|
|
3998
|
-
|
3999
|
-
return ctx->vocab.token_beg;
|
4000
|
-
}
|
5368
|
+
WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
|
4001
5369
|
|
4002
|
-
|
4003
|
-
|
4004
|
-
|
5370
|
+
// Allocate final segments
|
5371
|
+
std::vector<whisper_vad_segment> segments;
|
5372
|
+
if (speeches.size() > 0) {
|
5373
|
+
try {
|
5374
|
+
segments.resize(speeches.size());
|
5375
|
+
} catch (const std::bad_alloc &) {
|
5376
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
|
5377
|
+
return nullptr;
|
5378
|
+
}
|
5379
|
+
}
|
4005
5380
|
|
4006
|
-
|
4007
|
-
|
4008
|
-
|
5381
|
+
// Apply padding to segments and copy to final segments
|
5382
|
+
for (int i = 0; i < (int) speeches.size(); i++) {
|
5383
|
+
// Apply padding to the start of the first segment
|
5384
|
+
if (i == 0) {
|
5385
|
+
speeches[i].start =
|
5386
|
+
(speeches[i].start > speech_pad_samples) ?
|
5387
|
+
(speeches[i].start - speech_pad_samples) : 0;
|
5388
|
+
}
|
4009
5389
|
|
4010
|
-
|
4011
|
-
|
4012
|
-
|
5390
|
+
// Handle spacing between segments
|
5391
|
+
if (i < (int) speeches.size() - 1) {
|
5392
|
+
int silence_duration = speeches[i+1].start - speeches[i].end;
|
4013
5393
|
|
4014
|
-
|
4015
|
-
|
5394
|
+
if (silence_duration < 2 * speech_pad_samples) {
|
5395
|
+
// If segments are close, split the difference
|
5396
|
+
speeches[i].end += silence_duration / 2;
|
5397
|
+
speeches[i+1].start =
|
5398
|
+
(speeches[i+1].start > silence_duration / 2) ?
|
5399
|
+
(speeches[i+1].start - silence_duration / 2) : 0;
|
5400
|
+
} else {
|
5401
|
+
// Otherwise, apply full padding to both
|
5402
|
+
speeches[i].end =
|
5403
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
5404
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
5405
|
+
speeches[i+1].start =
|
5406
|
+
(speeches[i+1].start > speech_pad_samples) ?
|
5407
|
+
(speeches[i+1].start - speech_pad_samples) : 0;
|
5408
|
+
}
|
5409
|
+
} else {
|
5410
|
+
// Apply padding to the end of the last segment
|
5411
|
+
speeches[i].end =
|
5412
|
+
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
|
5413
|
+
(speeches[i].end + speech_pad_samples) : audio_length_samples;
|
5414
|
+
}
|
4016
5415
|
|
4017
|
-
|
4018
|
-
|
4019
|
-
|
5416
|
+
// Convert from samples to seconds and copy to final segments
|
5417
|
+
segments[i].start = (float)speeches[i].start / sample_rate;
|
5418
|
+
segments[i].end = (float)speeches[i].end / sample_rate;
|
4020
5419
|
|
4021
|
-
|
4022
|
-
|
4023
|
-
|
4024
|
-
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
|
4025
|
-
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
5420
|
+
WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
|
5421
|
+
__func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
|
5422
|
+
}
|
4026
5423
|
|
4027
|
-
|
4028
|
-
|
4029
|
-
|
4030
|
-
|
4031
|
-
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
4032
|
-
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
|
4033
|
-
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
5424
|
+
whisper_vad_segments * vad_segments = new whisper_vad_segments;
|
5425
|
+
if (vad_segments == NULL) {
|
5426
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
|
5427
|
+
return nullptr;
|
4034
5428
|
}
|
4035
|
-
|
5429
|
+
|
5430
|
+
vad_segments->data = std::move(segments);
|
5431
|
+
|
5432
|
+
return vad_segments;
|
4036
5433
|
}
|
4037
5434
|
|
4038
|
-
|
4039
|
-
|
4040
|
-
|
4041
|
-
|
4042
|
-
|
4043
|
-
|
4044
|
-
|
4045
|
-
|
4046
|
-
|
4047
|
-
ctx->state->n_sample = 0;
|
4048
|
-
ctx->state->n_encode = 0;
|
4049
|
-
ctx->state->n_decode = 0;
|
4050
|
-
ctx->state->n_batchd = 0;
|
4051
|
-
ctx->state->n_prompt = 0;
|
5435
|
+
struct whisper_vad_segments * whisper_vad_segments_from_samples(
|
5436
|
+
whisper_vad_context * vctx,
|
5437
|
+
whisper_vad_params params,
|
5438
|
+
const float * samples,
|
5439
|
+
int n_samples) {
|
5440
|
+
WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
|
5441
|
+
if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
|
5442
|
+
WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
|
5443
|
+
return nullptr;
|
4052
5444
|
}
|
5445
|
+
return whisper_vad_segments_from_probs(vctx, params);
|
4053
5446
|
}
|
4054
5447
|
|
4055
|
-
|
4056
|
-
|
4057
|
-
|
4058
|
-
|
4059
|
-
|
4060
|
-
#endif
|
4061
|
-
}
|
5448
|
+
void whisper_vad_free(whisper_vad_context * ctx) {
|
5449
|
+
if (ctx) {
|
5450
|
+
for (ggml_context * context : ctx->model.ctxs) {
|
5451
|
+
ggml_free(context);
|
5452
|
+
}
|
4062
5453
|
|
4063
|
-
|
4064
|
-
|
4065
|
-
|
4066
|
-
#else
|
4067
|
-
return 0;
|
4068
|
-
#endif
|
4069
|
-
}
|
5454
|
+
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
5455
|
+
ggml_backend_buffer_free(buf);
|
5456
|
+
}
|
4070
5457
|
|
4071
|
-
|
4072
|
-
static std::string s;
|
5458
|
+
ggml_backend_sched_free(ctx->sched.sched);
|
4073
5459
|
|
4074
|
-
|
4075
|
-
|
4076
|
-
|
4077
|
-
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
4078
|
-
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
4079
|
-
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
4080
|
-
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
4081
|
-
s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
|
4082
|
-
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
4083
|
-
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
4084
|
-
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
4085
|
-
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
4086
|
-
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
4087
|
-
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
4088
|
-
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
4089
|
-
s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
|
4090
|
-
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
4091
|
-
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ;
|
5460
|
+
for (auto & backend : ctx->backends) {
|
5461
|
+
ggml_backend_free(backend);
|
5462
|
+
}
|
4092
5463
|
|
4093
|
-
|
5464
|
+
|
5465
|
+
delete ctx;
|
5466
|
+
}
|
5467
|
+
}
|
5468
|
+
|
5469
|
+
void whisper_vad_free_segments(whisper_vad_segments * segments) {
|
5470
|
+
if (segments) {
|
5471
|
+
delete segments;
|
5472
|
+
}
|
4094
5473
|
}
|
4095
5474
|
|
4096
5475
|
//////////////////////////////////
|
@@ -4099,7 +5478,7 @@ const char * whisper_print_system_info(void) {
|
|
4099
5478
|
|
4100
5479
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
4101
5480
|
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
|
4102
|
-
std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
5481
|
+
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
|
4103
5482
|
const char * src,
|
4104
5483
|
whisper_partial_utf8 partial_start) {
|
4105
5484
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
@@ -4248,7 +5627,7 @@ static void whisper_grammar_advance_stack(
|
|
4248
5627
|
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
|
4249
5628
|
|
4250
5629
|
if (stack.empty()) {
|
4251
|
-
new_stacks.
|
5630
|
+
new_stacks.emplace_back();
|
4252
5631
|
return;
|
4253
5632
|
}
|
4254
5633
|
|
@@ -4513,7 +5892,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
|
4513
5892
|
|
4514
5893
|
////////////////////////////////////////////////////////////////////////////
|
4515
5894
|
|
4516
|
-
struct whisper_context_params * whisper_context_default_params_by_ref() {
|
5895
|
+
struct whisper_context_params * whisper_context_default_params_by_ref(void) {
|
4517
5896
|
struct whisper_context_params params = whisper_context_default_params();
|
4518
5897
|
|
4519
5898
|
struct whisper_context_params* result = new whisper_context_params();
|
@@ -4554,7 +5933,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4554
5933
|
/*.split_on_word =*/ false,
|
4555
5934
|
/*.max_tokens =*/ 0,
|
4556
5935
|
|
4557
|
-
/*.speed_up =*/ false,
|
4558
5936
|
/*.debug_mode =*/ false,
|
4559
5937
|
/*.audio_ctx =*/ 0,
|
4560
5938
|
|
@@ -4570,7 +5948,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4570
5948
|
/*.detect_language =*/ false,
|
4571
5949
|
|
4572
5950
|
/*.suppress_blank =*/ true,
|
4573
|
-
/*.
|
5951
|
+
/*.suppress_nst =*/ false,
|
4574
5952
|
|
4575
5953
|
/*.temperature =*/ 0.0f,
|
4576
5954
|
/*.max_initial_ts =*/ 1.0f,
|
@@ -4610,6 +5988,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
4610
5988
|
/*.n_grammar_rules =*/ 0,
|
4611
5989
|
/*.i_start_rule =*/ 0,
|
4612
5990
|
/*.grammar_penalty =*/ 100.0f,
|
5991
|
+
|
5992
|
+
/*.vad =*/ false,
|
5993
|
+
/*.vad_model_path =*/ nullptr,
|
5994
|
+
|
5995
|
+
/* vad_params =*/ whisper_vad_default_params(),
|
4613
5996
|
};
|
4614
5997
|
|
4615
5998
|
switch (strategy) {
|
@@ -4720,6 +6103,42 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
4720
6103
|
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
4721
6104
|
};
|
4722
6105
|
|
6106
|
+
static void whisper_compute_logprobs(
|
6107
|
+
const std::vector<float> & logits,
|
6108
|
+
const int n_logits,
|
6109
|
+
std::vector<float> & logprobs) {
|
6110
|
+
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
6111
|
+
float logsumexp = 0.0f;
|
6112
|
+
for (int i = 0; i < n_logits; ++i) {
|
6113
|
+
if (logits[i] > -INFINITY) {
|
6114
|
+
logsumexp += expf(logits[i] - logit_max);
|
6115
|
+
}
|
6116
|
+
}
|
6117
|
+
logsumexp = logf(logsumexp) + logit_max;
|
6118
|
+
|
6119
|
+
for (int i = 0; i < n_logits; ++i) {
|
6120
|
+
if (logits[i] > -INFINITY) {
|
6121
|
+
logprobs[i] = logits[i] - logsumexp;
|
6122
|
+
} else {
|
6123
|
+
logprobs[i] = -INFINITY;
|
6124
|
+
}
|
6125
|
+
}
|
6126
|
+
}
|
6127
|
+
|
6128
|
+
static void whisper_compute_probs(
|
6129
|
+
const std::vector<float> & logits,
|
6130
|
+
const int n_logits,
|
6131
|
+
const std::vector<float> & logprobs,
|
6132
|
+
std::vector<float> & probs) {
|
6133
|
+
for (int i = 0; i < n_logits; ++i) {
|
6134
|
+
if (logits[i] == -INFINITY) {
|
6135
|
+
probs[i] = 0.0f;
|
6136
|
+
} else {
|
6137
|
+
probs[i] = expf(logprobs[i]);
|
6138
|
+
}
|
6139
|
+
}
|
6140
|
+
}
|
6141
|
+
|
4723
6142
|
// process the logits for the selected decoder
|
4724
6143
|
// - applies logit filters
|
4725
6144
|
// - computes logprobs and probs
|
@@ -4781,7 +6200,7 @@ static void whisper_process_logits(
|
|
4781
6200
|
|
4782
6201
|
// suppress sot and nosp tokens
|
4783
6202
|
logits[vocab.token_sot] = -INFINITY;
|
4784
|
-
logits[vocab.token_nosp] = -INFINITY;
|
6203
|
+
logits[vocab.token_nosp] = -INFINITY;
|
4785
6204
|
|
4786
6205
|
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
4787
6206
|
if (params.tdrz_enable == false) {
|
@@ -4818,7 +6237,7 @@ static void whisper_process_logits(
|
|
4818
6237
|
|
4819
6238
|
// suppress non-speech tokens
|
4820
6239
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
4821
|
-
if (params.
|
6240
|
+
if (params.suppress_nst) {
|
4822
6241
|
for (const std::string & token : non_speech_tokens) {
|
4823
6242
|
const std::string suppress_tokens[] = {token, " " + token};
|
4824
6243
|
for (const std::string & suppress_token : suppress_tokens) {
|
@@ -4880,24 +6299,7 @@ static void whisper_process_logits(
|
|
4880
6299
|
}
|
4881
6300
|
|
4882
6301
|
// populate the logprobs array (log_softmax)
|
4883
|
-
|
4884
|
-
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
4885
|
-
float logsumexp = 0.0f;
|
4886
|
-
for (int i = 0; i < n_logits; ++i) {
|
4887
|
-
if (logits[i] > -INFINITY) {
|
4888
|
-
logsumexp += expf(logits[i] - logit_max);
|
4889
|
-
}
|
4890
|
-
}
|
4891
|
-
logsumexp = logf(logsumexp) + logit_max;
|
4892
|
-
|
4893
|
-
for (int i = 0; i < n_logits; ++i) {
|
4894
|
-
if (logits[i] > -INFINITY) {
|
4895
|
-
logprobs[i] = logits[i] - logsumexp;
|
4896
|
-
} else {
|
4897
|
-
logprobs[i] = -INFINITY;
|
4898
|
-
}
|
4899
|
-
}
|
4900
|
-
}
|
6302
|
+
whisper_compute_logprobs(logits, n_logits, logprobs);
|
4901
6303
|
|
4902
6304
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
4903
6305
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
@@ -4955,15 +6357,7 @@ static void whisper_process_logits(
|
|
4955
6357
|
}
|
4956
6358
|
|
4957
6359
|
// compute probs
|
4958
|
-
|
4959
|
-
for (int i = 0; i < n_logits; ++i) {
|
4960
|
-
if (logits[i] == -INFINITY) {
|
4961
|
-
probs[i] = 0.0f;
|
4962
|
-
} else {
|
4963
|
-
probs[i] = expf(logprobs[i]);
|
4964
|
-
}
|
4965
|
-
}
|
4966
|
-
}
|
6360
|
+
whisper_compute_probs(logits, n_logits, logprobs, probs);
|
4967
6361
|
|
4968
6362
|
#if 0
|
4969
6363
|
// print first 100 logits - token string : logit
|
@@ -5215,6 +6609,121 @@ static void whisper_sequence_score(
|
|
5215
6609
|
}
|
5216
6610
|
}
|
5217
6611
|
|
6612
|
+
static bool whisper_vad(
|
6613
|
+
struct whisper_context * ctx,
|
6614
|
+
struct whisper_state * state,
|
6615
|
+
struct whisper_full_params params,
|
6616
|
+
const float * samples,
|
6617
|
+
int n_samples,
|
6618
|
+
std::vector<float> & filtered_samples,
|
6619
|
+
int & filtered_n_samples) {
|
6620
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
|
6621
|
+
filtered_n_samples = 0;
|
6622
|
+
|
6623
|
+
if (state->vad_context == nullptr) {
|
6624
|
+
struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
|
6625
|
+
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
|
6626
|
+
if (vctx == nullptr) {
|
6627
|
+
WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
|
6628
|
+
return false;
|
6629
|
+
}
|
6630
|
+
state->vad_context = vctx;
|
6631
|
+
}
|
6632
|
+
auto vctx = state->vad_context;
|
6633
|
+
|
6634
|
+
const whisper_vad_params & vad_params = params.vad_params;
|
6635
|
+
|
6636
|
+
whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
|
6637
|
+
|
6638
|
+
if (vad_segments->data.size() > 0) {
|
6639
|
+
state->has_vad_segments = true;
|
6640
|
+
ctx->state->vad_segments.clear();
|
6641
|
+
ctx->state->vad_segments.reserve(vad_segments->data.size());
|
6642
|
+
|
6643
|
+
WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
|
6644
|
+
float overlap_seconds = vad_params.samples_overlap;
|
6645
|
+
int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
|
6646
|
+
|
6647
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
6648
|
+
int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
|
6649
|
+
int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
|
6650
|
+
|
6651
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6652
|
+
segment_end_samples += overlap_samples;
|
6653
|
+
}
|
6654
|
+
segment_end_samples = std::min(segment_end_samples, n_samples - 1);
|
6655
|
+
filtered_n_samples += (segment_end_samples - segment_start_samples);
|
6656
|
+
|
6657
|
+
WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
|
6658
|
+
__func__, i, vad_segments->data[i].start,
|
6659
|
+
vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
|
6660
|
+
(vad_segments->data[i].end - vad_segments->data[i].start) +
|
6661
|
+
(i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
|
6662
|
+
}
|
6663
|
+
|
6664
|
+
int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
|
6665
|
+
int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
|
6666
|
+
int total_samples_needed = filtered_n_samples + total_silence_samples;
|
6667
|
+
|
6668
|
+
WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
|
6669
|
+
__func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
|
6670
|
+
|
6671
|
+
try {
|
6672
|
+
filtered_samples.resize(total_samples_needed);
|
6673
|
+
} catch (const std::bad_alloc & /* e */) {
|
6674
|
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
|
6675
|
+
whisper_vad_free_segments(vad_segments);
|
6676
|
+
whisper_vad_free(vctx);
|
6677
|
+
return false;
|
6678
|
+
}
|
6679
|
+
|
6680
|
+
int offset = 0;
|
6681
|
+
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
|
6682
|
+
int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
|
6683
|
+
int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
|
6684
|
+
|
6685
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6686
|
+
segment_end_samples += overlap_samples;
|
6687
|
+
}
|
6688
|
+
|
6689
|
+
segment_start_samples = std::min(segment_start_samples, n_samples - 1);
|
6690
|
+
segment_end_samples = std::min(segment_end_samples, n_samples);
|
6691
|
+
int segment_length = segment_end_samples - segment_start_samples;
|
6692
|
+
|
6693
|
+
if (segment_length > 0) {
|
6694
|
+
whisper_state::vad_segment_info segment;
|
6695
|
+
|
6696
|
+
segment.orig_start = vad_segments->data[i].start;
|
6697
|
+
segment.orig_end = vad_segments->data[i].end;
|
6698
|
+
|
6699
|
+
segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
|
6700
|
+
segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
|
6701
|
+
|
6702
|
+
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
|
6703
|
+
__func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
|
6704
|
+
ctx->state->vad_segments.push_back(segment);
|
6705
|
+
|
6706
|
+
// Copy this speech segment
|
6707
|
+
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
|
6708
|
+
offset += segment_length;
|
6709
|
+
|
6710
|
+
// Add silence after this segment (except after the last segment)
|
6711
|
+
if (i < (int)vad_segments->data.size() - 1) {
|
6712
|
+
// Fill with zeros (silence)
|
6713
|
+
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
|
6714
|
+
offset += silence_samples;
|
6715
|
+
}
|
6716
|
+
}
|
6717
|
+
}
|
6718
|
+
|
6719
|
+
filtered_n_samples = offset;
|
6720
|
+
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
|
6721
|
+
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
|
6722
|
+
}
|
6723
|
+
|
6724
|
+
return true;
|
6725
|
+
}
|
6726
|
+
|
5218
6727
|
int whisper_full_with_state(
|
5219
6728
|
struct whisper_context * ctx,
|
5220
6729
|
struct whisper_state * state,
|
@@ -5226,17 +6735,29 @@ int whisper_full_with_state(
|
|
5226
6735
|
|
5227
6736
|
result_all.clear();
|
5228
6737
|
|
5229
|
-
|
6738
|
+
const float * process_samples = samples;
|
6739
|
+
int n_process_samples = n_samples;
|
6740
|
+
std::vector<float> vad_samples;
|
6741
|
+
|
6742
|
+
if (params.vad) {
|
6743
|
+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
|
6744
|
+
int vad_n_samples;
|
6745
|
+
if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
|
6746
|
+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
|
6747
|
+
return -1;
|
6748
|
+
}
|
6749
|
+
if (vad_n_samples == 0) {
|
6750
|
+
return 0;
|
6751
|
+
}
|
6752
|
+
process_samples = vad_samples.data();
|
6753
|
+
n_process_samples = vad_n_samples;
|
6754
|
+
}
|
6755
|
+
|
6756
|
+
if (n_process_samples > 0) {
|
5230
6757
|
// compute log mel spectrogram
|
5231
|
-
if (params.
|
5232
|
-
// TODO: Replace PV with more advanced algorithm
|
6758
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
|
5233
6759
|
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5234
|
-
return -
|
5235
|
-
} else {
|
5236
|
-
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
5237
|
-
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
5238
|
-
return -2;
|
5239
|
-
}
|
6760
|
+
return -2;
|
5240
6761
|
}
|
5241
6762
|
}
|
5242
6763
|
|
@@ -5270,11 +6791,13 @@ int whisper_full_with_state(
|
|
5270
6791
|
const int seek_start = params.offset_ms/10;
|
5271
6792
|
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
|
5272
6793
|
|
5273
|
-
// if length of spectrogram is less than
|
5274
|
-
// basically don't process anything that is less than
|
5275
|
-
//
|
5276
|
-
|
5277
|
-
|
6794
|
+
// if length of spectrogram is less than 100ms (10 frames), then return
|
6795
|
+
// basically don't process anything that is less than 100ms
|
6796
|
+
// ref: https://github.com/ggml-org/whisper.cpp/issues/2065
|
6797
|
+
const int delta_min = 10;
|
6798
|
+
|
6799
|
+
if (seek_end < seek_start + delta_min) {
|
6800
|
+
WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
|
5278
6801
|
return 0;
|
5279
6802
|
}
|
5280
6803
|
|
@@ -5321,7 +6844,7 @@ int whisper_full_with_state(
|
|
5321
6844
|
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
5322
6845
|
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
5323
6846
|
|
5324
|
-
decoder.rng = std::mt19937(
|
6847
|
+
decoder.rng = std::mt19937(j);
|
5325
6848
|
}
|
5326
6849
|
|
5327
6850
|
// the accumulated text context so far
|
@@ -5418,8 +6941,8 @@ int whisper_full_with_state(
|
|
5418
6941
|
ctx, state, progress_cur, params.progress_callback_user_data);
|
5419
6942
|
}
|
5420
6943
|
|
5421
|
-
// if only
|
5422
|
-
if (seek +
|
6944
|
+
// if only 100ms left, then stop
|
6945
|
+
if (seek + delta_min >= seek_end) {
|
5423
6946
|
break;
|
5424
6947
|
}
|
5425
6948
|
|
@@ -5518,13 +7041,46 @@ int whisper_full_with_state(
|
|
5518
7041
|
}
|
5519
7042
|
WHISPER_LOG_DEBUG("\n\n");
|
5520
7043
|
|
7044
|
+
// recreate the KV cache if the number of decoders has changed
|
7045
|
+
if (state->kv_self_n_dec < n_decoders_cur) {
|
7046
|
+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
7047
|
+
|
7048
|
+
whisper_kv_cache_free(state->kv_self);
|
7049
|
+
|
7050
|
+
// overallocate to workaround KV cache fragmentation issues
|
7051
|
+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
7052
|
+
|
7053
|
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
7054
|
+
ctx->model.hparams.n_text_state,
|
7055
|
+
ctx->model.hparams.n_text_layer,
|
7056
|
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
7057
|
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
7058
|
+
whisper_free_state(state);
|
7059
|
+
return -7;
|
7060
|
+
}
|
7061
|
+
|
7062
|
+
state->kv_self_n_dec = n_decoders_cur;
|
7063
|
+
}
|
7064
|
+
|
5521
7065
|
whisper_kv_cache_clear(state->kv_self);
|
5522
7066
|
|
5523
7067
|
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
5524
7068
|
|
5525
7069
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
5526
7070
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
5527
|
-
return -
|
7071
|
+
return -8;
|
7072
|
+
}
|
7073
|
+
|
7074
|
+
// Calculate no_speech probability after first decode.
|
7075
|
+
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
|
7076
|
+
{
|
7077
|
+
const int n_logits = ctx->vocab.id_to_token.size();
|
7078
|
+
std::vector<float> logprobs(n_logits);
|
7079
|
+
std::vector<float> probs(n_logits);
|
7080
|
+
|
7081
|
+
whisper_compute_logprobs(state->logits, n_logits, logprobs);
|
7082
|
+
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
|
7083
|
+
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
|
5528
7084
|
}
|
5529
7085
|
|
5530
7086
|
{
|
@@ -5733,10 +7289,10 @@ int whisper_full_with_state(
|
|
5733
7289
|
// end of segment
|
5734
7290
|
if (token.id == whisper_token_eot(ctx) || // end of text token
|
5735
7291
|
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
5736
|
-
(has_ts && seek + seek_delta +
|
7292
|
+
(has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
|
5737
7293
|
) {
|
5738
7294
|
if (result_len == 0 && !params.no_timestamps) {
|
5739
|
-
if (seek + seek_delta +
|
7295
|
+
if (seek + seek_delta + delta_min >= seek_end) {
|
5740
7296
|
result_len = i + 1;
|
5741
7297
|
} else {
|
5742
7298
|
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
@@ -5824,7 +7380,7 @@ int whisper_full_with_state(
|
|
5824
7380
|
|
5825
7381
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
5826
7382
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
5827
|
-
return -
|
7383
|
+
return -9;
|
5828
7384
|
}
|
5829
7385
|
|
5830
7386
|
const int64_t t_start_sample_us = ggml_time_us();
|
@@ -5918,8 +7474,9 @@ int whisper_full_with_state(
|
|
5918
7474
|
if (it != (int) temperatures.size() - 1) {
|
5919
7475
|
const auto & decoder = state->decoders[best_decoder_id];
|
5920
7476
|
|
5921
|
-
if (decoder.failed ||
|
5922
|
-
|
7477
|
+
if (decoder.failed ||
|
7478
|
+
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
|
7479
|
+
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
|
5923
7480
|
success = false;
|
5924
7481
|
state->n_fail_p++;
|
5925
7482
|
}
|
@@ -5940,7 +7497,7 @@ int whisper_full_with_state(
|
|
5940
7497
|
{
|
5941
7498
|
const auto & best_decoder = state->decoders[best_decoder_id];
|
5942
7499
|
|
5943
|
-
|
7500
|
+
auto seek_delta = best_decoder.seek_delta;
|
5944
7501
|
const auto result_len = best_decoder.sequence.result_len;
|
5945
7502
|
|
5946
7503
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
@@ -5948,6 +7505,9 @@ int whisper_full_with_state(
|
|
5948
7505
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
5949
7506
|
const auto n_segments_before = state->result_all.size();
|
5950
7507
|
|
7508
|
+
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
|
7509
|
+
best_decoder.sequence.avg_logprobs < params.logprob_thold);
|
7510
|
+
|
5951
7511
|
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
5952
7512
|
|
5953
7513
|
// update prompt_past
|
@@ -5956,11 +7516,11 @@ int whisper_full_with_state(
|
|
5956
7516
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
5957
7517
|
}
|
5958
7518
|
|
5959
|
-
for (int i = 0; i < result_len; ++i) {
|
7519
|
+
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
5960
7520
|
prompt_past.push_back(tokens_cur[i].id);
|
5961
7521
|
}
|
5962
7522
|
|
5963
|
-
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
7523
|
+
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
5964
7524
|
int i0 = 0;
|
5965
7525
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
5966
7526
|
|
@@ -5985,8 +7545,8 @@ int whisper_full_with_state(
|
|
5985
7545
|
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
|
5986
7546
|
|
5987
7547
|
if (!text.empty()) {
|
5988
|
-
const auto tt0 =
|
5989
|
-
const auto tt1 =
|
7548
|
+
const auto tt0 = t0;
|
7549
|
+
const auto tt1 = t1;
|
5990
7550
|
|
5991
7551
|
if (params.print_realtime) {
|
5992
7552
|
if (params.print_timestamps) {
|
@@ -5999,7 +7559,7 @@ int whisper_full_with_state(
|
|
5999
7559
|
|
6000
7560
|
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
|
6001
7561
|
|
6002
|
-
result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
|
7562
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
6003
7563
|
for (int j = i0; j <= i; j++) {
|
6004
7564
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
6005
7565
|
}
|
@@ -6014,7 +7574,7 @@ int whisper_full_with_state(
|
|
6014
7574
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
6015
7575
|
}
|
6016
7576
|
}
|
6017
|
-
if (params.new_segment_callback) {
|
7577
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
6018
7578
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
6019
7579
|
}
|
6020
7580
|
}
|
@@ -6032,8 +7592,8 @@ int whisper_full_with_state(
|
|
6032
7592
|
if (!text.empty()) {
|
6033
7593
|
const auto t1 = seek + seek_delta;
|
6034
7594
|
|
6035
|
-
const auto tt0 =
|
6036
|
-
const auto tt1 =
|
7595
|
+
const auto tt0 = t0;
|
7596
|
+
const auto tt1 = t1;
|
6037
7597
|
|
6038
7598
|
if (params.print_realtime) {
|
6039
7599
|
if (params.print_timestamps) {
|
@@ -6044,7 +7604,7 @@ int whisper_full_with_state(
|
|
6044
7604
|
}
|
6045
7605
|
}
|
6046
7606
|
|
6047
|
-
result_all.push_back({ tt0, tt1, text, {}
|
7607
|
+
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
|
6048
7608
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
6049
7609
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
6050
7610
|
}
|
@@ -6059,7 +7619,7 @@ int whisper_full_with_state(
|
|
6059
7619
|
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
6060
7620
|
}
|
6061
7621
|
}
|
6062
|
-
if (params.new_segment_callback) {
|
7622
|
+
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
|
6063
7623
|
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
6064
7624
|
}
|
6065
7625
|
}
|
@@ -6068,14 +7628,28 @@ int whisper_full_with_state(
|
|
6068
7628
|
// FIXME: will timestamp offsets be correct?
|
6069
7629
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
6070
7630
|
{
|
6071
|
-
const
|
7631
|
+
const int n_segments = state->result_all.size() - n_segments_before;
|
6072
7632
|
if (ctx->params.dtw_token_timestamps && n_segments) {
|
6073
7633
|
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
|
6074
7634
|
whisper_exp_compute_token_level_timestamps_dtw(
|
6075
7635
|
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
|
7636
|
+
if (params.new_segment_callback) {
|
7637
|
+
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
|
7638
|
+
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
|
7639
|
+
}
|
7640
|
+
}
|
6076
7641
|
}
|
6077
7642
|
}
|
6078
7643
|
|
7644
|
+
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
|
7645
|
+
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
7646
|
+
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
7647
|
+
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
7648
|
+
if (single_timestamp_ending) {
|
7649
|
+
WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
|
7650
|
+
seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
|
7651
|
+
}
|
7652
|
+
|
6079
7653
|
// update audio window
|
6080
7654
|
seek += seek_delta;
|
6081
7655
|
|
@@ -6226,19 +7800,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
|
|
6226
7800
|
}
|
6227
7801
|
|
6228
7802
|
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
|
6229
|
-
return
|
7803
|
+
// If VAD wasn't used, return the original timestamp
|
7804
|
+
if (!state->has_vad_segments || state->vad_segments.empty()) {
|
7805
|
+
return state->result_all[i_segment].t0;
|
7806
|
+
}
|
7807
|
+
|
7808
|
+
// Get the start timestamp produced by whisper_full. whisper_full processes
|
7809
|
+
// only the speech segments in this case so we need to map these timestamps
|
7810
|
+
// back to the original audio.
|
7811
|
+
float t0 = state->result_all[i_segment].t0 / 100.0f;
|
7812
|
+
|
7813
|
+
// Find which VAD segment this timestamp belongs.
|
7814
|
+
// TODO(danbev) This could be optimized by using a binary search if the number
|
7815
|
+
// of segments exceed a certain limit. Also we might be able to assume that
|
7816
|
+
// the access pattern is sequential and optimized for that too.
|
7817
|
+
for (size_t i = 0; i < state->vad_segments.size(); i++) {
|
7818
|
+
const auto & segment = state->vad_segments[i];
|
7819
|
+
|
7820
|
+
// Check if the timestamp falls within this segment.
|
7821
|
+
if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
|
7822
|
+
float proportion = 0.0f;
|
7823
|
+
if (segment.vad_end > segment.vad_start) {
|
7824
|
+
proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
|
7825
|
+
}
|
7826
|
+
float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
|
7827
|
+
return (int64_t)(orig_t0 * 100);
|
7828
|
+
}
|
7829
|
+
}
|
7830
|
+
|
7831
|
+
// Check if the timestamp falls between two segments.
|
7832
|
+
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
|
7833
|
+
const auto & curr = state->vad_segments[i];
|
7834
|
+
const auto & next = state->vad_segments[i + 1];
|
7835
|
+
|
7836
|
+
if (t0 > curr.vad_end && t0 < next.vad_start) {
|
7837
|
+
// Calculate how far we are through the gap as a proportion
|
7838
|
+
float gap_proportion = 0.0f;
|
7839
|
+
if (next.vad_start > curr.vad_end) {
|
7840
|
+
gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
|
7841
|
+
}
|
7842
|
+
float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
|
7843
|
+
return (int64_t)(orig_t0 * 100);
|
7844
|
+
}
|
7845
|
+
}
|
7846
|
+
|
7847
|
+
// Handle the case where the timestamp is after the last segment.
|
7848
|
+
if (t0 > state->vad_segments.back().vad_end) {
|
7849
|
+
// For timestamps after the last segment, add the extra time to the end of the last segment
|
7850
|
+
const auto& last = state->vad_segments.back();
|
7851
|
+
// Calculate how far beyond the last segment
|
7852
|
+
float extra_time = t0 - last.vad_end;
|
7853
|
+
// Add this extra time to the original end time
|
7854
|
+
float orig_t0 = last.orig_end + extra_time;
|
7855
|
+
return (int64_t)(orig_t0 * 100);
|
7856
|
+
}
|
7857
|
+
|
7858
|
+
WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
|
7859
|
+
return t0;
|
6230
7860
|
}
|
6231
7861
|
|
6232
7862
|
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
6233
|
-
return ctx->state
|
7863
|
+
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
|
6234
7864
|
}
|
6235
7865
|
|
6236
7866
|
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
|
6237
|
-
return
|
7867
|
+
// If VAD wasn't used, return the original timestamp
|
7868
|
+
if (!state->has_vad_segments || state->vad_segments.empty()) {
|
7869
|
+
return state->result_all[i_segment].t1;
|
7870
|
+
}
|
7871
|
+
|
7872
|
+
// Get the end timestamp produced by whisper_full. whisper_full processes
|
7873
|
+
// only the speech segments in this case so we need to map these timestamps
|
7874
|
+
// back to the original audio.
|
7875
|
+
float t1 = state->result_all[i_segment].t1 / 100.0f;
|
7876
|
+
|
7877
|
+
// Find which VAD segment this timestamp belongs.
|
7878
|
+
// TODO(danbev) This could be optimized by using a binary search if the number
|
7879
|
+
// of segments exceed a certain limit. Also we might be able to assume that
|
7880
|
+
// the access pattern is sequential and optimized for that too.
|
7881
|
+
for (size_t i = 0; i < state->vad_segments.size(); i++) {
|
7882
|
+
const auto& segment = state->vad_segments[i];
|
7883
|
+
|
7884
|
+
// Check if the timestamp falls within this segment.
|
7885
|
+
if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
|
7886
|
+
// Calculate the proportion through the filtered segment.
|
7887
|
+
float proportion = 0.0f;
|
7888
|
+
if (segment.vad_end > segment.vad_start) {
|
7889
|
+
proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
|
7890
|
+
}
|
7891
|
+
float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
|
7892
|
+
return (int64_t)(orig_t1 * 100);
|
7893
|
+
}
|
7894
|
+
}
|
7895
|
+
|
7896
|
+
// Check if the timestamp falls between two segments.
|
7897
|
+
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
|
7898
|
+
const auto & curr = state->vad_segments[i];
|
7899
|
+
const auto & next = state->vad_segments[i + 1];
|
7900
|
+
|
7901
|
+
if (t1 > curr.vad_end && t1 < next.vad_start) {
|
7902
|
+
// Calculate how far we are through the gap as a proportion
|
7903
|
+
float gap_proportion = 0.0f;
|
7904
|
+
if (next.vad_start > curr.vad_end) {
|
7905
|
+
gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
|
7906
|
+
}
|
7907
|
+
// Map to the corresponding position in the original gap
|
7908
|
+
float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
|
7909
|
+
return (int64_t)(orig_t1 * 100);
|
7910
|
+
}
|
7911
|
+
}
|
7912
|
+
|
7913
|
+
// Handle the case where the timestamp is after the last segment
|
7914
|
+
if (t1 > state->vad_segments.back().vad_end) {
|
7915
|
+
// For the last segment, use the end of the last VAD segment
|
7916
|
+
const auto& last = state->vad_segments.back();
|
7917
|
+
// Calculate how far beyond the last segment
|
7918
|
+
float extra_time = t1 - last.vad_end;
|
7919
|
+
// Add this extra time to the original end time
|
7920
|
+
float orig_t1 = last.orig_end + extra_time;
|
7921
|
+
return (int64_t)(orig_t1 * 100);
|
7922
|
+
}
|
7923
|
+
|
7924
|
+
WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
|
7925
|
+
return t1;
|
6238
7926
|
}
|
6239
7927
|
|
6240
7928
|
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
6241
|
-
return ctx->state
|
7929
|
+
return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
|
6242
7930
|
}
|
6243
7931
|
|
6244
7932
|
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
@@ -6297,6 +7985,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
|
|
6297
7985
|
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
6298
7986
|
}
|
6299
7987
|
|
7988
|
+
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
|
7989
|
+
return ctx->state->result_all[i_segment].no_speech_prob;
|
7990
|
+
}
|
7991
|
+
|
7992
|
+
float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
|
7993
|
+
return state->result_all[i_segment].no_speech_prob;
|
7994
|
+
}
|
7995
|
+
|
6300
7996
|
// =================================================================================================
|
6301
7997
|
|
6302
7998
|
//
|
@@ -6458,6 +8154,8 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
|
|
6458
8154
|
}
|
6459
8155
|
|
6460
8156
|
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
8157
|
+
whisper_load_backends();
|
8158
|
+
|
6461
8159
|
static std::string s;
|
6462
8160
|
s = "";
|
6463
8161
|
char strbuf[256];
|
@@ -6477,7 +8175,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6477
8175
|
// c: N*N*sizeof(float)
|
6478
8176
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
6479
8177
|
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
|
6480
|
-
std::vector<uint8_t> work;
|
6481
8178
|
|
6482
8179
|
// put a bunch of random data in the buffer
|
6483
8180
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
@@ -6534,12 +8231,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
6534
8231
|
double tsum = 0.0;
|
6535
8232
|
|
6536
8233
|
// heat-up
|
6537
|
-
ggml_graph_compute_helper(gf,
|
8234
|
+
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
6538
8235
|
|
6539
8236
|
for (int i = 0; i < n_max; ++i) {
|
6540
8237
|
const int64_t t0 = ggml_time_us();
|
6541
8238
|
|
6542
|
-
ggml_graph_compute_helper(gf,
|
8239
|
+
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
|
6543
8240
|
|
6544
8241
|
const int64_t t1 = ggml_time_us();
|
6545
8242
|
|
@@ -6700,12 +8397,6 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6700
8397
|
|
6701
8398
|
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
|
6702
8399
|
|
6703
|
-
tokens[j].id = token.id;
|
6704
|
-
tokens[j].tid = token.tid;
|
6705
|
-
tokens[j].p = token.p;
|
6706
|
-
tokens[j].pt = token.pt;
|
6707
|
-
tokens[j].ptsum = token.ptsum;
|
6708
|
-
|
6709
8400
|
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
|
6710
8401
|
|
6711
8402
|
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
@@ -6835,7 +8526,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
6835
8526
|
k++;
|
6836
8527
|
}
|
6837
8528
|
tokens[j].t1 = sample_to_timestamp(k);
|
6838
|
-
if (j <
|
8529
|
+
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
6839
8530
|
tokens[j].t1 = tokens[j + 1].t0;
|
6840
8531
|
} else {
|
6841
8532
|
s1 = k;
|
@@ -6916,18 +8607,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
6916
8607
|
struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
|
6917
8608
|
struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
|
6918
8609
|
|
6919
|
-
cost =
|
6920
|
-
trace =
|
6921
|
-
|
8610
|
+
cost = whisper_set_f32(cost, INFINITY);
|
8611
|
+
trace = whisper_set_i32(trace, -1);
|
8612
|
+
whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
|
6922
8613
|
|
6923
8614
|
// dtw
|
6924
8615
|
// supposedly can be optmized by computing diagonals in parallel ?
|
6925
8616
|
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
|
6926
8617
|
for (int64_t j = 1; j < M + 1; ++j) {
|
6927
8618
|
for (int64_t i = 1; i < N + 1; ++i) {
|
6928
|
-
float c0 =
|
6929
|
-
float c1 =
|
6930
|
-
float c2 =
|
8619
|
+
float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
|
8620
|
+
float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
|
8621
|
+
float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
|
6931
8622
|
|
6932
8623
|
float c;
|
6933
8624
|
int32_t t;
|
@@ -6942,9 +8633,9 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
6942
8633
|
t = 2;
|
6943
8634
|
}
|
6944
8635
|
|
6945
|
-
c =
|
6946
|
-
|
6947
|
-
|
8636
|
+
c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
|
8637
|
+
whisper_set_f32_nd(cost, i, j, 0, 0, c);
|
8638
|
+
whisper_set_i32_nd(trace, i, j, 0, 0, t);
|
6948
8639
|
}
|
6949
8640
|
}
|
6950
8641
|
|
@@ -6953,19 +8644,19 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
6953
8644
|
struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
|
6954
8645
|
// trace[0, :] = 2;
|
6955
8646
|
for (int64_t i = 0; i < M + 1; ++i)
|
6956
|
-
|
8647
|
+
whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
|
6957
8648
|
//trace[:, 0] = 1;
|
6958
8649
|
for (int64_t i = 0; i < N + 1; ++i)
|
6959
|
-
|
8650
|
+
whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
|
6960
8651
|
int bt_row_idx = BT_MAX_ROWS - 1;
|
6961
8652
|
int64_t i = N;
|
6962
8653
|
int64_t j = M;
|
6963
8654
|
while (i > 0 || j > 0) {
|
6964
|
-
|
6965
|
-
|
8655
|
+
whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
|
8656
|
+
whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
|
6966
8657
|
--bt_row_idx;
|
6967
8658
|
|
6968
|
-
int32_t t =
|
8659
|
+
int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
|
6969
8660
|
if (t == 0) {
|
6970
8661
|
--i;
|
6971
8662
|
--j;
|
@@ -6986,8 +8677,8 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
|
|
6986
8677
|
ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
|
6987
8678
|
for (int64_t i = 0; i < 2; ++i) {
|
6988
8679
|
for (int64_t j = 0; j < result_n_cols; ++j) {
|
6989
|
-
int32_t v =
|
6990
|
-
|
8680
|
+
int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
|
8681
|
+
whisper_set_i32_nd(r, i, j, 0, 0, v);
|
6991
8682
|
}
|
6992
8683
|
}
|
6993
8684
|
|
@@ -6998,10 +8689,11 @@ struct median_filter_user_data {
|
|
6998
8689
|
int filter_width;
|
6999
8690
|
};
|
7000
8691
|
|
7001
|
-
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth
|
8692
|
+
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
|
8693
|
+
if (ith != 0) {
|
8694
|
+
return;
|
8695
|
+
}
|
7002
8696
|
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
|
7003
|
-
WHISPER_ASSERT(nth == 1);
|
7004
|
-
WHISPER_ASSERT(ith == 0);
|
7005
8697
|
WHISPER_ASSERT(filter_width < a->ne[2]);
|
7006
8698
|
WHISPER_ASSERT(filter_width % 2);
|
7007
8699
|
WHISPER_ASSERT(ggml_n_dims(a) == 3);
|
@@ -7021,11 +8713,11 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
|
|
7021
8713
|
idx = 2*(a->ne[2] - 1) - idx;
|
7022
8714
|
}
|
7023
8715
|
|
7024
|
-
filter.push_back(
|
8716
|
+
filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
|
7025
8717
|
}
|
7026
8718
|
std::sort(filter.begin(), filter.end());
|
7027
8719
|
const float v = filter[filter.size()/2];
|
7028
|
-
|
8720
|
+
whisper_set_f32_nd(dst, i, j, k, 0, v);
|
7029
8721
|
filter.clear();
|
7030
8722
|
}
|
7031
8723
|
}
|
@@ -7124,7 +8816,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7124
8816
|
// operation (after median filter)
|
7125
8817
|
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
|
7126
8818
|
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
|
7127
|
-
w = ggml_norm(gctx, w, 1e-
|
8819
|
+
w = ggml_norm(gctx, w, 1e-9f);
|
7128
8820
|
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
|
7129
8821
|
|
7130
8822
|
// Pass median filter - this is done over AUDIO_TOKENS dimension.
|
@@ -7147,7 +8839,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7147
8839
|
// Compute
|
7148
8840
|
struct ggml_cgraph * gf = ggml_new_graph(gctx);
|
7149
8841
|
ggml_build_forward_expand(gf, w);
|
7150
|
-
|
8842
|
+
|
8843
|
+
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
8844
|
+
ggml_backend_graph_compute(backend.get(), gf);
|
7151
8845
|
|
7152
8846
|
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
|
7153
8847
|
|
@@ -7156,9 +8850,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7156
8850
|
auto seg_i = state->result_all.begin() + i_segment;
|
7157
8851
|
auto tok_i = seg_i->tokens.begin();
|
7158
8852
|
for (int i = 0; i < alignment->ne[1]; ++i) {
|
7159
|
-
int32_t v =
|
8853
|
+
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
|
7160
8854
|
if (v != last_v) {
|
7161
|
-
int32_t time_index =
|
8855
|
+
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
|
7162
8856
|
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
|
7163
8857
|
last_v = v;
|
7164
8858
|
|
@@ -7196,6 +8890,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
|
|
7196
8890
|
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
7197
8891
|
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
7198
8892
|
g_state.log_callback_user_data = user_data;
|
8893
|
+
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
7199
8894
|
}
|
7200
8895
|
|
7201
8896
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
@@ -7219,6 +8914,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
|
|
7219
8914
|
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
7220
8915
|
(void) level;
|
7221
8916
|
(void) user_data;
|
8917
|
+
#ifndef WHISPER_DEBUG
|
8918
|
+
if (level == GGML_LOG_LEVEL_DEBUG) {
|
8919
|
+
return;
|
8920
|
+
}
|
8921
|
+
#endif
|
7222
8922
|
fputs(text, stderr);
|
7223
8923
|
fflush(stderr);
|
7224
8924
|
}
|