whispercpp 1.3.0 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (787) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -0
  3. data/LICENSE +1 -1
  4. data/README.md +216 -424
  5. data/Rakefile +79 -11
  6. data/ext/.gitignore +11 -0
  7. data/ext/dependencies.rb +61 -0
  8. data/ext/extconf.rb +18 -26
  9. data/ext/options.rb +221 -0
  10. data/ext/ruby_whisper.c +159 -0
  11. data/ext/ruby_whisper.h +27 -2
  12. data/ext/ruby_whisper_context.c +641 -0
  13. data/ext/ruby_whisper_error.c +52 -0
  14. data/ext/ruby_whisper_model.c +232 -0
  15. data/ext/ruby_whisper_params.c +1301 -0
  16. data/ext/ruby_whisper_segment.c +143 -0
  17. data/ext/ruby_whisper_transcribe.cpp +87 -0
  18. data/ext/ruby_whisper_vad_params.c +288 -0
  19. data/ext/sources/.dockerignore +3 -0
  20. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  21. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  22. data/ext/sources/CMakeLists.txt +251 -0
  23. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  24. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  25. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  26. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  27. data/ext/sources/bindings/javascript/package.json +26 -0
  28. data/ext/sources/bindings/javascript/whisper.js +19 -0
  29. data/ext/sources/build-xcframework.sh +547 -0
  30. data/ext/sources/ci/run.sh +336 -0
  31. data/ext/sources/close-issue.yml +28 -0
  32. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  33. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  34. data/ext/sources/cmake/build-info.cmake +60 -0
  35. data/ext/sources/cmake/git-vars.cmake +22 -0
  36. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  37. data/ext/sources/cmake/whisper.pc.in +10 -0
  38. data/ext/sources/examples/CMakeLists.txt +124 -0
  39. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  40. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  41. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  42. data/ext/sources/examples/addon.node/index.js +54 -0
  43. data/ext/sources/examples/addon.node/package.json +16 -0
  44. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  45. data/ext/sources/examples/bench/bench.cpp +175 -0
  46. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  47. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  48. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  49. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  50. data/ext/sources/examples/cli/cli.cpp +1294 -0
  51. data/ext/sources/examples/coi-serviceworker.js +146 -0
  52. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  53. data/ext/sources/examples/command/command.cpp +776 -0
  54. data/ext/sources/examples/command/commands.txt +9 -0
  55. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  56. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  57. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  58. data/ext/sources/examples/common-ggml.cpp +238 -0
  59. data/ext/sources/examples/common-ggml.h +18 -0
  60. data/ext/sources/examples/common-sdl.cpp +227 -0
  61. data/ext/sources/examples/common-sdl.h +49 -0
  62. data/ext/sources/examples/common-whisper.cpp +168 -0
  63. data/ext/sources/examples/common-whisper.h +24 -0
  64. data/ext/sources/examples/common.cpp +675 -0
  65. data/ext/sources/examples/common.h +322 -0
  66. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  67. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  68. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  69. data/ext/sources/examples/generate-karaoke.sh +57 -0
  70. data/ext/sources/examples/grammar-parser.cpp +423 -0
  71. data/ext/sources/examples/grammar-parser.h +29 -0
  72. data/ext/sources/examples/helpers.js +191 -0
  73. data/ext/sources/examples/json.hpp +24596 -0
  74. data/ext/sources/examples/livestream.sh +112 -0
  75. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  76. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  77. data/ext/sources/examples/lsp/whisper.vim +362 -0
  78. data/ext/sources/examples/miniaudio.h +93468 -0
  79. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  80. data/ext/sources/examples/python/whisper_processor.py +54 -0
  81. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  82. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  83. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  84. data/ext/sources/examples/server/bench.js +29 -0
  85. data/ext/sources/examples/server/httplib.h +10497 -0
  86. data/ext/sources/examples/server/server.cpp +1091 -0
  87. data/ext/sources/examples/server.py +115 -0
  88. data/ext/sources/examples/stb_vorbis.c +5584 -0
  89. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  90. data/ext/sources/examples/stream/stream.cpp +429 -0
  91. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  92. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  93. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  94. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  95. data/ext/sources/examples/sycl/build.sh +22 -0
  96. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  97. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  98. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  99. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  101. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  103. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  105. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  107. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  108. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  109. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  111. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  113. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  115. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  117. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  119. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  120. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  121. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  124. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  126. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  128. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  130. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  132. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  133. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  134. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  136. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  138. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  140. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  141. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  142. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  143. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  144. data/ext/sources/examples/talk-llama/speak +40 -0
  145. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  146. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  147. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  149. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  150. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  151. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  152. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  153. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  154. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  155. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  157. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  159. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  160. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  162. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  163. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  164. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  165. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  166. data/ext/sources/ggml/CMakeLists.txt +390 -0
  167. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  168. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  169. data/ext/sources/ggml/cmake/common.cmake +26 -0
  170. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  171. data/ext/sources/ggml/include/ggml-alloc.h +76 -0
  172. data/ext/sources/ggml/include/ggml-backend.h +354 -0
  173. data/ext/sources/ggml/include/ggml-blas.h +25 -0
  174. data/ext/sources/ggml/include/ggml-cann.h +123 -0
  175. data/ext/sources/ggml/include/ggml-cpp.h +39 -0
  176. data/ext/sources/ggml/include/ggml-cpu.h +143 -0
  177. data/ext/sources/ggml/include/ggml-cuda.h +47 -0
  178. data/ext/sources/ggml/include/ggml-kompute.h +50 -0
  179. data/ext/sources/ggml/include/ggml-metal.h +66 -0
  180. data/ext/sources/ggml/include/ggml-opencl.h +26 -0
  181. data/ext/sources/ggml/include/ggml-opt.h +237 -0
  182. data/ext/sources/ggml/include/ggml-rpc.h +33 -0
  183. data/ext/sources/ggml/include/ggml-sycl.h +49 -0
  184. data/ext/sources/ggml/include/ggml-vulkan.h +29 -0
  185. data/ext/{ggml.h → sources/ggml/include/ggml.h} +621 -821
  186. data/ext/sources/ggml/include/gguf.h +202 -0
  187. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  188. data/ext/sources/ggml/src/ggml-alloc.c +1042 -0
  189. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  190. data/ext/sources/ggml/src/ggml-amx/common.h +94 -0
  191. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  192. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +2510 -0
  193. data/ext/sources/ggml/src/ggml-amx/mmq.h +17 -0
  194. data/ext/sources/ggml/src/ggml-backend-impl.h +255 -0
  195. data/ext/sources/ggml/src/ggml-backend-reg.cpp +586 -0
  196. data/ext/sources/ggml/src/ggml-backend.cpp +2011 -0
  197. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  198. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  199. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  200. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  201. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +181 -0
  202. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +258 -0
  203. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +3193 -0
  204. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  205. data/ext/sources/ggml/src/ggml-cann/common.h +420 -0
  206. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +2606 -0
  207. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  208. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  209. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +234 -0
  210. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  211. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  212. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  213. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  214. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  215. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  216. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  217. data/ext/sources/ggml/src/ggml-common.h +1857 -0
  218. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  219. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +221 -0
  220. data/ext/sources/ggml/src/ggml-cpu/amx/amx.h +8 -0
  221. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +91 -0
  222. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  223. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  224. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  225. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  226. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  227. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  228. data/ext/sources/ggml/src/ggml-cpu/cpu-feats-x86.cpp +327 -0
  229. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  232. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  233. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +508 -0
  234. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +13747 -0
  235. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  236. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  237. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  238. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +671 -0
  240. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  241. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  244. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  246. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  247. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  248. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  249. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  250. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  251. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  252. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  253. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  254. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  255. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  257. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  259. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  260. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  261. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  262. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  263. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  264. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  265. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  267. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  268. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  270. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  271. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  272. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  273. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  274. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  276. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  277. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  278. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  280. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  281. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  282. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  285. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  294. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  295. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  296. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  298. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  299. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  300. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  302. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  304. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  305. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  309. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  310. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  311. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  312. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  313. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  314. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  316. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  317. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  318. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  320. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  321. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  322. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  323. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  324. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  326. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  327. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  328. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  329. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  330. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  331. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  332. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  334. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  430. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  432. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  433. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  434. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  436. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  437. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  438. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  439. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  440. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  441. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  442. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  443. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  444. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  445. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  446. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  447. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  448. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  449. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  450. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  451. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  452. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  453. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  454. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  455. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  456. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  457. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  458. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  459. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  460. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  461. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  462. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  463. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  464. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  465. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  466. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +15 -0
  467. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +243 -0
  468. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +140 -0
  469. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  470. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  471. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  472. data/ext/sources/ggml/src/ggml-impl.h +601 -0
  473. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  474. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  481. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  482. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  483. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  484. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  485. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  486. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  487. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  488. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  489. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  490. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  491. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  492. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  493. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  494. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  495. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  496. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  497. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  498. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  499. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  500. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  501. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  502. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  503. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  504. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  505. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  506. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  507. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  508. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  509. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  510. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  511. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  512. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  513. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  514. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +5998 -0
  515. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +7089 -0
  516. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  517. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  518. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  519. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  520. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  526. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  527. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  528. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  529. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  530. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  531. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  532. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  533. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  534. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  535. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  536. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  537. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  538. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  539. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  540. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  541. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  542. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  543. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  544. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  545. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  546. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  547. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  548. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  549. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  550. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  551. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  552. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  553. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  554. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  555. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  556. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  557. data/ext/sources/ggml/src/ggml-opt.cpp +1037 -0
  558. data/ext/sources/ggml/src/ggml-quants.c +5232 -0
  559. data/ext/sources/ggml/src/ggml-quants.h +100 -0
  560. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  561. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +1813 -0
  562. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  563. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  564. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  565. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  566. data/ext/sources/ggml/src/ggml-sycl/common.cpp +83 -0
  567. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  568. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +195 -0
  569. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  570. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +101 -0
  571. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +623 -0
  573. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  574. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  575. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  576. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  577. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +1162 -0
  578. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  579. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  580. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  581. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  582. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  583. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  584. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  585. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +4493 -0
  586. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  587. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  588. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  589. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  590. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +3030 -0
  591. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  592. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1110 -0
  593. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  594. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +501 -0
  595. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  596. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +47 -0
  597. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  598. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  599. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  600. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  601. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  602. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +261 -0
  603. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  604. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  605. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  606. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  607. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  608. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  609. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  610. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  611. data/ext/sources/ggml/src/ggml-threading.cpp +12 -0
  612. data/ext/sources/ggml/src/ggml-threading.h +14 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +10700 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +751 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  729. data/ext/sources/ggml/src/ggml.c +6550 -0
  730. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  731. data/ext/{whisper.h → sources/include/whisper.h} +91 -24
  732. data/ext/sources/src/CMakeLists.txt +143 -0
  733. data/ext/sources/src/coreml/whisper-decoder-impl.h +158 -0
  734. data/ext/sources/src/coreml/whisper-decoder-impl.m +226 -0
  735. data/ext/sources/src/coreml/whisper-encoder-impl.h +154 -0
  736. data/ext/sources/src/coreml/whisper-encoder-impl.m +222 -0
  737. data/ext/sources/src/coreml/whisper-encoder.h +26 -0
  738. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  739. data/ext/sources/src/openvino/whisper-openvino-encoder.cpp +108 -0
  740. data/ext/sources/src/openvino/whisper-openvino-encoder.h +31 -0
  741. data/ext/sources/src/whisper-arch.h +197 -0
  742. data/ext/{whisper.cpp → sources/src/whisper.cpp} +2535 -835
  743. data/ext/sources/tests/CMakeLists.txt +105 -0
  744. data/ext/sources/tests/earnings21/eval.mk +58 -0
  745. data/ext/sources/tests/earnings21/eval.py +68 -0
  746. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  747. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  748. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  749. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  750. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  751. data/ext/sources/tests/en-0-ref.txt +1 -0
  752. data/ext/sources/tests/en-1-ref.txt +1 -0
  753. data/ext/sources/tests/en-2-ref.txt +1 -0
  754. data/ext/sources/tests/es-0-ref.txt +1 -0
  755. data/ext/sources/tests/librispeech/eval.mk +39 -0
  756. data/ext/sources/tests/librispeech/eval.py +47 -0
  757. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  758. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  759. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  760. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  761. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  762. data/ext/sources/tests/run-tests.sh +130 -0
  763. data/ext/sources/tests/test-c.c +3 -0
  764. data/ext/sources/tests/test-vad-full.cpp +54 -0
  765. data/ext/sources/tests/test-vad.cpp +83 -0
  766. data/ext/sources/tests/test-whisper.js +58 -0
  767. data/extsources.rb +34 -0
  768. data/lib/whisper/model/uri.rb +178 -0
  769. data/sig/whisper.rbs +480 -0
  770. data/tests/helper.rb +35 -0
  771. data/tests/jfk_reader/.gitignore +5 -0
  772. data/tests/jfk_reader/extconf.rb +3 -0
  773. data/tests/jfk_reader/jfk_reader.c +68 -0
  774. data/tests/test_callback.rb +202 -0
  775. data/tests/test_error.rb +20 -0
  776. data/tests/test_model.rb +109 -0
  777. data/tests/test_package.rb +46 -0
  778. data/tests/test_params.rb +297 -0
  779. data/tests/test_segment.rb +74 -0
  780. data/tests/test_vad.rb +19 -0
  781. data/tests/test_vad_params.rb +103 -0
  782. data/tests/test_whisper.rb +212 -124
  783. data/whispercpp.gemspec +37 -0
  784. metadata +794 -13
  785. data/ext/dr_wav.h +0 -6434
  786. data/ext/ggml.c +0 -21755
  787. data/ext/ruby_whisper.cpp +0 -426
@@ -1,62 +1,52 @@
1
1
  #include "whisper.h"
2
+ #include "whisper-arch.h"
3
+
4
+ #include "ggml.h"
5
+ #include "ggml-cpp.h"
6
+ #include "ggml-alloc.h"
7
+ #include "ggml-backend.h"
2
8
 
3
9
  #ifdef WHISPER_USE_COREML
4
10
  #include "coreml/whisper-encoder.h"
5
11
  #endif
6
12
 
7
- #ifdef GGML_USE_METAL
8
- #include "ggml-metal.h"
9
- #endif
10
-
11
- #ifdef GGML_USE_CUDA
12
- #include "ggml-cuda.h"
13
- #endif
14
-
15
- #ifdef GGML_USE_SYCL
16
- #include "ggml-sycl.h"
17
- #endif
18
-
19
13
  #ifdef WHISPER_USE_OPENVINO
20
14
  #include "openvino/whisper-openvino-encoder.h"
21
15
  #endif
22
16
 
23
- #include "ggml.h"
24
- #include "ggml-alloc.h"
25
- #include "ggml-backend.h"
26
-
27
17
  #include <atomic>
28
18
  #include <algorithm>
29
19
  #include <cassert>
20
+ #include <cfloat>
30
21
  #define _USE_MATH_DEFINES
31
22
  #include <cmath>
32
- #include <cstdio>
23
+ #include <climits>
24
+ #include <codecvt>
33
25
  #include <cstdarg>
26
+ #include <cstdio>
34
27
  #include <cstring>
35
28
  #include <fstream>
29
+ #include <functional>
36
30
  #include <map>
31
+ #include <mutex>
32
+ #include <random>
33
+ #include <regex>
37
34
  #include <set>
38
35
  #include <string>
39
36
  #include <thread>
40
37
  #include <vector>
41
- #include <regex>
42
- #include <random>
43
- #include <functional>
44
-
45
- #if defined(_MSC_VER)
46
- #pragma warning(disable: 4244 4267) // possible loss of data
47
- #endif
48
-
49
- #if defined(GGML_BIG_ENDIAN)
50
- #include <bit>
51
38
 
39
+ #if defined(WHISPER_BIG_ENDIAN)
52
40
  template<typename T>
53
41
  static T byteswap(T value) {
54
- return std::byteswap(value);
55
- }
56
-
57
- template<>
58
- float byteswap(float value) {
59
- return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
42
+ T value_swapped;
43
+ char * source = reinterpret_cast<char *>(&value);
44
+ char * target = reinterpret_cast<char *>(&value_swapped);
45
+ int size = sizeof(T);
46
+ for (int i = 0; i < size; i++) {
47
+ target[size - 1 - i] = source[i];
48
+ }
49
+ return value_swapped;
60
50
  }
61
51
 
62
52
  template<typename T>
@@ -92,14 +82,14 @@ static void byteswap_tensor(ggml_tensor * tensor) {
92
82
  }
93
83
 
94
84
  #define BYTESWAP_VALUE(d) d = byteswap(d)
95
- #define BYTESWAP_FILTERS(f) \
85
+ #define BYTESWAP_FILTERS(f) \
96
86
  do { \
97
87
  for (auto & datum : f.data) { \
98
88
  datum = byteswap(datum); \
99
89
  } \
100
90
  } while (0)
101
- #define BYTESWAP_TENSOR(t) \
102
- do { \
91
+ #define BYTESWAP_TENSOR(t) \
92
+ do { \
103
93
  byteswap_tensor(t); \
104
94
  } while (0)
105
95
  #else
@@ -147,47 +137,128 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
147
137
  } \
148
138
  } while (0)
149
139
 
150
- //#define WHISPER_USE_FLASH_ATTN
151
- //#define WHISPER_USE_FLASH_FF
152
140
  #define WHISPER_MAX_DECODERS 8
153
141
  #define WHISPER_MAX_NODES 4096
154
142
 
143
+ static std::string format(const char * fmt, ...) {
144
+ va_list ap;
145
+ va_list ap2;
146
+ va_start(ap, fmt);
147
+ va_copy(ap2, ap);
148
+ int size = vsnprintf(NULL, 0, fmt, ap);
149
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
150
+ std::vector<char> buf(size + 1);
151
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
152
+ GGML_ASSERT(size2 == size);
153
+ va_end(ap2);
154
+ va_end(ap);
155
+ return std::string(buf.data(), size);
156
+ }
157
+
155
158
  //
156
159
  // ggml helpers
157
160
  //
158
161
 
159
162
  static bool ggml_graph_compute_helper(
160
163
  struct ggml_cgraph * graph,
161
- std::vector<uint8_t> & buf,
162
164
  int n_threads,
163
165
  ggml_abort_callback abort_callback,
164
166
  void * abort_callback_data) {
165
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
167
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
166
168
 
167
- plan.abort_callback = abort_callback;
168
- plan.abort_callback_data = abort_callback_data;
169
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
169
170
 
170
- if (plan.work_size > 0) {
171
- buf.resize(plan.work_size);
172
- plan.work_data = buf.data();
171
+ auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
172
+ if (set_abort_callback_fn) {
173
+ set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
173
174
  }
174
175
 
175
- return ggml_graph_compute(graph, &plan);
176
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
177
+ if (ggml_backend_set_n_threads_fn) {
178
+ ggml_backend_set_n_threads_fn(backend.get(), n_threads);
179
+ }
180
+
181
+ return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
176
182
  }
177
183
 
178
184
  static bool ggml_graph_compute_helper(
179
- struct ggml_backend * backend,
185
+ ggml_backend_sched_t sched,
180
186
  struct ggml_cgraph * graph,
181
- int n_threads) {
182
- if (ggml_backend_is_cpu(backend)) {
183
- ggml_backend_cpu_set_n_threads(backend, n_threads);
187
+ int n_threads,
188
+ bool sched_reset = true) {
189
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
190
+ ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
191
+ ggml_backend_dev_t dev = ggml_backend_get_device(backend);
192
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
193
+
194
+ auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
195
+ if (fn_set_n_threads) {
196
+ fn_set_n_threads(backend, n_threads);
197
+ }
184
198
  }
185
- #ifdef GGML_USE_METAL
186
- if (ggml_backend_is_metal(backend)) {
187
- ggml_backend_metal_set_n_cb(backend, n_threads);
199
+
200
+ const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
201
+
202
+ if (!t || sched_reset) {
203
+ ggml_backend_sched_reset(sched);
188
204
  }
205
+
206
+ return t;
207
+ }
208
+
209
+ static void whisper_load_backends() {
210
+ #ifdef GGML_BACKEND_DL
211
+ static std::once_flag flag;
212
+ std::call_once(flag, []() {
213
+ ggml_backend_load_all();
214
+ });
189
215
  #endif
190
- return ggml_backend_graph_compute(backend, graph) == GGML_STATUS_SUCCESS;
216
+ }
217
+
218
+ // TODO: move these functions to ggml-base with support for ggml-backend?
219
+
220
+ static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
221
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
222
+ GGML_ASSERT(ggml_is_contiguous(t));
223
+ size_t nels = ggml_nelements(t);
224
+ for (size_t i = 0; i < nels; ++i) {
225
+ ((float *) t->data)[i] = v;
226
+ }
227
+ return t;
228
+ }
229
+
230
+ static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
231
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
232
+ GGML_ASSERT(ggml_is_contiguous(t));
233
+ size_t nels = ggml_nelements(t);
234
+ for (size_t i = 0; i < nels; ++i) {
235
+ ((int32_t *) t->data)[i] = v;
236
+ }
237
+ return t;
238
+ }
239
+
240
+ static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
241
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
242
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
243
+ return *(float *) data;
244
+ }
245
+
246
+ static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
247
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
248
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
249
+ *(float *) data = v;
250
+ }
251
+
252
+ static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
253
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
254
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
255
+ return *(int32_t *) data;
256
+ }
257
+
258
+ static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
259
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
260
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
261
+ *(int32_t *) data = v;
191
262
  }
192
263
 
193
264
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -363,6 +434,7 @@ static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15},
363
434
  static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
364
435
  static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
365
436
  static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
437
+ static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
366
438
 
367
439
  static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
368
440
  { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
@@ -376,6 +448,7 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
376
448
  { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
377
449
  { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
378
450
  { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
451
+ { WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
379
452
  };
380
453
 
381
454
  static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
@@ -431,6 +504,7 @@ struct whisper_segment {
431
504
  int64_t t1;
432
505
 
433
506
  std::string text;
507
+ float no_speech_prob;
434
508
 
435
509
  std::vector<whisper_token_data> tokens;
436
510
 
@@ -502,33 +576,41 @@ struct whisper_pair {
502
576
  whisper_pair() : first(A()), second(B()) {}
503
577
  };
504
578
 
505
- // ggml_allocr wrapper for whisper usage
506
- struct whisper_allocr {
507
- ggml_gallocr_t alloc = nullptr;
579
+ // ggml_backend_sched wrapper for whisper usage
580
+ struct whisper_sched {
581
+ ggml_backend_sched_t sched = nullptr;
508
582
 
509
583
  std::vector<uint8_t> meta;
510
584
  };
511
585
 
512
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
513
- return allocr.meta.size() + ggml_gallocr_get_buffer_size(allocr.alloc, 0);
586
+ static size_t whisper_sched_size(struct whisper_sched & allocr) {
587
+ size_t size = allocr.meta.size();
588
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
589
+ ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
590
+ size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
591
+ }
592
+ return size;
514
593
  }
515
594
 
516
595
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
517
- static bool whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
518
- auto & alloc = allocr.alloc;
596
+ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
597
+ auto & sched = allocr.sched;
519
598
  auto & meta = allocr.meta;
520
599
 
521
- alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
600
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
522
601
 
523
602
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
524
603
 
525
604
  // since there are dependencies between the different graphs,
526
605
  // we need to allocate them instead of only reserving to get the correct compute buffer size
527
- if (!ggml_gallocr_alloc_graph(alloc, get_graph())) {
606
+ if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
528
607
  // failed to allocate the compute buffer
529
608
  WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
530
609
  return false;
531
610
  }
611
+
612
+ ggml_backend_sched_reset(sched);
613
+
532
614
  return true;
533
615
  }
534
616
 
@@ -671,9 +753,9 @@ struct whisper_kv_cache {
671
753
  struct ggml_tensor * k;
672
754
  struct ggml_tensor * v;
673
755
 
674
- struct ggml_context * ctx = nullptr;
675
-
676
756
  ggml_backend_buffer_t buffer = nullptr;
757
+
758
+ std::vector<uint8_t> ctx_buf;
677
759
  };
678
760
 
679
761
  struct whisper_model {
@@ -711,10 +793,10 @@ struct whisper_model {
711
793
  std::vector<whisper_layer_decoder> layers_decoder;
712
794
 
713
795
  // ggml context that contains all the meta information about the model tensors
714
- struct ggml_context * ctx = nullptr;
796
+ std::vector<ggml_context *> ctxs;
715
797
 
716
798
  // the model backend data is read-only and can be shared between processors
717
- ggml_backend_buffer_t buffer = nullptr;
799
+ std::vector<ggml_backend_buffer_t> buffers;
718
800
 
719
801
  // tensors
720
802
  int n_loaded;
@@ -802,6 +884,9 @@ struct whisper_state {
802
884
  int32_t n_fail_p = 0; // number of logprob threshold failures
803
885
  int32_t n_fail_h = 0; // number of entropy threshold failures
804
886
 
887
+ // number of decoders for which we have constructed the KV cache
888
+ int32_t kv_self_n_dec = 0;
889
+
805
890
  // unified self-attention KV cache for all decoders
806
891
  whisper_kv_cache kv_self;
807
892
 
@@ -809,21 +894,22 @@ struct whisper_state {
809
894
  // shared between all decoders
810
895
  whisper_kv_cache kv_cross;
811
896
 
897
+ // padded buffer for flash-attention
898
+ whisper_kv_cache kv_pad;
899
+
812
900
  whisper_mel mel;
813
901
 
814
902
  whisper_batch batch;
815
903
 
816
904
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
817
905
 
818
- ggml_backend_t backend = nullptr;
906
+ std::vector<ggml_backend_t> backends;
819
907
 
820
- // ggml-alloc:
821
908
  // - stores meta info about the intermediate tensors into the `meta` buffers
822
- // - stores the actual tensor data into the `data` buffers
823
- whisper_allocr alloc_conv;
824
- whisper_allocr alloc_encode;
825
- whisper_allocr alloc_cross;
826
- whisper_allocr alloc_decode;
909
+ whisper_sched sched_conv;
910
+ whisper_sched sched_encode;
911
+ whisper_sched sched_cross;
912
+ whisper_sched sched_decode;
827
913
 
828
914
  // result of the encoder
829
915
  struct ggml_tensor * embd_conv = nullptr;
@@ -858,6 +944,7 @@ struct whisper_state {
858
944
  whisper_token tid_last;
859
945
 
860
946
  std::vector<float> energy; // PCM signal energy
947
+ float no_speech_prob = 0.0f;
861
948
 
862
949
  // [EXPERIMENTAL] Token-level timestamps with DTW
863
950
  whisper_aheads_masks aheads_masks;
@@ -866,6 +953,17 @@ struct whisper_state {
866
953
 
867
954
  // [EXPERIMENTAL] speed-up techniques
868
955
  int32_t exp_n_audio_ctx = 0; // 0 - use default
956
+
957
+ whisper_vad_context * vad_context = nullptr;
958
+
959
+ struct vad_segment_info {
960
+ float orig_start;
961
+ float orig_end;
962
+ float vad_start;
963
+ float vad_end;
964
+ };
965
+ std::vector<vad_segment_info> vad_segments;
966
+ bool has_vad_segments = false;
869
967
  };
870
968
 
871
969
  struct whisper_context {
@@ -882,8 +980,6 @@ struct whisper_context {
882
980
 
883
981
  whisper_state * state = nullptr;
884
982
 
885
- ggml_backend_t backend = nullptr;
886
-
887
983
  std::string path_model; // populated by whisper_init_from_file_with_params()
888
984
  };
889
985
 
@@ -901,21 +997,21 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
901
997
  BYTESWAP_VALUE(dest);
902
998
  }
903
999
 
904
- static bool kv_cache_init(
905
- const struct whisper_hparams & hparams,
1000
+ static bool whisper_kv_cache_init(
906
1001
  struct whisper_kv_cache & cache,
907
1002
  ggml_backend_t backend,
908
1003
  ggml_type wtype,
1004
+ int64_t n_text_state,
1005
+ int64_t n_text_layer,
909
1006
  int n_ctx) {
910
- const int64_t n_text_state = hparams.n_text_state;
911
- const int64_t n_text_layer = hparams.n_text_layer;
912
-
913
1007
  const int64_t n_mem = n_text_layer*n_ctx;
914
1008
  const int64_t n_elements = n_text_state*n_mem;
915
1009
 
1010
+ cache.ctx_buf.resize(2*ggml_tensor_overhead());
1011
+
916
1012
  struct ggml_init_params params = {
917
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
918
- /*.mem_buffer =*/ nullptr,
1013
+ /*.mem_size =*/ cache.ctx_buf.size(),
1014
+ /*.mem_buffer =*/ cache.ctx_buf.data(),
919
1015
  /*.no_alloc =*/ true,
920
1016
  };
921
1017
 
@@ -925,29 +1021,31 @@ static bool kv_cache_init(
925
1021
  cache.cells.clear();
926
1022
  cache.cells.resize(n_ctx);
927
1023
 
928
- cache.ctx = ggml_init(params);
1024
+ struct ggml_context * ctx = ggml_init(params);
929
1025
 
930
- if (!cache.ctx) {
1026
+ if (!ctx) {
931
1027
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
932
1028
  return false;
933
1029
  }
934
1030
 
935
- cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
936
- cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
1031
+ cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
1032
+ cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
937
1033
 
938
- cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend);
1034
+ cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
939
1035
  if (!cache.buffer) {
940
1036
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
941
1037
  return false;
942
1038
  }
943
1039
 
1040
+ ggml_backend_buffer_clear(cache.buffer, 0);
1041
+
1042
+ ggml_free(ctx);
1043
+
944
1044
  return true;
945
1045
  }
946
1046
 
947
- static void kv_cache_free(struct whisper_kv_cache & cache) {
948
- ggml_free(cache.ctx);
1047
+ static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
949
1048
  ggml_backend_buffer_free(cache.buffer);
950
- cache.ctx = nullptr;
951
1049
  }
952
1050
 
953
1051
  static bool whisper_kv_cache_find_slot(
@@ -1018,6 +1116,8 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
1018
1116
  cache.cells[i].seq_id.clear();
1019
1117
  }
1020
1118
  cache.head = 0;
1119
+
1120
+ ggml_backend_buffer_clear(cache.buffer, 0);
1021
1121
  }
1022
1122
 
1023
1123
  static void whisper_kv_cache_seq_rm(
@@ -1068,6 +1168,26 @@ static void whisper_kv_cache_seq_cp(
1068
1168
  }
1069
1169
  }
1070
1170
 
1171
+ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1172
+ if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
1173
+ return 1u;
1174
+ }
1175
+
1176
+ #ifdef GGML_USE_METAL
1177
+ if (wctx.params.use_gpu) {
1178
+ return 32u;
1179
+ }
1180
+ #endif
1181
+
1182
+ #ifdef GGML_USE_CUDA
1183
+ if (wctx.params.use_gpu) {
1184
+ return 256u;
1185
+ }
1186
+ #endif
1187
+
1188
+ return 1u;
1189
+ }
1190
+
1071
1191
  // [EXPERIMENTAL] Token-level timestamps with DTW
1072
1192
  static bool aheads_masks_init(
1073
1193
  const whisper_context_params & cparams,
@@ -1199,49 +1319,178 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1199
1319
  return size;
1200
1320
  }
1201
1321
 
1202
- static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1203
- ggml_backend_t backend_gpu = NULL;
1322
+ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1323
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1204
1324
 
1205
- // initialize the backends
1206
- #ifdef GGML_USE_CUDA
1325
+ whisper_load_backends();
1326
+
1327
+ ggml_backend_dev_t dev = nullptr;
1328
+
1329
+ int cnt = 0;
1207
1330
  if (params.use_gpu) {
1208
- WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1209
- backend_gpu = ggml_backend_cuda_init(params.gpu_device);
1210
- if (!backend_gpu) {
1211
- WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1331
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1332
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
1333
+ if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1334
+ if (cnt == 0 || cnt == params.gpu_device) {
1335
+ dev = dev_cur;
1336
+ }
1337
+
1338
+ if (++cnt > params.gpu_device) {
1339
+ break;
1340
+ }
1341
+ }
1212
1342
  }
1213
1343
  }
1214
- #endif
1215
1344
 
1216
- #ifdef GGML_USE_METAL
1217
- if (params.use_gpu) {
1218
- WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1219
- ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1220
- backend_gpu = ggml_backend_metal_init();
1221
- if (!backend_gpu) {
1222
- WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1223
- } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
1224
- WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1225
- ggml_backend_free(backend_gpu);
1226
- backend_gpu = NULL;
1345
+ if (dev == nullptr) {
1346
+ WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
1347
+ return nullptr;
1348
+ }
1349
+
1350
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1351
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1352
+ if (!result) {
1353
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1354
+ }
1355
+
1356
+ return result;
1357
+ }
1358
+
1359
+ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
1360
+ std::vector<ggml_backend_t> result;
1361
+
1362
+ ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
1363
+
1364
+ if (backend_gpu) {
1365
+ result.push_back(backend_gpu);
1366
+ }
1367
+
1368
+ // ACCEL backends
1369
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1370
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1371
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1372
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1373
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
1374
+ if (!backend) {
1375
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1376
+ continue;
1377
+ }
1378
+ result.push_back(backend);
1227
1379
  }
1228
1380
  }
1229
- #endif
1230
1381
 
1231
- #ifdef GGML_USE_SYCL
1382
+ ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
1383
+ if (backend_cpu == nullptr) {
1384
+ throw std::runtime_error("failed to initialize CPU backend");
1385
+ }
1386
+ result.push_back(backend_cpu);
1387
+
1388
+ return result;
1389
+ }
1390
+
1391
+ using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
1392
+
1393
+ static buft_list_t make_buft_list(whisper_context_params & params) {
1394
+ // Prio order: GPU -> CPU Extra -> CPU
1395
+ buft_list_t buft_list;
1396
+
1397
+ // GPU
1232
1398
  if (params.use_gpu) {
1233
- WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1234
- backend_gpu = ggml_backend_sycl_init(params.gpu_device);
1235
- if (!backend_gpu) {
1236
- WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
1399
+ int cnt = 0;
1400
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1401
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1402
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1403
+ if (cnt == 0 || cnt == params.gpu_device) {
1404
+ auto * buft = ggml_backend_dev_buffer_type(dev);
1405
+ if (buft) {
1406
+ buft_list.emplace_back(dev, buft);
1407
+ }
1408
+ }
1409
+
1410
+ if (++cnt > params.gpu_device) {
1411
+ break;
1412
+ }
1413
+ }
1237
1414
  }
1238
1415
  }
1239
- #endif
1240
1416
 
1241
- if (backend_gpu) {
1242
- return backend_gpu;
1417
+ // CPU Extra
1418
+ auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
1419
+ auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
1420
+ auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
1421
+ ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
1422
+ if (get_extra_bufts_fn) {
1423
+ ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
1424
+ while (extra_bufts && *extra_bufts) {
1425
+ buft_list.emplace_back(cpu_dev, *extra_bufts);
1426
+ ++extra_bufts;
1427
+ }
1428
+ }
1429
+
1430
+ // CPU
1431
+ buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
1432
+
1433
+ return buft_list;
1434
+ }
1435
+
1436
+ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
1437
+ bool op_supported = true;
1438
+
1439
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
1440
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
1441
+ // GPU and default CPU backend support all operators
1442
+ op_supported = true;
1443
+ } else {
1444
+ switch (op) {
1445
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
1446
+ case GGML_OP_MUL_MAT: {
1447
+ ggml_init_params params = {
1448
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
1449
+ /*.mem_buffer =*/ nullptr,
1450
+ /*.no_alloc =*/ true,
1451
+ };
1452
+
1453
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1454
+ if (!ctx_ptr) {
1455
+ throw std::runtime_error("failed to create ggml context");
1456
+ }
1457
+ ggml_context * ctx = ctx_ptr.get();
1458
+
1459
+ ggml_tensor * op_tensor = nullptr;
1460
+
1461
+ int64_t n_ctx = hparams.n_audio_ctx;
1462
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1463
+ op_tensor = ggml_mul_mat(ctx, w, b);
1464
+
1465
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1466
+ GGML_ASSERT(w->buffer == nullptr);
1467
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
1468
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
1469
+ ggml_backend_buffer_free(w->buffer);
1470
+ w->buffer = nullptr;
1471
+ break;
1472
+ }
1473
+ default: {
1474
+ op_supported = false;
1475
+ break;
1476
+ }
1477
+ };
1478
+ }
1479
+
1480
+ return op_supported;
1481
+ }
1482
+
1483
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
1484
+ GGML_ASSERT(!buft_list.empty());
1485
+ for (const auto & p : buft_list) {
1486
+ ggml_backend_dev_t dev = p.first;
1487
+ ggml_backend_buffer_type_t buft = p.second;
1488
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
1489
+ return buft;
1490
+ }
1243
1491
  }
1244
- return ggml_backend_cpu_init();
1492
+
1493
+ return nullptr;
1245
1494
  }
1246
1495
 
1247
1496
  // load the model from a ggml file
@@ -1450,31 +1699,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1450
1699
  const ggml_type wtype = wctx.wtype;
1451
1700
  const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
1452
1701
 
1453
- // create the ggml context
1454
- {
1455
- const auto & hparams = model.hparams;
1702
+ const auto & hparams = model.hparams;
1456
1703
 
1457
- const int n_audio_layer = hparams.n_audio_layer;
1458
- const int n_text_layer = hparams.n_text_layer;
1704
+ const int n_audio_layer = hparams.n_audio_layer;
1705
+ const int n_text_layer = hparams.n_text_layer;
1459
1706
 
1460
- const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1707
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1461
1708
 
1462
- struct ggml_init_params params = {
1463
- /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
1464
- /*.mem_buffer =*/ nullptr,
1465
- /*.no_alloc =*/ true,
1466
- };
1709
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1710
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1711
+ auto it = ctx_map.find(buft);
1712
+ if (it == ctx_map.end()) {
1713
+ ggml_init_params params = {
1714
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1715
+ /*.mem_buffer =*/ nullptr,
1716
+ /*.no_alloc =*/ true,
1717
+ };
1467
1718
 
1468
- model.ctx = ggml_init(params);
1469
- if (!model.ctx) {
1470
- WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
1471
- return false;
1719
+ ggml_context * ctx = ggml_init(params);
1720
+ if (!ctx) {
1721
+ throw std::runtime_error("failed to create ggml context");
1722
+ }
1723
+
1724
+ ctx_map[buft] = ctx;
1725
+ model.ctxs.emplace_back(ctx);
1726
+
1727
+ return ctx;
1472
1728
  }
1473
- }
1729
+
1730
+ return it->second;
1731
+ };
1732
+
1733
+ // Create a list of available bufts, in priority order
1734
+ buft_list_t buft_list = make_buft_list(wctx.params);
1735
+
1736
+ auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
1737
+ ggml_op op = ASR_TENSOR_INFO.at(type);
1738
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
1739
+ if (!buft) {
1740
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
1741
+ }
1742
+
1743
+ ggml_context * ctx = get_ctx(buft);
1744
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
1745
+
1746
+ model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
1747
+
1748
+ return tensor;
1749
+ };
1750
+
1474
1751
 
1475
1752
  // prepare tensors for the weights
1476
1753
  {
1477
- auto & ctx = model.ctx;
1754
+ ggml_init_params params = {
1755
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1756
+ /*.mem_buffer =*/ nullptr,
1757
+ /*.no_alloc =*/ true,
1758
+ };
1759
+
1760
+ ggml_context * ctx = ggml_init(params);
1478
1761
 
1479
1762
  const auto & hparams = model.hparams;
1480
1763
 
@@ -1494,195 +1777,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1494
1777
  model.layers_decoder.resize(n_text_layer);
1495
1778
 
1496
1779
  // encoder
1497
- {
1498
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1499
-
1500
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1501
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1502
-
1503
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1504
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1505
-
1506
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1507
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1508
-
1509
- // map by name
1510
- model.tensors["encoder.positional_embedding"] = model.e_pe;
1780
+ model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
1511
1781
 
1512
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1513
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1782
+ model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
1783
+ model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
1514
1784
 
1515
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1516
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1785
+ model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
1786
+ model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
1517
1787
 
1518
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1519
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1788
+ model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1789
+ model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1520
1790
 
1521
- for (int i = 0; i < n_audio_layer; ++i) {
1522
- auto & layer = model.layers_encoder[i];
1791
+ for (int i = 0; i < n_audio_layer; ++i) {
1792
+ auto & layer = model.layers_encoder[i];
1523
1793
 
1524
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1525
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1794
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1795
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1526
1796
 
1527
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1528
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1797
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
1798
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
1529
1799
 
1530
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1531
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1800
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
1801
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1532
1802
 
1533
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1534
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1803
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1804
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1535
1805
 
1536
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1537
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1806
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1807
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1538
1808
 
1539
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1809
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1540
1810
 
1541
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1542
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1811
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1812
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1543
1813
 
1544
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1545
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1546
-
1547
- // map by name
1548
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1549
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1550
-
1551
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1552
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1553
-
1554
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1555
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1556
-
1557
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1558
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1559
-
1560
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1561
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1562
-
1563
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1564
-
1565
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1566
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1567
-
1568
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1569
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1570
- }
1814
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1815
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1571
1816
  }
1572
1817
 
1573
1818
  // decoder
1574
- {
1575
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1576
-
1577
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1578
-
1579
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1580
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1581
-
1582
- // map by name
1583
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1584
-
1585
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
1586
-
1587
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1588
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1589
-
1590
- for (int i = 0; i < n_text_layer; ++i) {
1591
- auto & layer = model.layers_decoder[i];
1819
+ model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
1592
1820
 
1593
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1594
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1821
+ model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
1595
1822
 
1596
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1597
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1823
+ model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
1824
+ model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
1598
1825
 
1599
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1600
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1826
+ for (int i = 0; i < n_text_layer; ++i) {
1827
+ auto & layer = model.layers_decoder[i];
1601
1828
 
1602
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1603
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1829
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1830
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1604
1831
 
1605
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1606
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1832
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
1833
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
1607
1834
 
1608
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1835
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
1836
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1609
1837
 
1610
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1611
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1838
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1839
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1612
1840
 
1613
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1614
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1841
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1842
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1615
1843
 
1616
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1617
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1844
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1618
1845
 
1619
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1620
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1846
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1847
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1621
1848
 
1622
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1849
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1850
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1623
1851
 
1624
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1625
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1852
+ layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1853
+ layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1626
1854
 
1627
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1628
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1855
+ layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1856
+ layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1629
1857
 
1630
- // map by name
1631
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1632
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1858
+ layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1633
1859
 
1634
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1635
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1860
+ layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1861
+ layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1636
1862
 
1637
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1638
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1639
-
1640
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1641
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1642
-
1643
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1644
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1645
-
1646
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1647
-
1648
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1649
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1650
-
1651
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1652
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1653
-
1654
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1655
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1656
-
1657
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1658
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1659
-
1660
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1661
-
1662
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1663
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1664
-
1665
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1666
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1667
- }
1863
+ layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1864
+ layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1668
1865
  }
1669
- }
1670
1866
 
1671
- wctx.backend = whisper_backend_init(wctx.params);
1672
- if (!wctx.backend) {
1673
- WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
1674
- return false;
1867
+ ggml_free(ctx);
1675
1868
  }
1676
1869
 
1677
1870
  // allocate tensors in the backend buffers
1678
- model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, wctx.backend);
1679
- if (!model.buffer) {
1680
- WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1681
- return false;
1682
- }
1871
+ for (auto & p : ctx_map) {
1872
+ ggml_backend_buffer_type_t buft = p.first;
1873
+ ggml_context * ctx = p.second;
1874
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1875
+ if (buf) {
1876
+ model.buffers.emplace_back(buf);
1683
1877
 
1684
- size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1685
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
1878
+ size_t size_main = ggml_backend_buffer_get_size(buf);
1879
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
1880
+ }
1881
+ }
1686
1882
 
1687
1883
  // load weights
1688
1884
  {
@@ -1745,11 +1941,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1745
1941
  return false;
1746
1942
  }
1747
1943
 
1748
- //ggml_backend_t backend = wctx.backend;
1749
-
1750
- //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
1751
-
1752
- if (ggml_backend_buffer_is_host(model.buffer)) {
1944
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
1753
1945
  // for the CPU and Metal backend, we can read directly into the tensor
1754
1946
  loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1755
1947
  BYTESWAP_TENSOR(tensor);
@@ -1762,7 +1954,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1762
1954
  ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1763
1955
  }
1764
1956
 
1765
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
1766
1957
  total_size += ggml_nbytes(tensor);
1767
1958
  model.n_loaded++;
1768
1959
  }
@@ -1777,6 +1968,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1777
1968
  }
1778
1969
  }
1779
1970
 
1971
+ for (auto & buf : model.buffers) {
1972
+ ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1973
+ }
1974
+
1780
1975
  wctx.t_load_us = ggml_time_us() - t_start_us;
1781
1976
 
1782
1977
  return true;
@@ -1812,8 +2007,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1812
2007
  const int n_mels = hparams.n_mels;
1813
2008
 
1814
2009
  struct ggml_init_params params = {
1815
- /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1816
- /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
2010
+ /*.mem_size =*/ wstate.sched_conv.meta.size(),
2011
+ /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
1817
2012
  /*.no_alloc =*/ true,
1818
2013
  };
1819
2014
 
@@ -1847,6 +2042,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1847
2042
  ggml_build_forward_expand(gf, mel);
1848
2043
 
1849
2044
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
2045
+ ggml_set_input(cur); // the external encoder will write into this tensor
1850
2046
 
1851
2047
  ggml_set_name(cur, "embd_enc");
1852
2048
  wstate.embd_enc = cur;
@@ -1872,9 +2068,17 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1872
2068
  const int n_head = hparams.n_audio_head;
1873
2069
  const int n_layer = hparams.n_audio_layer;
1874
2070
 
2071
+ const int n_state_head = n_state/n_head;
2072
+
2073
+ auto & kv_pad = wstate.kv_pad;
2074
+
2075
+ WHISPER_ASSERT(!!kv_pad.buffer);
2076
+
2077
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2078
+
1875
2079
  struct ggml_init_params params = {
1876
- /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1877
- /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
2080
+ /*.mem_size =*/ wstate.sched_encode.meta.size(),
2081
+ /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
1878
2082
  /*.no_alloc =*/ true,
1879
2083
  };
1880
2084
 
@@ -1884,7 +2088,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1884
2088
 
1885
2089
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1886
2090
 
1887
- const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
2091
+ const float KQscale = 1.0f/sqrtf(float(n_state_head));
1888
2092
 
1889
2093
  // ===================================================================
1890
2094
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1934,14 +2138,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1934
2138
 
1935
2139
  Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
1936
2140
 
1937
- //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
2141
+ //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
1938
2142
 
1939
2143
  // note: no bias for Key
1940
2144
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1941
2145
  layer.attn_k_w,
1942
2146
  cur);
1943
2147
 
1944
- //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
2148
+ //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
1945
2149
 
1946
2150
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1947
2151
  layer.attn_v_w,
@@ -1951,70 +2155,60 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1951
2155
 
1952
2156
  // ------
1953
2157
 
1954
- #ifdef WHISPER_USE_FLASH_ATTN
1955
2158
  struct ggml_tensor * Q =
1956
2159
  ggml_permute(ctx0,
1957
- ggml_cpy(ctx0,
1958
- Qcur,
1959
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
2160
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
1960
2161
  0, 2, 1, 3);
1961
2162
 
1962
- struct ggml_tensor * K =
1963
- ggml_permute(ctx0,
1964
- ggml_cpy(ctx0,
1965
- Kcur,
1966
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1967
- 0, 2, 1, 3);
2163
+ if (wctx.params.flash_attn) {
2164
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
2165
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
1968
2166
 
1969
- struct ggml_tensor * V =
1970
- ggml_cpy(ctx0,
1971
- ggml_permute(ctx0,
1972
- ggml_reshape_3d(ctx0,
1973
- Vcur,
1974
- n_state/n_head, n_head, n_ctx),
1975
- 1, 2, 0, 3),
1976
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
2167
+ struct ggml_tensor * K =
2168
+ ggml_view_3d(ctx0, kv_pad.k,
2169
+ n_state_head, n_ctx_pad, n_head,
2170
+ ggml_element_size(kv_pad.k)*n_state,
2171
+ ggml_element_size(kv_pad.k)*n_state_head,
2172
+ 0);
1977
2173
 
1978
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
1979
- #else
1980
- struct ggml_tensor * Q =
1981
- ggml_permute(ctx0,
1982
- ggml_cpy(ctx0,
1983
- Qcur,
1984
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1985
- 0, 2, 1, 3);
2174
+ struct ggml_tensor * V =
2175
+ ggml_view_3d(ctx0, kv_pad.v,
2176
+ n_state_head, n_ctx_pad, n_head,
2177
+ ggml_element_size(kv_pad.v)*n_state,
2178
+ ggml_element_size(kv_pad.v)*n_state_head,
2179
+ 0);
1986
2180
 
1987
- struct ggml_tensor * K =
1988
- ggml_permute(ctx0,
1989
- ggml_cpy(ctx0,
1990
- Kcur,
1991
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1992
- 0, 2, 1, 3);
2181
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
2182
+
2183
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2184
+ } else {
2185
+ struct ggml_tensor * K =
2186
+ ggml_permute(ctx0,
2187
+ ggml_cast(ctx0,
2188
+ ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2189
+ wctx.itype),
2190
+ 0, 2, 1, 3);
1993
2191
 
1994
- // K * Q
1995
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2192
+ // K * Q
2193
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1996
2194
 
1997
- struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
2195
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
1998
2196
 
1999
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
2197
+ struct ggml_tensor * V =
2198
+ ggml_cast(ctx0,
2199
+ ggml_permute(ctx0,
2200
+ ggml_reshape_3d(ctx0,
2201
+ Vcur,
2202
+ n_state_head, n_head, n_ctx),
2203
+ 1, 2, 0, 3),
2204
+ wctx.itype);
2000
2205
 
2001
- struct ggml_tensor * V =
2002
- ggml_cpy(ctx0,
2003
- ggml_permute(ctx0,
2004
- ggml_reshape_3d(ctx0,
2005
- Vcur,
2006
- n_state/n_head, n_head, n_ctx),
2007
- 1, 2, 0, 3),
2008
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
2009
- );
2206
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2010
2207
 
2011
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2012
- #endif
2013
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2208
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2014
2209
 
2015
- cur = ggml_cpy(ctx0,
2016
- KQV_merged,
2017
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
2210
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
2211
+ }
2018
2212
  }
2019
2213
 
2020
2214
  // projection
@@ -2043,11 +2237,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2043
2237
  layer.mlp_ln_b);
2044
2238
  }
2045
2239
 
2046
- #ifdef WHISPER_USE_FLASH_FF
2047
- cur = ggml_flash_ff(ctx0,
2048
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
2049
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
2050
- #else
2051
2240
  // fully connected
2052
2241
  cur = ggml_mul_mat(ctx0,
2053
2242
  layer.mlp_0_w,
@@ -2064,7 +2253,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2064
2253
  cur);
2065
2254
 
2066
2255
  cur = ggml_add(ctx0, cur, layer.mlp_1_b);
2067
- #endif
2068
2256
  }
2069
2257
 
2070
2258
  inpL = ggml_add(ctx0, cur, inpFF);
@@ -2113,9 +2301,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2113
2301
  const int n_state = hparams.n_audio_state;
2114
2302
  const int n_head = hparams.n_audio_head;
2115
2303
 
2304
+ const int n_state_head = n_state/n_head;
2305
+
2306
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2307
+
2116
2308
  struct ggml_init_params params = {
2117
- /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2118
- /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
2309
+ /*.mem_size =*/ wstate.sched_cross.meta.size(),
2310
+ /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
2119
2311
  /*.no_alloc =*/ true,
2120
2312
  };
2121
2313
 
@@ -2125,18 +2317,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2125
2317
 
2126
2318
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
2127
2319
 
2128
- const float Kscale = pow(float(n_state) / n_head, -0.25);
2320
+ const float Kscale = pow(float(n_state_head), -0.25);
2129
2321
 
2130
2322
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2131
2323
  auto & layer = model.layers_decoder[il];
2132
2324
 
2133
- struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
2325
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
2134
2326
  layer.cross_attn_k_w,
2135
2327
  cur);
2136
2328
 
2137
2329
  Kcross = ggml_scale(ctx0, Kcross, Kscale);
2138
2330
 
2139
- struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
2331
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
2140
2332
  layer.cross_attn_v_w,
2141
2333
  cur);
2142
2334
 
@@ -2144,15 +2336,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2144
2336
  Vcross,
2145
2337
  layer.cross_attn_v_b);
2146
2338
 
2147
- Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2339
+ struct ggml_tensor * k;
2340
+ struct ggml_tensor * v;
2341
+
2342
+ if (wctx.params.flash_attn) {
2343
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2344
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
2148
2345
 
2149
- struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
2150
- n_state*n_ctx,
2151
- (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2346
+ v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
2347
+ (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
2348
+ } else {
2349
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2350
+
2351
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2352
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2152
2353
 
2153
- struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2154
- ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2155
- (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2354
+ v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2355
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2356
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2357
+ }
2156
2358
 
2157
2359
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
2158
2360
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
@@ -2186,11 +2388,11 @@ static bool whisper_encode_internal(
2186
2388
 
2187
2389
  // conv
2188
2390
  {
2189
- auto & alloc = wstate.alloc_conv.alloc;
2391
+ auto & sched = wstate.sched_conv.sched;
2190
2392
 
2191
2393
  ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2192
2394
 
2193
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2395
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2194
2396
  // should never happen as we pre-allocate the memory
2195
2397
  return false;
2196
2398
  }
@@ -2223,7 +2425,7 @@ static bool whisper_encode_internal(
2223
2425
  }
2224
2426
 
2225
2427
  if (!whisper_encode_external(wstate)) {
2226
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2428
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2227
2429
  return false;
2228
2430
  }
2229
2431
  } else {
@@ -2237,32 +2439,32 @@ static bool whisper_encode_internal(
2237
2439
 
2238
2440
  // encoder
2239
2441
  if (!whisper_encode_external(wstate)) {
2240
- auto & alloc = wstate.alloc_encode.alloc;
2442
+ auto & sched = wstate.sched_encode.sched;
2241
2443
 
2242
2444
  ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2243
2445
 
2244
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2446
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2245
2447
  // should never happen as we pre-allocate the memory
2246
2448
  return false;
2247
2449
  }
2248
2450
 
2249
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2451
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2250
2452
  return false;
2251
2453
  }
2252
2454
  }
2253
2455
 
2254
2456
  // cross
2255
2457
  {
2256
- auto & alloc = wstate.alloc_cross.alloc;
2458
+ auto & sched = wstate.sched_cross.sched;
2257
2459
 
2258
2460
  ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2259
2461
 
2260
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2462
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2261
2463
  // should never happen as we pre-allocate the memory
2262
2464
  return false;
2263
2465
  }
2264
2466
 
2265
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2467
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2266
2468
  return false;
2267
2469
  }
2268
2470
  }
@@ -2284,24 +2486,28 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2284
2486
 
2285
2487
  auto & kv_self = wstate.kv_self;
2286
2488
 
2287
- WHISPER_ASSERT(!!kv_self.ctx);
2489
+ WHISPER_ASSERT(!!kv_self.buffer);
2288
2490
 
2289
2491
  const int n_ctx = kv_self.size;
2290
2492
  const int n_state = hparams.n_text_state;
2291
2493
  const int n_head = hparams.n_text_head;
2292
2494
  const int n_layer = hparams.n_text_layer;
2293
2495
 
2496
+ const int n_state_head = n_state/n_head;
2497
+
2294
2498
  const int n_tokens = batch.n_tokens;
2295
2499
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2296
2500
 
2297
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2298
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2501
+ const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
2502
+
2503
+ const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2504
+ const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2299
2505
 
2300
2506
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2301
2507
 
2302
2508
  struct ggml_init_params params = {
2303
- /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2304
- /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2509
+ /*.mem_size =*/ wstate.sched_decode.meta.size(),
2510
+ /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
2305
2511
  /*.no_alloc =*/ true,
2306
2512
  };
2307
2513
 
@@ -2317,12 +2523,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2317
2523
  ggml_set_name(position, "position");
2318
2524
  ggml_set_input(position);
2319
2525
 
2320
- const float KQscale = pow(float(n_state)/n_head, -0.25);
2526
+ const float KQscale = pow(float(n_state_head), -0.25);
2321
2527
 
2322
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
2528
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
2323
2529
  ggml_set_name(KQ_mask, "KQ_mask");
2324
2530
  ggml_set_input(KQ_mask);
2325
2531
 
2532
+ struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
2533
+
2326
2534
  // token encoding + position encoding
2327
2535
  struct ggml_tensor * cur =
2328
2536
  ggml_add(ctx0,
@@ -2378,12 +2586,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2378
2586
  Vcur,
2379
2587
  layer.attn_v_b);
2380
2588
 
2381
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2589
+ struct ggml_tensor * k;
2590
+ struct ggml_tensor * v;
2591
+
2592
+ if (wctx.params.flash_attn) {
2593
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2594
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2595
+
2596
+ v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
2597
+ (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
2598
+ } else {
2599
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2600
+
2601
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2602
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2382
2603
 
2383
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2384
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2385
- ( n_ctx)*ggml_element_size(kv_self.v),
2386
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2604
+ v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2605
+ ( n_ctx)*ggml_element_size(kv_self.v),
2606
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2607
+ }
2387
2608
 
2388
2609
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2389
2610
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2393,40 +2614,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2393
2614
 
2394
2615
  struct ggml_tensor * Q =
2395
2616
  ggml_permute(ctx0,
2396
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2617
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2397
2618
  0, 2, 1, 3);
2398
2619
 
2399
2620
  struct ggml_tensor * K =
2400
2621
  ggml_view_3d(ctx0, kv_self.k,
2401
- n_state/n_head, n_kv, n_head,
2622
+ n_state_head, n_kv, n_head,
2402
2623
  ggml_element_size(kv_self.k)*n_state,
2403
- ggml_element_size(kv_self.k)*n_state/n_head,
2624
+ ggml_element_size(kv_self.k)*n_state_head,
2404
2625
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2405
2626
 
2406
- // K * Q
2407
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2627
+ if (wctx.params.flash_attn) {
2628
+ struct ggml_tensor * V =
2629
+ ggml_view_3d(ctx0, kv_self.v,
2630
+ n_state_head, n_kv, n_head,
2631
+ ggml_element_size(kv_self.v)*n_state,
2632
+ ggml_element_size(kv_self.v)*n_state_head,
2633
+ ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2408
2634
 
2409
- //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
2635
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
2410
2636
 
2411
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
2412
- struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
2637
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2638
+ } else {
2639
+ // K * Q
2640
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2413
2641
 
2414
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
2642
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2415
2643
 
2416
- struct ggml_tensor * V =
2417
- ggml_view_3d(ctx0, kv_self.v,
2418
- n_kv, n_state/n_head, n_head,
2419
- n_ctx*ggml_element_size(kv_self.v),
2420
- n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2421
- n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2644
+ struct ggml_tensor * V =
2645
+ ggml_view_3d(ctx0, kv_self.v,
2646
+ n_kv, n_state_head, n_head,
2647
+ n_ctx*ggml_element_size(kv_self.v),
2648
+ n_ctx*ggml_element_size(kv_self.v)*n_state_head,
2649
+ n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2422
2650
 
2423
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2651
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2424
2652
 
2425
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2653
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2426
2654
 
2427
- cur = ggml_cpy(ctx0,
2428
- KQV_merged,
2429
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2655
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2656
+ }
2430
2657
  }
2431
2658
 
2432
2659
  // projection
@@ -2465,80 +2692,75 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2465
2692
  Qcur,
2466
2693
  layer.cross_attn_q_b);
2467
2694
 
2468
- Qcur = ggml_scale(ctx0, Qcur, KQscale);
2469
-
2470
- // Kcross is already scaled
2471
- struct ggml_tensor * Kcross =
2472
- ggml_view_3d(ctx0, wstate.kv_cross.k,
2473
- n_state/n_head, n_audio_ctx, n_head,
2474
- ggml_element_size(wstate.kv_cross.k)*n_state,
2475
- ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2476
- ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2477
-
2478
- //struct ggml_tensor * Vcross =
2479
- // ggml_reshape_3d(ctx0,
2480
- // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
2481
- // n_state/n_head, n_head, n_audio_ctx);
2482
-
2483
- //struct ggml_tensor * V_trans =
2484
- // ggml_cpy(ctx0,
2485
- // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2486
- // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2487
-
2488
- struct ggml_tensor * V =
2489
- ggml_view_3d(ctx0, wstate.kv_cross.v,
2490
- n_audio_ctx, n_state/n_head, n_head,
2491
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2492
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2493
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2494
-
2495
- // ------
2496
-
2497
2695
  struct ggml_tensor * Q =
2498
2696
  ggml_permute(ctx0,
2499
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2697
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2500
2698
  0, 2, 1, 3);
2501
2699
 
2502
- // K * Q
2503
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2700
+ if (wctx.params.flash_attn) {
2701
+ struct ggml_tensor * Kcross =
2702
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2703
+ n_state_head, n_audio_ctx_pad, n_head,
2704
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2705
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2706
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
2504
2707
 
2505
- //struct ggml_tensor * KQ_scaled =
2506
- // ggml_scale(ctx0,
2507
- // KQ,
2508
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2509
- // );
2708
+ struct ggml_tensor * Vcross =
2709
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2710
+ n_state_head, n_audio_ctx_pad, n_head,
2711
+ ggml_element_size(wstate.kv_cross.v)*n_state,
2712
+ ggml_element_size(wstate.kv_cross.v)*n_state_head,
2713
+ ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2510
2714
 
2511
- // no masking for cross-attention
2512
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2715
+ cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
2513
2716
 
2514
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2515
-
2516
- // [EXPERIMENTAL] Token-level timestamps with DTW
2517
- if (wctx.params.dtw_token_timestamps) {
2518
- if (wstate.aheads_masks.m[il] != nullptr) {
2519
- struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2520
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2521
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2522
- aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2523
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2524
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2525
- aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2526
- if (aheads_cross_QKs == NULL) {
2527
- aheads_cross_QKs = aheads_KQs;
2528
- } else {
2529
- aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
2717
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2718
+ } else {
2719
+ struct ggml_tensor * Kcross =
2720
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2721
+ n_state_head, n_audio_ctx, n_head,
2722
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2723
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2724
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2725
+
2726
+ struct ggml_tensor * Vcross =
2727
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2728
+ n_audio_ctx, n_state_head, n_head,
2729
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2730
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
2731
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2732
+
2733
+ // ------
2734
+
2735
+ // K * Q
2736
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2737
+
2738
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2739
+
2740
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2741
+ if (wctx.params.dtw_token_timestamps) {
2742
+ if (wstate.aheads_masks.m[il] != nullptr) {
2743
+ struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2744
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2745
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2746
+ aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2747
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2748
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2749
+ aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2750
+ if (aheads_cross_QKs == NULL) {
2751
+ aheads_cross_QKs = aheads_KQs;
2752
+ } else {
2753
+ aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
2754
+ }
2530
2755
  }
2531
2756
  }
2532
- }
2533
2757
 
2534
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2758
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
2535
2759
 
2536
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2760
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2537
2761
 
2538
- // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2539
- cur = ggml_cpy(ctx0,
2540
- KQV_merged,
2541
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2762
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
2763
+ }
2542
2764
  }
2543
2765
 
2544
2766
  // projection
@@ -2671,18 +2893,20 @@ static bool whisper_decode_internal(
2671
2893
  return false;
2672
2894
  }
2673
2895
 
2674
- kv_self.n = whisper_kv_cache_cell_max(kv_self);
2896
+ const uint32_t pad = whisper_kv_cache_get_padding(wctx);
2897
+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
2898
+
2675
2899
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2676
2900
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2677
2901
  }
2678
2902
 
2679
2903
  // decoder
2680
2904
  {
2681
- auto & alloc = wstate.alloc_decode.alloc;
2905
+ auto & sched = wstate.sched_decode.sched;
2682
2906
 
2683
2907
  ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2684
2908
 
2685
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2909
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2686
2910
  // should never happen as we pre-allocate the memory
2687
2911
  return false;
2688
2912
  }
@@ -2705,9 +2929,10 @@ static bool whisper_decode_internal(
2705
2929
  struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
2706
2930
 
2707
2931
  auto & kv_self = wstate.kv_self;
2708
- const int32_t n_kv = kv_self.n;
2709
2932
 
2710
- wstate.inp_mask.resize(n_kv*n_tokens);
2933
+ const int32_t n_kv = kv_self.n;
2934
+
2935
+ wstate.inp_mask.resize(ggml_nelements(KQ_mask));
2711
2936
 
2712
2937
  float * data = wstate.inp_mask.data();
2713
2938
  memset(data, 0, ggml_nbytes(KQ_mask));
@@ -2723,14 +2948,20 @@ static bool whisper_decode_internal(
2723
2948
  }
2724
2949
  }
2725
2950
  }
2951
+
2952
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
2953
+ for (int j = 0; j < n_kv; ++j) {
2954
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2955
+ }
2956
+ }
2726
2957
  }
2727
2958
 
2728
2959
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
2729
2960
  }
2730
2961
 
2731
- logits = gf->nodes[gf->n_nodes - 1];
2962
+ logits = ggml_graph_node(gf, -1);
2732
2963
 
2733
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2964
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2734
2965
  return false;
2735
2966
  }
2736
2967
  }
@@ -2784,29 +3015,47 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2784
3015
  }
2785
3016
 
2786
3017
  #define SIN_COS_N_COUNT WHISPER_N_FFT
2787
- static float sin_vals[SIN_COS_N_COUNT];
2788
- static float cos_vals[SIN_COS_N_COUNT];
3018
+ namespace {
3019
+ struct whisper_global_cache {
3020
+ // In FFT, we frequently use sine and cosine operations with the same values.
3021
+ // We can use precalculated values to speed up the process.
3022
+ float sin_vals[SIN_COS_N_COUNT];
3023
+ float cos_vals[SIN_COS_N_COUNT];
3024
+
3025
+ // Hann window (Use cosf to eliminate difference)
3026
+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
3027
+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
3028
+ float hann_window[WHISPER_N_FFT];
3029
+
3030
+ whisper_global_cache() {
3031
+ fill_sin_cos_table();
3032
+ fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
3033
+ }
3034
+
3035
+ void fill_sin_cos_table() {
3036
+ for (int i = 0; i < SIN_COS_N_COUNT; i++) {
3037
+ double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
3038
+ sin_vals[i] = sinf(theta);
3039
+ cos_vals[i] = cosf(theta);
3040
+ }
3041
+ }
2789
3042
 
2790
- // In FFT, we frequently use sine and cosine operations with the same values.
2791
- // We can use precalculated values to speed up the process.
2792
- static void fill_sin_cos_table() {
2793
- static bool is_filled = false;
2794
- if (is_filled) return;
2795
- for (int i = 0; i < SIN_COS_N_COUNT; i++) {
2796
- double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
2797
- sin_vals[i] = sinf(theta);
2798
- cos_vals[i] = cosf(theta);
3043
+ void fill_hann_window(int length, bool periodic, float * output) {
3044
+ int offset = -1;
3045
+ if (periodic) {
3046
+ offset = 0;
3047
+ }
3048
+ for (int i = 0; i < length; i++) {
3049
+ output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
3050
+ }
2799
3051
  }
2800
- is_filled = true;
3052
+ } global_cache;
2801
3053
  }
2802
3054
 
2803
3055
  // naive Discrete Fourier Transform
2804
3056
  // input is real-valued
2805
3057
  // output is complex-valued
2806
- static void dft(const std::vector<float> & in, std::vector<float> & out) {
2807
- int N = in.size();
2808
-
2809
- out.resize(N*2);
3058
+ static void dft(const float* in, int N, float* out) {
2810
3059
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2811
3060
 
2812
3061
  for (int k = 0; k < N; k++) {
@@ -2815,8 +3064,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2815
3064
 
2816
3065
  for (int n = 0; n < N; n++) {
2817
3066
  int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
2818
- re += in[n]*cos_vals[idx]; // cos(t)
2819
- im -= in[n]*sin_vals[idx]; // sin(t)
3067
+ re += in[n]*global_cache.cos_vals[idx]; // cos(t)
3068
+ im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
2820
3069
  }
2821
3070
 
2822
3071
  out[k*2 + 0] = re;
@@ -2828,47 +3077,38 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2828
3077
  // poor man's implementation - use something better
2829
3078
  // input is real-valued
2830
3079
  // output is complex-valued
2831
- static void fft(const std::vector<float> & in, std::vector<float> & out) {
2832
- out.resize(in.size()*2);
2833
-
2834
- int N = in.size();
2835
-
3080
+ static void fft(float* in, int N, float* out) {
2836
3081
  if (N == 1) {
2837
3082
  out[0] = in[0];
2838
3083
  out[1] = 0;
2839
3084
  return;
2840
3085
  }
2841
3086
 
2842
- if (N%2 == 1) {
2843
- dft(in, out);
3087
+ const int half_N = N / 2;
3088
+ if (N - half_N*2 == 1) {
3089
+ dft(in, N, out);
2844
3090
  return;
2845
3091
  }
2846
3092
 
2847
- std::vector<float> even;
2848
- std::vector<float> odd;
2849
-
2850
- even.reserve(N/2);
2851
- odd.reserve(N/2);
2852
-
2853
- for (int i = 0; i < N; i++) {
2854
- if (i % 2 == 0) {
2855
- even.push_back(in[i]);
2856
- } else {
2857
- odd.push_back(in[i]);
2858
- }
3093
+ float* even = in + N;
3094
+ for (int i = 0; i < half_N; ++i) {
3095
+ even[i]= in[2*i];
2859
3096
  }
3097
+ float* even_fft = out + 2 * N;
3098
+ fft(even, half_N, even_fft);
2860
3099
 
2861
- std::vector<float> even_fft;
2862
- std::vector<float> odd_fft;
2863
-
2864
- fft(even, even_fft);
2865
- fft(odd, odd_fft);
3100
+ float* odd = even;
3101
+ for (int i = 0; i < half_N; ++i) {
3102
+ odd[i] = in[2*i + 1];
3103
+ }
3104
+ float* odd_fft = even_fft + N;
3105
+ fft(odd, half_N, odd_fft);
2866
3106
 
2867
3107
  const int sin_cos_step = SIN_COS_N_COUNT / N;
2868
- for (int k = 0; k < N/2; k++) {
3108
+ for (int k = 0; k < half_N; k++) {
2869
3109
  int idx = k * sin_cos_step; // t = 2*M_PI*k/N
2870
- float re = cos_vals[idx]; // cos(t)
2871
- float im = -sin_vals[idx]; // sin(t)
3110
+ float re = global_cache.cos_vals[idx]; // cos(t)
3111
+ float im = -global_cache.sin_vals[idx]; // sin(t)
2872
3112
 
2873
3113
  float re_odd = odd_fft[2*k + 0];
2874
3114
  float im_odd = odd_fft[2*k + 1];
@@ -2876,52 +3116,39 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2876
3116
  out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
2877
3117
  out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
2878
3118
 
2879
- out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
2880
- out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2881
- }
2882
- }
2883
-
2884
- static bool hann_window(int length, bool periodic, std::vector<float> & output) {
2885
- if (output.size() < static_cast<size_t>(length)) {
2886
- output.resize(length);
3119
+ out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3120
+ out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
2887
3121
  }
2888
- int offset = -1;
2889
- if (periodic) {
2890
- offset = 0;
2891
- }
2892
- for (int i = 0; i < length; i++) {
2893
- output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
2894
- }
2895
-
2896
- return true;
2897
3122
  }
2898
3123
 
2899
- static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
3124
+ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
2900
3125
  int n_samples, int frame_size, int frame_step, int n_threads,
2901
3126
  const whisper_filters & filters, whisper_mel & mel) {
2902
- std::vector<float> fft_in(frame_size, 0.0);
2903
- std::vector<float> fft_out(2 * frame_size);
3127
+ std::vector<float> fft_in(frame_size * 2, 0.0);
3128
+ std::vector<float> fft_out(frame_size * 2 * 2 * 2);
3129
+
2904
3130
  int n_fft = filters.n_fft;
2905
3131
  int i = ith;
2906
3132
 
2907
3133
  // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2908
- assert( n_fft == 1 + (frame_size / 2) );
2909
-
3134
+ assert(n_fft == 1 + (frame_size / 2));
3135
+
2910
3136
  // calculate FFT only when fft_in are not all zero
2911
3137
  for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
2912
3138
  const int offset = i * frame_step;
2913
3139
 
2914
- // apply Hanning window (~10% faster)
3140
+ // apply Hann window (~10% faster)
2915
3141
  for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
2916
3142
  fft_in[j] = hann[j] * samples[offset + j];
2917
3143
  }
3144
+
2918
3145
  // fill the rest with zeros
2919
3146
  if (n_samples - offset < frame_size) {
2920
3147
  std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
2921
3148
  }
2922
3149
 
2923
3150
  // FFT
2924
- fft(fft_in, fft_out);
3151
+ fft(fft_in.data(), frame_size, fft_out.data());
2925
3152
 
2926
3153
  // Calculate modulus^2 of complex numbers
2927
3154
  // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
@@ -2932,7 +3159,6 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2932
3159
  // mel spectrogram
2933
3160
  for (int j = 0; j < mel.n_mel; j++) {
2934
3161
  double sum = 0.0;
2935
-
2936
3162
  // unroll loop (suggested by GH user @lunixbochs)
2937
3163
  int k = 0;
2938
3164
  for (k = 0; k < n_fft - 3; k += 4) {
@@ -2942,14 +3168,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
2942
3168
  fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
2943
3169
  fft_out[k + 3] * filters.data[j * n_fft + k + 3];
2944
3170
  }
2945
-
2946
3171
  // handle n_fft remainder
2947
3172
  for (; k < n_fft; k++) {
2948
3173
  sum += fft_out[k] * filters.data[j * n_fft + k];
2949
3174
  }
2950
-
2951
3175
  sum = log10(std::max(sum, 1e-10));
2952
-
2953
3176
  mel.data[j * mel.n_len + i] = sum;
2954
3177
  }
2955
3178
  }
@@ -2978,12 +3201,9 @@ static bool log_mel_spectrogram(
2978
3201
  whisper_mel & mel) {
2979
3202
  const int64_t t_start_us = ggml_time_us();
2980
3203
 
2981
- // Hanning window (Use cosf to eliminate difference)
2982
- // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2983
- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
2984
- std::vector<float> hann;
2985
- hann_window(frame_size, true, hann);
2986
-
3204
+ // Hann window
3205
+ WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
3206
+ const float * hann = global_cache.hann_window;
2987
3207
 
2988
3208
  // Calculate the length of padding
2989
3209
  int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
@@ -3008,12 +3228,11 @@ static bool log_mel_spectrogram(
3008
3228
  mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
3009
3229
  mel.data.resize(mel.n_mel * mel.n_len);
3010
3230
 
3011
-
3012
3231
  {
3013
3232
  std::vector<std::thread> workers(n_threads - 1);
3014
3233
  for (int iw = 0; iw < n_threads - 1; ++iw) {
3015
3234
  workers[iw] = std::thread(
3016
- log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
3235
+ log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
3017
3236
  n_samples + stage_2_pad, frame_size, frame_step, n_threads,
3018
3237
  std::cref(filters), std::ref(mel));
3019
3238
  }
@@ -3173,23 +3392,23 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3173
3392
  #endif
3174
3393
 
3175
3394
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3176
- fill_sin_cos_table();
3177
-
3178
3395
  whisper_state * state = new whisper_state;
3179
3396
 
3180
- state->backend = whisper_backend_init(ctx->params);
3181
- if (!state->backend) {
3397
+ state->backends = whisper_backend_init(ctx->params);
3398
+ if (state->backends.empty()) {
3182
3399
  WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3183
3400
  whisper_free_state(state);
3184
3401
  return nullptr;
3185
3402
  }
3186
3403
 
3187
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3188
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3189
- const int factor = 3;
3190
-
3191
- if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
3192
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3404
+ // at this point, we don't know yet how many decoders will be used
3405
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3406
+ state->kv_self_n_dec = 1;
3407
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3408
+ ctx->model.hparams.n_text_state,
3409
+ ctx->model.hparams.n_text_layer,
3410
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
3411
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3193
3412
  whisper_free_state(state);
3194
3413
  return nullptr;
3195
3414
  }
@@ -3199,8 +3418,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3199
3418
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3200
3419
  }
3201
3420
 
3202
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
3203
- WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3421
+ if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
3422
+ ctx->model.hparams.n_text_state,
3423
+ ctx->model.hparams.n_text_layer,
3424
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3425
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
3204
3426
  whisper_free_state(state);
3205
3427
  return nullptr;
3206
3428
  }
@@ -3210,9 +3432,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3210
3432
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3211
3433
  }
3212
3434
 
3435
+ if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
3436
+ ctx->model.hparams.n_audio_state,
3437
+ 1,
3438
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3439
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3440
+ whisper_free_state(state);
3441
+ return nullptr;
3442
+ }
3443
+
3444
+ {
3445
+ const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
3446
+ WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
3447
+ }
3448
+
3213
3449
  // [EXPERIMENTAL] Token-level timestamps with DTW
3214
3450
  if (ctx->params.dtw_token_timestamps) {
3215
- if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
3451
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
3216
3452
  WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3217
3453
  whisper_free_state(state);
3218
3454
  return nullptr;
@@ -3255,7 +3491,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3255
3491
 
3256
3492
  // conv allocator
3257
3493
  {
3258
- bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
3494
+ bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
3259
3495
  [&]() {
3260
3496
  return whisper_build_graph_conv(*ctx, *state);
3261
3497
  });
@@ -3266,12 +3502,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3266
3502
  return nullptr;
3267
3503
  }
3268
3504
 
3269
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
3505
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
3270
3506
  }
3271
3507
 
3272
3508
  // encoder allocator
3273
3509
  if (!whisper_encode_external(*state)) {
3274
- bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
3510
+ bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
3275
3511
  [&]() {
3276
3512
  return whisper_build_graph_encoder(*ctx, *state);
3277
3513
  });
@@ -3282,12 +3518,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3282
3518
  return nullptr;
3283
3519
  }
3284
3520
 
3285
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
3521
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
3286
3522
  }
3287
3523
 
3288
3524
  // cross allocator
3289
3525
  {
3290
- bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
3526
+ bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
3291
3527
  [&]() {
3292
3528
  return whisper_build_graph_cross(*ctx, *state);
3293
3529
  });
@@ -3298,12 +3534,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3298
3534
  return nullptr;
3299
3535
  }
3300
3536
 
3301
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
3537
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
3302
3538
  }
3303
3539
 
3304
3540
  // decoder allocator
3305
3541
  {
3306
- bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
3542
+ bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
3307
3543
  [&]() {
3308
3544
  const auto & hparams = ctx->model.hparams;
3309
3545
 
@@ -3322,19 +3558,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3322
3558
  return nullptr;
3323
3559
  }
3324
3560
 
3325
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3561
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
3326
3562
  }
3327
3563
 
3328
3564
  return state;
3329
3565
  }
3330
3566
 
3331
- int whisper_ctx_init_openvino_encoder(
3567
+ int whisper_ctx_init_openvino_encoder_with_state(
3332
3568
  struct whisper_context * ctx,
3569
+ struct whisper_state * state,
3333
3570
  const char * model_path,
3334
3571
  const char * device,
3335
3572
  const char * cache_dir) {
3336
3573
  #ifndef WHISPER_USE_OPENVINO
3337
3574
  (void)(ctx);
3575
+ (void)(state);
3338
3576
  (void)(model_path);
3339
3577
  (void)(device);
3340
3578
  (void)(cache_dir);
@@ -3365,8 +3603,8 @@ int whisper_ctx_init_openvino_encoder(
3365
3603
  WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3366
3604
  WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
3367
3605
 
3368
- ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3369
- if (!ctx->state->ctx_openvino) {
3606
+ state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3607
+ if (!state->ctx_openvino) {
3370
3608
  WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3371
3609
  return 1;
3372
3610
  } else {
@@ -3377,9 +3615,18 @@ int whisper_ctx_init_openvino_encoder(
3377
3615
  #endif
3378
3616
  }
3379
3617
 
3618
+ int whisper_ctx_init_openvino_encoder(
3619
+ struct whisper_context * ctx,
3620
+ const char * model_path,
3621
+ const char * device,
3622
+ const char * cache_dir) {
3623
+ return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
3624
+ }
3625
+
3380
3626
  struct whisper_context_params whisper_context_default_params() {
3381
3627
  struct whisper_context_params result = {
3382
3628
  /*.use_gpu =*/ true,
3629
+ /*.flash_attn =*/ false,
3383
3630
  /*.gpu_device =*/ 0,
3384
3631
 
3385
3632
  /*.dtw_token_timestamps =*/ false,
@@ -3396,8 +3643,14 @@ struct whisper_context_params whisper_context_default_params() {
3396
3643
 
3397
3644
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3398
3645
  WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3399
-
3646
+ #ifdef _MSC_VER
3647
+ // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
3648
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
3649
+ std::wstring path_model_wide = converter.from_bytes(path_model);
3650
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
3651
+ #else
3400
3652
  auto fin = std::ifstream(path_model, std::ios::binary);
3653
+ #endif
3401
3654
  if (!fin) {
3402
3655
  WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3403
3656
  return nullptr;
@@ -3472,6 +3725,18 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3472
3725
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3473
3726
  ggml_time_init();
3474
3727
 
3728
+ if (params.flash_attn && params.dtw_token_timestamps) {
3729
+ WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
3730
+ params.dtw_token_timestamps = false;
3731
+ }
3732
+
3733
+ WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3734
+ WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3735
+ WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3736
+ WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3737
+ WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
3738
+ WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
3739
+
3475
3740
  whisper_context * ctx = new whisper_context;
3476
3741
  ctx->params = params;
3477
3742
 
@@ -3558,8 +3823,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
3558
3823
 
3559
3824
  void whisper_free_state(struct whisper_state * state) {
3560
3825
  if (state) {
3561
- kv_cache_free(state->kv_self);
3562
- kv_cache_free(state->kv_cross);
3826
+ whisper_kv_cache_free(state->kv_self);
3827
+ whisper_kv_cache_free(state->kv_cross);
3828
+ whisper_kv_cache_free(state->kv_pad);
3563
3829
 
3564
3830
  #ifdef WHISPER_USE_COREML
3565
3831
  if (state->ctx_coreml != nullptr) {
@@ -3577,30 +3843,39 @@ void whisper_free_state(struct whisper_state * state) {
3577
3843
 
3578
3844
  whisper_batch_free(state->batch);
3579
3845
 
3580
- ggml_gallocr_free(state->alloc_conv.alloc);
3581
- ggml_gallocr_free(state->alloc_encode.alloc);
3582
- ggml_gallocr_free(state->alloc_cross.alloc);
3583
- ggml_gallocr_free(state->alloc_decode.alloc);
3846
+ ggml_backend_sched_free(state->sched_conv.sched);
3847
+ ggml_backend_sched_free(state->sched_encode.sched);
3848
+ ggml_backend_sched_free(state->sched_cross.sched);
3849
+ ggml_backend_sched_free(state->sched_decode.sched);
3584
3850
 
3585
- ggml_backend_free(state->backend);
3851
+ for (auto & backend : state->backends) {
3852
+ ggml_backend_free(backend);
3853
+ }
3586
3854
 
3587
3855
  // [EXPERIMENTAL] Token-level timestamps with DTW
3588
3856
  aheads_masks_free(state->aheads_masks);
3589
3857
 
3858
+ if (state->vad_context != nullptr) {
3859
+ whisper_vad_free(state->vad_context);
3860
+ state->vad_context = nullptr;
3861
+ }
3862
+
3590
3863
  delete state;
3591
3864
  }
3592
3865
  }
3593
3866
 
3594
3867
  void whisper_free(struct whisper_context * ctx) {
3595
3868
  if (ctx) {
3596
- ggml_free(ctx->model.ctx);
3869
+ for (ggml_context * context : ctx->model.ctxs) {
3870
+ ggml_free(context);
3871
+ }
3597
3872
 
3598
- ggml_backend_buffer_free(ctx->model.buffer);
3873
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
3874
+ ggml_backend_buffer_free(buf);
3875
+ }
3599
3876
 
3600
3877
  whisper_free_state(ctx->state);
3601
3878
 
3602
- ggml_backend_free(ctx->backend);
3603
-
3604
3879
  delete ctx;
3605
3880
  }
3606
3881
  }
@@ -3630,30 +3905,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3630
3905
  return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3631
3906
  }
3632
3907
 
3633
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3634
- int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3635
- if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3636
- WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3637
- return -1;
3638
- }
3639
-
3640
- return 0;
3641
- }
3642
-
3643
- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3644
- int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
3645
- return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
3646
- }
3647
-
3648
- // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3649
- // TODO
3650
-
3651
- // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3652
- // TODO
3653
-
3654
- // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3655
- // TODO
3656
-
3657
3908
  int whisper_set_mel_with_state(
3658
3909
  struct whisper_context * ctx,
3659
3910
  struct whisper_state * state,
@@ -3742,7 +3993,7 @@ int whisper_token_count(struct whisper_context * ctx, const char * text) {
3742
3993
  return -whisper_tokenize(ctx, text, NULL, 0);
3743
3994
  }
3744
3995
 
3745
- int whisper_lang_max_id() {
3996
+ int whisper_lang_max_id(void) {
3746
3997
  auto max_id = 0;
3747
3998
  for (const auto & kv : g_lang) {
3748
3999
  max_id = std::max(max_id, kv.second.first);
@@ -3963,134 +4214,1262 @@ float * whisper_get_logits(struct whisper_context * ctx) {
3963
4214
  return ctx->state->logits.data();
3964
4215
  }
3965
4216
 
3966
- float * whisper_get_logits_from_state(struct whisper_state * state) {
3967
- return state->logits.data();
3968
- }
4217
+ float * whisper_get_logits_from_state(struct whisper_state * state) {
4218
+ return state->logits.data();
4219
+ }
4220
+
4221
+ const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
4222
+ return ctx->vocab.id_to_token.at(token).c_str();
4223
+ }
4224
+
4225
+ whisper_token whisper_token_eot(struct whisper_context * ctx) {
4226
+ return ctx->vocab.token_eot;
4227
+ }
4228
+
4229
+ whisper_token whisper_token_sot(struct whisper_context * ctx) {
4230
+ return ctx->vocab.token_sot;
4231
+ }
4232
+
4233
+ whisper_token whisper_token_solm(struct whisper_context * ctx) {
4234
+ return ctx->vocab.token_solm;
4235
+ }
4236
+
4237
+ whisper_token whisper_token_prev(struct whisper_context * ctx) {
4238
+ return ctx->vocab.token_prev;
4239
+ }
4240
+
4241
+ whisper_token whisper_token_nosp(struct whisper_context * ctx) {
4242
+ return ctx->vocab.token_nosp;
4243
+ }
4244
+
4245
+ whisper_token whisper_token_not(struct whisper_context * ctx) {
4246
+ return ctx->vocab.token_not;
4247
+ }
4248
+
4249
+ whisper_token whisper_token_beg(struct whisper_context * ctx) {
4250
+ return ctx->vocab.token_beg;
4251
+ }
4252
+
4253
+ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
4254
+ return whisper_token_sot(ctx) + 1 + lang_id;
4255
+ }
4256
+
4257
+ whisper_token whisper_token_translate(struct whisper_context * ctx) {
4258
+ return ctx->vocab.token_translate;
4259
+ }
4260
+
4261
+ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
4262
+ return ctx->vocab.token_transcribe;
4263
+ }
4264
+
4265
+ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
4266
+ if (ctx->state == nullptr) {
4267
+ return nullptr;
4268
+ }
4269
+ whisper_timings * timings = new whisper_timings;
4270
+ timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
4271
+ timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
4272
+ timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
4273
+ timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
4274
+ timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
4275
+ return timings;
4276
+ }
4277
+
4278
+ void whisper_print_timings(struct whisper_context * ctx) {
4279
+ const int64_t t_end_us = ggml_time_us();
4280
+
4281
+ WHISPER_LOG_INFO("\n");
4282
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
4283
+ if (ctx->state != nullptr) {
4284
+
4285
+ const int32_t n_sample = std::max(1, ctx->state->n_sample);
4286
+ const int32_t n_encode = std::max(1, ctx->state->n_encode);
4287
+ const int32_t n_decode = std::max(1, ctx->state->n_decode);
4288
+ const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
4289
+ const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
4290
+
4291
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
4292
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
4293
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4294
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4295
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4296
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4297
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4298
+ }
4299
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4300
+ }
4301
+
4302
+ void whisper_reset_timings(struct whisper_context * ctx) {
4303
+ ctx->t_start_us = ggml_time_us();
4304
+ if (ctx->state != nullptr) {
4305
+ ctx->state->t_mel_us = 0;
4306
+ ctx->state->t_sample_us = 0;
4307
+ ctx->state->t_encode_us = 0;
4308
+ ctx->state->t_decode_us = 0;
4309
+ ctx->state->t_batchd_us = 0;
4310
+ ctx->state->t_prompt_us = 0;
4311
+ ctx->state->n_sample = 0;
4312
+ ctx->state->n_encode = 0;
4313
+ ctx->state->n_decode = 0;
4314
+ ctx->state->n_batchd = 0;
4315
+ ctx->state->n_prompt = 0;
4316
+ }
4317
+ }
4318
+
4319
+ static int whisper_has_coreml(void) {
4320
+ #ifdef WHISPER_USE_COREML
4321
+ return 1;
4322
+ #else
4323
+ return 0;
4324
+ #endif
4325
+ }
4326
+
4327
+ static int whisper_has_openvino(void) {
4328
+ #ifdef WHISPER_USE_OPENVINO
4329
+ return 1;
4330
+ #else
4331
+ return 0;
4332
+ #endif
4333
+ }
4334
+
4335
+ const char * whisper_print_system_info(void) {
4336
+ static std::string s;
4337
+
4338
+ whisper_load_backends();
4339
+
4340
+ s = "";
4341
+ s += "WHISPER : ";
4342
+ s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4343
+ s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
4344
+
4345
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
4346
+ auto * reg = ggml_backend_reg_get(i);
4347
+ auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
4348
+ if (get_features_fn) {
4349
+ ggml_backend_feature * features = get_features_fn(reg);
4350
+ s += ggml_backend_reg_name(reg);
4351
+ s += " : ";
4352
+ for (; features->name; features++) {
4353
+ s += features->name;
4354
+ s += " = ";
4355
+ s += features->value;
4356
+ s += " | ";
4357
+ }
4358
+ }
4359
+ }
4360
+ return s.c_str();
4361
+ }
4362
+
4363
+ //////////////////////////////////
4364
+ // Voice Activity Detection (VAD)
4365
+ //////////////////////////////////
4366
+
4367
+ struct whisper_vad_hparams {
4368
+ int32_t n_encoder_layers;
4369
+ int32_t * encoder_in_channels;
4370
+ int32_t * encoder_out_channels;
4371
+ int32_t * kernel_sizes;
4372
+ int32_t lstm_input_size;
4373
+ int32_t lstm_hidden_size;
4374
+ int32_t final_conv_in;
4375
+ int32_t final_conv_out;
4376
+ };
4377
+
4378
+ struct whisper_vad_model {
4379
+ std::string type;
4380
+ std::string version;
4381
+ whisper_vad_hparams hparams;
4382
+
4383
+ struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
4384
+
4385
+ // Encoder tensors - 4 convolutional layers
4386
+ struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
4387
+ struct ggml_tensor * encoder_0_bias; // [128]
4388
+
4389
+ // Second encoder layer
4390
+ struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
4391
+ struct ggml_tensor * encoder_1_bias; // [64]
4392
+
4393
+ // Third encoder layer
4394
+ struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
4395
+ struct ggml_tensor * encoder_2_bias; // [64]
4396
+
4397
+ // Fourth encoder layer
4398
+ struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
4399
+ struct ggml_tensor * encoder_3_bias; // [128]
4400
+
4401
+ // LSTM decoder tensors
4402
+ struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
4403
+ struct ggml_tensor * lstm_ih_bias; // [512]
4404
+ struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
4405
+ struct ggml_tensor * lstm_hh_bias; // [512]
4406
+
4407
+ // Final conv layer
4408
+ struct ggml_tensor * final_conv_weight; // [128]
4409
+ struct ggml_tensor * final_conv_bias; // [1]
4410
+
4411
+ // ggml contexts
4412
+ std::vector<ggml_context *> ctxs;
4413
+
4414
+ // buffer for the model tensors
4415
+ std::vector<ggml_backend_buffer_t> buffers;
4416
+
4417
+ // tensors
4418
+ int n_loaded;
4419
+ std::map<std::string, struct ggml_tensor *> tensors;
4420
+ };
4421
+
4422
+ struct whisper_vad_segment {
4423
+ float start; // Start time in seconds
4424
+ float end; // End time in seconds
4425
+ };
4426
+
4427
+ struct whisper_vad_segments {
4428
+ std::vector<whisper_vad_segment> data;
4429
+ };
4430
+
4431
+ struct whisper_vad_context {
4432
+ int64_t t_vad_us = 0;
4433
+
4434
+ int n_window;
4435
+ int n_context;
4436
+ int n_threads;
4437
+
4438
+ std::vector<ggml_backend_t> backends;
4439
+ ggml_backend_buffer_t buffer = nullptr;
4440
+ whisper_context_params params;
4441
+ std::vector<uint8_t> ctx_buf;
4442
+ whisper_sched sched;
4443
+
4444
+ whisper_vad_model model;
4445
+ std::string path_model;
4446
+ struct ggml_tensor * h_state;
4447
+ struct ggml_tensor * c_state;
4448
+ std::vector<float> probs;
4449
+ };
4450
+
4451
+ struct whisper_vad_context_params whisper_vad_default_context_params(void) {
4452
+ whisper_vad_context_params result = {
4453
+ /*.n_thread = */ 4,
4454
+ /*.use_gpu = */ false,
4455
+ /*.gpu_device = */ 0,
4456
+ };
4457
+ return result;
4458
+ }
4459
+
4460
+ struct whisper_vad_params whisper_vad_default_params(void) {
4461
+ whisper_vad_params result = {
4462
+ /* threshold = */ 0.5f,
4463
+ /* min_speech_duration_ms = */ 250,
4464
+ /* min_silence_duration_ms = */ 100,
4465
+ /* max_speech_duration_s = */ FLT_MAX,
4466
+ /* speech_pad_ms = */ 30,
4467
+ /* samples_overlap = */ 0.1,
4468
+ };
4469
+ return result;
4470
+ }
4471
+
4472
+ static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
4473
+ bool op_supported = true;
4474
+
4475
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
4476
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
4477
+ // GPU and default CPU backend support all operators
4478
+ op_supported = true;
4479
+ } else {
4480
+ switch (op) {
4481
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
4482
+ case GGML_OP_MUL_MAT: {
4483
+ ggml_init_params params = {
4484
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
4485
+ /*.mem_buffer =*/ nullptr,
4486
+ /*.no_alloc =*/ true,
4487
+ };
4488
+
4489
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
4490
+ if (!ctx_ptr) {
4491
+ throw std::runtime_error("failed to create ggml context");
4492
+ }
4493
+ ggml_context * ctx = ctx_ptr.get();
4494
+
4495
+ ggml_tensor * op_tensor = nullptr;
4496
+
4497
+ int64_t n_ctx = hparams.lstm_hidden_size;
4498
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
4499
+ op_tensor = ggml_mul_mat(ctx, w, b);
4500
+
4501
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
4502
+ GGML_ASSERT(w->buffer == nullptr);
4503
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
4504
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
4505
+ ggml_backend_buffer_free(w->buffer);
4506
+ w->buffer = nullptr;
4507
+ break;
4508
+ }
4509
+ default: {
4510
+ op_supported = false;
4511
+ break;
4512
+ }
4513
+ };
4514
+ }
4515
+ return op_supported;
4516
+ }
4517
+
4518
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
4519
+ GGML_ASSERT(!buft_list.empty());
4520
+ for (const auto & p : buft_list) {
4521
+ ggml_backend_dev_t dev = p.first;
4522
+ ggml_backend_buffer_type_t buft = p.second;
4523
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
4524
+ return buft;
4525
+ }
4526
+ }
4527
+
4528
+ return nullptr;
4529
+ }
4530
+
4531
+ static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
4532
+ const whisper_vad_model & model, ggml_tensor * cur) {
4533
+ // Apply reflective padding to the input tensor
4534
+ ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
4535
+
4536
+ struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
4537
+
4538
+ // Calculate cutoff for real/imaginary parts
4539
+ int cutoff = model.stft_forward_basis->ne[2] / 2;
4540
+
4541
+ // Extract real part (first half of the STFT output).
4542
+ struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
4543
+ // Extract imaginary part (second half of the STFT output).
4544
+ struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
4545
+
4546
+ // Calculate magnitude: sqrt(real^2 + imag^2)
4547
+ struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
4548
+ struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
4549
+ struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
4550
+ struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
4551
+ return magnitude;
4552
+ }
4553
+
4554
+ static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
4555
+ const whisper_vad_model & model, ggml_tensor * cur) {
4556
+ // First Conv1D: expands to 128 channels.
4557
+ cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
4558
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
4559
+ cur = ggml_relu(ctx0, cur);
4560
+
4561
+ // Second Conv1D: reduces to 64 channels.
4562
+ cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
4563
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
4564
+ cur = ggml_relu(ctx0, cur);
4565
+
4566
+ // Third Conv1D: maintains 64 channels
4567
+ cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
4568
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
4569
+ cur = ggml_relu(ctx0, cur);
4570
+
4571
+ // Fourth Conv1D: expands to 128 channels
4572
+ cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
4573
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
4574
+ cur = ggml_relu(ctx0, cur);
4575
+
4576
+ return cur;
4577
+ }
4578
+
4579
+ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
4580
+ const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
4581
+ const whisper_vad_model & model = vctx.model;
4582
+ const int hdim = model.hparams.lstm_hidden_size;
4583
+
4584
+ struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
4585
+
4586
+ // Create operations using the input-to-hidden weights.
4587
+ struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
4588
+ inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
4589
+
4590
+ // Create operations using the hidden-to-hidden weights.
4591
+ struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
4592
+ hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
4593
+
4594
+ // Create add operation to get preactivations for all gates.
4595
+ struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
4596
+
4597
+ const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
4598
+
4599
+ // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
4600
+ struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
4601
+
4602
+ // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
4603
+ struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
4604
+
4605
+ // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
4606
+ struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
4607
+
4608
+ // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
4609
+ struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
4610
+
4611
+ // Update cell state
4612
+ struct ggml_tensor * c_out = ggml_add(ctx0,
4613
+ ggml_mul(ctx0, f_t, vctx.c_state),
4614
+ ggml_mul(ctx0, i_t, g_t));
4615
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
4616
+
4617
+ // Update hidden state
4618
+ struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
4619
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
4620
+
4621
+ return out;
4622
+ }
4623
+
4624
+ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
4625
+ const auto & model = vctx.model;
4626
+
4627
+ struct ggml_init_params params = {
4628
+ /*.mem_size =*/ vctx.sched.meta.size(),
4629
+ /*.mem_buffer =*/ vctx.sched.meta.data(),
4630
+ /*.no_alloc =*/ true,
4631
+ };
4632
+
4633
+ struct ggml_context * ctx0 = ggml_init(params);
4634
+
4635
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
4636
+
4637
+ struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
4638
+ ggml_set_name(frame, "frame");
4639
+ ggml_set_input(frame);
4640
+
4641
+ struct ggml_tensor * cur = nullptr;
4642
+ {
4643
+ cur = whisper_vad_build_stft_layer(ctx0, model, frame);
4644
+
4645
+ cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4646
+
4647
+ // Extract the first element of the first dimension
4648
+ // (equivalent to pytorch's [:, :, 0])
4649
+ cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
4650
+
4651
+ cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
4652
+ cur = ggml_relu(ctx0, cur);
4653
+ cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
4654
+ cur = ggml_add(ctx0, cur, model.final_conv_bias);
4655
+ cur = ggml_sigmoid(ctx0, cur);
4656
+ ggml_set_name(cur, "prob");
4657
+ ggml_set_output(cur);
4658
+ }
4659
+
4660
+ ggml_build_forward_expand(gf, cur);
4661
+
4662
+ ggml_free(ctx0);
4663
+
4664
+ return gf;
4665
+ }
4666
+
4667
+ static bool whisper_vad_init_context(whisper_vad_context * vctx) {
4668
+
4669
+ auto whisper_context_params = whisper_context_default_params();
4670
+ // TODO: GPU VAD is forced disabled until the performance is improved
4671
+ //whisper_context_params.use_gpu = vctx->params.use_gpu;
4672
+ whisper_context_params.use_gpu = false;
4673
+ whisper_context_params.gpu_device = vctx->params.gpu_device;
4674
+
4675
+ vctx->backends = whisper_backend_init(whisper_context_params);
4676
+ if (vctx->backends.empty()) {
4677
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
4678
+ return false;
4679
+ }
4680
+
4681
+ const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
4682
+
4683
+ vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
4684
+
4685
+ struct ggml_init_params params = {
4686
+ /*.mem_size =*/ vctx->ctx_buf.size(),
4687
+ /*.mem_buffer =*/ vctx->ctx_buf.data(),
4688
+ /*.no_alloc =*/ true,
4689
+ };
4690
+
4691
+ ggml_context * ctx = ggml_init(params);
4692
+ if (!ctx) {
4693
+ WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
4694
+ return false;
4695
+ }
4696
+
4697
+ // LSTM Hidden state
4698
+ vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4699
+ ggml_set_name(vctx->h_state, "h_state");
4700
+
4701
+ // LSTM Cell state
4702
+ vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4703
+ ggml_set_name(vctx->c_state, "c_state");
4704
+
4705
+ vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
4706
+ if (!vctx->buffer) {
4707
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
4708
+ return false;
4709
+ }
4710
+
4711
+ {
4712
+ bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
4713
+ [&]() {
4714
+ return whisper_vad_build_graph(*vctx);
4715
+ });
4716
+
4717
+ if (!ok) {
4718
+ WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
4719
+ return false;
4720
+ }
4721
+
4722
+ WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
4723
+ }
4724
+
4725
+ return true;
4726
+ }
4727
+
4728
+ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
4729
+ const char * path_model,
4730
+ struct whisper_vad_context_params params) {
4731
+ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
4732
+ #ifdef _MSC_VER
4733
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
4734
+ std::wstring path_model_wide = converter.from_bytes(path_model);
4735
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
4736
+ #else
4737
+ auto fin = std::ifstream(path_model, std::ios::binary);
4738
+ #endif
4739
+ if (!fin) {
4740
+ WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
4741
+ return nullptr;
4742
+ }
4743
+
4744
+ whisper_model_loader loader = {};
4745
+ loader.context = &fin;
4746
+
4747
+ loader.read = [](void * ctx, void * output, size_t read_size) {
4748
+ std::ifstream * fin = (std::ifstream*)ctx;
4749
+ fin->read((char *)output, read_size);
4750
+ return read_size;
4751
+ };
4752
+
4753
+ loader.eof = [](void * ctx) {
4754
+ std::ifstream * fin = (std::ifstream*)ctx;
4755
+ return fin->eof();
4756
+ };
4757
+
4758
+ loader.close = [](void * ctx) {
4759
+ std::ifstream * fin = (std::ifstream*)ctx;
4760
+ fin->close();
4761
+ };
4762
+
4763
+ auto ctx = whisper_vad_init_with_params(&loader, params);
4764
+ if (!ctx) {
4765
+ whisper_vad_free(ctx);
4766
+ return nullptr;
4767
+ }
4768
+ ctx->path_model = path_model;
4769
+ return ctx;
4770
+ }
4771
+
4772
+ struct whisper_vad_context * whisper_vad_init_with_params(
4773
+ struct whisper_model_loader * loader,
4774
+ struct whisper_vad_context_params params) {
4775
+ // Read the VAD model
4776
+ {
4777
+ uint32_t magic;
4778
+ read_safe(loader, magic);
4779
+ if (magic != GGML_FILE_MAGIC) {
4780
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
4781
+ return nullptr;
4782
+ }
4783
+ }
4784
+
4785
+ whisper_vad_context * vctx = new whisper_vad_context;
4786
+ vctx->n_threads = params.n_threads;
4787
+ vctx->params.use_gpu = params.use_gpu;
4788
+ vctx->params.gpu_device = params.gpu_device;
4789
+
4790
+ auto & model = vctx->model;
4791
+ auto & hparams = model.hparams;
4792
+
4793
+ // load model context params.
4794
+ {
4795
+ int32_t str_len;
4796
+ read_safe(loader, str_len);
4797
+ std::vector<char> buffer(str_len + 1, 0);
4798
+ loader->read(loader->context, buffer.data(), str_len);
4799
+ std::string model_type(buffer.data(), str_len);
4800
+ model.type = model_type;
4801
+ WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
4802
+
4803
+ int32_t major, minor, patch;
4804
+ read_safe(loader, major);
4805
+ read_safe(loader, minor);
4806
+ read_safe(loader, patch);
4807
+ std::string version_str = std::to_string(major) + "." +
4808
+ std::to_string(minor) + "." +
4809
+ std::to_string(patch);
4810
+ model.version = version_str;
4811
+ WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
4812
+
4813
+ read_safe(loader, vctx->n_window);
4814
+ read_safe(loader, vctx->n_context);
4815
+ }
4816
+
4817
+ // load model hyper params (hparams).
4818
+ {
4819
+ read_safe(loader, hparams.n_encoder_layers);
4820
+
4821
+ hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
4822
+ hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
4823
+ hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
4824
+
4825
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4826
+ read_safe(loader, hparams.encoder_in_channels[i]);
4827
+ read_safe(loader, hparams.encoder_out_channels[i]);
4828
+ read_safe(loader, hparams.kernel_sizes[i]);
4829
+ }
4830
+
4831
+ read_safe(loader, hparams.lstm_input_size);
4832
+ read_safe(loader, hparams.lstm_hidden_size);
4833
+ read_safe(loader, hparams.final_conv_in);
4834
+ read_safe(loader, hparams.final_conv_out);
4835
+
4836
+ WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
4837
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4838
+ WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
4839
+ }
4840
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4841
+ WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
4842
+ }
4843
+ WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
4844
+ WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
4845
+ WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
4846
+ WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
4847
+ }
4848
+
4849
+ // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
4850
+ const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
4851
+
4852
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4853
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4854
+ auto it = ctx_map.find(buft);
4855
+ if (it == ctx_map.end()) {
4856
+ ggml_init_params params = {
4857
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4858
+ /*.mem_buffer =*/ nullptr,
4859
+ /*.no_alloc =*/ true,
4860
+ };
4861
+
4862
+ ggml_context * ctx = ggml_init(params);
4863
+ if (!ctx) {
4864
+ throw std::runtime_error("failed to create ggml context");
4865
+ }
4866
+
4867
+ ctx_map[buft] = ctx;
4868
+ model.ctxs.emplace_back(ctx);
4869
+
4870
+ return ctx;
4871
+ }
4872
+
4873
+ return it->second;
4874
+ };
4875
+
4876
+ whisper_context_params wparams = whisper_context_default_params();
4877
+ wparams.use_gpu = params.use_gpu;
4878
+ wparams.gpu_device = params.gpu_device;
4879
+ buft_list_t buft_list = make_buft_list(wparams);
4880
+
4881
+ auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
4882
+ ggml_op op = VAD_TENSOR_OPS.at(type);
4883
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
4884
+ if (!buft) {
4885
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
4886
+ }
4887
+ ggml_context * ctx = get_ctx(buft);
4888
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
4889
+ model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
4890
+
4891
+ return tensor;
4892
+ };
4893
+
4894
+ // create tensors
4895
+ {
4896
+ ggml_init_params params = {
4897
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4898
+ /*.mem_buffer =*/ nullptr,
4899
+ /*.no_alloc =*/ true,
4900
+ };
4901
+
4902
+ ggml_context * ctx = ggml_init(params);
4903
+ const auto & hparams = model.hparams;
4904
+
4905
+ // SFTF precomputed basis matrix
4906
+ model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
4907
+ ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
4908
+
4909
+ model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
4910
+ ggml_new_tensor_3d(
4911
+ ctx,
4912
+ GGML_TYPE_F16,
4913
+ hparams.kernel_sizes[0],
4914
+ hparams.encoder_in_channels[0],
4915
+ hparams.encoder_out_channels[0]
4916
+ ));
4917
+ model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
4918
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
4919
+
4920
+ model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
4921
+ ggml_new_tensor_3d(
4922
+ ctx,
4923
+ GGML_TYPE_F16,
4924
+ hparams.kernel_sizes[1],
4925
+ hparams.encoder_in_channels[1],
4926
+ hparams.encoder_out_channels[1]
4927
+ ));
4928
+ model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
4929
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
4930
+
4931
+ model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
4932
+ ggml_new_tensor_3d(
4933
+ ctx,
4934
+ GGML_TYPE_F16,
4935
+ hparams.kernel_sizes[2],
4936
+ hparams.encoder_in_channels[2],
4937
+ hparams.encoder_out_channels[2]
4938
+ ));
4939
+ model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
4940
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
4941
+
4942
+ model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
4943
+ ggml_new_tensor_3d(
4944
+ ctx,
4945
+ GGML_TYPE_F16,
4946
+ hparams.kernel_sizes[3],
4947
+ hparams.encoder_in_channels[3],
4948
+ hparams.encoder_out_channels[3]
4949
+ ));
4950
+ model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
4951
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
4952
+
4953
+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
4954
+ const int hstate_dim = hparams.lstm_hidden_size * 4;
4955
+
4956
+ // LSTM weights - input to hidden
4957
+ model.lstm_ih_weight = create_tensor(
4958
+ VAD_TENSOR_LSTM_WEIGHT_IH,
4959
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4960
+ );
4961
+ model.lstm_ih_bias = create_tensor(
4962
+ VAD_TENSOR_LSTM_BIAS_IH,
4963
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4964
+ );
4965
+
4966
+ // LSTM weights - hidden to hidden
4967
+ model.lstm_hh_weight = create_tensor(
4968
+ VAD_TENSOR_LSTM_WEIGHT_HH,
4969
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4970
+ );
4971
+ model.lstm_hh_bias = create_tensor(
4972
+ VAD_TENSOR_LSTM_BIAS_HH,
4973
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4974
+ );
4975
+
4976
+ // Final conv layer weight
4977
+ model.final_conv_weight = create_tensor(
4978
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
4979
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
4980
+ );
4981
+ model.final_conv_bias = create_tensor(
4982
+ VAD_TENSOR_FINAL_CONV_BIAS,
4983
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
4984
+ );
4985
+
4986
+ ggml_free(ctx);
4987
+ }
4988
+
4989
+ // allocate tensors in the backend buffers
4990
+ for (auto & p : ctx_map) {
4991
+ ggml_backend_buffer_type_t buft = p.first;
4992
+ ggml_context * ctx = p.second;
4993
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
4994
+ if (buf) {
4995
+ model.buffers.emplace_back(buf);
4996
+
4997
+ size_t size_main = ggml_backend_buffer_get_size(buf);
4998
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
4999
+ }
5000
+ }
5001
+
5002
+ // load weights
5003
+ {
5004
+ size_t total_size = 0;
5005
+ model.n_loaded = 0;
5006
+ std::vector<char> read_buf;
5007
+
5008
+ while (true) {
5009
+ int32_t n_dims;
5010
+ int32_t length;
5011
+ int32_t ttype;
5012
+
5013
+ read_safe(loader, n_dims);
5014
+ read_safe(loader, length);
5015
+ read_safe(loader, ttype);
5016
+
5017
+ if (loader->eof(loader->context)) {
5018
+ break;
5019
+ }
5020
+
5021
+ int32_t nelements = 1;
5022
+ int32_t ne[4] = { 1, 1, 1, 1 };
5023
+ for (int i = 0; i < n_dims; ++i) {
5024
+ read_safe(loader, ne[i]);
5025
+ nelements *= ne[i];
5026
+ }
5027
+
5028
+ std::string name;
5029
+ std::vector<char> tmp(length);
5030
+ loader->read(loader->context, &tmp[0], tmp.size());
5031
+ name.assign(&tmp[0], tmp.size());
5032
+
5033
+ if (model.tensors.find(name) == model.tensors.end()) {
5034
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
5035
+ return nullptr;
5036
+ }
5037
+
5038
+ auto tensor = model.tensors[name.data()];
5039
+
5040
+ if (ggml_nelements(tensor) != nelements) {
5041
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
5042
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
5043
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
5044
+ return nullptr;
5045
+ }
5046
+
5047
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
5048
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
5049
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
5050
+ return nullptr;
5051
+ }
5052
+
5053
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
5054
+
5055
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
5056
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
5057
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
5058
+ return nullptr;
5059
+ }
5060
+
5061
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
5062
+ // for the CPU and Metal backend, we can read directly into the tensor
5063
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
5064
+ BYTESWAP_TENSOR(tensor);
5065
+ } else {
5066
+ // read into a temporary buffer first, then copy to device memory
5067
+ read_buf.resize(ggml_nbytes(tensor));
5068
+
5069
+ loader->read(loader->context, read_buf.data(), read_buf.size());
5070
+
5071
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
5072
+ }
5073
+
5074
+ total_size += ggml_nbytes(tensor);
5075
+ model.n_loaded++;
5076
+ }
5077
+
5078
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
5079
+
5080
+ if (model.n_loaded == 0) {
5081
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
5082
+ } else if (model.n_loaded != (int) model.tensors.size()) {
5083
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
5084
+ return nullptr;
5085
+ }
5086
+
5087
+ }
5088
+
5089
+ if (!whisper_vad_init_context(vctx)) {
5090
+ whisper_vad_free(vctx);
5091
+ return nullptr;
5092
+ }
5093
+
5094
+ return vctx;
5095
+ }
5096
+
5097
+ bool whisper_vad_detect_speech(
5098
+ struct whisper_vad_context * vctx,
5099
+ const float * samples,
5100
+ int n_samples) {
5101
+ int n_chunks = n_samples / vctx->n_window;
5102
+ if (n_samples % vctx->n_window != 0) {
5103
+ n_chunks += 1; // Add one more chunk for remaining samples.
5104
+ }
5105
+
5106
+ WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
5107
+ WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
5108
+
5109
+ // Reset LSTM hidden/cell states
5110
+ ggml_backend_buffer_clear(vctx->buffer, 0);
5111
+
5112
+ vctx->probs.resize(n_chunks);
5113
+ WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
5114
+
5115
+ std::vector<float> window(vctx->n_window, 0.0f);
5116
+
5117
+ auto & sched = vctx->sched.sched;
5118
+
5119
+ ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
5120
+
5121
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
5122
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
5123
+ return false;
5124
+ }
5125
+
5126
+ struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
5127
+ struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
5128
+
5129
+ // we are going to reuse the graph multiple times for each chunk
5130
+ const int64_t t_start_vad_us = ggml_time_us();
5131
+
5132
+ for (int i = 0; i < n_chunks; i++) {
5133
+ const int idx_start = i * vctx->n_window;
5134
+ const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
5135
+
5136
+ const int chunk_len = idx_end - idx_start;
5137
+
5138
+ if (chunk_len < vctx->n_window) {
5139
+ WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
5140
+ std::vector<float> partial_chunk(vctx->n_window, 0.0f);
5141
+ std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
5142
+
5143
+ // Copy the zero-padded chunk to the window.
5144
+ const int samples_to_copy_max = vctx->n_window;
5145
+ const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
5146
+ std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
5147
+ if (samples_to_copy_cur < samples_to_copy_max) {
5148
+ std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
5149
+ }
5150
+ } else {
5151
+ // Copy current frame samples to the window.
5152
+ const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
5153
+ std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
5154
+ }
5155
+
5156
+ // Set the frame tensor data with the samples.
5157
+ ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
5158
+
5159
+ // do not reset the scheduler - we will reuse the graph in the next chunk
5160
+ if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
5161
+ WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
5162
+ break;
5163
+ }
5164
+
5165
+ // Get the probability for this chunk.
5166
+ ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
5167
+
5168
+ //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
5169
+ }
5170
+
5171
+ vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
5172
+ WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
5173
+
5174
+ ggml_backend_sched_reset(sched);
5175
+
5176
+ return true;
5177
+ }
5178
+
5179
+ int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
5180
+ return segments->data.size();
5181
+ }
5182
+
5183
+ float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
5184
+ return segments->data[i_segment].start;
5185
+ }
5186
+
5187
+ float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
5188
+ return segments->data[i_segment].end;
5189
+ }
5190
+
5191
+ int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
5192
+ return vctx->probs.size();
5193
+ }
5194
+
5195
+ float * whisper_vad_probs(struct whisper_vad_context * vctx) {
5196
+ return vctx->probs.data();
5197
+ }
5198
+
5199
+ struct whisper_vad_segments * whisper_vad_segments_from_probs(
5200
+ struct whisper_vad_context * vctx,
5201
+ whisper_vad_params params) {
5202
+ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
5203
+
5204
+ int n_probs = whisper_vad_n_probs(vctx);
5205
+ float * probs = whisper_vad_probs(vctx);
5206
+ float threshold = params.threshold;
5207
+ int min_speech_duration_ms = params.min_speech_duration_ms;
5208
+ int min_silence_duration_ms = params.min_silence_duration_ms;
5209
+ float max_speech_duration_s = params.max_speech_duration_s;
5210
+ int speech_pad_ms = params.speech_pad_ms;
5211
+ int n_window = vctx->n_window;
5212
+ int sample_rate = WHISPER_SAMPLE_RATE;
5213
+ int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
5214
+ int audio_length_samples = n_probs * n_window;
5215
+
5216
+ // Min number of samples to be considered valid speech.
5217
+ int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
5218
+ int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
5219
+
5220
+ // Max number of samples that a speech segment can contain before it is
5221
+ // split into multiple segments.
5222
+ int max_speech_samples;
5223
+ if (max_speech_duration_s > 100000.0f) {
5224
+ max_speech_samples = INT_MAX / 2;
5225
+ } else {
5226
+ int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
5227
+ max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
5228
+ if (max_speech_samples < 0) {
5229
+ max_speech_samples = INT_MAX / 2;
5230
+ }
5231
+ }
5232
+ // Detect silence period that exceeds this value, then that location (sample)
5233
+ // is marked as a potential place where the segment could be split if
5234
+ // max_speech_samples is reached. The value 98 was taken from the original
5235
+ // silaro-vad python implementation:
5236
+ //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
5237
+ int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
5238
+
5239
+ // Calculate lower threshold for detecting end of speech segments.
5240
+ float neg_threshold = threshold - 0.15f;
5241
+ if (neg_threshold < 0.01f) {
5242
+ neg_threshold = 0.01f;
5243
+ }
5244
+
5245
+ struct speech_segment_t {
5246
+ int start;
5247
+ int end;
5248
+ };
5249
+
5250
+ std::vector<speech_segment_t> speeches;
5251
+ speeches.reserve(256);
5252
+
5253
+ bool is_speech_segment = false;
5254
+ int temp_end = 0;
5255
+ int prev_end = 0;
5256
+ int next_start = 0;
5257
+ int curr_speech_start = 0;
5258
+ bool has_curr_speech = false;
5259
+
5260
+ for (int i = 0; i < n_probs; i++) {
5261
+ float curr_prob = probs[i];
5262
+ int curr_sample = n_window * i;
5263
+
5264
+ // Reset temp_end when we get back to speech
5265
+ if ((curr_prob >= threshold) && temp_end) {
5266
+ temp_end = 0;
5267
+ if (next_start < prev_end) {
5268
+ next_start = curr_sample;
5269
+ }
5270
+ }
5271
+
5272
+ // Start a new speech segment when probability exceeds threshold and not already in speech
5273
+ if ((curr_prob >= threshold) && !is_speech_segment) {
5274
+ is_speech_segment = true;
5275
+ curr_speech_start = curr_sample;
5276
+ has_curr_speech = true;
5277
+ continue;
5278
+ }
5279
+
5280
+ // Handle maximum speech duration
5281
+ if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
5282
+ if (prev_end) {
5283
+ speeches.push_back({ curr_speech_start, prev_end });
5284
+ has_curr_speech = true;
5285
+
5286
+ if (next_start < prev_end) { // Previously reached silence and is still not speech
5287
+ is_speech_segment = false;
5288
+ has_curr_speech = false;
5289
+ } else {
5290
+ curr_speech_start = next_start;
5291
+ }
5292
+ prev_end = next_start = temp_end = 0;
5293
+ } else {
5294
+ speeches.push_back({ curr_speech_start, curr_sample });
5295
+
5296
+ prev_end = next_start = temp_end = 0;
5297
+ is_speech_segment = false;
5298
+ has_curr_speech = false;
5299
+ continue;
5300
+ }
5301
+ }
5302
+
5303
+ // Handle silence after speech
5304
+ if ((curr_prob < neg_threshold) && is_speech_segment) {
5305
+ if (!temp_end) {
5306
+ temp_end = curr_sample;
5307
+ }
5308
+
5309
+ // Track potential segment ends for max_speech handling
5310
+ if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
5311
+ prev_end = temp_end;
5312
+ }
5313
+
5314
+ // Check if silence is long enough to end the segment
5315
+ if ((curr_sample - temp_end) < min_silence_samples) {
5316
+ continue;
5317
+ } else {
5318
+ // End the segment if it's long enough
5319
+ if ((temp_end - curr_speech_start) > min_speech_samples) {
5320
+ speeches.push_back({ curr_speech_start, temp_end });
5321
+ }
3969
5322
 
3970
- const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
3971
- return ctx->vocab.id_to_token.at(token).c_str();
3972
- }
5323
+ prev_end = next_start = temp_end = 0;
5324
+ is_speech_segment = false;
5325
+ has_curr_speech = false;
5326
+ continue;
5327
+ }
5328
+ }
5329
+ }
3973
5330
 
3974
- whisper_token whisper_token_eot(struct whisper_context * ctx) {
3975
- return ctx->vocab.token_eot;
3976
- }
5331
+ // Handle the case if we're still in a speech segment at the end
5332
+ if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
5333
+ speeches.push_back({ curr_speech_start, audio_length_samples });
5334
+ }
3977
5335
 
3978
- whisper_token whisper_token_sot(struct whisper_context * ctx) {
3979
- return ctx->vocab.token_sot;
3980
- }
5336
+ // Merge adjacent segments with small gaps in between (post-processing)
5337
+ if (speeches.size() > 1) {
5338
+ int merged_count = 0;
5339
+ for (int i = 0; i < (int) speeches.size() - 1; i++) {
5340
+ // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
5341
+ int max_merge_gap_samples = sample_rate * 200 / 1000;
3981
5342
 
3982
- whisper_token whisper_token_solm(struct whisper_context * ctx) {
3983
- return ctx->vocab.token_solm;
3984
- }
5343
+ // If the gap between this segment and the next is small enough
5344
+ if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
5345
+ // Merge by extending current segment to the end of next segment
5346
+ speeches[i].end = speeches[i+1].end;
5347
+ speeches.erase(speeches.begin() + i + 1);
3985
5348
 
3986
- whisper_token whisper_token_prev(struct whisper_context * ctx) {
3987
- return ctx->vocab.token_prev;
3988
- }
5349
+ i--;
5350
+ merged_count++;
5351
+ }
5352
+ }
5353
+ WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
5354
+ __func__, merged_count, (int) speeches.size());
5355
+ }
3989
5356
 
3990
- whisper_token whisper_token_nosp(struct whisper_context * ctx) {
3991
- return ctx->vocab.token_nosp;
3992
- }
5357
+ // Double-check for minimum speech duration
5358
+ for (int i = 0; i < (int) speeches.size(); i++) {
5359
+ if (speeches[i].end - speeches[i].start < min_speech_samples) {
5360
+ WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
5361
+ __func__, i, speeches[i].end - speeches[i].start);
3993
5362
 
3994
- whisper_token whisper_token_not(struct whisper_context * ctx) {
3995
- return ctx->vocab.token_not;
3996
- }
5363
+ speeches.erase(speeches.begin() + i);
5364
+ i--;
5365
+ }
5366
+ }
3997
5367
 
3998
- whisper_token whisper_token_beg(struct whisper_context * ctx) {
3999
- return ctx->vocab.token_beg;
4000
- }
5368
+ WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
4001
5369
 
4002
- whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
4003
- return whisper_token_sot(ctx) + 1 + lang_id;
4004
- }
5370
+ // Allocate final segments
5371
+ std::vector<whisper_vad_segment> segments;
5372
+ if (speeches.size() > 0) {
5373
+ try {
5374
+ segments.resize(speeches.size());
5375
+ } catch (const std::bad_alloc &) {
5376
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
5377
+ return nullptr;
5378
+ }
5379
+ }
4005
5380
 
4006
- whisper_token whisper_token_translate(struct whisper_context * ctx) {
4007
- return ctx->vocab.token_translate;
4008
- }
5381
+ // Apply padding to segments and copy to final segments
5382
+ for (int i = 0; i < (int) speeches.size(); i++) {
5383
+ // Apply padding to the start of the first segment
5384
+ if (i == 0) {
5385
+ speeches[i].start =
5386
+ (speeches[i].start > speech_pad_samples) ?
5387
+ (speeches[i].start - speech_pad_samples) : 0;
5388
+ }
4009
5389
 
4010
- whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
4011
- return ctx->vocab.token_transcribe;
4012
- }
5390
+ // Handle spacing between segments
5391
+ if (i < (int) speeches.size() - 1) {
5392
+ int silence_duration = speeches[i+1].start - speeches[i].end;
4013
5393
 
4014
- void whisper_print_timings(struct whisper_context * ctx) {
4015
- const int64_t t_end_us = ggml_time_us();
5394
+ if (silence_duration < 2 * speech_pad_samples) {
5395
+ // If segments are close, split the difference
5396
+ speeches[i].end += silence_duration / 2;
5397
+ speeches[i+1].start =
5398
+ (speeches[i+1].start > silence_duration / 2) ?
5399
+ (speeches[i+1].start - silence_duration / 2) : 0;
5400
+ } else {
5401
+ // Otherwise, apply full padding to both
5402
+ speeches[i].end =
5403
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5404
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5405
+ speeches[i+1].start =
5406
+ (speeches[i+1].start > speech_pad_samples) ?
5407
+ (speeches[i+1].start - speech_pad_samples) : 0;
5408
+ }
5409
+ } else {
5410
+ // Apply padding to the end of the last segment
5411
+ speeches[i].end =
5412
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5413
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5414
+ }
4016
5415
 
4017
- WHISPER_LOG_INFO("\n");
4018
- WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
4019
- if (ctx->state != nullptr) {
5416
+ // Convert from samples to seconds and copy to final segments
5417
+ segments[i].start = (float)speeches[i].start / sample_rate;
5418
+ segments[i].end = (float)speeches[i].end / sample_rate;
4020
5419
 
4021
- const int32_t n_sample = std::max(1, ctx->state->n_sample);
4022
- const int32_t n_encode = std::max(1, ctx->state->n_encode);
4023
- const int32_t n_decode = std::max(1, ctx->state->n_decode);
4024
- const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
4025
- const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
5420
+ WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
5421
+ __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
5422
+ }
4026
5423
 
4027
- WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
4028
- WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
4029
- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4030
- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4031
- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4032
- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4033
- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
5424
+ whisper_vad_segments * vad_segments = new whisper_vad_segments;
5425
+ if (vad_segments == NULL) {
5426
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
5427
+ return nullptr;
4034
5428
  }
4035
- WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
5429
+
5430
+ vad_segments->data = std::move(segments);
5431
+
5432
+ return vad_segments;
4036
5433
  }
4037
5434
 
4038
- void whisper_reset_timings(struct whisper_context * ctx) {
4039
- ctx->t_start_us = ggml_time_us();
4040
- if (ctx->state != nullptr) {
4041
- ctx->state->t_mel_us = 0;
4042
- ctx->state->t_sample_us = 0;
4043
- ctx->state->t_encode_us = 0;
4044
- ctx->state->t_decode_us = 0;
4045
- ctx->state->t_batchd_us = 0;
4046
- ctx->state->t_prompt_us = 0;
4047
- ctx->state->n_sample = 0;
4048
- ctx->state->n_encode = 0;
4049
- ctx->state->n_decode = 0;
4050
- ctx->state->n_batchd = 0;
4051
- ctx->state->n_prompt = 0;
5435
+ struct whisper_vad_segments * whisper_vad_segments_from_samples(
5436
+ whisper_vad_context * vctx,
5437
+ whisper_vad_params params,
5438
+ const float * samples,
5439
+ int n_samples) {
5440
+ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
5441
+ if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
5442
+ WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
5443
+ return nullptr;
4052
5444
  }
5445
+ return whisper_vad_segments_from_probs(vctx, params);
4053
5446
  }
4054
5447
 
4055
- static int whisper_has_coreml(void) {
4056
- #ifdef WHISPER_USE_COREML
4057
- return 1;
4058
- #else
4059
- return 0;
4060
- #endif
4061
- }
5448
+ void whisper_vad_free(whisper_vad_context * ctx) {
5449
+ if (ctx) {
5450
+ for (ggml_context * context : ctx->model.ctxs) {
5451
+ ggml_free(context);
5452
+ }
4062
5453
 
4063
- static int whisper_has_openvino(void) {
4064
- #ifdef WHISPER_USE_OPENVINO
4065
- return 1;
4066
- #else
4067
- return 0;
4068
- #endif
4069
- }
5454
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
5455
+ ggml_backend_buffer_free(buf);
5456
+ }
4070
5457
 
4071
- const char * whisper_print_system_info(void) {
4072
- static std::string s;
5458
+ ggml_backend_sched_free(ctx->sched.sched);
4073
5459
 
4074
- s = "";
4075
- s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
4076
- s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
4077
- s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
4078
- s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
4079
- s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
4080
- s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
4081
- s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
4082
- s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
4083
- s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
4084
- s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
4085
- s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
4086
- s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
4087
- s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
4088
- s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
4089
- s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
4090
- s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4091
- s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ;
5460
+ for (auto & backend : ctx->backends) {
5461
+ ggml_backend_free(backend);
5462
+ }
4092
5463
 
4093
- return s.c_str();
5464
+
5465
+ delete ctx;
5466
+ }
5467
+ }
5468
+
5469
+ void whisper_vad_free_segments(whisper_vad_segments * segments) {
5470
+ if (segments) {
5471
+ delete segments;
5472
+ }
4094
5473
  }
4095
5474
 
4096
5475
  //////////////////////////////////
@@ -4099,7 +5478,7 @@ const char * whisper_print_system_info(void) {
4099
5478
 
4100
5479
  // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4101
5480
  // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4102
- std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
5481
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4103
5482
  const char * src,
4104
5483
  whisper_partial_utf8 partial_start) {
4105
5484
  static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -4248,7 +5627,7 @@ static void whisper_grammar_advance_stack(
4248
5627
  std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4249
5628
 
4250
5629
  if (stack.empty()) {
4251
- new_stacks.push_back(stack);
5630
+ new_stacks.emplace_back();
4252
5631
  return;
4253
5632
  }
4254
5633
 
@@ -4513,7 +5892,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
4513
5892
 
4514
5893
  ////////////////////////////////////////////////////////////////////////////
4515
5894
 
4516
- struct whisper_context_params * whisper_context_default_params_by_ref() {
5895
+ struct whisper_context_params * whisper_context_default_params_by_ref(void) {
4517
5896
  struct whisper_context_params params = whisper_context_default_params();
4518
5897
 
4519
5898
  struct whisper_context_params* result = new whisper_context_params();
@@ -4554,7 +5933,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4554
5933
  /*.split_on_word =*/ false,
4555
5934
  /*.max_tokens =*/ 0,
4556
5935
 
4557
- /*.speed_up =*/ false,
4558
5936
  /*.debug_mode =*/ false,
4559
5937
  /*.audio_ctx =*/ 0,
4560
5938
 
@@ -4570,7 +5948,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4570
5948
  /*.detect_language =*/ false,
4571
5949
 
4572
5950
  /*.suppress_blank =*/ true,
4573
- /*.suppress_non_speech_tokens =*/ false,
5951
+ /*.suppress_nst =*/ false,
4574
5952
 
4575
5953
  /*.temperature =*/ 0.0f,
4576
5954
  /*.max_initial_ts =*/ 1.0f,
@@ -4610,6 +5988,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4610
5988
  /*.n_grammar_rules =*/ 0,
4611
5989
  /*.i_start_rule =*/ 0,
4612
5990
  /*.grammar_penalty =*/ 100.0f,
5991
+
5992
+ /*.vad =*/ false,
5993
+ /*.vad_model_path =*/ nullptr,
5994
+
5995
+ /* vad_params =*/ whisper_vad_default_params(),
4613
5996
  };
4614
5997
 
4615
5998
  switch (strategy) {
@@ -4720,6 +6103,42 @@ static const std::vector<std::string> non_speech_tokens = {
4720
6103
  "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
4721
6104
  };
4722
6105
 
6106
+ static void whisper_compute_logprobs(
6107
+ const std::vector<float> & logits,
6108
+ const int n_logits,
6109
+ std::vector<float> & logprobs) {
6110
+ const float logit_max = *std::max_element(logits.begin(), logits.end());
6111
+ float logsumexp = 0.0f;
6112
+ for (int i = 0; i < n_logits; ++i) {
6113
+ if (logits[i] > -INFINITY) {
6114
+ logsumexp += expf(logits[i] - logit_max);
6115
+ }
6116
+ }
6117
+ logsumexp = logf(logsumexp) + logit_max;
6118
+
6119
+ for (int i = 0; i < n_logits; ++i) {
6120
+ if (logits[i] > -INFINITY) {
6121
+ logprobs[i] = logits[i] - logsumexp;
6122
+ } else {
6123
+ logprobs[i] = -INFINITY;
6124
+ }
6125
+ }
6126
+ }
6127
+
6128
+ static void whisper_compute_probs(
6129
+ const std::vector<float> & logits,
6130
+ const int n_logits,
6131
+ const std::vector<float> & logprobs,
6132
+ std::vector<float> & probs) {
6133
+ for (int i = 0; i < n_logits; ++i) {
6134
+ if (logits[i] == -INFINITY) {
6135
+ probs[i] = 0.0f;
6136
+ } else {
6137
+ probs[i] = expf(logprobs[i]);
6138
+ }
6139
+ }
6140
+ }
6141
+
4723
6142
  // process the logits for the selected decoder
4724
6143
  // - applies logit filters
4725
6144
  // - computes logprobs and probs
@@ -4781,7 +6200,7 @@ static void whisper_process_logits(
4781
6200
 
4782
6201
  // suppress sot and nosp tokens
4783
6202
  logits[vocab.token_sot] = -INFINITY;
4784
- logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
6203
+ logits[vocab.token_nosp] = -INFINITY;
4785
6204
 
4786
6205
  // [TDRZ] when tinydiarize is disabled, suppress solm token
4787
6206
  if (params.tdrz_enable == false) {
@@ -4818,7 +6237,7 @@ static void whisper_process_logits(
4818
6237
 
4819
6238
  // suppress non-speech tokens
4820
6239
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4821
- if (params.suppress_non_speech_tokens) {
6240
+ if (params.suppress_nst) {
4822
6241
  for (const std::string & token : non_speech_tokens) {
4823
6242
  const std::string suppress_tokens[] = {token, " " + token};
4824
6243
  for (const std::string & suppress_token : suppress_tokens) {
@@ -4880,24 +6299,7 @@ static void whisper_process_logits(
4880
6299
  }
4881
6300
 
4882
6301
  // populate the logprobs array (log_softmax)
4883
- {
4884
- const float logit_max = *std::max_element(logits.begin(), logits.end());
4885
- float logsumexp = 0.0f;
4886
- for (int i = 0; i < n_logits; ++i) {
4887
- if (logits[i] > -INFINITY) {
4888
- logsumexp += expf(logits[i] - logit_max);
4889
- }
4890
- }
4891
- logsumexp = logf(logsumexp) + logit_max;
4892
-
4893
- for (int i = 0; i < n_logits; ++i) {
4894
- if (logits[i] > -INFINITY) {
4895
- logprobs[i] = logits[i] - logsumexp;
4896
- } else {
4897
- logprobs[i] = -INFINITY;
4898
- }
4899
- }
4900
- }
6302
+ whisper_compute_logprobs(logits, n_logits, logprobs);
4901
6303
 
4902
6304
  // if sum of probability over timestamps is above any other token, sample timestamp
4903
6305
  // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@@ -4955,15 +6357,7 @@ static void whisper_process_logits(
4955
6357
  }
4956
6358
 
4957
6359
  // compute probs
4958
- {
4959
- for (int i = 0; i < n_logits; ++i) {
4960
- if (logits[i] == -INFINITY) {
4961
- probs[i] = 0.0f;
4962
- } else {
4963
- probs[i] = expf(logprobs[i]);
4964
- }
4965
- }
4966
- }
6360
+ whisper_compute_probs(logits, n_logits, logprobs, probs);
4967
6361
 
4968
6362
  #if 0
4969
6363
  // print first 100 logits - token string : logit
@@ -5215,6 +6609,121 @@ static void whisper_sequence_score(
5215
6609
  }
5216
6610
  }
5217
6611
 
6612
+ static bool whisper_vad(
6613
+ struct whisper_context * ctx,
6614
+ struct whisper_state * state,
6615
+ struct whisper_full_params params,
6616
+ const float * samples,
6617
+ int n_samples,
6618
+ std::vector<float> & filtered_samples,
6619
+ int & filtered_n_samples) {
6620
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
6621
+ filtered_n_samples = 0;
6622
+
6623
+ if (state->vad_context == nullptr) {
6624
+ struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
6625
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
6626
+ if (vctx == nullptr) {
6627
+ WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
6628
+ return false;
6629
+ }
6630
+ state->vad_context = vctx;
6631
+ }
6632
+ auto vctx = state->vad_context;
6633
+
6634
+ const whisper_vad_params & vad_params = params.vad_params;
6635
+
6636
+ whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
6637
+
6638
+ if (vad_segments->data.size() > 0) {
6639
+ state->has_vad_segments = true;
6640
+ ctx->state->vad_segments.clear();
6641
+ ctx->state->vad_segments.reserve(vad_segments->data.size());
6642
+
6643
+ WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
6644
+ float overlap_seconds = vad_params.samples_overlap;
6645
+ int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
6646
+
6647
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6648
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
6649
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
6650
+
6651
+ if (i < (int)vad_segments->data.size() - 1) {
6652
+ segment_end_samples += overlap_samples;
6653
+ }
6654
+ segment_end_samples = std::min(segment_end_samples, n_samples - 1);
6655
+ filtered_n_samples += (segment_end_samples - segment_start_samples);
6656
+
6657
+ WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
6658
+ __func__, i, vad_segments->data[i].start,
6659
+ vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
6660
+ (vad_segments->data[i].end - vad_segments->data[i].start) +
6661
+ (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
6662
+ }
6663
+
6664
+ int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
6665
+ int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
6666
+ int total_samples_needed = filtered_n_samples + total_silence_samples;
6667
+
6668
+ WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
6669
+ __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
6670
+
6671
+ try {
6672
+ filtered_samples.resize(total_samples_needed);
6673
+ } catch (const std::bad_alloc & /* e */) {
6674
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
6675
+ whisper_vad_free_segments(vad_segments);
6676
+ whisper_vad_free(vctx);
6677
+ return false;
6678
+ }
6679
+
6680
+ int offset = 0;
6681
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6682
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
6683
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
6684
+
6685
+ if (i < (int)vad_segments->data.size() - 1) {
6686
+ segment_end_samples += overlap_samples;
6687
+ }
6688
+
6689
+ segment_start_samples = std::min(segment_start_samples, n_samples - 1);
6690
+ segment_end_samples = std::min(segment_end_samples, n_samples);
6691
+ int segment_length = segment_end_samples - segment_start_samples;
6692
+
6693
+ if (segment_length > 0) {
6694
+ whisper_state::vad_segment_info segment;
6695
+
6696
+ segment.orig_start = vad_segments->data[i].start;
6697
+ segment.orig_end = vad_segments->data[i].end;
6698
+
6699
+ segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
6700
+ segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
6701
+
6702
+ WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
6703
+ __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
6704
+ ctx->state->vad_segments.push_back(segment);
6705
+
6706
+ // Copy this speech segment
6707
+ memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
6708
+ offset += segment_length;
6709
+
6710
+ // Add silence after this segment (except after the last segment)
6711
+ if (i < (int)vad_segments->data.size() - 1) {
6712
+ // Fill with zeros (silence)
6713
+ memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
6714
+ offset += silence_samples;
6715
+ }
6716
+ }
6717
+ }
6718
+
6719
+ filtered_n_samples = offset;
6720
+ WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
6721
+ __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6722
+ }
6723
+
6724
+ return true;
6725
+ }
6726
+
5218
6727
  int whisper_full_with_state(
5219
6728
  struct whisper_context * ctx,
5220
6729
  struct whisper_state * state,
@@ -5226,17 +6735,29 @@ int whisper_full_with_state(
5226
6735
 
5227
6736
  result_all.clear();
5228
6737
 
5229
- if (n_samples > 0) {
6738
+ const float * process_samples = samples;
6739
+ int n_process_samples = n_samples;
6740
+ std::vector<float> vad_samples;
6741
+
6742
+ if (params.vad) {
6743
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6744
+ int vad_n_samples;
6745
+ if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
6746
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
6747
+ return -1;
6748
+ }
6749
+ if (vad_n_samples == 0) {
6750
+ return 0;
6751
+ }
6752
+ process_samples = vad_samples.data();
6753
+ n_process_samples = vad_n_samples;
6754
+ }
6755
+
6756
+ if (n_process_samples > 0) {
5230
6757
  // compute log mel spectrogram
5231
- if (params.speed_up) {
5232
- // TODO: Replace PV with more advanced algorithm
6758
+ if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
5233
6759
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5234
- return -1;
5235
- } else {
5236
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
5237
- WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5238
- return -2;
5239
- }
6760
+ return -2;
5240
6761
  }
5241
6762
  }
5242
6763
 
@@ -5270,11 +6791,13 @@ int whisper_full_with_state(
5270
6791
  const int seek_start = params.offset_ms/10;
5271
6792
  const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
5272
6793
 
5273
- // if length of spectrogram is less than 1.0s (100 frames), then return
5274
- // basically don't process anything that is less than 1.0s
5275
- // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5276
- if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
5277
- WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
6794
+ // if length of spectrogram is less than 100ms (10 frames), then return
6795
+ // basically don't process anything that is less than 100ms
6796
+ // ref: https://github.com/ggml-org/whisper.cpp/issues/2065
6797
+ const int delta_min = 10;
6798
+
6799
+ if (seek_end < seek_start + delta_min) {
6800
+ WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5278
6801
  return 0;
5279
6802
  }
5280
6803
 
@@ -5321,7 +6844,7 @@ int whisper_full_with_state(
5321
6844
  decoder.logprobs.resize(ctx->vocab.n_vocab);
5322
6845
  decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
5323
6846
 
5324
- decoder.rng = std::mt19937(0);
6847
+ decoder.rng = std::mt19937(j);
5325
6848
  }
5326
6849
 
5327
6850
  // the accumulated text context so far
@@ -5418,8 +6941,8 @@ int whisper_full_with_state(
5418
6941
  ctx, state, progress_cur, params.progress_callback_user_data);
5419
6942
  }
5420
6943
 
5421
- // if only 1 second left, then stop
5422
- if (seek + 100 >= seek_end) {
6944
+ // if only 100ms left, then stop
6945
+ if (seek + delta_min >= seek_end) {
5423
6946
  break;
5424
6947
  }
5425
6948
 
@@ -5518,13 +7041,46 @@ int whisper_full_with_state(
5518
7041
  }
5519
7042
  WHISPER_LOG_DEBUG("\n\n");
5520
7043
 
7044
+ // recreate the KV cache if the number of decoders has changed
7045
+ if (state->kv_self_n_dec < n_decoders_cur) {
7046
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
7047
+
7048
+ whisper_kv_cache_free(state->kv_self);
7049
+
7050
+ // overallocate to workaround KV cache fragmentation issues
7051
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
7052
+
7053
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
7054
+ ctx->model.hparams.n_text_state,
7055
+ ctx->model.hparams.n_text_layer,
7056
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
7057
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
7058
+ whisper_free_state(state);
7059
+ return -7;
7060
+ }
7061
+
7062
+ state->kv_self_n_dec = n_decoders_cur;
7063
+ }
7064
+
5521
7065
  whisper_kv_cache_clear(state->kv_self);
5522
7066
 
5523
7067
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5524
7068
 
5525
7069
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5526
7070
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5527
- return -7;
7071
+ return -8;
7072
+ }
7073
+
7074
+ // Calculate no_speech probability after first decode.
7075
+ // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
7076
+ {
7077
+ const int n_logits = ctx->vocab.id_to_token.size();
7078
+ std::vector<float> logprobs(n_logits);
7079
+ std::vector<float> probs(n_logits);
7080
+
7081
+ whisper_compute_logprobs(state->logits, n_logits, logprobs);
7082
+ whisper_compute_probs(state->logits, n_logits, logprobs, probs);
7083
+ state->no_speech_prob = probs[whisper_token_nosp(ctx)];
5528
7084
  }
5529
7085
 
5530
7086
  {
@@ -5733,10 +7289,10 @@ int whisper_full_with_state(
5733
7289
  // end of segment
5734
7290
  if (token.id == whisper_token_eot(ctx) || // end of text token
5735
7291
  (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
5736
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
7292
+ (has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
5737
7293
  ) {
5738
7294
  if (result_len == 0 && !params.no_timestamps) {
5739
- if (seek + seek_delta + 100 >= seek_end) {
7295
+ if (seek + seek_delta + delta_min >= seek_end) {
5740
7296
  result_len = i + 1;
5741
7297
  } else {
5742
7298
  WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
@@ -5824,7 +7380,7 @@ int whisper_full_with_state(
5824
7380
 
5825
7381
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5826
7382
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5827
- return -8;
7383
+ return -9;
5828
7384
  }
5829
7385
 
5830
7386
  const int64_t t_start_sample_us = ggml_time_us();
@@ -5918,8 +7474,9 @@ int whisper_full_with_state(
5918
7474
  if (it != (int) temperatures.size() - 1) {
5919
7475
  const auto & decoder = state->decoders[best_decoder_id];
5920
7476
 
5921
- if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
5922
- WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
7477
+ if (decoder.failed ||
7478
+ (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
7479
+ WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
5923
7480
  success = false;
5924
7481
  state->n_fail_p++;
5925
7482
  }
@@ -5940,7 +7497,7 @@ int whisper_full_with_state(
5940
7497
  {
5941
7498
  const auto & best_decoder = state->decoders[best_decoder_id];
5942
7499
 
5943
- const auto seek_delta = best_decoder.seek_delta;
7500
+ auto seek_delta = best_decoder.seek_delta;
5944
7501
  const auto result_len = best_decoder.sequence.result_len;
5945
7502
 
5946
7503
  const auto & tokens_cur = best_decoder.sequence.tokens;
@@ -5948,6 +7505,9 @@ int whisper_full_with_state(
5948
7505
  // [EXPERIMENTAL] Token-level timestamps with DTW
5949
7506
  const auto n_segments_before = state->result_all.size();
5950
7507
 
7508
+ const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
7509
+ best_decoder.sequence.avg_logprobs < params.logprob_thold);
7510
+
5951
7511
  //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
5952
7512
 
5953
7513
  // update prompt_past
@@ -5956,11 +7516,11 @@ int whisper_full_with_state(
5956
7516
  prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
5957
7517
  }
5958
7518
 
5959
- for (int i = 0; i < result_len; ++i) {
7519
+ for (int i = 0; i < result_len && !is_no_speech; ++i) {
5960
7520
  prompt_past.push_back(tokens_cur[i].id);
5961
7521
  }
5962
7522
 
5963
- if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
7523
+ if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
5964
7524
  int i0 = 0;
5965
7525
  auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
5966
7526
 
@@ -5985,8 +7545,8 @@ int whisper_full_with_state(
5985
7545
  const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
5986
7546
 
5987
7547
  if (!text.empty()) {
5988
- const auto tt0 = params.speed_up ? 2*t0 : t0;
5989
- const auto tt1 = params.speed_up ? 2*t1 : t1;
7548
+ const auto tt0 = t0;
7549
+ const auto tt1 = t1;
5990
7550
 
5991
7551
  if (params.print_realtime) {
5992
7552
  if (params.print_timestamps) {
@@ -5999,7 +7559,7 @@ int whisper_full_with_state(
5999
7559
 
6000
7560
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
6001
7561
 
6002
- result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
7562
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6003
7563
  for (int j = i0; j <= i; j++) {
6004
7564
  result_all.back().tokens.push_back(tokens_cur[j]);
6005
7565
  }
@@ -6014,7 +7574,7 @@ int whisper_full_with_state(
6014
7574
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
6015
7575
  }
6016
7576
  }
6017
- if (params.new_segment_callback) {
7577
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
6018
7578
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
6019
7579
  }
6020
7580
  }
@@ -6032,8 +7592,8 @@ int whisper_full_with_state(
6032
7592
  if (!text.empty()) {
6033
7593
  const auto t1 = seek + seek_delta;
6034
7594
 
6035
- const auto tt0 = params.speed_up ? 2*t0 : t0;
6036
- const auto tt1 = params.speed_up ? 2*t1 : t1;
7595
+ const auto tt0 = t0;
7596
+ const auto tt1 = t1;
6037
7597
 
6038
7598
  if (params.print_realtime) {
6039
7599
  if (params.print_timestamps) {
@@ -6044,7 +7604,7 @@ int whisper_full_with_state(
6044
7604
  }
6045
7605
  }
6046
7606
 
6047
- result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
7607
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6048
7608
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
6049
7609
  result_all.back().tokens.push_back(tokens_cur[j]);
6050
7610
  }
@@ -6059,7 +7619,7 @@ int whisper_full_with_state(
6059
7619
  n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
6060
7620
  }
6061
7621
  }
6062
- if (params.new_segment_callback) {
7622
+ if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
6063
7623
  params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
6064
7624
  }
6065
7625
  }
@@ -6068,14 +7628,28 @@ int whisper_full_with_state(
6068
7628
  // FIXME: will timestamp offsets be correct?
6069
7629
  // [EXPERIMENTAL] Token-level timestamps with DTW
6070
7630
  {
6071
- const auto n_segments = state->result_all.size() - n_segments_before;
7631
+ const int n_segments = state->result_all.size() - n_segments_before;
6072
7632
  if (ctx->params.dtw_token_timestamps && n_segments) {
6073
7633
  const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
6074
7634
  whisper_exp_compute_token_level_timestamps_dtw(
6075
7635
  ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
7636
+ if (params.new_segment_callback) {
7637
+ for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
7638
+ params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
7639
+ }
7640
+ }
6076
7641
  }
6077
7642
  }
6078
7643
 
7644
+ // ref: https://github.com/ggml-org/whisper.cpp/pull/2629
7645
+ const bool single_timestamp_ending = tokens_cur.size() > 1 &&
7646
+ tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
7647
+ tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
7648
+ if (single_timestamp_ending) {
7649
+ WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
7650
+ seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
7651
+ }
7652
+
6079
7653
  // update audio window
6080
7654
  seek += seek_delta;
6081
7655
 
@@ -6226,19 +7800,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
6226
7800
  }
6227
7801
 
6228
7802
  int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
6229
- return state->result_all[i_segment].t0;
7803
+ // If VAD wasn't used, return the original timestamp
7804
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
7805
+ return state->result_all[i_segment].t0;
7806
+ }
7807
+
7808
+ // Get the start timestamp produced by whisper_full. whisper_full processes
7809
+ // only the speech segments in this case so we need to map these timestamps
7810
+ // back to the original audio.
7811
+ float t0 = state->result_all[i_segment].t0 / 100.0f;
7812
+
7813
+ // Find which VAD segment this timestamp belongs.
7814
+ // TODO(danbev) This could be optimized by using a binary search if the number
7815
+ // of segments exceed a certain limit. Also we might be able to assume that
7816
+ // the access pattern is sequential and optimized for that too.
7817
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
7818
+ const auto & segment = state->vad_segments[i];
7819
+
7820
+ // Check if the timestamp falls within this segment.
7821
+ if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
7822
+ float proportion = 0.0f;
7823
+ if (segment.vad_end > segment.vad_start) {
7824
+ proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7825
+ }
7826
+ float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7827
+ return (int64_t)(orig_t0 * 100);
7828
+ }
7829
+ }
7830
+
7831
+ // Check if the timestamp falls between two segments.
7832
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7833
+ const auto & curr = state->vad_segments[i];
7834
+ const auto & next = state->vad_segments[i + 1];
7835
+
7836
+ if (t0 > curr.vad_end && t0 < next.vad_start) {
7837
+ // Calculate how far we are through the gap as a proportion
7838
+ float gap_proportion = 0.0f;
7839
+ if (next.vad_start > curr.vad_end) {
7840
+ gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
7841
+ }
7842
+ float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7843
+ return (int64_t)(orig_t0 * 100);
7844
+ }
7845
+ }
7846
+
7847
+ // Handle the case where the timestamp is after the last segment.
7848
+ if (t0 > state->vad_segments.back().vad_end) {
7849
+ // For timestamps after the last segment, add the extra time to the end of the last segment
7850
+ const auto& last = state->vad_segments.back();
7851
+ // Calculate how far beyond the last segment
7852
+ float extra_time = t0 - last.vad_end;
7853
+ // Add this extra time to the original end time
7854
+ float orig_t0 = last.orig_end + extra_time;
7855
+ return (int64_t)(orig_t0 * 100);
7856
+ }
7857
+
7858
+ WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
7859
+ return t0;
6230
7860
  }
6231
7861
 
6232
7862
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
6233
- return ctx->state->result_all[i_segment].t0;
7863
+ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
6234
7864
  }
6235
7865
 
6236
7866
  int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
6237
- return state->result_all[i_segment].t1;
7867
+ // If VAD wasn't used, return the original timestamp
7868
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
7869
+ return state->result_all[i_segment].t1;
7870
+ }
7871
+
7872
+ // Get the end timestamp produced by whisper_full. whisper_full processes
7873
+ // only the speech segments in this case so we need to map these timestamps
7874
+ // back to the original audio.
7875
+ float t1 = state->result_all[i_segment].t1 / 100.0f;
7876
+
7877
+ // Find which VAD segment this timestamp belongs.
7878
+ // TODO(danbev) This could be optimized by using a binary search if the number
7879
+ // of segments exceed a certain limit. Also we might be able to assume that
7880
+ // the access pattern is sequential and optimized for that too.
7881
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
7882
+ const auto& segment = state->vad_segments[i];
7883
+
7884
+ // Check if the timestamp falls within this segment.
7885
+ if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
7886
+ // Calculate the proportion through the filtered segment.
7887
+ float proportion = 0.0f;
7888
+ if (segment.vad_end > segment.vad_start) {
7889
+ proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7890
+ }
7891
+ float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7892
+ return (int64_t)(orig_t1 * 100);
7893
+ }
7894
+ }
7895
+
7896
+ // Check if the timestamp falls between two segments.
7897
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7898
+ const auto & curr = state->vad_segments[i];
7899
+ const auto & next = state->vad_segments[i + 1];
7900
+
7901
+ if (t1 > curr.vad_end && t1 < next.vad_start) {
7902
+ // Calculate how far we are through the gap as a proportion
7903
+ float gap_proportion = 0.0f;
7904
+ if (next.vad_start > curr.vad_end) {
7905
+ gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
7906
+ }
7907
+ // Map to the corresponding position in the original gap
7908
+ float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7909
+ return (int64_t)(orig_t1 * 100);
7910
+ }
7911
+ }
7912
+
7913
+ // Handle the case where the timestamp is after the last segment
7914
+ if (t1 > state->vad_segments.back().vad_end) {
7915
+ // For the last segment, use the end of the last VAD segment
7916
+ const auto& last = state->vad_segments.back();
7917
+ // Calculate how far beyond the last segment
7918
+ float extra_time = t1 - last.vad_end;
7919
+ // Add this extra time to the original end time
7920
+ float orig_t1 = last.orig_end + extra_time;
7921
+ return (int64_t)(orig_t1 * 100);
7922
+ }
7923
+
7924
+ WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
7925
+ return t1;
6238
7926
  }
6239
7927
 
6240
7928
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
6241
- return ctx->state->result_all[i_segment].t1;
7929
+ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
6242
7930
  }
6243
7931
 
6244
7932
  bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
@@ -6297,6 +7985,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
6297
7985
  return ctx->state->result_all[i_segment].tokens[i_token].p;
6298
7986
  }
6299
7987
 
7988
+ float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
7989
+ return ctx->state->result_all[i_segment].no_speech_prob;
7990
+ }
7991
+
7992
+ float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
7993
+ return state->result_all[i_segment].no_speech_prob;
7994
+ }
7995
+
6300
7996
  // =================================================================================================
6301
7997
 
6302
7998
  //
@@ -6458,6 +8154,8 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
6458
8154
  }
6459
8155
 
6460
8156
  WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
8157
+ whisper_load_backends();
8158
+
6461
8159
  static std::string s;
6462
8160
  s = "";
6463
8161
  char strbuf[256];
@@ -6477,7 +8175,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6477
8175
  // c: N*N*sizeof(float)
6478
8176
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
6479
8177
  std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
6480
- std::vector<uint8_t> work;
6481
8178
 
6482
8179
  // put a bunch of random data in the buffer
6483
8180
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -6534,12 +8231,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6534
8231
  double tsum = 0.0;
6535
8232
 
6536
8233
  // heat-up
6537
- ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8234
+ ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6538
8235
 
6539
8236
  for (int i = 0; i < n_max; ++i) {
6540
8237
  const int64_t t0 = ggml_time_us();
6541
8238
 
6542
- ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8239
+ ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6543
8240
 
6544
8241
  const int64_t t1 = ggml_time_us();
6545
8242
 
@@ -6700,12 +8397,6 @@ static void whisper_exp_compute_token_level_timestamps(
6700
8397
 
6701
8398
  const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
6702
8399
 
6703
- tokens[j].id = token.id;
6704
- tokens[j].tid = token.tid;
6705
- tokens[j].p = token.p;
6706
- tokens[j].pt = token.pt;
6707
- tokens[j].ptsum = token.ptsum;
6708
-
6709
8400
  tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
6710
8401
 
6711
8402
  if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
@@ -6835,7 +8526,7 @@ static void whisper_exp_compute_token_level_timestamps(
6835
8526
  k++;
6836
8527
  }
6837
8528
  tokens[j].t1 = sample_to_timestamp(k);
6838
- if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
8529
+ if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
6839
8530
  tokens[j].t1 = tokens[j + 1].t0;
6840
8531
  } else {
6841
8532
  s1 = k;
@@ -6916,18 +8607,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
6916
8607
  struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
6917
8608
  struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
6918
8609
 
6919
- cost = ggml_set_f32(cost, INFINITY);
6920
- trace = ggml_set_f32(trace, -1);
6921
- ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
8610
+ cost = whisper_set_f32(cost, INFINITY);
8611
+ trace = whisper_set_i32(trace, -1);
8612
+ whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
6922
8613
 
6923
8614
  // dtw
6924
8615
  // supposedly can be optmized by computing diagonals in parallel ?
6925
8616
  // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
6926
8617
  for (int64_t j = 1; j < M + 1; ++j) {
6927
8618
  for (int64_t i = 1; i < N + 1; ++i) {
6928
- float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
6929
- float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0);
6930
- float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0);
8619
+ float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
8620
+ float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
8621
+ float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
6931
8622
 
6932
8623
  float c;
6933
8624
  int32_t t;
@@ -6942,9 +8633,9 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
6942
8633
  t = 2;
6943
8634
  }
6944
8635
 
6945
- c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
6946
- ggml_set_f32_nd(cost, i, j, 0, 0, c);
6947
- ggml_set_i32_nd(trace, i, j, 0, 0, t);
8636
+ c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
8637
+ whisper_set_f32_nd(cost, i, j, 0, 0, c);
8638
+ whisper_set_i32_nd(trace, i, j, 0, 0, t);
6948
8639
  }
6949
8640
  }
6950
8641
 
@@ -6953,19 +8644,19 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
6953
8644
  struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
6954
8645
  // trace[0, :] = 2;
6955
8646
  for (int64_t i = 0; i < M + 1; ++i)
6956
- ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
8647
+ whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
6957
8648
  //trace[:, 0] = 1;
6958
8649
  for (int64_t i = 0; i < N + 1; ++i)
6959
- ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
8650
+ whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
6960
8651
  int bt_row_idx = BT_MAX_ROWS - 1;
6961
8652
  int64_t i = N;
6962
8653
  int64_t j = M;
6963
8654
  while (i > 0 || j > 0) {
6964
- ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
6965
- ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
8655
+ whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
8656
+ whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
6966
8657
  --bt_row_idx;
6967
8658
 
6968
- int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0);
8659
+ int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
6969
8660
  if (t == 0) {
6970
8661
  --i;
6971
8662
  --j;
@@ -6986,8 +8677,8 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
6986
8677
  ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
6987
8678
  for (int64_t i = 0; i < 2; ++i) {
6988
8679
  for (int64_t j = 0; j < result_n_cols; ++j) {
6989
- int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
6990
- ggml_set_i32_nd(r, i, j, 0, 0, v);
8680
+ int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
8681
+ whisper_set_i32_nd(r, i, j, 0, 0, v);
6991
8682
  }
6992
8683
  }
6993
8684
 
@@ -6998,10 +8689,11 @@ struct median_filter_user_data {
6998
8689
  int filter_width;
6999
8690
  };
7000
8691
 
7001
- static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
8692
+ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
8693
+ if (ith != 0) {
8694
+ return;
8695
+ }
7002
8696
  int filter_width = ((median_filter_user_data *) userdata)->filter_width;
7003
- WHISPER_ASSERT(nth == 1);
7004
- WHISPER_ASSERT(ith == 0);
7005
8697
  WHISPER_ASSERT(filter_width < a->ne[2]);
7006
8698
  WHISPER_ASSERT(filter_width % 2);
7007
8699
  WHISPER_ASSERT(ggml_n_dims(a) == 3);
@@ -7021,11 +8713,11 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
7021
8713
  idx = 2*(a->ne[2] - 1) - idx;
7022
8714
  }
7023
8715
 
7024
- filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0));
8716
+ filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
7025
8717
  }
7026
8718
  std::sort(filter.begin(), filter.end());
7027
8719
  const float v = filter[filter.size()/2];
7028
- ggml_set_f32_nd(dst, i, j, k, 0, v);
8720
+ whisper_set_f32_nd(dst, i, j, k, 0, v);
7029
8721
  filter.clear();
7030
8722
  }
7031
8723
  }
@@ -7124,7 +8816,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7124
8816
  // operation (after median filter)
7125
8817
  // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
7126
8818
  // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
7127
- w = ggml_norm(gctx, w, 1e-9);
8819
+ w = ggml_norm(gctx, w, 1e-9f);
7128
8820
  w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
7129
8821
 
7130
8822
  // Pass median filter - this is done over AUDIO_TOKENS dimension.
@@ -7147,7 +8839,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7147
8839
  // Compute
7148
8840
  struct ggml_cgraph * gf = ggml_new_graph(gctx);
7149
8841
  ggml_build_forward_expand(gf, w);
7150
- ggml_graph_compute_with_ctx(gctx, gf, n_threads);
8842
+
8843
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
8844
+ ggml_backend_graph_compute(backend.get(), gf);
7151
8845
 
7152
8846
  ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7153
8847
 
@@ -7156,9 +8850,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7156
8850
  auto seg_i = state->result_all.begin() + i_segment;
7157
8851
  auto tok_i = seg_i->tokens.begin();
7158
8852
  for (int i = 0; i < alignment->ne[1]; ++i) {
7159
- int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0);
8853
+ int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
7160
8854
  if (v != last_v) {
7161
- int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0);
8855
+ int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
7162
8856
  int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7163
8857
  last_v = v;
7164
8858
 
@@ -7196,6 +8890,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7196
8890
  void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
7197
8891
  g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
7198
8892
  g_state.log_callback_user_data = user_data;
8893
+ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
7199
8894
  }
7200
8895
 
7201
8896
  GGML_ATTRIBUTE_FORMAT(2, 3)
@@ -7219,6 +8914,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
7219
8914
  static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
7220
8915
  (void) level;
7221
8916
  (void) user_data;
8917
+ #ifndef WHISPER_DEBUG
8918
+ if (level == GGML_LOG_LEVEL_DEBUG) {
8919
+ return;
8920
+ }
8921
+ #endif
7222
8922
  fputs(text, stderr);
7223
8923
  fflush(stderr);
7224
8924
  }