whispercpp 1.3.1 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (857) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +7 -3
  3. data/README.md +161 -43
  4. data/Rakefile +45 -13
  5. data/ext/.gitignore +4 -8
  6. data/ext/dependencies.rb +73 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +85 -0
  9. data/ext/ruby_whisper.c +177 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +672 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1303 -0
  15. data/ext/ruby_whisper_segment.c +220 -0
  16. data/ext/ruby_whisper_transcribe.cpp +93 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  19. data/ext/sources/CMakeLists.txt +255 -0
  20. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  21. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  22. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  23. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  24. data/ext/sources/bindings/javascript/package.json +26 -0
  25. data/ext/sources/bindings/javascript/whisper.js +19 -0
  26. data/ext/sources/build-xcframework.sh +547 -0
  27. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  28. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  29. data/ext/sources/cmake/build-info.cmake +60 -0
  30. data/ext/sources/cmake/git-vars.cmake +22 -0
  31. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  32. data/ext/sources/cmake/whisper.pc.in +10 -0
  33. data/ext/sources/examples/CMakeLists.txt +124 -0
  34. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  35. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +133 -0
  36. data/ext/sources/examples/addon.node/addon.cpp +557 -0
  37. data/ext/sources/examples/addon.node/index.js +57 -0
  38. data/ext/sources/examples/addon.node/package.json +16 -0
  39. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  40. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  41. data/ext/sources/examples/bench/bench.cpp +176 -0
  42. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  43. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  44. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  45. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  46. data/ext/sources/examples/cli/cli.cpp +1295 -0
  47. data/ext/sources/examples/coi-serviceworker.js +146 -0
  48. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  49. data/ext/sources/examples/command/command.cpp +800 -0
  50. data/ext/sources/examples/command/commands.txt +9 -0
  51. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  52. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  53. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  54. data/ext/sources/examples/common-ggml.cpp +238 -0
  55. data/ext/sources/examples/common-ggml.h +18 -0
  56. data/ext/sources/examples/common-sdl.cpp +227 -0
  57. data/ext/sources/examples/common-sdl.h +49 -0
  58. data/ext/sources/examples/common-whisper.cpp +175 -0
  59. data/ext/sources/examples/common-whisper.h +24 -0
  60. data/ext/sources/examples/common.cpp +675 -0
  61. data/ext/sources/examples/common.h +322 -0
  62. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  63. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  64. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  65. data/ext/sources/examples/generate-karaoke.sh +57 -0
  66. data/ext/sources/examples/grammar-parser.cpp +423 -0
  67. data/ext/sources/examples/grammar-parser.h +29 -0
  68. data/ext/sources/examples/helpers.js +191 -0
  69. data/ext/sources/examples/json.hpp +24596 -0
  70. data/ext/sources/examples/livestream.sh +112 -0
  71. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  72. data/ext/sources/examples/lsp/lsp.cpp +469 -0
  73. data/ext/sources/examples/lsp/whisper.vim +362 -0
  74. data/ext/sources/examples/miniaudio.h +93468 -0
  75. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  76. data/ext/sources/examples/python/whisper_processor.py +54 -0
  77. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  78. data/ext/sources/examples/quantize/quantize.cpp +226 -0
  79. data/ext/sources/examples/server/CMakeLists.txt +15 -0
  80. data/ext/sources/examples/server/bench.js +29 -0
  81. data/ext/sources/examples/server/httplib.h +10497 -0
  82. data/ext/sources/examples/server/server.cpp +1238 -0
  83. data/ext/sources/examples/server.py +115 -0
  84. data/ext/sources/examples/stb_vorbis.c +5584 -0
  85. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  86. data/ext/sources/examples/stream/stream.cpp +435 -0
  87. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  88. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  89. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  90. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  91. data/ext/sources/examples/sycl/build.sh +22 -0
  92. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  93. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  94. data/ext/sources/examples/talk-llama/CMakeLists.txt +43 -0
  95. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  96. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  97. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  98. data/ext/sources/examples/talk-llama/llama-arch.cpp +1914 -0
  99. data/ext/sources/examples/talk-llama/llama-arch.h +464 -0
  100. data/ext/sources/examples/talk-llama/llama-batch.cpp +843 -0
  101. data/ext/sources/examples/talk-llama/llama-batch.h +147 -0
  102. data/ext/sources/examples/talk-llama/llama-chat.cpp +685 -0
  103. data/ext/sources/examples/talk-llama/llama-chat.h +59 -0
  104. data/ext/sources/examples/talk-llama/llama-context.cpp +2845 -0
  105. data/ext/sources/examples/talk-llama/llama-context.h +297 -0
  106. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  107. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  108. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  109. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  110. data/ext/sources/examples/talk-llama/llama-graph.cpp +1693 -0
  111. data/ext/sources/examples/talk-llama/llama-graph.h +710 -0
  112. data/ext/sources/examples/talk-llama/llama-hparams.cpp +103 -0
  113. data/ext/sources/examples/talk-llama/llama-hparams.h +207 -0
  114. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  115. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  116. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  117. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  118. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  119. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  120. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +44 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +439 -0
  124. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  125. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  126. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  127. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  128. data/ext/sources/examples/talk-llama/llama-memory.cpp +59 -0
  129. data/ext/sources/examples/talk-llama/llama-memory.h +116 -0
  130. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  131. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  132. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1163 -0
  133. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  134. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +282 -0
  135. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  136. data/ext/sources/examples/talk-llama/llama-model.cpp +15114 -0
  137. data/ext/sources/examples/talk-llama/llama-model.h +452 -0
  138. data/ext/sources/examples/talk-llama/llama-quant.cpp +1049 -0
  139. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  140. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  141. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  142. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3377 -0
  143. data/ext/sources/examples/talk-llama/llama-vocab.h +132 -0
  144. data/ext/sources/examples/talk-llama/llama.cpp +358 -0
  145. data/ext/sources/examples/talk-llama/llama.h +1484 -0
  146. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  147. data/ext/sources/examples/talk-llama/speak +40 -0
  148. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  149. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  150. data/ext/sources/examples/talk-llama/talk-llama.cpp +810 -0
  151. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  152. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  153. data/ext/sources/examples/talk-llama/unicode.cpp +854 -0
  154. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  155. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  156. data/ext/sources/examples/vad-speech-segments/speech.cpp +149 -0
  157. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  158. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  159. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  160. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  161. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  162. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  163. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  164. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  165. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +251 -0
  166. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  167. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  168. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  169. data/ext/sources/ggml/CMakeLists.txt +435 -0
  170. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  171. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  172. data/ext/sources/ggml/cmake/common.cmake +50 -0
  173. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  174. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +10 -8
  176. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +11 -1
  178. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  179. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  180. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  181. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  182. data/ext/{ggml → sources/ggml}/include/ggml.h +325 -269
  183. data/ext/sources/ggml/include/gguf.h +202 -0
  184. data/ext/sources/ggml/src/CMakeLists.txt +404 -0
  185. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  186. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  187. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  188. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +92 -53
  189. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +69 -34
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  191. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +75 -0
  192. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  195. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  196. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  197. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +140 -1
  198. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +588 -146
  199. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  200. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  201. data/ext/{ggml → sources/ggml}/src/ggml-common.h +16 -8
  202. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +597 -0
  203. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +3 -2
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +11 -10
  205. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  208. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  209. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  210. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  211. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  212. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  213. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  214. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  215. data/ext/{ggml/src/ggml-cpu/cpu-feats-x86.cpp → sources/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp} +5 -1
  216. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  217. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +3285 -0
  218. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  219. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  220. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  221. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  222. data/ext/sources/ggml/src/ggml-cpu/common.h +73 -0
  223. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +172 -41
  224. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3551 -0
  225. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +78 -25
  226. data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.cpp → sources/ggml/src/ggml-cpu/hbm.cpp} +1 -1
  227. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  228. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  229. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  230. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  231. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3594 -0
  232. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +19 -0
  233. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +9786 -0
  234. data/ext/sources/ggml/src/ggml-cpu/ops.h +118 -0
  235. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  236. data/ext/{ggml/src/ggml-cpu/ggml-cpu-quants.h → sources/ggml/src/ggml-cpu/quants.h} +26 -0
  237. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  238. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  239. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +1184 -0
  240. data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.cpp → sources/ggml/src/ggml-cpu/traits.cpp} +1 -1
  241. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  242. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  243. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +345 -0
  244. data/ext/sources/ggml/src/ggml-cpu/vec.h +1027 -0
  245. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  246. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  247. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  248. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  249. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  250. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  251. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  252. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  253. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  254. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  255. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  256. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  257. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/common.cuh +851 -0
  259. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  260. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  262. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  264. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  266. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  267. data/ext/sources/ggml/src/ggml-cuda/convert.cu +752 -0
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +31 -0
  269. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  270. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  271. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  273. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  275. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  276. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  277. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  278. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1474 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  287. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +638 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  289. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  290. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  291. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  292. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  293. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3647 -0
  294. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  295. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  296. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  297. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  298. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  299. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  300. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  301. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  302. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +506 -0
  304. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +11 -0
  305. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  307. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  308. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  309. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  310. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  312. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  313. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  314. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  315. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  316. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  317. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  318. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  319. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  320. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  321. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  322. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  323. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  324. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  325. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  326. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  327. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +155 -0
  328. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  329. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  330. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +26 -0
  332. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +4 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  430. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  432. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  433. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  434. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  436. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  437. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  438. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  439. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  440. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  441. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  442. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  443. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  444. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  445. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  446. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  447. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  448. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  449. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  450. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  451. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  452. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  453. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  454. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  455. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  456. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  457. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  458. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  459. data/ext/sources/ggml/src/ggml-cuda/unary.cu +378 -0
  460. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +66 -0
  461. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  462. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  463. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  464. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  465. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  466. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  467. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  468. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  469. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +135 -0
  470. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +147 -158
  471. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  481. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  482. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  483. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  484. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  485. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  486. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  487. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  488. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  489. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  490. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  491. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  492. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  493. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  494. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  495. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  496. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  497. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  498. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  499. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  500. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  501. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  502. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  503. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  504. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  505. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  506. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  507. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  508. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  509. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +121 -0
  510. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +649 -0
  511. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2504 -1108
  512. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +2102 -1463
  513. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  514. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  515. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  516. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +110 -0
  517. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +6494 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  526. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  527. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  528. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  529. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  530. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  531. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  532. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  533. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  534. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  535. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  536. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  537. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  538. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  539. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  540. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  541. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  542. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  543. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  544. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  545. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  546. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  547. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  548. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  549. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  550. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  551. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  552. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  553. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  554. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  555. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  556. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  557. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  558. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  559. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  560. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  561. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  562. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  563. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  564. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  565. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  566. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  567. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  568. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  569. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +120 -128
  570. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  571. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +494 -84
  572. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  573. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  574. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +344 -0
  575. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  576. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  577. data/ext/sources/ggml/src/ggml-sycl/common.hpp +561 -0
  578. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +56 -70
  579. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  580. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +8 -12
  581. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  582. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +575 -0
  583. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  584. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +839 -0
  585. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  586. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +823 -0
  587. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +188 -67
  588. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  589. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2987 -0
  590. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1120 -0
  591. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +84 -0
  592. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +102 -0
  593. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +212 -0
  594. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  595. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1197 -1295
  596. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  597. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  598. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  599. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  600. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +60 -81
  601. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  602. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1065 -0
  603. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  604. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +482 -0
  605. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  606. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  607. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  608. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  609. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +111 -0
  610. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +472 -0
  611. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  612. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +38 -28
  613. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  614. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +15 -0
  615. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +26 -0
  616. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +6 -11
  617. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  618. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1307 -0
  619. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +289 -0
  620. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +200 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  623. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3822 -1335
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +31 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +61 -0
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  740. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +203 -36
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  743. data/ext/{ggml → sources/ggml}/src/ggml.c +918 -1782
  744. data/ext/sources/ggml/src/ggml.cpp +26 -0
  745. data/ext/sources/ggml/src/gguf.cpp +1351 -0
  746. data/ext/{include → sources/include}/whisper.h +70 -2
  747. data/ext/sources/src/CMakeLists.txt +145 -0
  748. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  749. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  750. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  751. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +36 -10
  752. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  753. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +29 -3
  754. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  755. data/ext/sources/src/whisper-arch.h +197 -0
  756. data/ext/{src → sources/src}/whisper.cpp +1966 -386
  757. data/ext/sources/tests/CMakeLists.txt +105 -0
  758. data/ext/sources/tests/earnings21/eval.mk +58 -0
  759. data/ext/sources/tests/earnings21/eval.py +68 -0
  760. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  761. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  762. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  763. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  764. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  765. data/ext/sources/tests/en-0-ref.txt +1 -0
  766. data/ext/sources/tests/en-1-ref.txt +1 -0
  767. data/ext/sources/tests/en-2-ref.txt +1 -0
  768. data/ext/sources/tests/es-0-ref.txt +1 -0
  769. data/ext/sources/tests/librispeech/eval.mk +39 -0
  770. data/ext/sources/tests/librispeech/eval.py +47 -0
  771. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  772. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  773. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  774. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  775. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  776. data/ext/sources/tests/run-tests.sh +130 -0
  777. data/ext/sources/tests/test-c.c +3 -0
  778. data/ext/sources/tests/test-vad-full.cpp +54 -0
  779. data/ext/sources/tests/test-vad.cpp +83 -0
  780. data/ext/sources/tests/test-whisper.js +58 -0
  781. data/extsources.rb +39 -5
  782. data/lib/whisper/context.rb +15 -0
  783. data/lib/whisper/model/uri.rb +202 -126
  784. data/lib/whisper/segment.rb +58 -0
  785. data/sig/whisper.rbs +510 -0
  786. data/test/helper.rb +24 -0
  787. data/{tests → test}/test_callback.rb +45 -3
  788. data/{tests → test}/test_error.rb +2 -2
  789. data/{tests → test}/test_model.rb +47 -0
  790. data/test/test_package.rb +51 -0
  791. data/test/test_params.rb +297 -0
  792. data/test/test_segment.rb +146 -0
  793. data/test/test_vad.rb +19 -0
  794. data/test/test_vad_params.rb +103 -0
  795. data/{tests → test}/test_whisper.rb +106 -36
  796. data/whispercpp.gemspec +5 -5
  797. metadata +837 -134
  798. data/ext/cpu.mk +0 -9
  799. data/ext/examples/dr_wav.h +0 -8815
  800. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  801. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  802. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  803. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -10835
  804. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  805. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  806. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  807. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  808. data/ext/ggml/src/ggml-sycl/convert.cpp +0 -547
  809. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  810. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  811. data/ext/ggml/src/ggml-sycl/mmvq.cpp +0 -1015
  812. data/ext/ggml/src/ggml-sycl/norm.cpp +0 -378
  813. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  814. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  815. data/ext/metal-embed.mk +0 -17
  816. data/ext/metal.mk +0 -6
  817. data/ext/ruby_whisper.cpp +0 -1909
  818. data/ext/scripts/get-flags.mk +0 -38
  819. data/lib/whisper.rb +0 -2
  820. data/tests/helper.rb +0 -7
  821. data/tests/test_package.rb +0 -31
  822. data/tests/test_params.rb +0 -160
  823. data/tests/test_segment.rb +0 -83
  824. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  825. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  826. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  827. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  828. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  829. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  830. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  831. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  832. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  833. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  834. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  835. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  836. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  837. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  838. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  839. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  840. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  841. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  842. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  843. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  844. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  845. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  846. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-hbm.h → sources/ggml/src/ggml-cpu/hbm.h} +0 -0
  847. /data/ext/{ggml/src/ggml-cpu/ggml-cpu-traits.h → sources/ggml/src/ggml-cpu/traits.h} +0 -0
  848. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  849. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  850. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  851. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  852. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  853. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  854. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
  855. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  856. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  857. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
@@ -1,8 +1,8 @@
1
1
  #include "whisper.h"
2
-
3
- #include "ggml-cpu.h"
2
+ #include "whisper-arch.h"
4
3
 
5
4
  #include "ggml.h"
5
+ #include "ggml-cpp.h"
6
6
  #include "ggml-alloc.h"
7
7
  #include "ggml-backend.h"
8
8
 
@@ -17,37 +17,36 @@
17
17
  #include <atomic>
18
18
  #include <algorithm>
19
19
  #include <cassert>
20
+ #include <cfloat>
20
21
  #define _USE_MATH_DEFINES
21
22
  #include <cmath>
22
- #include <cstdio>
23
+ #include <climits>
24
+ #include <codecvt>
23
25
  #include <cstdarg>
26
+ #include <cstdio>
24
27
  #include <cstring>
25
28
  #include <fstream>
29
+ #include <functional>
26
30
  #include <map>
31
+ #include <mutex>
32
+ #include <random>
33
+ #include <regex>
27
34
  #include <set>
28
35
  #include <string>
29
36
  #include <thread>
30
37
  #include <vector>
31
- #include <regex>
32
- #include <random>
33
- #include <functional>
34
- #include <codecvt>
35
-
36
- #if defined(_MSC_VER)
37
- #pragma warning(disable: 4244 4267) // possible loss of data
38
- #endif
39
-
40
- #if defined(GGML_BIG_ENDIAN)
41
- #include <bit>
42
38
 
39
+ #if defined(WHISPER_BIG_ENDIAN)
43
40
  template<typename T>
44
41
  static T byteswap(T value) {
45
- return std::byteswap(value);
46
- }
47
-
48
- template<>
49
- float byteswap(float value) {
50
- return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
42
+ T value_swapped;
43
+ char * source = reinterpret_cast<char *>(&value);
44
+ char * target = reinterpret_cast<char *>(&value_swapped);
45
+ int size = sizeof(T);
46
+ for (int i = 0; i < size; i++) {
47
+ target[size - 1 - i] = source[i];
48
+ }
49
+ return value_swapped;
51
50
  }
52
51
 
53
52
  template<typename T>
@@ -83,14 +82,14 @@ static void byteswap_tensor(ggml_tensor * tensor) {
83
82
  }
84
83
 
85
84
  #define BYTESWAP_VALUE(d) d = byteswap(d)
86
- #define BYTESWAP_FILTERS(f) \
85
+ #define BYTESWAP_FILTERS(f) \
87
86
  do { \
88
87
  for (auto & datum : f.data) { \
89
88
  datum = byteswap(datum); \
90
89
  } \
91
90
  } while (0)
92
- #define BYTESWAP_TENSOR(t) \
93
- do { \
91
+ #define BYTESWAP_TENSOR(t) \
92
+ do { \
94
93
  byteswap_tensor(t); \
95
94
  } while (0)
96
95
  #else
@@ -141,34 +140,52 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
141
140
  #define WHISPER_MAX_DECODERS 8
142
141
  #define WHISPER_MAX_NODES 4096
143
142
 
143
+ static std::string format(const char * fmt, ...) {
144
+ va_list ap;
145
+ va_list ap2;
146
+ va_start(ap, fmt);
147
+ va_copy(ap2, ap);
148
+ int size = vsnprintf(NULL, 0, fmt, ap);
149
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
150
+ std::vector<char> buf(size + 1);
151
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
152
+ GGML_ASSERT(size2 == size);
153
+ va_end(ap2);
154
+ va_end(ap);
155
+ return std::string(buf.data(), size);
156
+ }
157
+
144
158
  //
145
159
  // ggml helpers
146
160
  //
147
161
 
148
162
  static bool ggml_graph_compute_helper(
149
163
  struct ggml_cgraph * graph,
150
- std::vector<uint8_t> & buf,
151
164
  int n_threads,
152
165
  ggml_abort_callback abort_callback,
153
166
  void * abort_callback_data) {
154
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
167
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
168
+
169
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
155
170
 
156
- plan.abort_callback = abort_callback;
157
- plan.abort_callback_data = abort_callback_data;
171
+ auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
172
+ if (set_abort_callback_fn) {
173
+ set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
174
+ }
158
175
 
159
- if (plan.work_size > 0) {
160
- buf.resize(plan.work_size);
161
- plan.work_data = buf.data();
176
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
177
+ if (ggml_backend_set_n_threads_fn) {
178
+ ggml_backend_set_n_threads_fn(backend.get(), n_threads);
162
179
  }
163
180
 
164
- return ggml_graph_compute(graph, &plan);
181
+ return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
165
182
  }
166
183
 
167
184
  static bool ggml_graph_compute_helper(
168
185
  ggml_backend_sched_t sched,
169
186
  struct ggml_cgraph * graph,
170
- int n_threads) {
171
-
187
+ int n_threads,
188
+ bool sched_reset = true) {
172
189
  for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
173
190
  ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
174
191
  ggml_backend_dev_t dev = ggml_backend_get_device(backend);
@@ -180,11 +197,61 @@ static bool ggml_graph_compute_helper(
180
197
  }
181
198
  }
182
199
 
183
- bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
184
- ggml_backend_sched_reset(sched);
200
+ const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
201
+
202
+ if (!t || sched_reset) {
203
+ ggml_backend_sched_reset(sched);
204
+ }
205
+
206
+ return t;
207
+ }
208
+
209
+ // TODO: move these functions to ggml-base with support for ggml-backend?
210
+
211
+ static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
212
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
213
+ GGML_ASSERT(ggml_is_contiguous(t));
214
+ size_t nels = ggml_nelements(t);
215
+ for (size_t i = 0; i < nels; ++i) {
216
+ ((float *) t->data)[i] = v;
217
+ }
218
+ return t;
219
+ }
220
+
221
+ static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
222
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
223
+ GGML_ASSERT(ggml_is_contiguous(t));
224
+ size_t nels = ggml_nelements(t);
225
+ for (size_t i = 0; i < nels; ++i) {
226
+ ((int32_t *) t->data)[i] = v;
227
+ }
185
228
  return t;
186
229
  }
187
230
 
231
+ static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
232
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
233
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
234
+ return *(float *) data;
235
+ }
236
+
237
+ static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
238
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
239
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
240
+ *(float *) data = v;
241
+ }
242
+
243
+ static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
244
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
245
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
246
+ return *(int32_t *) data;
247
+ }
248
+
249
+ static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
250
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
251
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
252
+ *(int32_t *) data = v;
253
+ }
254
+
188
255
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
189
256
  // the idea is to represent the original matrix multiplication:
190
257
  //
@@ -428,6 +495,7 @@ struct whisper_segment {
428
495
  int64_t t1;
429
496
 
430
497
  std::string text;
498
+ float no_speech_prob;
431
499
 
432
500
  std::vector<whisper_token_data> tokens;
433
501
 
@@ -520,7 +588,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
520
588
  auto & sched = allocr.sched;
521
589
  auto & meta = allocr.meta;
522
590
 
523
- sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
591
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
524
592
 
525
593
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
526
594
 
@@ -716,10 +784,10 @@ struct whisper_model {
716
784
  std::vector<whisper_layer_decoder> layers_decoder;
717
785
 
718
786
  // ggml context that contains all the meta information about the model tensors
719
- struct ggml_context * ctx = nullptr;
787
+ std::vector<ggml_context *> ctxs;
720
788
 
721
789
  // the model backend data is read-only and can be shared between processors
722
- ggml_backend_buffer_t buffer = nullptr;
790
+ std::vector<ggml_backend_buffer_t> buffers;
723
791
 
724
792
  // tensors
725
793
  int n_loaded;
@@ -791,6 +859,11 @@ struct whisper_aheads_masks {
791
859
  ggml_backend_buffer_t buffer = nullptr;
792
860
  };
793
861
 
862
+ struct vad_time_mapping {
863
+ int64_t processed_time; // Time in processed (VAD) audio
864
+ int64_t original_time; // Corresponding time in original audio
865
+ };
866
+
794
867
  struct whisper_state {
795
868
  int64_t t_sample_us = 0;
796
869
  int64_t t_encode_us = 0;
@@ -876,6 +949,19 @@ struct whisper_state {
876
949
 
877
950
  // [EXPERIMENTAL] speed-up techniques
878
951
  int32_t exp_n_audio_ctx = 0; // 0 - use default
952
+
953
+ whisper_vad_context * vad_context = nullptr;
954
+
955
+ struct vad_segment_info {
956
+ int64_t orig_start;
957
+ int64_t orig_end;
958
+ int64_t vad_start;
959
+ int64_t vad_end;
960
+ };
961
+ std::vector<vad_segment_info> vad_segments;
962
+ bool has_vad_segments = false;
963
+
964
+ std::vector<vad_time_mapping> vad_mapping_table;
879
965
  };
880
966
 
881
967
  struct whisper_context {
@@ -1234,21 +1320,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1234
1320
  static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1235
1321
  ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1236
1322
 
1323
+ ggml_backend_dev_t dev = nullptr;
1324
+
1325
+ int cnt = 0;
1237
1326
  if (params.use_gpu) {
1238
1327
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1239
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1240
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1241
- WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1242
- ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1243
- if (!result) {
1244
- WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1328
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
1329
+ if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1330
+ if (cnt == 0 || cnt == params.gpu_device) {
1331
+ dev = dev_cur;
1332
+ }
1333
+
1334
+ if (++cnt > params.gpu_device) {
1335
+ break;
1245
1336
  }
1246
- return result;
1247
1337
  }
1248
1338
  }
1249
1339
  }
1250
1340
 
1251
- return nullptr;
1341
+ if (dev == nullptr) {
1342
+ WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
1343
+ return nullptr;
1344
+ }
1345
+
1346
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1347
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1348
+ if (!result) {
1349
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1350
+ }
1351
+
1352
+ return result;
1252
1353
  }
1253
1354
 
1254
1355
  static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@@ -1274,28 +1375,118 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
1274
1375
  }
1275
1376
  }
1276
1377
 
1277
- GGML_UNUSED(params);
1278
-
1279
- result.push_back(ggml_backend_cpu_init());
1378
+ ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
1379
+ if (backend_cpu == nullptr) {
1380
+ throw std::runtime_error("failed to initialize CPU backend");
1381
+ }
1382
+ result.push_back(backend_cpu);
1280
1383
 
1281
1384
  return result;
1282
1385
  }
1283
1386
 
1284
- static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1285
- if (!params.use_gpu) {
1286
- return ggml_backend_cpu_buffer_type();
1387
+ using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
1388
+
1389
+ static buft_list_t make_buft_list(whisper_context_params & params) {
1390
+ // Prio order: GPU -> CPU Extra -> CPU
1391
+ buft_list_t buft_list;
1392
+
1393
+ // GPU
1394
+ if (params.use_gpu) {
1395
+ int cnt = 0;
1396
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1397
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1398
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1399
+ if (cnt == 0 || cnt == params.gpu_device) {
1400
+ auto * buft = ggml_backend_dev_buffer_type(dev);
1401
+ if (buft) {
1402
+ buft_list.emplace_back(dev, buft);
1403
+ }
1404
+ }
1405
+
1406
+ if (++cnt > params.gpu_device) {
1407
+ break;
1408
+ }
1409
+ }
1410
+ }
1411
+ }
1412
+
1413
+ // CPU Extra
1414
+ auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
1415
+ auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
1416
+ auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
1417
+ ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
1418
+ if (get_extra_bufts_fn) {
1419
+ ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
1420
+ while (extra_bufts && *extra_bufts) {
1421
+ buft_list.emplace_back(cpu_dev, *extra_bufts);
1422
+ ++extra_bufts;
1423
+ }
1287
1424
  }
1288
1425
 
1289
- // if we have a GPU device - use it
1290
- for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1291
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1292
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1293
- WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1294
- return ggml_backend_dev_buffer_type(dev);
1426
+ // CPU
1427
+ buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
1428
+
1429
+ return buft_list;
1430
+ }
1431
+
1432
+ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
1433
+ bool op_supported = true;
1434
+
1435
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
1436
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
1437
+ // GPU and default CPU backend support all operators
1438
+ op_supported = true;
1439
+ } else {
1440
+ switch (op) {
1441
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
1442
+ case GGML_OP_MUL_MAT: {
1443
+ ggml_init_params params = {
1444
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
1445
+ /*.mem_buffer =*/ nullptr,
1446
+ /*.no_alloc =*/ true,
1447
+ };
1448
+
1449
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1450
+ if (!ctx_ptr) {
1451
+ throw std::runtime_error("failed to create ggml context");
1452
+ }
1453
+ ggml_context * ctx = ctx_ptr.get();
1454
+
1455
+ ggml_tensor * op_tensor = nullptr;
1456
+
1457
+ int64_t n_ctx = hparams.n_audio_ctx;
1458
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1459
+ op_tensor = ggml_mul_mat(ctx, w, b);
1460
+
1461
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1462
+ GGML_ASSERT(w->buffer == nullptr);
1463
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
1464
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
1465
+ ggml_backend_buffer_free(w->buffer);
1466
+ w->buffer = nullptr;
1467
+ break;
1468
+ }
1469
+ default: {
1470
+ op_supported = false;
1471
+ break;
1472
+ }
1473
+ };
1474
+ }
1475
+
1476
+ return op_supported;
1477
+ }
1478
+
1479
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
1480
+ GGML_ASSERT(!buft_list.empty());
1481
+ for (const auto & p : buft_list) {
1482
+ ggml_backend_dev_t dev = p.first;
1483
+ ggml_backend_buffer_type_t buft = p.second;
1484
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
1485
+ return buft;
1295
1486
  }
1296
1487
  }
1297
1488
 
1298
- return ggml_backend_cpu_buffer_type();
1489
+ return nullptr;
1299
1490
  }
1300
1491
 
1301
1492
  // load the model from a ggml file
@@ -1504,31 +1695,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1504
1695
  const ggml_type wtype = wctx.wtype;
1505
1696
  const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
1506
1697
 
1507
- // create the ggml context
1508
- {
1509
- const auto & hparams = model.hparams;
1698
+ const auto & hparams = model.hparams;
1510
1699
 
1511
- const int n_audio_layer = hparams.n_audio_layer;
1512
- const int n_text_layer = hparams.n_text_layer;
1700
+ const int n_audio_layer = hparams.n_audio_layer;
1701
+ const int n_text_layer = hparams.n_text_layer;
1513
1702
 
1514
- const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1703
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1515
1704
 
1516
- struct ggml_init_params params = {
1517
- /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
1518
- /*.mem_buffer =*/ nullptr,
1519
- /*.no_alloc =*/ true,
1520
- };
1705
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1706
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1707
+ auto it = ctx_map.find(buft);
1708
+ if (it == ctx_map.end()) {
1709
+ ggml_init_params params = {
1710
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1711
+ /*.mem_buffer =*/ nullptr,
1712
+ /*.no_alloc =*/ true,
1713
+ };
1521
1714
 
1522
- model.ctx = ggml_init(params);
1523
- if (!model.ctx) {
1524
- WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
1525
- return false;
1715
+ ggml_context * ctx = ggml_init(params);
1716
+ if (!ctx) {
1717
+ throw std::runtime_error("failed to create ggml context");
1718
+ }
1719
+
1720
+ ctx_map[buft] = ctx;
1721
+ model.ctxs.emplace_back(ctx);
1722
+
1723
+ return ctx;
1526
1724
  }
1527
- }
1725
+
1726
+ return it->second;
1727
+ };
1728
+
1729
+ // Create a list of available bufts, in priority order
1730
+ buft_list_t buft_list = make_buft_list(wctx.params);
1731
+
1732
+ auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
1733
+ ggml_op op = ASR_TENSOR_INFO.at(type);
1734
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
1735
+ if (!buft) {
1736
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
1737
+ }
1738
+
1739
+ ggml_context * ctx = get_ctx(buft);
1740
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
1741
+
1742
+ model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
1743
+
1744
+ return tensor;
1745
+ };
1746
+
1528
1747
 
1529
1748
  // prepare tensors for the weights
1530
1749
  {
1531
- auto & ctx = model.ctx;
1750
+ ggml_init_params params = {
1751
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1752
+ /*.mem_buffer =*/ nullptr,
1753
+ /*.no_alloc =*/ true,
1754
+ };
1755
+
1756
+ ggml_context * ctx = ggml_init(params);
1532
1757
 
1533
1758
  const auto & hparams = model.hparams;
1534
1759
 
@@ -1548,189 +1773,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1548
1773
  model.layers_decoder.resize(n_text_layer);
1549
1774
 
1550
1775
  // encoder
1551
- {
1552
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1553
-
1554
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1555
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1556
-
1557
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1558
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1559
-
1560
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1561
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1562
-
1563
- // map by name
1564
- model.tensors["encoder.positional_embedding"] = model.e_pe;
1565
-
1566
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1567
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1568
-
1569
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1570
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1571
-
1572
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1573
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1574
-
1575
- for (int i = 0; i < n_audio_layer; ++i) {
1576
- auto & layer = model.layers_encoder[i];
1577
-
1578
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1579
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1580
-
1581
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1582
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1776
+ model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
1583
1777
 
1584
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1585
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1778
+ model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
1779
+ model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
1586
1780
 
1587
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1588
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1781
+ model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
1782
+ model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
1589
1783
 
1590
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1591
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1784
+ model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1785
+ model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1592
1786
 
1593
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1787
+ for (int i = 0; i < n_audio_layer; ++i) {
1788
+ auto & layer = model.layers_encoder[i];
1594
1789
 
1595
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1596
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1790
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1791
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1597
1792
 
1598
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1599
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1793
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
1794
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
1600
1795
 
1601
- // map by name
1602
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1603
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1796
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
1797
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1604
1798
 
1605
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1606
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1799
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1800
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1607
1801
 
1608
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1609
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1802
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1803
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1610
1804
 
1611
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1612
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1805
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1613
1806
 
1614
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1615
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1807
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1808
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1616
1809
 
1617
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1618
-
1619
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1620
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1621
-
1622
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1623
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1624
- }
1810
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1811
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1625
1812
  }
1626
1813
 
1627
1814
  // decoder
1628
- {
1629
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1630
-
1631
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1632
-
1633
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1634
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1635
-
1636
- // map by name
1637
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1638
-
1639
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
1640
-
1641
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1642
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1643
-
1644
- for (int i = 0; i < n_text_layer; ++i) {
1645
- auto & layer = model.layers_decoder[i];
1646
-
1647
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1648
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1649
-
1650
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1651
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1652
-
1653
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1654
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1655
-
1656
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1657
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1658
-
1659
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1660
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1661
-
1662
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1815
+ model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
1663
1816
 
1664
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1665
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1817
+ model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
1666
1818
 
1667
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1668
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1819
+ model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
1820
+ model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
1669
1821
 
1670
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1671
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1822
+ for (int i = 0; i < n_text_layer; ++i) {
1823
+ auto & layer = model.layers_decoder[i];
1672
1824
 
1673
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1674
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1825
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1826
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1675
1827
 
1676
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1828
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
1829
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
1677
1830
 
1678
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1679
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1831
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
1832
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1680
1833
 
1681
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1682
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1834
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1835
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1683
1836
 
1684
- // map by name
1685
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1686
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1837
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1838
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1687
1839
 
1688
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1689
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1840
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1690
1841
 
1691
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1692
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1842
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1843
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1693
1844
 
1694
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1695
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1845
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1846
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1696
1847
 
1697
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1698
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1848
+ layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1849
+ layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1699
1850
 
1700
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1851
+ layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1852
+ layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1701
1853
 
1702
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1703
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1854
+ layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1704
1855
 
1705
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1706
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1856
+ layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1857
+ layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1707
1858
 
1708
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1709
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1710
-
1711
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1712
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1713
-
1714
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1715
-
1716
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1717
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1718
-
1719
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1720
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1721
- }
1859
+ layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1860
+ layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1722
1861
  }
1862
+
1863
+ ggml_free(ctx);
1723
1864
  }
1724
1865
 
1725
1866
  // allocate tensors in the backend buffers
1726
- model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1727
- if (!model.buffer) {
1728
- WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1729
- return false;
1730
- }
1867
+ for (auto & p : ctx_map) {
1868
+ ggml_backend_buffer_type_t buft = p.first;
1869
+ ggml_context * ctx = p.second;
1870
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1871
+ if (buf) {
1872
+ model.buffers.emplace_back(buf);
1731
1873
 
1732
- size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1733
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1874
+ size_t size_main = ggml_backend_buffer_get_size(buf);
1875
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
1876
+ }
1877
+ }
1734
1878
 
1735
1879
  // load weights
1736
1880
  {
@@ -1793,11 +1937,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1793
1937
  return false;
1794
1938
  }
1795
1939
 
1796
- //ggml_backend_t backend = wctx.backend;
1797
-
1798
- //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
1799
-
1800
- if (ggml_backend_buffer_is_host(model.buffer)) {
1940
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
1801
1941
  // for the CPU and Metal backend, we can read directly into the tensor
1802
1942
  loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1803
1943
  BYTESWAP_TENSOR(tensor);
@@ -1810,7 +1950,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1810
1950
  ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1811
1951
  }
1812
1952
 
1813
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
1814
1953
  total_size += ggml_nbytes(tensor);
1815
1954
  model.n_loaded++;
1816
1955
  }
@@ -1825,7 +1964,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1825
1964
  }
1826
1965
  }
1827
1966
 
1828
- ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1967
+ for (auto & buf : model.buffers) {
1968
+ ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1969
+ }
1829
1970
 
1830
1971
  wctx.t_load_us = ggml_time_us() - t_start_us;
1831
1972
 
@@ -3710,15 +3851,24 @@ void whisper_free_state(struct whisper_state * state) {
3710
3851
  // [EXPERIMENTAL] Token-level timestamps with DTW
3711
3852
  aheads_masks_free(state->aheads_masks);
3712
3853
 
3854
+ if (state->vad_context != nullptr) {
3855
+ whisper_vad_free(state->vad_context);
3856
+ state->vad_context = nullptr;
3857
+ }
3858
+
3713
3859
  delete state;
3714
3860
  }
3715
3861
  }
3716
3862
 
3717
3863
  void whisper_free(struct whisper_context * ctx) {
3718
3864
  if (ctx) {
3719
- ggml_free(ctx->model.ctx);
3865
+ for (ggml_context * context : ctx->model.ctxs) {
3866
+ ggml_free(context);
3867
+ }
3720
3868
 
3721
- ggml_backend_buffer_free(ctx->model.buffer);
3869
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
3870
+ ggml_backend_buffer_free(buf);
3871
+ }
3722
3872
 
3723
3873
  whisper_free_state(ctx->state);
3724
3874
 
@@ -4136,11 +4286,11 @@ void whisper_print_timings(struct whisper_context * ctx) {
4136
4286
 
4137
4287
  WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
4138
4288
  WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
4139
- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4140
- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4141
- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4142
- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4143
- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4289
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4290
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4291
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4292
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4293
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4144
4294
  }
4145
4295
  WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4146
4296
  }
@@ -4182,113 +4332,1238 @@ const char * whisper_print_system_info(void) {
4182
4332
  static std::string s;
4183
4333
 
4184
4334
  s = "";
4185
- s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
4186
- s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
4187
- s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
4188
- s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
4189
- s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
4190
- s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
4191
- s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
4192
- s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
4193
- s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
4194
- s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
4195
- s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
4196
- s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
4335
+ s += "WHISPER : ";
4197
4336
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4198
4337
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
4199
4338
 
4339
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
4340
+ auto * reg = ggml_backend_reg_get(i);
4341
+ auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
4342
+ if (get_features_fn) {
4343
+ ggml_backend_feature * features = get_features_fn(reg);
4344
+ s += ggml_backend_reg_name(reg);
4345
+ s += " : ";
4346
+ for (; features->name; features++) {
4347
+ s += features->name;
4348
+ s += " = ";
4349
+ s += features->value;
4350
+ s += " | ";
4351
+ }
4352
+ }
4353
+ }
4200
4354
  return s.c_str();
4201
4355
  }
4202
4356
 
4203
4357
  //////////////////////////////////
4204
- // Grammar - ported from llama.cpp
4358
+ // Voice Activity Detection (VAD)
4205
4359
  //////////////////////////////////
4206
4360
 
4207
- // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4208
- // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4209
- static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4210
- const char * src,
4211
- whisper_partial_utf8 partial_start) {
4212
- static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
4213
- const char * pos = src;
4214
- std::vector<uint32_t> code_points;
4215
- uint32_t value = partial_start.value;
4216
- int n_remain = partial_start.n_remain;
4361
+ struct whisper_vad_hparams {
4362
+ int32_t n_encoder_layers;
4363
+ int32_t * encoder_in_channels;
4364
+ int32_t * encoder_out_channels;
4365
+ int32_t * kernel_sizes;
4366
+ int32_t lstm_input_size;
4367
+ int32_t lstm_hidden_size;
4368
+ int32_t final_conv_in;
4369
+ int32_t final_conv_out;
4370
+ };
4217
4371
 
4218
- // continue previous decode, if applicable
4219
- while (*pos != 0 && n_remain > 0) {
4220
- uint8_t next_byte = static_cast<uint8_t>(*pos);
4221
- if ((next_byte >> 6) != 2) {
4222
- // invalid sequence, abort
4223
- code_points.push_back(0);
4224
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
4225
- }
4226
- value = (value << 6) + (next_byte & 0x3F);
4227
- ++pos;
4228
- --n_remain;
4229
- }
4372
+ struct whisper_vad_model {
4373
+ std::string type;
4374
+ std::string version;
4375
+ whisper_vad_hparams hparams;
4230
4376
 
4231
- if (partial_start.n_remain > 0 && n_remain == 0) {
4232
- code_points.push_back(value);
4233
- }
4377
+ struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
4234
4378
 
4235
- // decode any subsequent utf-8 sequences, which may end in an incomplete one
4236
- while (*pos != 0) {
4237
- uint8_t first_byte = static_cast<uint8_t>(*pos);
4238
- uint8_t highbits = first_byte >> 4;
4239
- n_remain = lookup[highbits] - 1;
4379
+ // Encoder tensors - 4 convolutional layers
4380
+ struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
4381
+ struct ggml_tensor * encoder_0_bias; // [128]
4240
4382
 
4241
- if (n_remain < 0) {
4242
- // invalid sequence, abort
4243
- code_points.clear();
4244
- code_points.push_back(0);
4245
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
4246
- }
4383
+ // Second encoder layer
4384
+ struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
4385
+ struct ggml_tensor * encoder_1_bias; // [64]
4247
4386
 
4248
- uint8_t mask = (1 << (7 - n_remain)) - 1;
4249
- value = first_byte & mask;
4250
- ++pos;
4251
- while (*pos != 0 && n_remain > 0) {
4252
- value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
4253
- ++pos;
4254
- --n_remain;
4255
- }
4256
- if (n_remain == 0) {
4257
- code_points.push_back(value);
4258
- }
4259
- }
4260
- code_points.push_back(0);
4387
+ // Third encoder layer
4388
+ struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
4389
+ struct ggml_tensor * encoder_2_bias; // [64]
4261
4390
 
4262
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
4263
- }
4391
+ // Fourth encoder layer
4392
+ struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
4393
+ struct ggml_tensor * encoder_3_bias; // [128]
4264
4394
 
4265
- // returns true iff pos points to the end of one of the definitions of a rule
4266
- static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
4267
- switch (pos->type) {
4268
- case WHISPER_GRETYPE_END: return true; // NOLINT
4269
- case WHISPER_GRETYPE_ALT: return true; // NOLINT
4270
- default: return false;
4271
- }
4272
- }
4395
+ // LSTM decoder tensors
4396
+ struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
4397
+ struct ggml_tensor * lstm_ih_bias; // [512]
4398
+ struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
4399
+ struct ggml_tensor * lstm_hh_bias; // [512]
4273
4400
 
4274
- // returns true iff chr satisfies the char range at pos (regular or inverse range)
4275
- // asserts that pos is pointing to a char range element
4276
- static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
4277
- const whisper_grammar_element * pos,
4278
- const uint32_t chr) {
4401
+ // Final conv layer
4402
+ struct ggml_tensor * final_conv_weight; // [128]
4403
+ struct ggml_tensor * final_conv_bias; // [1]
4279
4404
 
4280
- bool found = false;
4281
- bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
4405
+ // ggml contexts
4406
+ std::vector<ggml_context *> ctxs;
4282
4407
 
4283
- WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
4408
+ // buffer for the model tensors
4409
+ std::vector<ggml_backend_buffer_t> buffers;
4284
4410
 
4285
- do {
4286
- if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
4287
- // inclusive range, e.g. [a-z]
4288
- found = found || (pos->value <= chr && chr <= pos[1].value);
4289
- pos += 2;
4290
- } else {
4291
- // exact char match, e.g. [a] or "a"
4411
+ // tensors
4412
+ int n_loaded;
4413
+ std::map<std::string, struct ggml_tensor *> tensors;
4414
+ };
4415
+
4416
+ struct whisper_vad_segment {
4417
+ int64_t start;
4418
+ int64_t end;
4419
+ };
4420
+
4421
+ struct whisper_vad_segments {
4422
+ std::vector<whisper_vad_segment> data;
4423
+ };
4424
+
4425
+ struct whisper_vad_context {
4426
+ int64_t t_vad_us = 0;
4427
+
4428
+ int n_window;
4429
+ int n_context;
4430
+ int n_threads;
4431
+
4432
+ std::vector<ggml_backend_t> backends;
4433
+ ggml_backend_buffer_t buffer = nullptr;
4434
+ whisper_context_params params;
4435
+ std::vector<uint8_t> ctx_buf;
4436
+ whisper_sched sched;
4437
+
4438
+ whisper_vad_model model;
4439
+ std::string path_model;
4440
+ struct ggml_tensor * h_state;
4441
+ struct ggml_tensor * c_state;
4442
+ std::vector<float> probs;
4443
+ };
4444
+
4445
+ struct whisper_vad_context_params whisper_vad_default_context_params(void) {
4446
+ whisper_vad_context_params result = {
4447
+ /*.n_thread = */ 4,
4448
+ /*.use_gpu = */ false,
4449
+ /*.gpu_device = */ 0,
4450
+ };
4451
+ return result;
4452
+ }
4453
+
4454
+ struct whisper_vad_params whisper_vad_default_params(void) {
4455
+ whisper_vad_params result = {
4456
+ /* threshold = */ 0.5f,
4457
+ /* min_speech_duration_ms = */ 250,
4458
+ /* min_silence_duration_ms = */ 100,
4459
+ /* max_speech_duration_s = */ FLT_MAX,
4460
+ /* speech_pad_ms = */ 30,
4461
+ /* samples_overlap = */ 0.1,
4462
+ };
4463
+ return result;
4464
+ }
4465
+
4466
+ // Time conversion utility functions for whisper VAD
4467
+ static int cs_to_samples(int64_t cs) {
4468
+ return (int)((cs / 100.0) * WHISPER_SAMPLE_RATE + 0.5);
4469
+ }
4470
+
4471
+ static int64_t samples_to_cs(int samples) {
4472
+ return (int64_t)((samples / (double)WHISPER_SAMPLE_RATE) * 100.0 + 0.5);
4473
+ }
4474
+
4475
+ static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
4476
+ bool op_supported = true;
4477
+
4478
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
4479
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
4480
+ // GPU and default CPU backend support all operators
4481
+ op_supported = true;
4482
+ } else {
4483
+ switch (op) {
4484
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
4485
+ case GGML_OP_MUL_MAT: {
4486
+ ggml_init_params params = {
4487
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
4488
+ /*.mem_buffer =*/ nullptr,
4489
+ /*.no_alloc =*/ true,
4490
+ };
4491
+
4492
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
4493
+ if (!ctx_ptr) {
4494
+ throw std::runtime_error("failed to create ggml context");
4495
+ }
4496
+ ggml_context * ctx = ctx_ptr.get();
4497
+
4498
+ ggml_tensor * op_tensor = nullptr;
4499
+
4500
+ int64_t n_ctx = hparams.lstm_hidden_size;
4501
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
4502
+ op_tensor = ggml_mul_mat(ctx, w, b);
4503
+
4504
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
4505
+ GGML_ASSERT(w->buffer == nullptr);
4506
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
4507
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
4508
+ ggml_backend_buffer_free(w->buffer);
4509
+ w->buffer = nullptr;
4510
+ break;
4511
+ }
4512
+ default: {
4513
+ op_supported = false;
4514
+ break;
4515
+ }
4516
+ };
4517
+ }
4518
+ return op_supported;
4519
+ }
4520
+
4521
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
4522
+ GGML_ASSERT(!buft_list.empty());
4523
+ for (const auto & p : buft_list) {
4524
+ ggml_backend_dev_t dev = p.first;
4525
+ ggml_backend_buffer_type_t buft = p.second;
4526
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
4527
+ return buft;
4528
+ }
4529
+ }
4530
+
4531
+ return nullptr;
4532
+ }
4533
+
4534
+ static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
4535
+ const whisper_vad_model & model, ggml_tensor * cur) {
4536
+ // Apply reflective padding to the input tensor
4537
+ ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
4538
+
4539
+ struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
4540
+
4541
+ // Calculate cutoff for real/imaginary parts
4542
+ int cutoff = model.stft_forward_basis->ne[2] / 2;
4543
+
4544
+ // Extract real part (first half of the STFT output).
4545
+ struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
4546
+ // Extract imaginary part (second half of the STFT output).
4547
+ struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
4548
+
4549
+ // Calculate magnitude: sqrt(real^2 + imag^2)
4550
+ struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
4551
+ struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
4552
+ struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
4553
+ struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
4554
+ return magnitude;
4555
+ }
4556
+
4557
+ static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
4558
+ const whisper_vad_model & model, ggml_tensor * cur) {
4559
+ // First Conv1D: expands to 128 channels.
4560
+ cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
4561
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
4562
+ cur = ggml_relu(ctx0, cur);
4563
+
4564
+ // Second Conv1D: reduces to 64 channels.
4565
+ cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
4566
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
4567
+ cur = ggml_relu(ctx0, cur);
4568
+
4569
+ // Third Conv1D: maintains 64 channels
4570
+ cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
4571
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
4572
+ cur = ggml_relu(ctx0, cur);
4573
+
4574
+ // Fourth Conv1D: expands to 128 channels
4575
+ cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
4576
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
4577
+ cur = ggml_relu(ctx0, cur);
4578
+
4579
+ return cur;
4580
+ }
4581
+
4582
+ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
4583
+ const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
4584
+ const whisper_vad_model & model = vctx.model;
4585
+ const int hdim = model.hparams.lstm_hidden_size;
4586
+
4587
+ struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
4588
+
4589
+ // Create operations using the input-to-hidden weights.
4590
+ struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
4591
+ inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
4592
+
4593
+ // Create operations using the hidden-to-hidden weights.
4594
+ struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
4595
+ hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
4596
+
4597
+ // Create add operation to get preactivations for all gates.
4598
+ struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
4599
+
4600
+ const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
4601
+
4602
+ // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
4603
+ struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
4604
+
4605
+ // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
4606
+ struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
4607
+
4608
+ // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
4609
+ struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
4610
+
4611
+ // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
4612
+ struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
4613
+
4614
+ // Update cell state
4615
+ struct ggml_tensor * c_out = ggml_add(ctx0,
4616
+ ggml_mul(ctx0, f_t, vctx.c_state),
4617
+ ggml_mul(ctx0, i_t, g_t));
4618
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
4619
+
4620
+ // Update hidden state
4621
+ struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
4622
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
4623
+
4624
+ return out;
4625
+ }
4626
+
4627
+ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
4628
+ const auto & model = vctx.model;
4629
+
4630
+ struct ggml_init_params params = {
4631
+ /*.mem_size =*/ vctx.sched.meta.size(),
4632
+ /*.mem_buffer =*/ vctx.sched.meta.data(),
4633
+ /*.no_alloc =*/ true,
4634
+ };
4635
+
4636
+ struct ggml_context * ctx0 = ggml_init(params);
4637
+
4638
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
4639
+
4640
+ struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
4641
+ ggml_set_name(frame, "frame");
4642
+ ggml_set_input(frame);
4643
+
4644
+ struct ggml_tensor * cur = nullptr;
4645
+ {
4646
+ cur = whisper_vad_build_stft_layer(ctx0, model, frame);
4647
+
4648
+ cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4649
+
4650
+ // Extract the first element of the first dimension
4651
+ // (equivalent to pytorch's [:, :, 0])
4652
+ cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
4653
+
4654
+ cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
4655
+ cur = ggml_relu(ctx0, cur);
4656
+ cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
4657
+ cur = ggml_add(ctx0, cur, model.final_conv_bias);
4658
+ cur = ggml_sigmoid(ctx0, cur);
4659
+ ggml_set_name(cur, "prob");
4660
+ ggml_set_output(cur);
4661
+ }
4662
+
4663
+ ggml_build_forward_expand(gf, cur);
4664
+
4665
+ ggml_free(ctx0);
4666
+
4667
+ return gf;
4668
+ }
4669
+
4670
+ static bool whisper_vad_init_context(whisper_vad_context * vctx) {
4671
+
4672
+ auto whisper_context_params = whisper_context_default_params();
4673
+ // TODO: GPU VAD is forced disabled until the performance is improved
4674
+ //whisper_context_params.use_gpu = vctx->params.use_gpu;
4675
+ whisper_context_params.use_gpu = false;
4676
+ whisper_context_params.gpu_device = vctx->params.gpu_device;
4677
+
4678
+ vctx->backends = whisper_backend_init(whisper_context_params);
4679
+ if (vctx->backends.empty()) {
4680
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
4681
+ return false;
4682
+ }
4683
+
4684
+ const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
4685
+
4686
+ vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
4687
+
4688
+ struct ggml_init_params params = {
4689
+ /*.mem_size =*/ vctx->ctx_buf.size(),
4690
+ /*.mem_buffer =*/ vctx->ctx_buf.data(),
4691
+ /*.no_alloc =*/ true,
4692
+ };
4693
+
4694
+ ggml_context * ctx = ggml_init(params);
4695
+ if (!ctx) {
4696
+ WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
4697
+ return false;
4698
+ }
4699
+
4700
+ // LSTM Hidden state
4701
+ vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4702
+ ggml_set_name(vctx->h_state, "h_state");
4703
+
4704
+ // LSTM Cell state
4705
+ vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4706
+ ggml_set_name(vctx->c_state, "c_state");
4707
+
4708
+ vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
4709
+ if (!vctx->buffer) {
4710
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
4711
+ return false;
4712
+ }
4713
+
4714
+ {
4715
+ bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
4716
+ [&]() {
4717
+ return whisper_vad_build_graph(*vctx);
4718
+ });
4719
+
4720
+ if (!ok) {
4721
+ WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
4722
+ return false;
4723
+ }
4724
+
4725
+ WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
4726
+ }
4727
+
4728
+ return true;
4729
+ }
4730
+
4731
+ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
4732
+ const char * path_model,
4733
+ struct whisper_vad_context_params params) {
4734
+ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
4735
+ #ifdef _MSC_VER
4736
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
4737
+ std::wstring path_model_wide = converter.from_bytes(path_model);
4738
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
4739
+ #else
4740
+ auto fin = std::ifstream(path_model, std::ios::binary);
4741
+ #endif
4742
+ if (!fin) {
4743
+ WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
4744
+ return nullptr;
4745
+ }
4746
+
4747
+ whisper_model_loader loader = {};
4748
+ loader.context = &fin;
4749
+
4750
+ loader.read = [](void * ctx, void * output, size_t read_size) {
4751
+ std::ifstream * fin = (std::ifstream*)ctx;
4752
+ fin->read((char *)output, read_size);
4753
+ return read_size;
4754
+ };
4755
+
4756
+ loader.eof = [](void * ctx) {
4757
+ std::ifstream * fin = (std::ifstream*)ctx;
4758
+ return fin->eof();
4759
+ };
4760
+
4761
+ loader.close = [](void * ctx) {
4762
+ std::ifstream * fin = (std::ifstream*)ctx;
4763
+ fin->close();
4764
+ };
4765
+
4766
+ auto ctx = whisper_vad_init_with_params(&loader, params);
4767
+ if (!ctx) {
4768
+ whisper_vad_free(ctx);
4769
+ return nullptr;
4770
+ }
4771
+ ctx->path_model = path_model;
4772
+ return ctx;
4773
+ }
4774
+
4775
+ struct whisper_vad_context * whisper_vad_init_with_params(
4776
+ struct whisper_model_loader * loader,
4777
+ struct whisper_vad_context_params params) {
4778
+ // Read the VAD model
4779
+ {
4780
+ uint32_t magic;
4781
+ read_safe(loader, magic);
4782
+ if (magic != GGML_FILE_MAGIC) {
4783
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
4784
+ return nullptr;
4785
+ }
4786
+ }
4787
+
4788
+ whisper_vad_context * vctx = new whisper_vad_context;
4789
+ vctx->n_threads = params.n_threads;
4790
+ vctx->params.use_gpu = params.use_gpu;
4791
+ vctx->params.gpu_device = params.gpu_device;
4792
+
4793
+ auto & model = vctx->model;
4794
+ auto & hparams = model.hparams;
4795
+
4796
+ // load model context params.
4797
+ {
4798
+ int32_t str_len;
4799
+ read_safe(loader, str_len);
4800
+ std::vector<char> buffer(str_len + 1, 0);
4801
+ loader->read(loader->context, buffer.data(), str_len);
4802
+ std::string model_type(buffer.data(), str_len);
4803
+ model.type = model_type;
4804
+ WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
4805
+
4806
+ int32_t major, minor, patch;
4807
+ read_safe(loader, major);
4808
+ read_safe(loader, minor);
4809
+ read_safe(loader, patch);
4810
+ std::string version_str = std::to_string(major) + "." +
4811
+ std::to_string(minor) + "." +
4812
+ std::to_string(patch);
4813
+ model.version = version_str;
4814
+ WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
4815
+
4816
+ read_safe(loader, vctx->n_window);
4817
+ read_safe(loader, vctx->n_context);
4818
+ }
4819
+
4820
+ // load model hyper params (hparams).
4821
+ {
4822
+ read_safe(loader, hparams.n_encoder_layers);
4823
+
4824
+ hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
4825
+ hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
4826
+ hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
4827
+
4828
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4829
+ read_safe(loader, hparams.encoder_in_channels[i]);
4830
+ read_safe(loader, hparams.encoder_out_channels[i]);
4831
+ read_safe(loader, hparams.kernel_sizes[i]);
4832
+ }
4833
+
4834
+ read_safe(loader, hparams.lstm_input_size);
4835
+ read_safe(loader, hparams.lstm_hidden_size);
4836
+ read_safe(loader, hparams.final_conv_in);
4837
+ read_safe(loader, hparams.final_conv_out);
4838
+
4839
+ WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
4840
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4841
+ WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
4842
+ }
4843
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4844
+ WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
4845
+ }
4846
+ WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
4847
+ WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
4848
+ WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
4849
+ WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
4850
+ }
4851
+
4852
+ // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
4853
+ const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
4854
+
4855
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4856
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4857
+ auto it = ctx_map.find(buft);
4858
+ if (it == ctx_map.end()) {
4859
+ ggml_init_params params = {
4860
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4861
+ /*.mem_buffer =*/ nullptr,
4862
+ /*.no_alloc =*/ true,
4863
+ };
4864
+
4865
+ ggml_context * ctx = ggml_init(params);
4866
+ if (!ctx) {
4867
+ throw std::runtime_error("failed to create ggml context");
4868
+ }
4869
+
4870
+ ctx_map[buft] = ctx;
4871
+ model.ctxs.emplace_back(ctx);
4872
+
4873
+ return ctx;
4874
+ }
4875
+
4876
+ return it->second;
4877
+ };
4878
+
4879
+ whisper_context_params wparams = whisper_context_default_params();
4880
+ wparams.use_gpu = params.use_gpu;
4881
+ wparams.gpu_device = params.gpu_device;
4882
+ buft_list_t buft_list = make_buft_list(wparams);
4883
+
4884
+ auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
4885
+ ggml_op op = VAD_TENSOR_OPS.at(type);
4886
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
4887
+ if (!buft) {
4888
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
4889
+ }
4890
+ ggml_context * ctx = get_ctx(buft);
4891
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
4892
+ model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
4893
+
4894
+ return tensor;
4895
+ };
4896
+
4897
+ // create tensors
4898
+ {
4899
+ ggml_init_params params = {
4900
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4901
+ /*.mem_buffer =*/ nullptr,
4902
+ /*.no_alloc =*/ true,
4903
+ };
4904
+
4905
+ ggml_context * ctx = ggml_init(params);
4906
+ const auto & hparams = model.hparams;
4907
+
4908
+ // SFTF precomputed basis matrix
4909
+ model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
4910
+ ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
4911
+
4912
+ model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
4913
+ ggml_new_tensor_3d(
4914
+ ctx,
4915
+ GGML_TYPE_F16,
4916
+ hparams.kernel_sizes[0],
4917
+ hparams.encoder_in_channels[0],
4918
+ hparams.encoder_out_channels[0]
4919
+ ));
4920
+ model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
4921
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
4922
+
4923
+ model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
4924
+ ggml_new_tensor_3d(
4925
+ ctx,
4926
+ GGML_TYPE_F16,
4927
+ hparams.kernel_sizes[1],
4928
+ hparams.encoder_in_channels[1],
4929
+ hparams.encoder_out_channels[1]
4930
+ ));
4931
+ model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
4932
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
4933
+
4934
+ model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
4935
+ ggml_new_tensor_3d(
4936
+ ctx,
4937
+ GGML_TYPE_F16,
4938
+ hparams.kernel_sizes[2],
4939
+ hparams.encoder_in_channels[2],
4940
+ hparams.encoder_out_channels[2]
4941
+ ));
4942
+ model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
4943
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
4944
+
4945
+ model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
4946
+ ggml_new_tensor_3d(
4947
+ ctx,
4948
+ GGML_TYPE_F16,
4949
+ hparams.kernel_sizes[3],
4950
+ hparams.encoder_in_channels[3],
4951
+ hparams.encoder_out_channels[3]
4952
+ ));
4953
+ model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
4954
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
4955
+
4956
+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
4957
+ const int hstate_dim = hparams.lstm_hidden_size * 4;
4958
+
4959
+ // LSTM weights - input to hidden
4960
+ model.lstm_ih_weight = create_tensor(
4961
+ VAD_TENSOR_LSTM_WEIGHT_IH,
4962
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4963
+ );
4964
+ model.lstm_ih_bias = create_tensor(
4965
+ VAD_TENSOR_LSTM_BIAS_IH,
4966
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4967
+ );
4968
+
4969
+ // LSTM weights - hidden to hidden
4970
+ model.lstm_hh_weight = create_tensor(
4971
+ VAD_TENSOR_LSTM_WEIGHT_HH,
4972
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4973
+ );
4974
+ model.lstm_hh_bias = create_tensor(
4975
+ VAD_TENSOR_LSTM_BIAS_HH,
4976
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4977
+ );
4978
+
4979
+ // Final conv layer weight
4980
+ model.final_conv_weight = create_tensor(
4981
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
4982
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
4983
+ );
4984
+ model.final_conv_bias = create_tensor(
4985
+ VAD_TENSOR_FINAL_CONV_BIAS,
4986
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
4987
+ );
4988
+
4989
+ ggml_free(ctx);
4990
+ }
4991
+
4992
+ // allocate tensors in the backend buffers
4993
+ for (auto & p : ctx_map) {
4994
+ ggml_backend_buffer_type_t buft = p.first;
4995
+ ggml_context * ctx = p.second;
4996
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
4997
+ if (buf) {
4998
+ model.buffers.emplace_back(buf);
4999
+
5000
+ size_t size_main = ggml_backend_buffer_get_size(buf);
5001
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
5002
+ }
5003
+ }
5004
+
5005
+ // load weights
5006
+ {
5007
+ size_t total_size = 0;
5008
+ model.n_loaded = 0;
5009
+ std::vector<char> read_buf;
5010
+
5011
+ while (true) {
5012
+ int32_t n_dims;
5013
+ int32_t length;
5014
+ int32_t ttype;
5015
+
5016
+ read_safe(loader, n_dims);
5017
+ read_safe(loader, length);
5018
+ read_safe(loader, ttype);
5019
+
5020
+ if (loader->eof(loader->context)) {
5021
+ break;
5022
+ }
5023
+
5024
+ int32_t nelements = 1;
5025
+ int32_t ne[4] = { 1, 1, 1, 1 };
5026
+ for (int i = 0; i < n_dims; ++i) {
5027
+ read_safe(loader, ne[i]);
5028
+ nelements *= ne[i];
5029
+ }
5030
+
5031
+ std::string name;
5032
+ std::vector<char> tmp(length);
5033
+ loader->read(loader->context, &tmp[0], tmp.size());
5034
+ name.assign(&tmp[0], tmp.size());
5035
+
5036
+ if (model.tensors.find(name) == model.tensors.end()) {
5037
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
5038
+ return nullptr;
5039
+ }
5040
+
5041
+ auto tensor = model.tensors[name.data()];
5042
+
5043
+ if (ggml_nelements(tensor) != nelements) {
5044
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
5045
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
5046
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
5047
+ return nullptr;
5048
+ }
5049
+
5050
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
5051
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
5052
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
5053
+ return nullptr;
5054
+ }
5055
+
5056
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
5057
+
5058
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
5059
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
5060
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
5061
+ return nullptr;
5062
+ }
5063
+
5064
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
5065
+ // for the CPU and Metal backend, we can read directly into the tensor
5066
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
5067
+ BYTESWAP_TENSOR(tensor);
5068
+ } else {
5069
+ // read into a temporary buffer first, then copy to device memory
5070
+ read_buf.resize(ggml_nbytes(tensor));
5071
+
5072
+ loader->read(loader->context, read_buf.data(), read_buf.size());
5073
+
5074
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
5075
+ }
5076
+
5077
+ total_size += ggml_nbytes(tensor);
5078
+ model.n_loaded++;
5079
+ }
5080
+
5081
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
5082
+
5083
+ if (model.n_loaded == 0) {
5084
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
5085
+ } else if (model.n_loaded != (int) model.tensors.size()) {
5086
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
5087
+ return nullptr;
5088
+ }
5089
+
5090
+ }
5091
+
5092
+ if (!whisper_vad_init_context(vctx)) {
5093
+ whisper_vad_free(vctx);
5094
+ return nullptr;
5095
+ }
5096
+
5097
+ return vctx;
5098
+ }
5099
+
5100
+ bool whisper_vad_detect_speech(
5101
+ struct whisper_vad_context * vctx,
5102
+ const float * samples,
5103
+ int n_samples) {
5104
+ int n_chunks = n_samples / vctx->n_window;
5105
+ if (n_samples % vctx->n_window != 0) {
5106
+ n_chunks += 1; // Add one more chunk for remaining samples.
5107
+ }
5108
+
5109
+ WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
5110
+ WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
5111
+
5112
+ // Reset LSTM hidden/cell states
5113
+ ggml_backend_buffer_clear(vctx->buffer, 0);
5114
+
5115
+ vctx->probs.resize(n_chunks);
5116
+ WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
5117
+
5118
+ std::vector<float> window(vctx->n_window, 0.0f);
5119
+
5120
+ auto & sched = vctx->sched.sched;
5121
+
5122
+ ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
5123
+
5124
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
5125
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
5126
+ return false;
5127
+ }
5128
+
5129
+ struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
5130
+ struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
5131
+
5132
+ // we are going to reuse the graph multiple times for each chunk
5133
+ const int64_t t_start_vad_us = ggml_time_us();
5134
+
5135
+ for (int i = 0; i < n_chunks; i++) {
5136
+ const int idx_start = i * vctx->n_window;
5137
+ const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
5138
+
5139
+ const int chunk_len = idx_end - idx_start;
5140
+
5141
+ if (chunk_len < vctx->n_window) {
5142
+ WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
5143
+ std::vector<float> partial_chunk(vctx->n_window, 0.0f);
5144
+ std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
5145
+
5146
+ // Copy the zero-padded chunk to the window.
5147
+ const int samples_to_copy_max = vctx->n_window;
5148
+ const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
5149
+ std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
5150
+ if (samples_to_copy_cur < samples_to_copy_max) {
5151
+ std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
5152
+ }
5153
+ } else {
5154
+ // Copy current frame samples to the window.
5155
+ const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
5156
+ std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
5157
+ }
5158
+
5159
+ // Set the frame tensor data with the samples.
5160
+ ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
5161
+
5162
+ // do not reset the scheduler - we will reuse the graph in the next chunk
5163
+ if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
5164
+ WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
5165
+ break;
5166
+ }
5167
+
5168
+ // Get the probability for this chunk.
5169
+ ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
5170
+
5171
+ //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
5172
+ }
5173
+
5174
+ vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
5175
+ WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
5176
+
5177
+ ggml_backend_sched_reset(sched);
5178
+
5179
+ return true;
5180
+ }
5181
+
5182
+ int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
5183
+ return segments->data.size();
5184
+ }
5185
+
5186
+ float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
5187
+ return segments->data[i_segment].start;
5188
+ }
5189
+
5190
+ float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
5191
+ return segments->data[i_segment].end;
5192
+ }
5193
+
5194
+ int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
5195
+ return vctx->probs.size();
5196
+ }
5197
+
5198
+ float * whisper_vad_probs(struct whisper_vad_context * vctx) {
5199
+ return vctx->probs.data();
5200
+ }
5201
+
5202
+ struct whisper_vad_segments * whisper_vad_segments_from_probs(
5203
+ struct whisper_vad_context * vctx,
5204
+ whisper_vad_params params) {
5205
+ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
5206
+
5207
+ int n_probs = whisper_vad_n_probs(vctx);
5208
+ float * probs = whisper_vad_probs(vctx);
5209
+ float threshold = params.threshold;
5210
+ int min_speech_duration_ms = params.min_speech_duration_ms;
5211
+ int min_silence_duration_ms = params.min_silence_duration_ms;
5212
+ float max_speech_duration_s = params.max_speech_duration_s;
5213
+ int speech_pad_ms = params.speech_pad_ms;
5214
+ int n_window = vctx->n_window;
5215
+ int sample_rate = WHISPER_SAMPLE_RATE;
5216
+ int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
5217
+ int audio_length_samples = n_probs * n_window;
5218
+
5219
+ // Min number of samples to be considered valid speech.
5220
+ int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
5221
+ int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
5222
+
5223
+ // Max number of samples that a speech segment can contain before it is
5224
+ // split into multiple segments.
5225
+ int max_speech_samples;
5226
+ if (max_speech_duration_s > 100000.0f) {
5227
+ max_speech_samples = INT_MAX / 2;
5228
+ } else {
5229
+ int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
5230
+ max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
5231
+ if (max_speech_samples < 0) {
5232
+ max_speech_samples = INT_MAX / 2;
5233
+ }
5234
+ }
5235
+ // Detect silence period that exceeds this value, then that location (sample)
5236
+ // is marked as a potential place where the segment could be split if
5237
+ // max_speech_samples is reached. The value 98 was taken from the original
5238
+ // silaro-vad python implementation:
5239
+ //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
5240
+ int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
5241
+
5242
+ // Calculate lower threshold for detecting end of speech segments.
5243
+ float neg_threshold = threshold - 0.15f;
5244
+ if (neg_threshold < 0.01f) {
5245
+ neg_threshold = 0.01f;
5246
+ }
5247
+
5248
+ struct speech_segment_t {
5249
+ int start;
5250
+ int end;
5251
+ };
5252
+
5253
+ std::vector<speech_segment_t> speeches;
5254
+ speeches.reserve(256);
5255
+
5256
+ bool is_speech_segment = false;
5257
+ int temp_end = 0;
5258
+ int prev_end = 0;
5259
+ int next_start = 0;
5260
+ int curr_speech_start = 0;
5261
+ bool has_curr_speech = false;
5262
+
5263
+ for (int i = 0; i < n_probs; i++) {
5264
+ float curr_prob = probs[i];
5265
+ int curr_sample = n_window * i;
5266
+
5267
+ // Reset temp_end when we get back to speech
5268
+ if ((curr_prob >= threshold) && temp_end) {
5269
+ temp_end = 0;
5270
+ if (next_start < prev_end) {
5271
+ next_start = curr_sample;
5272
+ }
5273
+ }
5274
+
5275
+ // Start a new speech segment when probability exceeds threshold and not already in speech
5276
+ if ((curr_prob >= threshold) && !is_speech_segment) {
5277
+ is_speech_segment = true;
5278
+ curr_speech_start = curr_sample;
5279
+ has_curr_speech = true;
5280
+ continue;
5281
+ }
5282
+
5283
+ // Handle maximum speech duration
5284
+ if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
5285
+ if (prev_end) {
5286
+ speeches.push_back({ curr_speech_start, prev_end });
5287
+ has_curr_speech = true;
5288
+
5289
+ if (next_start < prev_end) { // Previously reached silence and is still not speech
5290
+ is_speech_segment = false;
5291
+ has_curr_speech = false;
5292
+ } else {
5293
+ curr_speech_start = next_start;
5294
+ }
5295
+ prev_end = next_start = temp_end = 0;
5296
+ } else {
5297
+ speeches.push_back({ curr_speech_start, curr_sample });
5298
+
5299
+ prev_end = next_start = temp_end = 0;
5300
+ is_speech_segment = false;
5301
+ has_curr_speech = false;
5302
+ continue;
5303
+ }
5304
+ }
5305
+
5306
+ // Handle silence after speech
5307
+ if ((curr_prob < neg_threshold) && is_speech_segment) {
5308
+ if (!temp_end) {
5309
+ temp_end = curr_sample;
5310
+ }
5311
+
5312
+ // Track potential segment ends for max_speech handling
5313
+ if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
5314
+ prev_end = temp_end;
5315
+ }
5316
+
5317
+ // Check if silence is long enough to end the segment
5318
+ if ((curr_sample - temp_end) < min_silence_samples) {
5319
+ continue;
5320
+ } else {
5321
+ // End the segment if it's long enough
5322
+ if ((temp_end - curr_speech_start) > min_speech_samples) {
5323
+ speeches.push_back({ curr_speech_start, temp_end });
5324
+ }
5325
+
5326
+ prev_end = next_start = temp_end = 0;
5327
+ is_speech_segment = false;
5328
+ has_curr_speech = false;
5329
+ continue;
5330
+ }
5331
+ }
5332
+ }
5333
+
5334
+ // Handle the case if we're still in a speech segment at the end
5335
+ if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
5336
+ speeches.push_back({ curr_speech_start, audio_length_samples });
5337
+ }
5338
+
5339
+ // Merge adjacent segments with small gaps in between (post-processing)
5340
+ if (speeches.size() > 1) {
5341
+ int merged_count = 0;
5342
+ for (int i = 0; i < (int) speeches.size() - 1; i++) {
5343
+ // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
5344
+ int max_merge_gap_samples = sample_rate * 200 / 1000;
5345
+
5346
+ // If the gap between this segment and the next is small enough
5347
+ if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
5348
+ // Merge by extending current segment to the end of next segment
5349
+ speeches[i].end = speeches[i+1].end;
5350
+ speeches.erase(speeches.begin() + i + 1);
5351
+
5352
+ i--;
5353
+ merged_count++;
5354
+ }
5355
+ }
5356
+ WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
5357
+ __func__, merged_count, (int) speeches.size());
5358
+ }
5359
+
5360
+ // Double-check for minimum speech duration
5361
+ for (int i = 0; i < (int) speeches.size(); i++) {
5362
+ if (speeches[i].end - speeches[i].start < min_speech_samples) {
5363
+ WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
5364
+ __func__, i, speeches[i].end - speeches[i].start);
5365
+
5366
+ speeches.erase(speeches.begin() + i);
5367
+ i--;
5368
+ }
5369
+ }
5370
+
5371
+ WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
5372
+
5373
+ // Allocate final segments
5374
+ std::vector<whisper_vad_segment> segments;
5375
+ if (speeches.size() > 0) {
5376
+ try {
5377
+ segments.resize(speeches.size());
5378
+ } catch (const std::bad_alloc &) {
5379
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
5380
+ return nullptr;
5381
+ }
5382
+ }
5383
+
5384
+ // Apply padding to segments and copy to final segments
5385
+ for (int i = 0; i < (int) speeches.size(); i++) {
5386
+ // Apply padding to the start of the first segment
5387
+ if (i == 0) {
5388
+ speeches[i].start =
5389
+ (speeches[i].start > speech_pad_samples) ?
5390
+ (speeches[i].start - speech_pad_samples) : 0;
5391
+ }
5392
+
5393
+ // Handle spacing between segments
5394
+ if (i < (int) speeches.size() - 1) {
5395
+ int silence_duration = speeches[i+1].start - speeches[i].end;
5396
+
5397
+ if (silence_duration < 2 * speech_pad_samples) {
5398
+ // If segments are close, split the difference
5399
+ speeches[i].end += silence_duration / 2;
5400
+ speeches[i+1].start =
5401
+ (speeches[i+1].start > silence_duration / 2) ?
5402
+ (speeches[i+1].start - silence_duration / 2) : 0;
5403
+ } else {
5404
+ // Otherwise, apply full padding to both
5405
+ speeches[i].end =
5406
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5407
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5408
+ speeches[i+1].start =
5409
+ (speeches[i+1].start > speech_pad_samples) ?
5410
+ (speeches[i+1].start - speech_pad_samples) : 0;
5411
+ }
5412
+ } else {
5413
+ // Apply padding to the end of the last segment
5414
+ speeches[i].end =
5415
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5416
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5417
+ }
5418
+
5419
+ // Convert from samples to centiseconds
5420
+ segments[i].start = samples_to_cs(speeches[i].start);
5421
+ segments[i].end = samples_to_cs(speeches[i].end);
5422
+
5423
+ WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
5424
+ __func__, i, segments[i].start/100.0, segments[i].end/100.0, (segments[i].end - segments[i].start)/100.0);
5425
+ }
5426
+
5427
+ whisper_vad_segments * vad_segments = new whisper_vad_segments;
5428
+ if (vad_segments == NULL) {
5429
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
5430
+ return nullptr;
5431
+ }
5432
+
5433
+ vad_segments->data = std::move(segments);
5434
+
5435
+ return vad_segments;
5436
+ }
5437
+
5438
+ struct whisper_vad_segments * whisper_vad_segments_from_samples(
5439
+ whisper_vad_context * vctx,
5440
+ whisper_vad_params params,
5441
+ const float * samples,
5442
+ int n_samples) {
5443
+ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
5444
+ if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
5445
+ WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
5446
+ return nullptr;
5447
+ }
5448
+ return whisper_vad_segments_from_probs(vctx, params);
5449
+ }
5450
+
5451
+ void whisper_vad_free(whisper_vad_context * ctx) {
5452
+ if (ctx) {
5453
+ for (ggml_context * context : ctx->model.ctxs) {
5454
+ ggml_free(context);
5455
+ }
5456
+
5457
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
5458
+ ggml_backend_buffer_free(buf);
5459
+ }
5460
+
5461
+ ggml_backend_sched_free(ctx->sched.sched);
5462
+
5463
+ for (auto & backend : ctx->backends) {
5464
+ ggml_backend_free(backend);
5465
+ }
5466
+
5467
+
5468
+ delete ctx;
5469
+ }
5470
+ }
5471
+
5472
+ void whisper_vad_free_segments(whisper_vad_segments * segments) {
5473
+ if (segments) {
5474
+ delete segments;
5475
+ }
5476
+ }
5477
+
5478
+ //////////////////////////////////
5479
+ // Grammar - ported from llama.cpp
5480
+ //////////////////////////////////
5481
+
5482
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
5483
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
5484
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
5485
+ const char * src,
5486
+ whisper_partial_utf8 partial_start) {
5487
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
5488
+ const char * pos = src;
5489
+ std::vector<uint32_t> code_points;
5490
+ uint32_t value = partial_start.value;
5491
+ int n_remain = partial_start.n_remain;
5492
+
5493
+ // continue previous decode, if applicable
5494
+ while (*pos != 0 && n_remain > 0) {
5495
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
5496
+ if ((next_byte >> 6) != 2) {
5497
+ // invalid sequence, abort
5498
+ code_points.push_back(0);
5499
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
5500
+ }
5501
+ value = (value << 6) + (next_byte & 0x3F);
5502
+ ++pos;
5503
+ --n_remain;
5504
+ }
5505
+
5506
+ if (partial_start.n_remain > 0 && n_remain == 0) {
5507
+ code_points.push_back(value);
5508
+ }
5509
+
5510
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
5511
+ while (*pos != 0) {
5512
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
5513
+ uint8_t highbits = first_byte >> 4;
5514
+ n_remain = lookup[highbits] - 1;
5515
+
5516
+ if (n_remain < 0) {
5517
+ // invalid sequence, abort
5518
+ code_points.clear();
5519
+ code_points.push_back(0);
5520
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
5521
+ }
5522
+
5523
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
5524
+ value = first_byte & mask;
5525
+ ++pos;
5526
+ while (*pos != 0 && n_remain > 0) {
5527
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
5528
+ ++pos;
5529
+ --n_remain;
5530
+ }
5531
+ if (n_remain == 0) {
5532
+ code_points.push_back(value);
5533
+ }
5534
+ }
5535
+ code_points.push_back(0);
5536
+
5537
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
5538
+ }
5539
+
5540
+ // returns true iff pos points to the end of one of the definitions of a rule
5541
+ static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
5542
+ switch (pos->type) {
5543
+ case WHISPER_GRETYPE_END: return true; // NOLINT
5544
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
5545
+ default: return false;
5546
+ }
5547
+ }
5548
+
5549
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
5550
+ // asserts that pos is pointing to a char range element
5551
+ static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
5552
+ const whisper_grammar_element * pos,
5553
+ const uint32_t chr) {
5554
+
5555
+ bool found = false;
5556
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
5557
+
5558
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
5559
+
5560
+ do {
5561
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
5562
+ // inclusive range, e.g. [a-z]
5563
+ found = found || (pos->value <= chr && chr <= pos[1].value);
5564
+ pos += 2;
5565
+ } else {
5566
+ // exact char match, e.g. [a] or "a"
4292
5567
  found = found || pos->value == chr;
4293
5568
  pos += 1;
4294
5569
  }
@@ -4355,7 +5630,7 @@ static void whisper_grammar_advance_stack(
4355
5630
  std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4356
5631
 
4357
5632
  if (stack.empty()) {
4358
- new_stacks.push_back(stack);
5633
+ new_stacks.emplace_back();
4359
5634
  return;
4360
5635
  }
4361
5636
 
@@ -4676,7 +5951,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4676
5951
  /*.detect_language =*/ false,
4677
5952
 
4678
5953
  /*.suppress_blank =*/ true,
4679
- /*.suppress_non_speech_tokens =*/ false,
5954
+ /*.suppress_nst =*/ false,
4680
5955
 
4681
5956
  /*.temperature =*/ 0.0f,
4682
5957
  /*.max_initial_ts =*/ 1.0f,
@@ -4716,6 +5991,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4716
5991
  /*.n_grammar_rules =*/ 0,
4717
5992
  /*.i_start_rule =*/ 0,
4718
5993
  /*.grammar_penalty =*/ 100.0f,
5994
+
5995
+ /*.vad =*/ false,
5996
+ /*.vad_model_path =*/ nullptr,
5997
+
5998
+ /* vad_params =*/ whisper_vad_default_params(),
4719
5999
  };
4720
6000
 
4721
6001
  switch (strategy) {
@@ -4960,7 +6240,7 @@ static void whisper_process_logits(
4960
6240
 
4961
6241
  // suppress non-speech tokens
4962
6242
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4963
- if (params.suppress_non_speech_tokens) {
6243
+ if (params.suppress_nst) {
4964
6244
  for (const std::string & token : non_speech_tokens) {
4965
6245
  const std::string suppress_tokens[] = {token, " " + token};
4966
6246
  for (const std::string & suppress_token : suppress_tokens) {
@@ -5332,6 +6612,186 @@ static void whisper_sequence_score(
5332
6612
  }
5333
6613
  }
5334
6614
 
6615
+ static bool whisper_vad(
6616
+ struct whisper_context * ctx,
6617
+ struct whisper_state * state,
6618
+ struct whisper_full_params params,
6619
+ const float * samples,
6620
+ int n_samples,
6621
+ std::vector<float> & filtered_samples) {
6622
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6623
+ int filtered_n_samples = 0;
6624
+
6625
+ // Clear any existing mapping table
6626
+ state->vad_mapping_table.clear();
6627
+ state->has_vad_segments = false;
6628
+
6629
+ if (state->vad_context == nullptr) {
6630
+ struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
6631
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
6632
+ if (vctx == nullptr) {
6633
+ WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
6634
+ return false;
6635
+ }
6636
+ state->vad_context = vctx;
6637
+ }
6638
+ auto vctx = state->vad_context;
6639
+
6640
+ const whisper_vad_params & vad_params = params.vad_params;
6641
+
6642
+ whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
6643
+
6644
+ if (vad_segments->data.size() > 0) {
6645
+ state->has_vad_segments = true;
6646
+ ctx->state->vad_segments.clear();
6647
+ ctx->state->vad_segments.reserve(vad_segments->data.size());
6648
+
6649
+ // Initialize the time mapping table
6650
+ state->vad_mapping_table.clear();
6651
+ state->vad_mapping_table.reserve(vad_segments->data.size() * 4);
6652
+
6653
+ WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
6654
+ float overlap_seconds = vad_params.samples_overlap;
6655
+ int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
6656
+
6657
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6658
+ int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
6659
+ int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
6660
+
6661
+ if (i < (int)vad_segments->data.size() - 1) {
6662
+ segment_end_samples += overlap_samples;
6663
+ }
6664
+ segment_end_samples = std::min(segment_end_samples, n_samples - 1);
6665
+ filtered_n_samples += (segment_end_samples - segment_start_samples);
6666
+
6667
+ WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
6668
+ __func__, i, vad_segments->data[i].start/100.0,
6669
+ (vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)),
6670
+ (vad_segments->data[i].end - vad_segments->data[i].start)/100.0 +
6671
+ (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
6672
+ }
6673
+
6674
+ int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
6675
+ int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
6676
+ int total_samples_needed = filtered_n_samples + total_silence_samples;
6677
+
6678
+ WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
6679
+ __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
6680
+
6681
+ try {
6682
+ filtered_samples.resize(total_samples_needed);
6683
+ } catch (const std::bad_alloc & /* e */) {
6684
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
6685
+ whisper_vad_free_segments(vad_segments);
6686
+ whisper_vad_free(vctx);
6687
+ return false;
6688
+ }
6689
+
6690
+ int offset = 0;
6691
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6692
+ int segment_start_samples = cs_to_samples(vad_segments->data[i].start);
6693
+ int segment_end_samples = cs_to_samples(vad_segments->data[i].end);
6694
+
6695
+ if (i < (int)vad_segments->data.size() - 1) {
6696
+ segment_end_samples += overlap_samples;
6697
+ }
6698
+
6699
+ segment_start_samples = std::min(segment_start_samples, n_samples - 1);
6700
+ segment_end_samples = std::min(segment_end_samples, n_samples);
6701
+ int segment_length = segment_end_samples - segment_start_samples;
6702
+ if (segment_length > 0) {
6703
+ whisper_state::vad_segment_info segment;
6704
+
6705
+ segment.orig_start = vad_segments->data[i].start;
6706
+ segment.orig_end = vad_segments->data[i].end;
6707
+
6708
+ segment.vad_start = samples_to_cs(offset);
6709
+ segment.vad_end = samples_to_cs(offset + segment_length);
6710
+
6711
+ // Add segment boundaries to mapping table
6712
+ vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start};
6713
+ vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end};
6714
+
6715
+ state->vad_mapping_table.push_back(start_mapping);
6716
+ state->vad_mapping_table.push_back(end_mapping);
6717
+
6718
+ // Add intermediate points for longer segments to improve interpolation accuracy
6719
+ const int64_t min_segment_length = 100; // 1 second
6720
+ const int64_t point_interval = 20; // Add a point every 200ms
6721
+
6722
+ if (segment.vad_end - segment.vad_start > min_segment_length) {
6723
+ int64_t segment_duration = segment.vad_end - segment.vad_start;
6724
+ int num_points = (int)(segment_duration / point_interval) - 1;
6725
+
6726
+ for (int j = 1; j <= num_points; j++) {
6727
+ int64_t vad_time = segment.vad_start + j * point_interval;
6728
+
6729
+ if (vad_time >= segment.vad_end) continue;
6730
+
6731
+ int64_t vad_elapsed = vad_time - segment.vad_start;
6732
+ int64_t vad_total = segment.vad_end - segment.vad_start;
6733
+ int64_t orig_total = segment.orig_end - segment.orig_start;
6734
+ int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total;
6735
+
6736
+ vad_time_mapping intermediate_mapping = {vad_time, orig_time};
6737
+ state->vad_mapping_table.push_back(intermediate_mapping);
6738
+ }
6739
+ }
6740
+
6741
+ WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
6742
+ __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0);
6743
+ ctx->state->vad_segments.push_back(segment);
6744
+
6745
+ // Copy this speech segment
6746
+ memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
6747
+ offset += segment_length;
6748
+
6749
+ // Add silence after this segment (except after the last segment)
6750
+ if (i < (int)vad_segments->data.size() - 1) {
6751
+ // Calculate the start and end time of the silence gap in processed audio
6752
+ int64_t silence_start_vad = samples_to_cs(offset);
6753
+ int64_t silence_end_vad = samples_to_cs(offset + silence_samples);
6754
+ // Calculate the corresponding original times
6755
+ int64_t orig_silence_start = segment.orig_end;
6756
+ int64_t orig_silence_end = vad_segments->data[i+1].start;
6757
+
6758
+ // Add mapping points for silence boundaries
6759
+ state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start});
6760
+ state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
6761
+
6762
+ // Fill with zeros (silence)
6763
+ memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
6764
+ offset += silence_samples;
6765
+ }
6766
+ }
6767
+ }
6768
+
6769
+ // Sort the mapping table by processed time
6770
+ std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
6771
+ [](const vad_time_mapping& a, const vad_time_mapping& b) {
6772
+ return a.processed_time < b.processed_time;
6773
+ });
6774
+
6775
+ // Remove any duplicate processed times to ensure monotonicity which is
6776
+ // needed for binary search and interpolation later.
6777
+ if (!state->vad_mapping_table.empty()) {
6778
+ auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(),
6779
+ [](const vad_time_mapping& a, const vad_time_mapping& b) {
6780
+ return a.processed_time == b.processed_time;
6781
+ });
6782
+ state->vad_mapping_table.erase(last, state->vad_mapping_table.end());
6783
+ }
6784
+
6785
+ WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size());
6786
+
6787
+ filtered_n_samples = offset;
6788
+ WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
6789
+ __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6790
+ }
6791
+
6792
+ return true;
6793
+ }
6794
+
5335
6795
  int whisper_full_with_state(
5336
6796
  struct whisper_context * ctx,
5337
6797
  struct whisper_state * state,
@@ -5381,11 +6841,13 @@ int whisper_full_with_state(
5381
6841
  const int seek_start = params.offset_ms/10;
5382
6842
  const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
5383
6843
 
5384
- // if length of spectrogram is less than 1.0s (100 frames), then return
5385
- // basically don't process anything that is less than 1.0s
5386
- // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5387
- if (seek_end < seek_start + 100) {
5388
- WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
6844
+ // if length of spectrogram is less than 100ms (10 frames), then return
6845
+ // basically don't process anything that is less than 100ms
6846
+ // ref: https://github.com/ggml-org/whisper.cpp/issues/2065
6847
+ const int delta_min = 10;
6848
+
6849
+ if (seek_end < seek_start + delta_min) {
6850
+ WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5389
6851
  return 0;
5390
6852
  }
5391
6853
 
@@ -5432,7 +6894,7 @@ int whisper_full_with_state(
5432
6894
  decoder.logprobs.resize(ctx->vocab.n_vocab);
5433
6895
  decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
5434
6896
 
5435
- decoder.rng = std::mt19937(0);
6897
+ decoder.rng = std::mt19937(j);
5436
6898
  }
5437
6899
 
5438
6900
  // the accumulated text context so far
@@ -5529,8 +6991,8 @@ int whisper_full_with_state(
5529
6991
  ctx, state, progress_cur, params.progress_callback_user_data);
5530
6992
  }
5531
6993
 
5532
- // if only 1 second left, then stop
5533
- if (seek + 100 >= seek_end) {
6994
+ // if only 100ms left, then stop
6995
+ if (seek + delta_min >= seek_end) {
5534
6996
  break;
5535
6997
  }
5536
6998
 
@@ -5877,10 +7339,10 @@ int whisper_full_with_state(
5877
7339
  // end of segment
5878
7340
  if (token.id == whisper_token_eot(ctx) || // end of text token
5879
7341
  (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
5880
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
7342
+ (has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
5881
7343
  ) {
5882
7344
  if (result_len == 0 && !params.no_timestamps) {
5883
- if (seek + seek_delta + 100 >= seek_end) {
7345
+ if (seek + seek_delta + delta_min >= seek_end) {
5884
7346
  result_len = i + 1;
5885
7347
  } else {
5886
7348
  WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
@@ -6147,7 +7609,7 @@ int whisper_full_with_state(
6147
7609
 
6148
7610
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
6149
7611
 
6150
- result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
7612
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6151
7613
  for (int j = i0; j <= i; j++) {
6152
7614
  result_all.back().tokens.push_back(tokens_cur[j]);
6153
7615
  }
@@ -6192,7 +7654,7 @@ int whisper_full_with_state(
6192
7654
  }
6193
7655
  }
6194
7656
 
6195
- result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
7657
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6196
7658
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
6197
7659
  result_all.back().tokens.push_back(tokens_cur[j]);
6198
7660
  }
@@ -6229,7 +7691,7 @@ int whisper_full_with_state(
6229
7691
  }
6230
7692
  }
6231
7693
 
6232
- // ref: https://github.com/ggerganov/whisper.cpp/pull/2629
7694
+ // ref: https://github.com/ggml-org/whisper.cpp/pull/2629
6233
7695
  const bool single_timestamp_ending = tokens_cur.size() > 1 &&
6234
7696
  tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
6235
7697
  tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
@@ -6253,6 +7715,21 @@ int whisper_full(
6253
7715
  struct whisper_full_params params,
6254
7716
  const float * samples,
6255
7717
  int n_samples) {
7718
+
7719
+ std::vector<float> vad_samples;
7720
+ if (params.vad) {
7721
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7722
+ if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7723
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7724
+ return -1;
7725
+ }
7726
+ if (vad_samples.empty()) {
7727
+ ctx->state->result_all.clear();
7728
+ return 0;
7729
+ }
7730
+ samples = vad_samples.data();
7731
+ n_samples = vad_samples.size();
7732
+ }
6256
7733
  return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
6257
7734
  }
6258
7735
 
@@ -6262,9 +7739,24 @@ int whisper_full_parallel(
6262
7739
  const float * samples,
6263
7740
  int n_samples,
6264
7741
  int n_processors) {
7742
+
6265
7743
  if (n_processors == 1) {
6266
7744
  return whisper_full(ctx, params, samples, n_samples);
6267
7745
  }
7746
+
7747
+ std::vector<float> vad_samples;
7748
+ if (params.vad) {
7749
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7750
+ if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7751
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7752
+ return -1;
7753
+ }
7754
+ if (vad_samples.empty()) {
7755
+ return 0;
7756
+ }
7757
+ samples = vad_samples.data();
7758
+ n_samples = vad_samples.size();
7759
+ }
6268
7760
  int ret = 0;
6269
7761
 
6270
7762
  // prepare separate states for each thread
@@ -6387,20 +7879,93 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
6387
7879
  return ctx->state->lang_id;
6388
7880
  }
6389
7881
 
6390
- int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
6391
- return state->result_all[i_segment].t0;
7882
+ static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector<vad_time_mapping> & mapping_table) {
7883
+ if (mapping_table.empty()) {
7884
+ return processed_time;
7885
+ }
7886
+
7887
+ if (processed_time <= mapping_table.front().processed_time) {
7888
+ return mapping_table.front().original_time; // Before first mapping point
7889
+ }
7890
+
7891
+ if (processed_time >= mapping_table.back().processed_time) {
7892
+ return mapping_table.back().original_time; // After last mapping point
7893
+ }
7894
+
7895
+ // Binary search over the time map that finds the first entry that has a
7896
+ // processed time greater than or equal to the current processed time.
7897
+ auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time,
7898
+ [](const vad_time_mapping & entry, int64_t time) {
7899
+ return entry.processed_time < time;
7900
+ }
7901
+ );
7902
+
7903
+ // If exact match found
7904
+ if (upper->processed_time == processed_time) {
7905
+ return upper->original_time;
7906
+ }
7907
+
7908
+ // Need to interpolate between two points
7909
+ auto lower = upper - 1;
7910
+
7911
+ int64_t processed_diff = upper->processed_time - lower->processed_time;
7912
+ int64_t original_diff = upper->original_time - lower->original_time;
7913
+ int64_t offset = processed_time - lower->processed_time;
7914
+
7915
+ if (processed_diff == 0) {
7916
+ return lower->original_time;
7917
+ }
7918
+
7919
+ // Perform linear interpolation
7920
+ return lower->original_time + (offset * original_diff) / processed_diff;
6392
7921
  }
6393
7922
 
6394
- int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
6395
- return ctx->state->result_all[i_segment].t0;
7923
+ // Function to get the starting timestamp of a segment
7924
+ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
7925
+ // If VAD wasn't used, return the original timestamp
7926
+ if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
7927
+ return state->result_all[i_segment].t0;
7928
+ }
7929
+
7930
+ // Get the processed timestamp
7931
+ int64_t t0 = state->result_all[i_segment].t0;
7932
+
7933
+ // Map to original time using the mapping table
7934
+ return map_processed_to_original_time(t0, state->vad_mapping_table);
6396
7935
  }
6397
7936
 
7937
+ // Function to get the ending timestamp of a segment
6398
7938
  int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
6399
- return state->result_all[i_segment].t1;
7939
+ // If VAD wasn't used, return the original timestamp
7940
+ if (!state->has_vad_segments || state->vad_mapping_table.empty()) {
7941
+ return state->result_all[i_segment].t1;
7942
+ }
7943
+
7944
+ // Get the processed timestamp
7945
+ int64_t t1 = state->result_all[i_segment].t1;
7946
+
7947
+ // Map to original time using the mapping table
7948
+ int64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table);
7949
+
7950
+ // Get the corresponding t0 for this segment
7951
+ int64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment);
7952
+
7953
+ // Ensure minimum duration to prevent zero-length segments
7954
+ const int64_t min_duration = 10; // 10ms minimum
7955
+ if (orig_t1 - orig_t0 < min_duration) {
7956
+ orig_t1 = orig_t0 + min_duration;
7957
+ }
7958
+
7959
+ return orig_t1;
7960
+ }
7961
+
7962
+
7963
+ int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
7964
+ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
6400
7965
  }
6401
7966
 
6402
7967
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
6403
- return ctx->state->result_all[i_segment].t1;
7968
+ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
6404
7969
  }
6405
7970
 
6406
7971
  bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
@@ -6459,6 +8024,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
6459
8024
  return ctx->state->result_all[i_segment].tokens[i_token].p;
6460
8025
  }
6461
8026
 
8027
+ float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
8028
+ return ctx->state->result_all[i_segment].no_speech_prob;
8029
+ }
8030
+
8031
+ float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
8032
+ return state->result_all[i_segment].no_speech_prob;
8033
+ }
8034
+
6462
8035
  // =================================================================================================
6463
8036
 
6464
8037
  //
@@ -6639,7 +8212,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6639
8212
  // c: N*N*sizeof(float)
6640
8213
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
6641
8214
  std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
6642
- std::vector<uint8_t> work;
6643
8215
 
6644
8216
  // put a bunch of random data in the buffer
6645
8217
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -6696,12 +8268,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6696
8268
  double tsum = 0.0;
6697
8269
 
6698
8270
  // heat-up
6699
- ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8271
+ ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6700
8272
 
6701
8273
  for (int i = 0; i < n_max; ++i) {
6702
8274
  const int64_t t0 = ggml_time_us();
6703
8275
 
6704
- ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8276
+ ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6705
8277
 
6706
8278
  const int64_t t1 = ggml_time_us();
6707
8279
 
@@ -6754,10 +8326,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6754
8326
  // token-level timestamps
6755
8327
  //
6756
8328
 
6757
- static int timestamp_to_sample(int64_t t, int n_samples) {
6758
- return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
6759
- }
6760
-
6761
8329
  static int64_t sample_to_timestamp(int i_sample) {
6762
8330
  return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
6763
8331
  }
@@ -6807,6 +8375,18 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
6807
8375
  return result;
6808
8376
  }
6809
8377
 
8378
+ static int timestamp_to_sample(int64_t t, int64_t segment_t0, int n_samples) {
8379
+ // Convert absolute timestamp to segment-relative timestamp
8380
+ int64_t relative_t = t - segment_t0;
8381
+ int sample = (int)((relative_t * WHISPER_SAMPLE_RATE) / 100);
8382
+ return std::max(0, std::min(n_samples - 1, sample));
8383
+ }
8384
+
8385
+ static int64_t sample_to_timestamp(int i_sample, int64_t segment_t0) {
8386
+ int64_t relative_timestamp = (100ll * i_sample) / WHISPER_SAMPLE_RATE;
8387
+ return relative_timestamp + segment_t0;
8388
+ }
8389
+
6810
8390
  static void whisper_exp_compute_token_level_timestamps(
6811
8391
  struct whisper_context & ctx,
6812
8392
  struct whisper_state & state,
@@ -6862,12 +8442,6 @@ static void whisper_exp_compute_token_level_timestamps(
6862
8442
 
6863
8443
  const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
6864
8444
 
6865
- tokens[j].id = token.id;
6866
- tokens[j].tid = token.tid;
6867
- tokens[j].p = token.p;
6868
- tokens[j].pt = token.pt;
6869
- tokens[j].ptsum = token.ptsum;
6870
-
6871
8445
  tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
6872
8446
 
6873
8447
  if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
@@ -6953,8 +8527,8 @@ static void whisper_exp_compute_token_level_timestamps(
6953
8527
  continue;
6954
8528
  }
6955
8529
 
6956
- int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
6957
- int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
8530
+ int s0 = timestamp_to_sample(tokens[j].t0, segment.t0, n_samples);
8531
+ int s1 = timestamp_to_sample(tokens[j].t1, segment.t0, n_samples);
6958
8532
 
6959
8533
  const int ss0 = std::max(s0 - hw, 0);
6960
8534
  const int ss1 = std::min(s1 + hw, n_samples);
@@ -6975,7 +8549,7 @@ static void whisper_exp_compute_token_level_timestamps(
6975
8549
  while (k > 0 && state.energy[k] > thold) {
6976
8550
  k--;
6977
8551
  }
6978
- tokens[j].t0 = sample_to_timestamp(k);
8552
+ tokens[j].t0 = sample_to_timestamp(k, segment.t0);
6979
8553
  if (tokens[j].t0 < tokens[j - 1].t1) {
6980
8554
  tokens[j].t0 = tokens[j - 1].t1;
6981
8555
  } else {
@@ -6986,7 +8560,7 @@ static void whisper_exp_compute_token_level_timestamps(
6986
8560
  k++;
6987
8561
  }
6988
8562
  s0 = k;
6989
- tokens[j].t0 = sample_to_timestamp(k);
8563
+ tokens[j].t0 = sample_to_timestamp(k, segment.t0);
6990
8564
  }
6991
8565
  }
6992
8566
 
@@ -6996,7 +8570,7 @@ static void whisper_exp_compute_token_level_timestamps(
6996
8570
  while (k < n_samples - 1 && state.energy[k] > thold) {
6997
8571
  k++;
6998
8572
  }
6999
- tokens[j].t1 = sample_to_timestamp(k);
8573
+ tokens[j].t1 = sample_to_timestamp(k, segment.t0);
7000
8574
  if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
7001
8575
  tokens[j].t1 = tokens[j + 1].t0;
7002
8576
  } else {
@@ -7007,7 +8581,7 @@ static void whisper_exp_compute_token_level_timestamps(
7007
8581
  k--;
7008
8582
  }
7009
8583
  s1 = k;
7010
- tokens[j].t1 = sample_to_timestamp(k);
8584
+ tokens[j].t1 = sample_to_timestamp(k, segment.t0);
7011
8585
  }
7012
8586
  }
7013
8587
  }
@@ -7078,18 +8652,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7078
8652
  struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
7079
8653
  struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
7080
8654
 
7081
- cost = ggml_set_f32(cost, INFINITY);
7082
- trace = ggml_set_f32(trace, -1);
7083
- ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
8655
+ cost = whisper_set_f32(cost, INFINITY);
8656
+ trace = whisper_set_i32(trace, -1);
8657
+ whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
7084
8658
 
7085
8659
  // dtw
7086
8660
  // supposedly can be optmized by computing diagonals in parallel ?
7087
8661
  // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
7088
8662
  for (int64_t j = 1; j < M + 1; ++j) {
7089
8663
  for (int64_t i = 1; i < N + 1; ++i) {
7090
- float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
7091
- float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0);
7092
- float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0);
8664
+ float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
8665
+ float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
8666
+ float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
7093
8667
 
7094
8668
  float c;
7095
8669
  int32_t t;
@@ -7104,9 +8678,9 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7104
8678
  t = 2;
7105
8679
  }
7106
8680
 
7107
- c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
7108
- ggml_set_f32_nd(cost, i, j, 0, 0, c);
7109
- ggml_set_i32_nd(trace, i, j, 0, 0, t);
8681
+ c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
8682
+ whisper_set_f32_nd(cost, i, j, 0, 0, c);
8683
+ whisper_set_i32_nd(trace, i, j, 0, 0, t);
7110
8684
  }
7111
8685
  }
7112
8686
 
@@ -7115,19 +8689,19 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7115
8689
  struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
7116
8690
  // trace[0, :] = 2;
7117
8691
  for (int64_t i = 0; i < M + 1; ++i)
7118
- ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
8692
+ whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
7119
8693
  //trace[:, 0] = 1;
7120
8694
  for (int64_t i = 0; i < N + 1; ++i)
7121
- ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
8695
+ whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
7122
8696
  int bt_row_idx = BT_MAX_ROWS - 1;
7123
8697
  int64_t i = N;
7124
8698
  int64_t j = M;
7125
8699
  while (i > 0 || j > 0) {
7126
- ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
7127
- ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
8700
+ whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
8701
+ whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
7128
8702
  --bt_row_idx;
7129
8703
 
7130
- int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0);
8704
+ int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
7131
8705
  if (t == 0) {
7132
8706
  --i;
7133
8707
  --j;
@@ -7148,8 +8722,8 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7148
8722
  ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
7149
8723
  for (int64_t i = 0; i < 2; ++i) {
7150
8724
  for (int64_t j = 0; j < result_n_cols; ++j) {
7151
- int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
7152
- ggml_set_i32_nd(r, i, j, 0, 0, v);
8725
+ int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
8726
+ whisper_set_i32_nd(r, i, j, 0, 0, v);
7153
8727
  }
7154
8728
  }
7155
8729
 
@@ -7184,11 +8758,11 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
7184
8758
  idx = 2*(a->ne[2] - 1) - idx;
7185
8759
  }
7186
8760
 
7187
- filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0));
8761
+ filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
7188
8762
  }
7189
8763
  std::sort(filter.begin(), filter.end());
7190
8764
  const float v = filter[filter.size()/2];
7191
- ggml_set_f32_nd(dst, i, j, k, 0, v);
8765
+ whisper_set_f32_nd(dst, i, j, k, 0, v);
7192
8766
  filter.clear();
7193
8767
  }
7194
8768
  }
@@ -7310,7 +8884,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7310
8884
  // Compute
7311
8885
  struct ggml_cgraph * gf = ggml_new_graph(gctx);
7312
8886
  ggml_build_forward_expand(gf, w);
7313
- ggml_graph_compute_with_ctx(gctx, gf, n_threads);
8887
+
8888
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
8889
+ ggml_backend_graph_compute(backend.get(), gf);
7314
8890
 
7315
8891
  ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7316
8892
 
@@ -7319,9 +8895,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7319
8895
  auto seg_i = state->result_all.begin() + i_segment;
7320
8896
  auto tok_i = seg_i->tokens.begin();
7321
8897
  for (int i = 0; i < alignment->ne[1]; ++i) {
7322
- int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0);
8898
+ int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
7323
8899
  if (v != last_v) {
7324
- int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0);
8900
+ int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
7325
8901
  int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7326
8902
  last_v = v;
7327
8903
 
@@ -7362,6 +8938,10 @@ void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
7362
8938
  ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
7363
8939
  }
7364
8940
 
8941
+ const char * whisper_version(void) {
8942
+ return WHISPER_VERSION;
8943
+ }
8944
+
7365
8945
  GGML_ATTRIBUTE_FORMAT(2, 3)
7366
8946
  static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
7367
8947
  va_list args;