whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -1,8 +1,8 @@
1
1
  #include "whisper.h"
2
-
3
- #include "ggml-cpu.h"
2
+ #include "whisper-arch.h"
4
3
 
5
4
  #include "ggml.h"
5
+ #include "ggml-cpp.h"
6
6
  #include "ggml-alloc.h"
7
7
  #include "ggml-backend.h"
8
8
 
@@ -17,37 +17,36 @@
17
17
  #include <atomic>
18
18
  #include <algorithm>
19
19
  #include <cassert>
20
+ #include <cfloat>
20
21
  #define _USE_MATH_DEFINES
21
22
  #include <cmath>
22
- #include <cstdio>
23
+ #include <climits>
24
+ #include <codecvt>
23
25
  #include <cstdarg>
26
+ #include <cstdio>
24
27
  #include <cstring>
25
28
  #include <fstream>
29
+ #include <functional>
26
30
  #include <map>
31
+ #include <mutex>
32
+ #include <random>
33
+ #include <regex>
27
34
  #include <set>
28
35
  #include <string>
29
36
  #include <thread>
30
37
  #include <vector>
31
- #include <regex>
32
- #include <random>
33
- #include <functional>
34
- #include <codecvt>
35
-
36
- #if defined(_MSC_VER)
37
- #pragma warning(disable: 4244 4267) // possible loss of data
38
- #endif
39
-
40
- #if defined(GGML_BIG_ENDIAN)
41
- #include <bit>
42
38
 
39
+ #if defined(WHISPER_BIG_ENDIAN)
43
40
  template<typename T>
44
41
  static T byteswap(T value) {
45
- return std::byteswap(value);
46
- }
47
-
48
- template<>
49
- float byteswap(float value) {
50
- return std::bit_cast<float>(byteswap(std::bit_cast<std::uint32_t>(value)));
42
+ T value_swapped;
43
+ char * source = reinterpret_cast<char *>(&value);
44
+ char * target = reinterpret_cast<char *>(&value_swapped);
45
+ int size = sizeof(T);
46
+ for (int i = 0; i < size; i++) {
47
+ target[size - 1 - i] = source[i];
48
+ }
49
+ return value_swapped;
51
50
  }
52
51
 
53
52
  template<typename T>
@@ -83,14 +82,14 @@ static void byteswap_tensor(ggml_tensor * tensor) {
83
82
  }
84
83
 
85
84
  #define BYTESWAP_VALUE(d) d = byteswap(d)
86
- #define BYTESWAP_FILTERS(f) \
85
+ #define BYTESWAP_FILTERS(f) \
87
86
  do { \
88
87
  for (auto & datum : f.data) { \
89
88
  datum = byteswap(datum); \
90
89
  } \
91
90
  } while (0)
92
- #define BYTESWAP_TENSOR(t) \
93
- do { \
91
+ #define BYTESWAP_TENSOR(t) \
92
+ do { \
94
93
  byteswap_tensor(t); \
95
94
  } while (0)
96
95
  #else
@@ -141,34 +140,52 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
141
140
  #define WHISPER_MAX_DECODERS 8
142
141
  #define WHISPER_MAX_NODES 4096
143
142
 
143
+ static std::string format(const char * fmt, ...) {
144
+ va_list ap;
145
+ va_list ap2;
146
+ va_start(ap, fmt);
147
+ va_copy(ap2, ap);
148
+ int size = vsnprintf(NULL, 0, fmt, ap);
149
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
150
+ std::vector<char> buf(size + 1);
151
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
152
+ GGML_ASSERT(size2 == size);
153
+ va_end(ap2);
154
+ va_end(ap);
155
+ return std::string(buf.data(), size);
156
+ }
157
+
144
158
  //
145
159
  // ggml helpers
146
160
  //
147
161
 
148
162
  static bool ggml_graph_compute_helper(
149
163
  struct ggml_cgraph * graph,
150
- std::vector<uint8_t> & buf,
151
164
  int n_threads,
152
165
  ggml_abort_callback abort_callback,
153
166
  void * abort_callback_data) {
154
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
167
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
168
+
169
+ auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
155
170
 
156
- plan.abort_callback = abort_callback;
157
- plan.abort_callback_data = abort_callback_data;
171
+ auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
172
+ if (set_abort_callback_fn) {
173
+ set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
174
+ }
158
175
 
159
- if (plan.work_size > 0) {
160
- buf.resize(plan.work_size);
161
- plan.work_data = buf.data();
176
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
177
+ if (ggml_backend_set_n_threads_fn) {
178
+ ggml_backend_set_n_threads_fn(backend.get(), n_threads);
162
179
  }
163
180
 
164
- return ggml_graph_compute(graph, &plan);
181
+ return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
165
182
  }
166
183
 
167
184
  static bool ggml_graph_compute_helper(
168
185
  ggml_backend_sched_t sched,
169
186
  struct ggml_cgraph * graph,
170
- int n_threads) {
171
-
187
+ int n_threads,
188
+ bool sched_reset = true) {
172
189
  for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
173
190
  ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
174
191
  ggml_backend_dev_t dev = ggml_backend_get_device(backend);
@@ -180,11 +197,70 @@ static bool ggml_graph_compute_helper(
180
197
  }
181
198
  }
182
199
 
183
- bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
184
- ggml_backend_sched_reset(sched);
200
+ const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
201
+
202
+ if (!t || sched_reset) {
203
+ ggml_backend_sched_reset(sched);
204
+ }
205
+
206
+ return t;
207
+ }
208
+
209
+ static void whisper_load_backends() {
210
+ #ifdef GGML_BACKEND_DL
211
+ static std::once_flag flag;
212
+ std::call_once(flag, []() {
213
+ ggml_backend_load_all();
214
+ });
215
+ #endif
216
+ }
217
+
218
+ // TODO: move these functions to ggml-base with support for ggml-backend?
219
+
220
+ static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
221
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
222
+ GGML_ASSERT(ggml_is_contiguous(t));
223
+ size_t nels = ggml_nelements(t);
224
+ for (size_t i = 0; i < nels; ++i) {
225
+ ((float *) t->data)[i] = v;
226
+ }
227
+ return t;
228
+ }
229
+
230
+ static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
231
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
232
+ GGML_ASSERT(ggml_is_contiguous(t));
233
+ size_t nels = ggml_nelements(t);
234
+ for (size_t i = 0; i < nels; ++i) {
235
+ ((int32_t *) t->data)[i] = v;
236
+ }
185
237
  return t;
186
238
  }
187
239
 
240
+ static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
241
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
242
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
243
+ return *(float *) data;
244
+ }
245
+
246
+ static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
247
+ GGML_ASSERT(t->type == GGML_TYPE_F32);
248
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
249
+ *(float *) data = v;
250
+ }
251
+
252
+ static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
253
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
254
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
255
+ return *(int32_t *) data;
256
+ }
257
+
258
+ static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
259
+ GGML_ASSERT(t->type == GGML_TYPE_I32);
260
+ void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
261
+ *(int32_t *) data = v;
262
+ }
263
+
188
264
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
189
265
  // the idea is to represent the original matrix multiplication:
190
266
  //
@@ -428,6 +504,7 @@ struct whisper_segment {
428
504
  int64_t t1;
429
505
 
430
506
  std::string text;
507
+ float no_speech_prob;
431
508
 
432
509
  std::vector<whisper_token_data> tokens;
433
510
 
@@ -520,7 +597,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<
520
597
  auto & sched = allocr.sched;
521
598
  auto & meta = allocr.meta;
522
599
 
523
- sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
600
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true);
524
601
 
525
602
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
526
603
 
@@ -716,10 +793,10 @@ struct whisper_model {
716
793
  std::vector<whisper_layer_decoder> layers_decoder;
717
794
 
718
795
  // ggml context that contains all the meta information about the model tensors
719
- struct ggml_context * ctx = nullptr;
796
+ std::vector<ggml_context *> ctxs;
720
797
 
721
798
  // the model backend data is read-only and can be shared between processors
722
- ggml_backend_buffer_t buffer = nullptr;
799
+ std::vector<ggml_backend_buffer_t> buffers;
723
800
 
724
801
  // tensors
725
802
  int n_loaded;
@@ -876,6 +953,17 @@ struct whisper_state {
876
953
 
877
954
  // [EXPERIMENTAL] speed-up techniques
878
955
  int32_t exp_n_audio_ctx = 0; // 0 - use default
956
+
957
+ whisper_vad_context * vad_context = nullptr;
958
+
959
+ struct vad_segment_info {
960
+ float orig_start;
961
+ float orig_end;
962
+ float vad_start;
963
+ float vad_end;
964
+ };
965
+ std::vector<vad_segment_info> vad_segments;
966
+ bool has_vad_segments = false;
879
967
  };
880
968
 
881
969
  struct whisper_context {
@@ -1234,21 +1322,38 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1234
1322
  static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1235
1323
  ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
1236
1324
 
1325
+ whisper_load_backends();
1326
+
1327
+ ggml_backend_dev_t dev = nullptr;
1328
+
1329
+ int cnt = 0;
1237
1330
  if (params.use_gpu) {
1238
1331
  for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1239
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1240
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1241
- WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1242
- ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1243
- if (!result) {
1244
- WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1332
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
1333
+ if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1334
+ if (cnt == 0 || cnt == params.gpu_device) {
1335
+ dev = dev_cur;
1336
+ }
1337
+
1338
+ if (++cnt > params.gpu_device) {
1339
+ break;
1245
1340
  }
1246
- return result;
1247
1341
  }
1248
1342
  }
1249
1343
  }
1250
1344
 
1251
- return nullptr;
1345
+ if (dev == nullptr) {
1346
+ WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
1347
+ return nullptr;
1348
+ }
1349
+
1350
+ WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1351
+ ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1352
+ if (!result) {
1353
+ WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1354
+ }
1355
+
1356
+ return result;
1252
1357
  }
1253
1358
 
1254
1359
  static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@@ -1274,28 +1379,118 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
1274
1379
  }
1275
1380
  }
1276
1381
 
1277
- GGML_UNUSED(params);
1278
-
1279
- result.push_back(ggml_backend_cpu_init());
1382
+ ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
1383
+ if (backend_cpu == nullptr) {
1384
+ throw std::runtime_error("failed to initialize CPU backend");
1385
+ }
1386
+ result.push_back(backend_cpu);
1280
1387
 
1281
1388
  return result;
1282
1389
  }
1283
1390
 
1284
- static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1285
- if (!params.use_gpu) {
1286
- return ggml_backend_cpu_buffer_type();
1391
+ using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
1392
+
1393
+ static buft_list_t make_buft_list(whisper_context_params & params) {
1394
+ // Prio order: GPU -> CPU Extra -> CPU
1395
+ buft_list_t buft_list;
1396
+
1397
+ // GPU
1398
+ if (params.use_gpu) {
1399
+ int cnt = 0;
1400
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1401
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1402
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1403
+ if (cnt == 0 || cnt == params.gpu_device) {
1404
+ auto * buft = ggml_backend_dev_buffer_type(dev);
1405
+ if (buft) {
1406
+ buft_list.emplace_back(dev, buft);
1407
+ }
1408
+ }
1409
+
1410
+ if (++cnt > params.gpu_device) {
1411
+ break;
1412
+ }
1413
+ }
1414
+ }
1415
+ }
1416
+
1417
+ // CPU Extra
1418
+ auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
1419
+ auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
1420
+ auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
1421
+ ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
1422
+ if (get_extra_bufts_fn) {
1423
+ ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
1424
+ while (extra_bufts && *extra_bufts) {
1425
+ buft_list.emplace_back(cpu_dev, *extra_bufts);
1426
+ ++extra_bufts;
1427
+ }
1287
1428
  }
1288
1429
 
1289
- // if we have a GPU device - use it
1290
- for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1291
- ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1292
- if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1293
- WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1294
- return ggml_backend_dev_buffer_type(dev);
1430
+ // CPU
1431
+ buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
1432
+
1433
+ return buft_list;
1434
+ }
1435
+
1436
+ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
1437
+ bool op_supported = true;
1438
+
1439
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
1440
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
1441
+ // GPU and default CPU backend support all operators
1442
+ op_supported = true;
1443
+ } else {
1444
+ switch (op) {
1445
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
1446
+ case GGML_OP_MUL_MAT: {
1447
+ ggml_init_params params = {
1448
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
1449
+ /*.mem_buffer =*/ nullptr,
1450
+ /*.no_alloc =*/ true,
1451
+ };
1452
+
1453
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1454
+ if (!ctx_ptr) {
1455
+ throw std::runtime_error("failed to create ggml context");
1456
+ }
1457
+ ggml_context * ctx = ctx_ptr.get();
1458
+
1459
+ ggml_tensor * op_tensor = nullptr;
1460
+
1461
+ int64_t n_ctx = hparams.n_audio_ctx;
1462
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1463
+ op_tensor = ggml_mul_mat(ctx, w, b);
1464
+
1465
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1466
+ GGML_ASSERT(w->buffer == nullptr);
1467
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
1468
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
1469
+ ggml_backend_buffer_free(w->buffer);
1470
+ w->buffer = nullptr;
1471
+ break;
1472
+ }
1473
+ default: {
1474
+ op_supported = false;
1475
+ break;
1476
+ }
1477
+ };
1478
+ }
1479
+
1480
+ return op_supported;
1481
+ }
1482
+
1483
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
1484
+ GGML_ASSERT(!buft_list.empty());
1485
+ for (const auto & p : buft_list) {
1486
+ ggml_backend_dev_t dev = p.first;
1487
+ ggml_backend_buffer_type_t buft = p.second;
1488
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
1489
+ return buft;
1295
1490
  }
1296
1491
  }
1297
1492
 
1298
- return ggml_backend_cpu_buffer_type();
1493
+ return nullptr;
1299
1494
  }
1300
1495
 
1301
1496
  // load the model from a ggml file
@@ -1504,31 +1699,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1504
1699
  const ggml_type wtype = wctx.wtype;
1505
1700
  const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
1506
1701
 
1507
- // create the ggml context
1508
- {
1509
- const auto & hparams = model.hparams;
1702
+ const auto & hparams = model.hparams;
1510
1703
 
1511
- const int n_audio_layer = hparams.n_audio_layer;
1512
- const int n_text_layer = hparams.n_text_layer;
1704
+ const int n_audio_layer = hparams.n_audio_layer;
1705
+ const int n_text_layer = hparams.n_text_layer;
1513
1706
 
1514
- const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1707
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
1515
1708
 
1516
- struct ggml_init_params params = {
1517
- /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
1518
- /*.mem_buffer =*/ nullptr,
1519
- /*.no_alloc =*/ true,
1520
- };
1709
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
1710
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
1711
+ auto it = ctx_map.find(buft);
1712
+ if (it == ctx_map.end()) {
1713
+ ggml_init_params params = {
1714
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1715
+ /*.mem_buffer =*/ nullptr,
1716
+ /*.no_alloc =*/ true,
1717
+ };
1521
1718
 
1522
- model.ctx = ggml_init(params);
1523
- if (!model.ctx) {
1524
- WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
1525
- return false;
1719
+ ggml_context * ctx = ggml_init(params);
1720
+ if (!ctx) {
1721
+ throw std::runtime_error("failed to create ggml context");
1722
+ }
1723
+
1724
+ ctx_map[buft] = ctx;
1725
+ model.ctxs.emplace_back(ctx);
1726
+
1727
+ return ctx;
1526
1728
  }
1527
- }
1729
+
1730
+ return it->second;
1731
+ };
1732
+
1733
+ // Create a list of available bufts, in priority order
1734
+ buft_list_t buft_list = make_buft_list(wctx.params);
1735
+
1736
+ auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
1737
+ ggml_op op = ASR_TENSOR_INFO.at(type);
1738
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
1739
+ if (!buft) {
1740
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
1741
+ }
1742
+
1743
+ ggml_context * ctx = get_ctx(buft);
1744
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
1745
+
1746
+ model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
1747
+
1748
+ return tensor;
1749
+ };
1750
+
1528
1751
 
1529
1752
  // prepare tensors for the weights
1530
1753
  {
1531
- auto & ctx = model.ctx;
1754
+ ggml_init_params params = {
1755
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
1756
+ /*.mem_buffer =*/ nullptr,
1757
+ /*.no_alloc =*/ true,
1758
+ };
1759
+
1760
+ ggml_context * ctx = ggml_init(params);
1532
1761
 
1533
1762
  const auto & hparams = model.hparams;
1534
1763
 
@@ -1548,189 +1777,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1548
1777
  model.layers_decoder.resize(n_text_layer);
1549
1778
 
1550
1779
  // encoder
1551
- {
1552
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1553
-
1554
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1555
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1556
-
1557
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1558
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1559
-
1560
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1561
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1562
-
1563
- // map by name
1564
- model.tensors["encoder.positional_embedding"] = model.e_pe;
1565
-
1566
- model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
1567
- model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
1568
-
1569
- model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
1570
- model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
1571
-
1572
- model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
1573
- model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
1574
-
1575
- for (int i = 0; i < n_audio_layer; ++i) {
1576
- auto & layer = model.layers_encoder[i];
1577
-
1578
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1579
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1780
+ model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
1580
1781
 
1581
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
1582
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
1782
+ model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
1783
+ model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
1583
1784
 
1584
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
1585
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1785
+ model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
1786
+ model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
1586
1787
 
1587
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1588
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1788
+ model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1789
+ model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
1589
1790
 
1590
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1591
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1791
+ for (int i = 0; i < n_audio_layer; ++i) {
1792
+ auto & layer = model.layers_encoder[i];
1592
1793
 
1593
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1794
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1795
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1594
1796
 
1595
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1596
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1797
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
1798
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
1597
1799
 
1598
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
1599
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1800
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
1801
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1600
1802
 
1601
- // map by name
1602
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1603
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1803
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1804
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1604
1805
 
1605
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1606
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1806
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1807
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1607
1808
 
1608
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1609
- model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1809
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1610
1810
 
1611
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1612
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1811
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1812
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1613
1813
 
1614
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1615
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1616
-
1617
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1618
-
1619
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1620
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1621
-
1622
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1623
- model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1624
- }
1814
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
1815
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
1625
1816
  }
1626
1817
 
1627
1818
  // decoder
1628
- {
1629
- model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
1630
-
1631
- model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
1632
-
1633
- model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1634
- model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1635
-
1636
- // map by name
1637
- model.tensors["decoder.positional_embedding"] = model.d_pe;
1638
-
1639
- model.tensors["decoder.token_embedding.weight"] = model.d_te;
1640
-
1641
- model.tensors["decoder.ln.weight"] = model.d_ln_w;
1642
- model.tensors["decoder.ln.bias"] = model.d_ln_b;
1643
-
1644
- for (int i = 0; i < n_text_layer; ++i) {
1645
- auto & layer = model.layers_decoder[i];
1646
-
1647
- layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1648
- layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1649
-
1650
- layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
1651
- layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
1652
-
1653
- layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
1654
- layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1655
-
1656
- layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1657
- layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1658
-
1659
- layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1660
- layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1819
+ model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
1661
1820
 
1662
- layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1821
+ model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
1663
1822
 
1664
- layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1665
- layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1823
+ model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
1824
+ model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
1666
1825
 
1667
- layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1668
- layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1826
+ for (int i = 0; i < n_text_layer; ++i) {
1827
+ auto & layer = model.layers_decoder[i];
1669
1828
 
1670
- layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1671
- layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1829
+ layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1830
+ layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1672
1831
 
1673
- layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1674
- layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1832
+ layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
1833
+ layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
1675
1834
 
1676
- layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1835
+ layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
1836
+ layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1677
1837
 
1678
- layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1679
- layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1838
+ layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1839
+ layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1680
1840
 
1681
- layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
1682
- layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
1841
+ layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1842
+ layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1683
1843
 
1684
- // map by name
1685
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
1686
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
1844
+ layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1687
1845
 
1688
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
1689
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
1846
+ layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1847
+ layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1690
1848
 
1691
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
1692
- model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
1849
+ layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1850
+ layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1693
1851
 
1694
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
1695
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
1852
+ layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1853
+ layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1696
1854
 
1697
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
1698
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
1855
+ layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1856
+ layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1699
1857
 
1700
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
1858
+ layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1701
1859
 
1702
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
1703
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
1860
+ layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1861
+ layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1704
1862
 
1705
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
1706
- model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
1707
-
1708
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
1709
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
1710
-
1711
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
1712
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
1713
-
1714
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
1715
-
1716
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
1717
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
1718
-
1719
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
1720
- model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
1721
- }
1863
+ layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
1864
+ layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
1722
1865
  }
1866
+
1867
+ ggml_free(ctx);
1723
1868
  }
1724
1869
 
1725
1870
  // allocate tensors in the backend buffers
1726
- model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1727
- if (!model.buffer) {
1728
- WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1729
- return false;
1730
- }
1871
+ for (auto & p : ctx_map) {
1872
+ ggml_backend_buffer_type_t buft = p.first;
1873
+ ggml_context * ctx = p.second;
1874
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
1875
+ if (buf) {
1876
+ model.buffers.emplace_back(buf);
1731
1877
 
1732
- size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1733
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1878
+ size_t size_main = ggml_backend_buffer_get_size(buf);
1879
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
1880
+ }
1881
+ }
1734
1882
 
1735
1883
  // load weights
1736
1884
  {
@@ -1793,11 +1941,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1793
1941
  return false;
1794
1942
  }
1795
1943
 
1796
- //ggml_backend_t backend = wctx.backend;
1797
-
1798
- //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
1799
-
1800
- if (ggml_backend_buffer_is_host(model.buffer)) {
1944
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
1801
1945
  // for the CPU and Metal backend, we can read directly into the tensor
1802
1946
  loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1803
1947
  BYTESWAP_TENSOR(tensor);
@@ -1810,7 +1954,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1810
1954
  ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1811
1955
  }
1812
1956
 
1813
- //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
1814
1957
  total_size += ggml_nbytes(tensor);
1815
1958
  model.n_loaded++;
1816
1959
  }
@@ -1825,7 +1968,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1825
1968
  }
1826
1969
  }
1827
1970
 
1828
- ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1971
+ for (auto & buf : model.buffers) {
1972
+ ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1973
+ }
1829
1974
 
1830
1975
  wctx.t_load_us = ggml_time_us() - t_start_us;
1831
1976
 
@@ -3710,15 +3855,24 @@ void whisper_free_state(struct whisper_state * state) {
3710
3855
  // [EXPERIMENTAL] Token-level timestamps with DTW
3711
3856
  aheads_masks_free(state->aheads_masks);
3712
3857
 
3858
+ if (state->vad_context != nullptr) {
3859
+ whisper_vad_free(state->vad_context);
3860
+ state->vad_context = nullptr;
3861
+ }
3862
+
3713
3863
  delete state;
3714
3864
  }
3715
3865
  }
3716
3866
 
3717
3867
  void whisper_free(struct whisper_context * ctx) {
3718
3868
  if (ctx) {
3719
- ggml_free(ctx->model.ctx);
3869
+ for (ggml_context * context : ctx->model.ctxs) {
3870
+ ggml_free(context);
3871
+ }
3720
3872
 
3721
- ggml_backend_buffer_free(ctx->model.buffer);
3873
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
3874
+ ggml_backend_buffer_free(buf);
3875
+ }
3722
3876
 
3723
3877
  whisper_free_state(ctx->state);
3724
3878
 
@@ -4136,11 +4290,11 @@ void whisper_print_timings(struct whisper_context * ctx) {
4136
4290
 
4137
4291
  WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
4138
4292
  WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
4139
- WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4140
- WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4141
- WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4142
- WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4143
- WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4293
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
4294
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
4295
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
4296
+ WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
4297
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
4144
4298
  }
4145
4299
  WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
4146
4300
  }
@@ -4181,112 +4335,1230 @@ static int whisper_has_openvino(void) {
4181
4335
  const char * whisper_print_system_info(void) {
4182
4336
  static std::string s;
4183
4337
 
4338
+ whisper_load_backends();
4339
+
4184
4340
  s = "";
4185
- s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
4186
- s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
4187
- s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
4188
- s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
4189
- s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
4190
- s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
4191
- s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
4192
- s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
4193
- s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
4194
- s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
4195
- s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
4196
- s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
4341
+ s += "WHISPER : ";
4197
4342
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
4198
4343
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
4199
4344
 
4345
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
4346
+ auto * reg = ggml_backend_reg_get(i);
4347
+ auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
4348
+ if (get_features_fn) {
4349
+ ggml_backend_feature * features = get_features_fn(reg);
4350
+ s += ggml_backend_reg_name(reg);
4351
+ s += " : ";
4352
+ for (; features->name; features++) {
4353
+ s += features->name;
4354
+ s += " = ";
4355
+ s += features->value;
4356
+ s += " | ";
4357
+ }
4358
+ }
4359
+ }
4200
4360
  return s.c_str();
4201
4361
  }
4202
4362
 
4203
4363
  //////////////////////////////////
4204
- // Grammar - ported from llama.cpp
4364
+ // Voice Activity Detection (VAD)
4205
4365
  //////////////////////////////////
4206
4366
 
4207
- // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
4208
- // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
4209
- static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
4210
- const char * src,
4211
- whisper_partial_utf8 partial_start) {
4212
- static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
4213
- const char * pos = src;
4214
- std::vector<uint32_t> code_points;
4215
- uint32_t value = partial_start.value;
4216
- int n_remain = partial_start.n_remain;
4367
+ struct whisper_vad_hparams {
4368
+ int32_t n_encoder_layers;
4369
+ int32_t * encoder_in_channels;
4370
+ int32_t * encoder_out_channels;
4371
+ int32_t * kernel_sizes;
4372
+ int32_t lstm_input_size;
4373
+ int32_t lstm_hidden_size;
4374
+ int32_t final_conv_in;
4375
+ int32_t final_conv_out;
4376
+ };
4217
4377
 
4218
- // continue previous decode, if applicable
4219
- while (*pos != 0 && n_remain > 0) {
4220
- uint8_t next_byte = static_cast<uint8_t>(*pos);
4221
- if ((next_byte >> 6) != 2) {
4222
- // invalid sequence, abort
4223
- code_points.push_back(0);
4224
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
4225
- }
4226
- value = (value << 6) + (next_byte & 0x3F);
4227
- ++pos;
4228
- --n_remain;
4229
- }
4378
+ struct whisper_vad_model {
4379
+ std::string type;
4380
+ std::string version;
4381
+ whisper_vad_hparams hparams;
4230
4382
 
4231
- if (partial_start.n_remain > 0 && n_remain == 0) {
4232
- code_points.push_back(value);
4233
- }
4383
+ struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
4234
4384
 
4235
- // decode any subsequent utf-8 sequences, which may end in an incomplete one
4236
- while (*pos != 0) {
4237
- uint8_t first_byte = static_cast<uint8_t>(*pos);
4238
- uint8_t highbits = first_byte >> 4;
4239
- n_remain = lookup[highbits] - 1;
4385
+ // Encoder tensors - 4 convolutional layers
4386
+ struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
4387
+ struct ggml_tensor * encoder_0_bias; // [128]
4240
4388
 
4241
- if (n_remain < 0) {
4242
- // invalid sequence, abort
4243
- code_points.clear();
4244
- code_points.push_back(0);
4245
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
4246
- }
4389
+ // Second encoder layer
4390
+ struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
4391
+ struct ggml_tensor * encoder_1_bias; // [64]
4247
4392
 
4248
- uint8_t mask = (1 << (7 - n_remain)) - 1;
4249
- value = first_byte & mask;
4250
- ++pos;
4251
- while (*pos != 0 && n_remain > 0) {
4252
- value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
4253
- ++pos;
4254
- --n_remain;
4255
- }
4256
- if (n_remain == 0) {
4257
- code_points.push_back(value);
4258
- }
4259
- }
4260
- code_points.push_back(0);
4393
+ // Third encoder layer
4394
+ struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
4395
+ struct ggml_tensor * encoder_2_bias; // [64]
4261
4396
 
4262
- return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
4263
- }
4397
+ // Fourth encoder layer
4398
+ struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
4399
+ struct ggml_tensor * encoder_3_bias; // [128]
4264
4400
 
4265
- // returns true iff pos points to the end of one of the definitions of a rule
4266
- static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
4267
- switch (pos->type) {
4268
- case WHISPER_GRETYPE_END: return true; // NOLINT
4269
- case WHISPER_GRETYPE_ALT: return true; // NOLINT
4270
- default: return false;
4271
- }
4272
- }
4401
+ // LSTM decoder tensors
4402
+ struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
4403
+ struct ggml_tensor * lstm_ih_bias; // [512]
4404
+ struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
4405
+ struct ggml_tensor * lstm_hh_bias; // [512]
4273
4406
 
4274
- // returns true iff chr satisfies the char range at pos (regular or inverse range)
4275
- // asserts that pos is pointing to a char range element
4276
- static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
4277
- const whisper_grammar_element * pos,
4278
- const uint32_t chr) {
4407
+ // Final conv layer
4408
+ struct ggml_tensor * final_conv_weight; // [128]
4409
+ struct ggml_tensor * final_conv_bias; // [1]
4279
4410
 
4280
- bool found = false;
4281
- bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
4411
+ // ggml contexts
4412
+ std::vector<ggml_context *> ctxs;
4282
4413
 
4283
- WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
4414
+ // buffer for the model tensors
4415
+ std::vector<ggml_backend_buffer_t> buffers;
4284
4416
 
4285
- do {
4286
- if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
4287
- // inclusive range, e.g. [a-z]
4288
- found = found || (pos->value <= chr && chr <= pos[1].value);
4289
- pos += 2;
4417
+ // tensors
4418
+ int n_loaded;
4419
+ std::map<std::string, struct ggml_tensor *> tensors;
4420
+ };
4421
+
4422
+ struct whisper_vad_segment {
4423
+ float start; // Start time in seconds
4424
+ float end; // End time in seconds
4425
+ };
4426
+
4427
+ struct whisper_vad_segments {
4428
+ std::vector<whisper_vad_segment> data;
4429
+ };
4430
+
4431
+ struct whisper_vad_context {
4432
+ int64_t t_vad_us = 0;
4433
+
4434
+ int n_window;
4435
+ int n_context;
4436
+ int n_threads;
4437
+
4438
+ std::vector<ggml_backend_t> backends;
4439
+ ggml_backend_buffer_t buffer = nullptr;
4440
+ whisper_context_params params;
4441
+ std::vector<uint8_t> ctx_buf;
4442
+ whisper_sched sched;
4443
+
4444
+ whisper_vad_model model;
4445
+ std::string path_model;
4446
+ struct ggml_tensor * h_state;
4447
+ struct ggml_tensor * c_state;
4448
+ std::vector<float> probs;
4449
+ };
4450
+
4451
+ struct whisper_vad_context_params whisper_vad_default_context_params(void) {
4452
+ whisper_vad_context_params result = {
4453
+ /*.n_thread = */ 4,
4454
+ /*.use_gpu = */ false,
4455
+ /*.gpu_device = */ 0,
4456
+ };
4457
+ return result;
4458
+ }
4459
+
4460
+ struct whisper_vad_params whisper_vad_default_params(void) {
4461
+ whisper_vad_params result = {
4462
+ /* threshold = */ 0.5f,
4463
+ /* min_speech_duration_ms = */ 250,
4464
+ /* min_silence_duration_ms = */ 100,
4465
+ /* max_speech_duration_s = */ FLT_MAX,
4466
+ /* speech_pad_ms = */ 30,
4467
+ /* samples_overlap = */ 0.1,
4468
+ };
4469
+ return result;
4470
+ }
4471
+
4472
+ static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
4473
+ bool op_supported = true;
4474
+
4475
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
4476
+ (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
4477
+ // GPU and default CPU backend support all operators
4478
+ op_supported = true;
4479
+ } else {
4480
+ switch (op) {
4481
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
4482
+ case GGML_OP_MUL_MAT: {
4483
+ ggml_init_params params = {
4484
+ /*.mem_size =*/ 2 * ggml_tensor_overhead(),
4485
+ /*.mem_buffer =*/ nullptr,
4486
+ /*.no_alloc =*/ true,
4487
+ };
4488
+
4489
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
4490
+ if (!ctx_ptr) {
4491
+ throw std::runtime_error("failed to create ggml context");
4492
+ }
4493
+ ggml_context * ctx = ctx_ptr.get();
4494
+
4495
+ ggml_tensor * op_tensor = nullptr;
4496
+
4497
+ int64_t n_ctx = hparams.lstm_hidden_size;
4498
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
4499
+ op_tensor = ggml_mul_mat(ctx, w, b);
4500
+
4501
+ // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
4502
+ GGML_ASSERT(w->buffer == nullptr);
4503
+ w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
4504
+ op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
4505
+ ggml_backend_buffer_free(w->buffer);
4506
+ w->buffer = nullptr;
4507
+ break;
4508
+ }
4509
+ default: {
4510
+ op_supported = false;
4511
+ break;
4512
+ }
4513
+ };
4514
+ }
4515
+ return op_supported;
4516
+ }
4517
+
4518
+ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
4519
+ GGML_ASSERT(!buft_list.empty());
4520
+ for (const auto & p : buft_list) {
4521
+ ggml_backend_dev_t dev = p.first;
4522
+ ggml_backend_buffer_type_t buft = p.second;
4523
+ if (weight_buft_supported(hparams, w, op, buft, dev)) {
4524
+ return buft;
4525
+ }
4526
+ }
4527
+
4528
+ return nullptr;
4529
+ }
4530
+
4531
+ static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
4532
+ const whisper_vad_model & model, ggml_tensor * cur) {
4533
+ // Apply reflective padding to the input tensor
4534
+ ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
4535
+
4536
+ struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
4537
+
4538
+ // Calculate cutoff for real/imaginary parts
4539
+ int cutoff = model.stft_forward_basis->ne[2] / 2;
4540
+
4541
+ // Extract real part (first half of the STFT output).
4542
+ struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
4543
+ // Extract imaginary part (second half of the STFT output).
4544
+ struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
4545
+
4546
+ // Calculate magnitude: sqrt(real^2 + imag^2)
4547
+ struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
4548
+ struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
4549
+ struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
4550
+ struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
4551
+ return magnitude;
4552
+ }
4553
+
4554
+ static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
4555
+ const whisper_vad_model & model, ggml_tensor * cur) {
4556
+ // First Conv1D: expands to 128 channels.
4557
+ cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
4558
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
4559
+ cur = ggml_relu(ctx0, cur);
4560
+
4561
+ // Second Conv1D: reduces to 64 channels.
4562
+ cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
4563
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
4564
+ cur = ggml_relu(ctx0, cur);
4565
+
4566
+ // Third Conv1D: maintains 64 channels
4567
+ cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
4568
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
4569
+ cur = ggml_relu(ctx0, cur);
4570
+
4571
+ // Fourth Conv1D: expands to 128 channels
4572
+ cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
4573
+ cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
4574
+ cur = ggml_relu(ctx0, cur);
4575
+
4576
+ return cur;
4577
+ }
4578
+
4579
+ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
4580
+ const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
4581
+ const whisper_vad_model & model = vctx.model;
4582
+ const int hdim = model.hparams.lstm_hidden_size;
4583
+
4584
+ struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
4585
+
4586
+ // Create operations using the input-to-hidden weights.
4587
+ struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
4588
+ inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
4589
+
4590
+ // Create operations using the hidden-to-hidden weights.
4591
+ struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
4592
+ hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
4593
+
4594
+ // Create add operation to get preactivations for all gates.
4595
+ struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
4596
+
4597
+ const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
4598
+
4599
+ // Create sigmoid for input gate (using the first 128 bytes from the preactivations).
4600
+ struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
4601
+
4602
+ // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
4603
+ struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
4604
+
4605
+ // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
4606
+ struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
4607
+
4608
+ // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
4609
+ struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
4610
+
4611
+ // Update cell state
4612
+ struct ggml_tensor * c_out = ggml_add(ctx0,
4613
+ ggml_mul(ctx0, f_t, vctx.c_state),
4614
+ ggml_mul(ctx0, i_t, g_t));
4615
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
4616
+
4617
+ // Update hidden state
4618
+ struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
4619
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
4620
+
4621
+ return out;
4622
+ }
4623
+
4624
+ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
4625
+ const auto & model = vctx.model;
4626
+
4627
+ struct ggml_init_params params = {
4628
+ /*.mem_size =*/ vctx.sched.meta.size(),
4629
+ /*.mem_buffer =*/ vctx.sched.meta.data(),
4630
+ /*.no_alloc =*/ true,
4631
+ };
4632
+
4633
+ struct ggml_context * ctx0 = ggml_init(params);
4634
+
4635
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
4636
+
4637
+ struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
4638
+ ggml_set_name(frame, "frame");
4639
+ ggml_set_input(frame);
4640
+
4641
+ struct ggml_tensor * cur = nullptr;
4642
+ {
4643
+ cur = whisper_vad_build_stft_layer(ctx0, model, frame);
4644
+
4645
+ cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4646
+
4647
+ // Extract the first element of the first dimension
4648
+ // (equivalent to pytorch's [:, :, 0])
4649
+ cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
4650
+
4651
+ cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
4652
+ cur = ggml_relu(ctx0, cur);
4653
+ cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
4654
+ cur = ggml_add(ctx0, cur, model.final_conv_bias);
4655
+ cur = ggml_sigmoid(ctx0, cur);
4656
+ ggml_set_name(cur, "prob");
4657
+ ggml_set_output(cur);
4658
+ }
4659
+
4660
+ ggml_build_forward_expand(gf, cur);
4661
+
4662
+ ggml_free(ctx0);
4663
+
4664
+ return gf;
4665
+ }
4666
+
4667
+ static bool whisper_vad_init_context(whisper_vad_context * vctx) {
4668
+
4669
+ auto whisper_context_params = whisper_context_default_params();
4670
+ // TODO: GPU VAD is forced disabled until the performance is improved
4671
+ //whisper_context_params.use_gpu = vctx->params.use_gpu;
4672
+ whisper_context_params.use_gpu = false;
4673
+ whisper_context_params.gpu_device = vctx->params.gpu_device;
4674
+
4675
+ vctx->backends = whisper_backend_init(whisper_context_params);
4676
+ if (vctx->backends.empty()) {
4677
+ WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
4678
+ return false;
4679
+ }
4680
+
4681
+ const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
4682
+
4683
+ vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
4684
+
4685
+ struct ggml_init_params params = {
4686
+ /*.mem_size =*/ vctx->ctx_buf.size(),
4687
+ /*.mem_buffer =*/ vctx->ctx_buf.data(),
4688
+ /*.no_alloc =*/ true,
4689
+ };
4690
+
4691
+ ggml_context * ctx = ggml_init(params);
4692
+ if (!ctx) {
4693
+ WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
4694
+ return false;
4695
+ }
4696
+
4697
+ // LSTM Hidden state
4698
+ vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4699
+ ggml_set_name(vctx->h_state, "h_state");
4700
+
4701
+ // LSTM Cell state
4702
+ vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4703
+ ggml_set_name(vctx->c_state, "c_state");
4704
+
4705
+ vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
4706
+ if (!vctx->buffer) {
4707
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
4708
+ return false;
4709
+ }
4710
+
4711
+ {
4712
+ bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
4713
+ [&]() {
4714
+ return whisper_vad_build_graph(*vctx);
4715
+ });
4716
+
4717
+ if (!ok) {
4718
+ WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
4719
+ return false;
4720
+ }
4721
+
4722
+ WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
4723
+ }
4724
+
4725
+ return true;
4726
+ }
4727
+
4728
+ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
4729
+ const char * path_model,
4730
+ struct whisper_vad_context_params params) {
4731
+ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
4732
+ #ifdef _MSC_VER
4733
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
4734
+ std::wstring path_model_wide = converter.from_bytes(path_model);
4735
+ auto fin = std::ifstream(path_model_wide, std::ios::binary);
4736
+ #else
4737
+ auto fin = std::ifstream(path_model, std::ios::binary);
4738
+ #endif
4739
+ if (!fin) {
4740
+ WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
4741
+ return nullptr;
4742
+ }
4743
+
4744
+ whisper_model_loader loader = {};
4745
+ loader.context = &fin;
4746
+
4747
+ loader.read = [](void * ctx, void * output, size_t read_size) {
4748
+ std::ifstream * fin = (std::ifstream*)ctx;
4749
+ fin->read((char *)output, read_size);
4750
+ return read_size;
4751
+ };
4752
+
4753
+ loader.eof = [](void * ctx) {
4754
+ std::ifstream * fin = (std::ifstream*)ctx;
4755
+ return fin->eof();
4756
+ };
4757
+
4758
+ loader.close = [](void * ctx) {
4759
+ std::ifstream * fin = (std::ifstream*)ctx;
4760
+ fin->close();
4761
+ };
4762
+
4763
+ auto ctx = whisper_vad_init_with_params(&loader, params);
4764
+ if (!ctx) {
4765
+ whisper_vad_free(ctx);
4766
+ return nullptr;
4767
+ }
4768
+ ctx->path_model = path_model;
4769
+ return ctx;
4770
+ }
4771
+
4772
+ struct whisper_vad_context * whisper_vad_init_with_params(
4773
+ struct whisper_model_loader * loader,
4774
+ struct whisper_vad_context_params params) {
4775
+ // Read the VAD model
4776
+ {
4777
+ uint32_t magic;
4778
+ read_safe(loader, magic);
4779
+ if (magic != GGML_FILE_MAGIC) {
4780
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
4781
+ return nullptr;
4782
+ }
4783
+ }
4784
+
4785
+ whisper_vad_context * vctx = new whisper_vad_context;
4786
+ vctx->n_threads = params.n_threads;
4787
+ vctx->params.use_gpu = params.use_gpu;
4788
+ vctx->params.gpu_device = params.gpu_device;
4789
+
4790
+ auto & model = vctx->model;
4791
+ auto & hparams = model.hparams;
4792
+
4793
+ // load model context params.
4794
+ {
4795
+ int32_t str_len;
4796
+ read_safe(loader, str_len);
4797
+ std::vector<char> buffer(str_len + 1, 0);
4798
+ loader->read(loader->context, buffer.data(), str_len);
4799
+ std::string model_type(buffer.data(), str_len);
4800
+ model.type = model_type;
4801
+ WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
4802
+
4803
+ int32_t major, minor, patch;
4804
+ read_safe(loader, major);
4805
+ read_safe(loader, minor);
4806
+ read_safe(loader, patch);
4807
+ std::string version_str = std::to_string(major) + "." +
4808
+ std::to_string(minor) + "." +
4809
+ std::to_string(patch);
4810
+ model.version = version_str;
4811
+ WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
4812
+
4813
+ read_safe(loader, vctx->n_window);
4814
+ read_safe(loader, vctx->n_context);
4815
+ }
4816
+
4817
+ // load model hyper params (hparams).
4818
+ {
4819
+ read_safe(loader, hparams.n_encoder_layers);
4820
+
4821
+ hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
4822
+ hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
4823
+ hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
4824
+
4825
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4826
+ read_safe(loader, hparams.encoder_in_channels[i]);
4827
+ read_safe(loader, hparams.encoder_out_channels[i]);
4828
+ read_safe(loader, hparams.kernel_sizes[i]);
4829
+ }
4830
+
4831
+ read_safe(loader, hparams.lstm_input_size);
4832
+ read_safe(loader, hparams.lstm_hidden_size);
4833
+ read_safe(loader, hparams.final_conv_in);
4834
+ read_safe(loader, hparams.final_conv_out);
4835
+
4836
+ WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
4837
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4838
+ WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
4839
+ }
4840
+ for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
4841
+ WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
4842
+ }
4843
+ WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
4844
+ WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
4845
+ WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
4846
+ WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
4847
+ }
4848
+
4849
+ // 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
4850
+ const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
4851
+
4852
+ std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4853
+ auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4854
+ auto it = ctx_map.find(buft);
4855
+ if (it == ctx_map.end()) {
4856
+ ggml_init_params params = {
4857
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4858
+ /*.mem_buffer =*/ nullptr,
4859
+ /*.no_alloc =*/ true,
4860
+ };
4861
+
4862
+ ggml_context * ctx = ggml_init(params);
4863
+ if (!ctx) {
4864
+ throw std::runtime_error("failed to create ggml context");
4865
+ }
4866
+
4867
+ ctx_map[buft] = ctx;
4868
+ model.ctxs.emplace_back(ctx);
4869
+
4870
+ return ctx;
4871
+ }
4872
+
4873
+ return it->second;
4874
+ };
4875
+
4876
+ whisper_context_params wparams = whisper_context_default_params();
4877
+ wparams.use_gpu = params.use_gpu;
4878
+ wparams.gpu_device = params.gpu_device;
4879
+ buft_list_t buft_list = make_buft_list(wparams);
4880
+
4881
+ auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
4882
+ ggml_op op = VAD_TENSOR_OPS.at(type);
4883
+ ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
4884
+ if (!buft) {
4885
+ throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
4886
+ }
4887
+ ggml_context * ctx = get_ctx(buft);
4888
+ ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
4889
+ model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
4890
+
4891
+ return tensor;
4892
+ };
4893
+
4894
+ // create tensors
4895
+ {
4896
+ ggml_init_params params = {
4897
+ /*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
4898
+ /*.mem_buffer =*/ nullptr,
4899
+ /*.no_alloc =*/ true,
4900
+ };
4901
+
4902
+ ggml_context * ctx = ggml_init(params);
4903
+ const auto & hparams = model.hparams;
4904
+
4905
+ // SFTF precomputed basis matrix
4906
+ model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
4907
+ ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
4908
+
4909
+ model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
4910
+ ggml_new_tensor_3d(
4911
+ ctx,
4912
+ GGML_TYPE_F16,
4913
+ hparams.kernel_sizes[0],
4914
+ hparams.encoder_in_channels[0],
4915
+ hparams.encoder_out_channels[0]
4916
+ ));
4917
+ model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
4918
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
4919
+
4920
+ model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
4921
+ ggml_new_tensor_3d(
4922
+ ctx,
4923
+ GGML_TYPE_F16,
4924
+ hparams.kernel_sizes[1],
4925
+ hparams.encoder_in_channels[1],
4926
+ hparams.encoder_out_channels[1]
4927
+ ));
4928
+ model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
4929
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
4930
+
4931
+ model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
4932
+ ggml_new_tensor_3d(
4933
+ ctx,
4934
+ GGML_TYPE_F16,
4935
+ hparams.kernel_sizes[2],
4936
+ hparams.encoder_in_channels[2],
4937
+ hparams.encoder_out_channels[2]
4938
+ ));
4939
+ model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
4940
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
4941
+
4942
+ model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
4943
+ ggml_new_tensor_3d(
4944
+ ctx,
4945
+ GGML_TYPE_F16,
4946
+ hparams.kernel_sizes[3],
4947
+ hparams.encoder_in_channels[3],
4948
+ hparams.encoder_out_channels[3]
4949
+ ));
4950
+ model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
4951
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
4952
+
4953
+ // Hidden State dimension (input gate, forget gate, cell gate, output gate)
4954
+ const int hstate_dim = hparams.lstm_hidden_size * 4;
4955
+
4956
+ // LSTM weights - input to hidden
4957
+ model.lstm_ih_weight = create_tensor(
4958
+ VAD_TENSOR_LSTM_WEIGHT_IH,
4959
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4960
+ );
4961
+ model.lstm_ih_bias = create_tensor(
4962
+ VAD_TENSOR_LSTM_BIAS_IH,
4963
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4964
+ );
4965
+
4966
+ // LSTM weights - hidden to hidden
4967
+ model.lstm_hh_weight = create_tensor(
4968
+ VAD_TENSOR_LSTM_WEIGHT_HH,
4969
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
4970
+ );
4971
+ model.lstm_hh_bias = create_tensor(
4972
+ VAD_TENSOR_LSTM_BIAS_HH,
4973
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
4974
+ );
4975
+
4976
+ // Final conv layer weight
4977
+ model.final_conv_weight = create_tensor(
4978
+ VAD_TENSOR_FINAL_CONV_WEIGHT,
4979
+ ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
4980
+ );
4981
+ model.final_conv_bias = create_tensor(
4982
+ VAD_TENSOR_FINAL_CONV_BIAS,
4983
+ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
4984
+ );
4985
+
4986
+ ggml_free(ctx);
4987
+ }
4988
+
4989
+ // allocate tensors in the backend buffers
4990
+ for (auto & p : ctx_map) {
4991
+ ggml_backend_buffer_type_t buft = p.first;
4992
+ ggml_context * ctx = p.second;
4993
+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
4994
+ if (buf) {
4995
+ model.buffers.emplace_back(buf);
4996
+
4997
+ size_t size_main = ggml_backend_buffer_get_size(buf);
4998
+ WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
4999
+ }
5000
+ }
5001
+
5002
+ // load weights
5003
+ {
5004
+ size_t total_size = 0;
5005
+ model.n_loaded = 0;
5006
+ std::vector<char> read_buf;
5007
+
5008
+ while (true) {
5009
+ int32_t n_dims;
5010
+ int32_t length;
5011
+ int32_t ttype;
5012
+
5013
+ read_safe(loader, n_dims);
5014
+ read_safe(loader, length);
5015
+ read_safe(loader, ttype);
5016
+
5017
+ if (loader->eof(loader->context)) {
5018
+ break;
5019
+ }
5020
+
5021
+ int32_t nelements = 1;
5022
+ int32_t ne[4] = { 1, 1, 1, 1 };
5023
+ for (int i = 0; i < n_dims; ++i) {
5024
+ read_safe(loader, ne[i]);
5025
+ nelements *= ne[i];
5026
+ }
5027
+
5028
+ std::string name;
5029
+ std::vector<char> tmp(length);
5030
+ loader->read(loader->context, &tmp[0], tmp.size());
5031
+ name.assign(&tmp[0], tmp.size());
5032
+
5033
+ if (model.tensors.find(name) == model.tensors.end()) {
5034
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
5035
+ return nullptr;
5036
+ }
5037
+
5038
+ auto tensor = model.tensors[name.data()];
5039
+
5040
+ if (ggml_nelements(tensor) != nelements) {
5041
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
5042
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
5043
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
5044
+ return nullptr;
5045
+ }
5046
+
5047
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
5048
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
5049
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
5050
+ return nullptr;
5051
+ }
5052
+
5053
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
5054
+
5055
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
5056
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
5057
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
5058
+ return nullptr;
5059
+ }
5060
+
5061
+ if (ggml_backend_buffer_is_host(tensor->buffer)) {
5062
+ // for the CPU and Metal backend, we can read directly into the tensor
5063
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
5064
+ BYTESWAP_TENSOR(tensor);
5065
+ } else {
5066
+ // read into a temporary buffer first, then copy to device memory
5067
+ read_buf.resize(ggml_nbytes(tensor));
5068
+
5069
+ loader->read(loader->context, read_buf.data(), read_buf.size());
5070
+
5071
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
5072
+ }
5073
+
5074
+ total_size += ggml_nbytes(tensor);
5075
+ model.n_loaded++;
5076
+ }
5077
+
5078
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
5079
+
5080
+ if (model.n_loaded == 0) {
5081
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
5082
+ } else if (model.n_loaded != (int) model.tensors.size()) {
5083
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
5084
+ return nullptr;
5085
+ }
5086
+
5087
+ }
5088
+
5089
+ if (!whisper_vad_init_context(vctx)) {
5090
+ whisper_vad_free(vctx);
5091
+ return nullptr;
5092
+ }
5093
+
5094
+ return vctx;
5095
+ }
5096
+
5097
+ bool whisper_vad_detect_speech(
5098
+ struct whisper_vad_context * vctx,
5099
+ const float * samples,
5100
+ int n_samples) {
5101
+ int n_chunks = n_samples / vctx->n_window;
5102
+ if (n_samples % vctx->n_window != 0) {
5103
+ n_chunks += 1; // Add one more chunk for remaining samples.
5104
+ }
5105
+
5106
+ WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
5107
+ WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
5108
+
5109
+ // Reset LSTM hidden/cell states
5110
+ ggml_backend_buffer_clear(vctx->buffer, 0);
5111
+
5112
+ vctx->probs.resize(n_chunks);
5113
+ WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
5114
+
5115
+ std::vector<float> window(vctx->n_window, 0.0f);
5116
+
5117
+ auto & sched = vctx->sched.sched;
5118
+
5119
+ ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
5120
+
5121
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
5122
+ WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
5123
+ return false;
5124
+ }
5125
+
5126
+ struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
5127
+ struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
5128
+
5129
+ // we are going to reuse the graph multiple times for each chunk
5130
+ const int64_t t_start_vad_us = ggml_time_us();
5131
+
5132
+ for (int i = 0; i < n_chunks; i++) {
5133
+ const int idx_start = i * vctx->n_window;
5134
+ const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
5135
+
5136
+ const int chunk_len = idx_end - idx_start;
5137
+
5138
+ if (chunk_len < vctx->n_window) {
5139
+ WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
5140
+ std::vector<float> partial_chunk(vctx->n_window, 0.0f);
5141
+ std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
5142
+
5143
+ // Copy the zero-padded chunk to the window.
5144
+ const int samples_to_copy_max = vctx->n_window;
5145
+ const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
5146
+ std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
5147
+ if (samples_to_copy_cur < samples_to_copy_max) {
5148
+ std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
5149
+ }
5150
+ } else {
5151
+ // Copy current frame samples to the window.
5152
+ const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
5153
+ std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
5154
+ }
5155
+
5156
+ // Set the frame tensor data with the samples.
5157
+ ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
5158
+
5159
+ // do not reset the scheduler - we will reuse the graph in the next chunk
5160
+ if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
5161
+ WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
5162
+ break;
5163
+ }
5164
+
5165
+ // Get the probability for this chunk.
5166
+ ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
5167
+
5168
+ //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
5169
+ }
5170
+
5171
+ vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
5172
+ WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
5173
+
5174
+ ggml_backend_sched_reset(sched);
5175
+
5176
+ return true;
5177
+ }
5178
+
5179
+ int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
5180
+ return segments->data.size();
5181
+ }
5182
+
5183
+ float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
5184
+ return segments->data[i_segment].start;
5185
+ }
5186
+
5187
+ float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
5188
+ return segments->data[i_segment].end;
5189
+ }
5190
+
5191
+ int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
5192
+ return vctx->probs.size();
5193
+ }
5194
+
5195
+ float * whisper_vad_probs(struct whisper_vad_context * vctx) {
5196
+ return vctx->probs.data();
5197
+ }
5198
+
5199
+ struct whisper_vad_segments * whisper_vad_segments_from_probs(
5200
+ struct whisper_vad_context * vctx,
5201
+ whisper_vad_params params) {
5202
+ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
5203
+
5204
+ int n_probs = whisper_vad_n_probs(vctx);
5205
+ float * probs = whisper_vad_probs(vctx);
5206
+ float threshold = params.threshold;
5207
+ int min_speech_duration_ms = params.min_speech_duration_ms;
5208
+ int min_silence_duration_ms = params.min_silence_duration_ms;
5209
+ float max_speech_duration_s = params.max_speech_duration_s;
5210
+ int speech_pad_ms = params.speech_pad_ms;
5211
+ int n_window = vctx->n_window;
5212
+ int sample_rate = WHISPER_SAMPLE_RATE;
5213
+ int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
5214
+ int audio_length_samples = n_probs * n_window;
5215
+
5216
+ // Min number of samples to be considered valid speech.
5217
+ int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
5218
+ int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
5219
+
5220
+ // Max number of samples that a speech segment can contain before it is
5221
+ // split into multiple segments.
5222
+ int max_speech_samples;
5223
+ if (max_speech_duration_s > 100000.0f) {
5224
+ max_speech_samples = INT_MAX / 2;
5225
+ } else {
5226
+ int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
5227
+ max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
5228
+ if (max_speech_samples < 0) {
5229
+ max_speech_samples = INT_MAX / 2;
5230
+ }
5231
+ }
5232
+ // Detect silence period that exceeds this value, then that location (sample)
5233
+ // is marked as a potential place where the segment could be split if
5234
+ // max_speech_samples is reached. The value 98 was taken from the original
5235
+ // silaro-vad python implementation:
5236
+ //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
5237
+ int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
5238
+
5239
+ // Calculate lower threshold for detecting end of speech segments.
5240
+ float neg_threshold = threshold - 0.15f;
5241
+ if (neg_threshold < 0.01f) {
5242
+ neg_threshold = 0.01f;
5243
+ }
5244
+
5245
+ struct speech_segment_t {
5246
+ int start;
5247
+ int end;
5248
+ };
5249
+
5250
+ std::vector<speech_segment_t> speeches;
5251
+ speeches.reserve(256);
5252
+
5253
+ bool is_speech_segment = false;
5254
+ int temp_end = 0;
5255
+ int prev_end = 0;
5256
+ int next_start = 0;
5257
+ int curr_speech_start = 0;
5258
+ bool has_curr_speech = false;
5259
+
5260
+ for (int i = 0; i < n_probs; i++) {
5261
+ float curr_prob = probs[i];
5262
+ int curr_sample = n_window * i;
5263
+
5264
+ // Reset temp_end when we get back to speech
5265
+ if ((curr_prob >= threshold) && temp_end) {
5266
+ temp_end = 0;
5267
+ if (next_start < prev_end) {
5268
+ next_start = curr_sample;
5269
+ }
5270
+ }
5271
+
5272
+ // Start a new speech segment when probability exceeds threshold and not already in speech
5273
+ if ((curr_prob >= threshold) && !is_speech_segment) {
5274
+ is_speech_segment = true;
5275
+ curr_speech_start = curr_sample;
5276
+ has_curr_speech = true;
5277
+ continue;
5278
+ }
5279
+
5280
+ // Handle maximum speech duration
5281
+ if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
5282
+ if (prev_end) {
5283
+ speeches.push_back({ curr_speech_start, prev_end });
5284
+ has_curr_speech = true;
5285
+
5286
+ if (next_start < prev_end) { // Previously reached silence and is still not speech
5287
+ is_speech_segment = false;
5288
+ has_curr_speech = false;
5289
+ } else {
5290
+ curr_speech_start = next_start;
5291
+ }
5292
+ prev_end = next_start = temp_end = 0;
5293
+ } else {
5294
+ speeches.push_back({ curr_speech_start, curr_sample });
5295
+
5296
+ prev_end = next_start = temp_end = 0;
5297
+ is_speech_segment = false;
5298
+ has_curr_speech = false;
5299
+ continue;
5300
+ }
5301
+ }
5302
+
5303
+ // Handle silence after speech
5304
+ if ((curr_prob < neg_threshold) && is_speech_segment) {
5305
+ if (!temp_end) {
5306
+ temp_end = curr_sample;
5307
+ }
5308
+
5309
+ // Track potential segment ends for max_speech handling
5310
+ if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
5311
+ prev_end = temp_end;
5312
+ }
5313
+
5314
+ // Check if silence is long enough to end the segment
5315
+ if ((curr_sample - temp_end) < min_silence_samples) {
5316
+ continue;
5317
+ } else {
5318
+ // End the segment if it's long enough
5319
+ if ((temp_end - curr_speech_start) > min_speech_samples) {
5320
+ speeches.push_back({ curr_speech_start, temp_end });
5321
+ }
5322
+
5323
+ prev_end = next_start = temp_end = 0;
5324
+ is_speech_segment = false;
5325
+ has_curr_speech = false;
5326
+ continue;
5327
+ }
5328
+ }
5329
+ }
5330
+
5331
+ // Handle the case if we're still in a speech segment at the end
5332
+ if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
5333
+ speeches.push_back({ curr_speech_start, audio_length_samples });
5334
+ }
5335
+
5336
+ // Merge adjacent segments with small gaps in between (post-processing)
5337
+ if (speeches.size() > 1) {
5338
+ int merged_count = 0;
5339
+ for (int i = 0; i < (int) speeches.size() - 1; i++) {
5340
+ // Define maximum gap allowed for merging (e.g., 200ms converted to samples)
5341
+ int max_merge_gap_samples = sample_rate * 200 / 1000;
5342
+
5343
+ // If the gap between this segment and the next is small enough
5344
+ if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
5345
+ // Merge by extending current segment to the end of next segment
5346
+ speeches[i].end = speeches[i+1].end;
5347
+ speeches.erase(speeches.begin() + i + 1);
5348
+
5349
+ i--;
5350
+ merged_count++;
5351
+ }
5352
+ }
5353
+ WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
5354
+ __func__, merged_count, (int) speeches.size());
5355
+ }
5356
+
5357
+ // Double-check for minimum speech duration
5358
+ for (int i = 0; i < (int) speeches.size(); i++) {
5359
+ if (speeches[i].end - speeches[i].start < min_speech_samples) {
5360
+ WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
5361
+ __func__, i, speeches[i].end - speeches[i].start);
5362
+
5363
+ speeches.erase(speeches.begin() + i);
5364
+ i--;
5365
+ }
5366
+ }
5367
+
5368
+ WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
5369
+
5370
+ // Allocate final segments
5371
+ std::vector<whisper_vad_segment> segments;
5372
+ if (speeches.size() > 0) {
5373
+ try {
5374
+ segments.resize(speeches.size());
5375
+ } catch (const std::bad_alloc &) {
5376
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
5377
+ return nullptr;
5378
+ }
5379
+ }
5380
+
5381
+ // Apply padding to segments and copy to final segments
5382
+ for (int i = 0; i < (int) speeches.size(); i++) {
5383
+ // Apply padding to the start of the first segment
5384
+ if (i == 0) {
5385
+ speeches[i].start =
5386
+ (speeches[i].start > speech_pad_samples) ?
5387
+ (speeches[i].start - speech_pad_samples) : 0;
5388
+ }
5389
+
5390
+ // Handle spacing between segments
5391
+ if (i < (int) speeches.size() - 1) {
5392
+ int silence_duration = speeches[i+1].start - speeches[i].end;
5393
+
5394
+ if (silence_duration < 2 * speech_pad_samples) {
5395
+ // If segments are close, split the difference
5396
+ speeches[i].end += silence_duration / 2;
5397
+ speeches[i+1].start =
5398
+ (speeches[i+1].start > silence_duration / 2) ?
5399
+ (speeches[i+1].start - silence_duration / 2) : 0;
5400
+ } else {
5401
+ // Otherwise, apply full padding to both
5402
+ speeches[i].end =
5403
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5404
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5405
+ speeches[i+1].start =
5406
+ (speeches[i+1].start > speech_pad_samples) ?
5407
+ (speeches[i+1].start - speech_pad_samples) : 0;
5408
+ }
5409
+ } else {
5410
+ // Apply padding to the end of the last segment
5411
+ speeches[i].end =
5412
+ (speeches[i].end + speech_pad_samples < audio_length_samples) ?
5413
+ (speeches[i].end + speech_pad_samples) : audio_length_samples;
5414
+ }
5415
+
5416
+ // Convert from samples to seconds and copy to final segments
5417
+ segments[i].start = (float)speeches[i].start / sample_rate;
5418
+ segments[i].end = (float)speeches[i].end / sample_rate;
5419
+
5420
+ WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
5421
+ __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
5422
+ }
5423
+
5424
+ whisper_vad_segments * vad_segments = new whisper_vad_segments;
5425
+ if (vad_segments == NULL) {
5426
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
5427
+ return nullptr;
5428
+ }
5429
+
5430
+ vad_segments->data = std::move(segments);
5431
+
5432
+ return vad_segments;
5433
+ }
5434
+
5435
+ struct whisper_vad_segments * whisper_vad_segments_from_samples(
5436
+ whisper_vad_context * vctx,
5437
+ whisper_vad_params params,
5438
+ const float * samples,
5439
+ int n_samples) {
5440
+ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
5441
+ if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
5442
+ WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
5443
+ return nullptr;
5444
+ }
5445
+ return whisper_vad_segments_from_probs(vctx, params);
5446
+ }
5447
+
5448
+ void whisper_vad_free(whisper_vad_context * ctx) {
5449
+ if (ctx) {
5450
+ for (ggml_context * context : ctx->model.ctxs) {
5451
+ ggml_free(context);
5452
+ }
5453
+
5454
+ for (ggml_backend_buffer_t buf : ctx->model.buffers) {
5455
+ ggml_backend_buffer_free(buf);
5456
+ }
5457
+
5458
+ ggml_backend_sched_free(ctx->sched.sched);
5459
+
5460
+ for (auto & backend : ctx->backends) {
5461
+ ggml_backend_free(backend);
5462
+ }
5463
+
5464
+
5465
+ delete ctx;
5466
+ }
5467
+ }
5468
+
5469
+ void whisper_vad_free_segments(whisper_vad_segments * segments) {
5470
+ if (segments) {
5471
+ delete segments;
5472
+ }
5473
+ }
5474
+
5475
+ //////////////////////////////////
5476
+ // Grammar - ported from llama.cpp
5477
+ //////////////////////////////////
5478
+
5479
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
5480
+ // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
5481
+ static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
5482
+ const char * src,
5483
+ whisper_partial_utf8 partial_start) {
5484
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
5485
+ const char * pos = src;
5486
+ std::vector<uint32_t> code_points;
5487
+ uint32_t value = partial_start.value;
5488
+ int n_remain = partial_start.n_remain;
5489
+
5490
+ // continue previous decode, if applicable
5491
+ while (*pos != 0 && n_remain > 0) {
5492
+ uint8_t next_byte = static_cast<uint8_t>(*pos);
5493
+ if ((next_byte >> 6) != 2) {
5494
+ // invalid sequence, abort
5495
+ code_points.push_back(0);
5496
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
5497
+ }
5498
+ value = (value << 6) + (next_byte & 0x3F);
5499
+ ++pos;
5500
+ --n_remain;
5501
+ }
5502
+
5503
+ if (partial_start.n_remain > 0 && n_remain == 0) {
5504
+ code_points.push_back(value);
5505
+ }
5506
+
5507
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
5508
+ while (*pos != 0) {
5509
+ uint8_t first_byte = static_cast<uint8_t>(*pos);
5510
+ uint8_t highbits = first_byte >> 4;
5511
+ n_remain = lookup[highbits] - 1;
5512
+
5513
+ if (n_remain < 0) {
5514
+ // invalid sequence, abort
5515
+ code_points.clear();
5516
+ code_points.push_back(0);
5517
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
5518
+ }
5519
+
5520
+ uint8_t mask = (1 << (7 - n_remain)) - 1;
5521
+ value = first_byte & mask;
5522
+ ++pos;
5523
+ while (*pos != 0 && n_remain > 0) {
5524
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
5525
+ ++pos;
5526
+ --n_remain;
5527
+ }
5528
+ if (n_remain == 0) {
5529
+ code_points.push_back(value);
5530
+ }
5531
+ }
5532
+ code_points.push_back(0);
5533
+
5534
+ return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
5535
+ }
5536
+
5537
+ // returns true iff pos points to the end of one of the definitions of a rule
5538
+ static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
5539
+ switch (pos->type) {
5540
+ case WHISPER_GRETYPE_END: return true; // NOLINT
5541
+ case WHISPER_GRETYPE_ALT: return true; // NOLINT
5542
+ default: return false;
5543
+ }
5544
+ }
5545
+
5546
+ // returns true iff chr satisfies the char range at pos (regular or inverse range)
5547
+ // asserts that pos is pointing to a char range element
5548
+ static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
5549
+ const whisper_grammar_element * pos,
5550
+ const uint32_t chr) {
5551
+
5552
+ bool found = false;
5553
+ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
5554
+
5555
+ WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
5556
+
5557
+ do {
5558
+ if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
5559
+ // inclusive range, e.g. [a-z]
5560
+ found = found || (pos->value <= chr && chr <= pos[1].value);
5561
+ pos += 2;
4290
5562
  } else {
4291
5563
  // exact char match, e.g. [a] or "a"
4292
5564
  found = found || pos->value == chr;
@@ -4355,7 +5627,7 @@ static void whisper_grammar_advance_stack(
4355
5627
  std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
4356
5628
 
4357
5629
  if (stack.empty()) {
4358
- new_stacks.push_back(stack);
5630
+ new_stacks.emplace_back();
4359
5631
  return;
4360
5632
  }
4361
5633
 
@@ -4676,7 +5948,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4676
5948
  /*.detect_language =*/ false,
4677
5949
 
4678
5950
  /*.suppress_blank =*/ true,
4679
- /*.suppress_non_speech_tokens =*/ false,
5951
+ /*.suppress_nst =*/ false,
4680
5952
 
4681
5953
  /*.temperature =*/ 0.0f,
4682
5954
  /*.max_initial_ts =*/ 1.0f,
@@ -4716,6 +5988,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
4716
5988
  /*.n_grammar_rules =*/ 0,
4717
5989
  /*.i_start_rule =*/ 0,
4718
5990
  /*.grammar_penalty =*/ 100.0f,
5991
+
5992
+ /*.vad =*/ false,
5993
+ /*.vad_model_path =*/ nullptr,
5994
+
5995
+ /* vad_params =*/ whisper_vad_default_params(),
4719
5996
  };
4720
5997
 
4721
5998
  switch (strategy) {
@@ -4960,7 +6237,7 @@ static void whisper_process_logits(
4960
6237
 
4961
6238
  // suppress non-speech tokens
4962
6239
  // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
4963
- if (params.suppress_non_speech_tokens) {
6240
+ if (params.suppress_nst) {
4964
6241
  for (const std::string & token : non_speech_tokens) {
4965
6242
  const std::string suppress_tokens[] = {token, " " + token};
4966
6243
  for (const std::string & suppress_token : suppress_tokens) {
@@ -5332,6 +6609,121 @@ static void whisper_sequence_score(
5332
6609
  }
5333
6610
  }
5334
6611
 
6612
+ static bool whisper_vad(
6613
+ struct whisper_context * ctx,
6614
+ struct whisper_state * state,
6615
+ struct whisper_full_params params,
6616
+ const float * samples,
6617
+ int n_samples,
6618
+ std::vector<float> & filtered_samples,
6619
+ int & filtered_n_samples) {
6620
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
6621
+ filtered_n_samples = 0;
6622
+
6623
+ if (state->vad_context == nullptr) {
6624
+ struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
6625
+ struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
6626
+ if (vctx == nullptr) {
6627
+ WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
6628
+ return false;
6629
+ }
6630
+ state->vad_context = vctx;
6631
+ }
6632
+ auto vctx = state->vad_context;
6633
+
6634
+ const whisper_vad_params & vad_params = params.vad_params;
6635
+
6636
+ whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
6637
+
6638
+ if (vad_segments->data.size() > 0) {
6639
+ state->has_vad_segments = true;
6640
+ ctx->state->vad_segments.clear();
6641
+ ctx->state->vad_segments.reserve(vad_segments->data.size());
6642
+
6643
+ WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
6644
+ float overlap_seconds = vad_params.samples_overlap;
6645
+ int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
6646
+
6647
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6648
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
6649
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
6650
+
6651
+ if (i < (int)vad_segments->data.size() - 1) {
6652
+ segment_end_samples += overlap_samples;
6653
+ }
6654
+ segment_end_samples = std::min(segment_end_samples, n_samples - 1);
6655
+ filtered_n_samples += (segment_end_samples - segment_start_samples);
6656
+
6657
+ WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
6658
+ __func__, i, vad_segments->data[i].start,
6659
+ vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
6660
+ (vad_segments->data[i].end - vad_segments->data[i].start) +
6661
+ (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
6662
+ }
6663
+
6664
+ int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
6665
+ int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
6666
+ int total_samples_needed = filtered_n_samples + total_silence_samples;
6667
+
6668
+ WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
6669
+ __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
6670
+
6671
+ try {
6672
+ filtered_samples.resize(total_samples_needed);
6673
+ } catch (const std::bad_alloc & /* e */) {
6674
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
6675
+ whisper_vad_free_segments(vad_segments);
6676
+ whisper_vad_free(vctx);
6677
+ return false;
6678
+ }
6679
+
6680
+ int offset = 0;
6681
+ for (int i = 0; i < (int)vad_segments->data.size(); i++) {
6682
+ int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
6683
+ int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
6684
+
6685
+ if (i < (int)vad_segments->data.size() - 1) {
6686
+ segment_end_samples += overlap_samples;
6687
+ }
6688
+
6689
+ segment_start_samples = std::min(segment_start_samples, n_samples - 1);
6690
+ segment_end_samples = std::min(segment_end_samples, n_samples);
6691
+ int segment_length = segment_end_samples - segment_start_samples;
6692
+
6693
+ if (segment_length > 0) {
6694
+ whisper_state::vad_segment_info segment;
6695
+
6696
+ segment.orig_start = vad_segments->data[i].start;
6697
+ segment.orig_end = vad_segments->data[i].end;
6698
+
6699
+ segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
6700
+ segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
6701
+
6702
+ WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
6703
+ __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
6704
+ ctx->state->vad_segments.push_back(segment);
6705
+
6706
+ // Copy this speech segment
6707
+ memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
6708
+ offset += segment_length;
6709
+
6710
+ // Add silence after this segment (except after the last segment)
6711
+ if (i < (int)vad_segments->data.size() - 1) {
6712
+ // Fill with zeros (silence)
6713
+ memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
6714
+ offset += silence_samples;
6715
+ }
6716
+ }
6717
+ }
6718
+
6719
+ filtered_n_samples = offset;
6720
+ WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
6721
+ __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6722
+ }
6723
+
6724
+ return true;
6725
+ }
6726
+
5335
6727
  int whisper_full_with_state(
5336
6728
  struct whisper_context * ctx,
5337
6729
  struct whisper_state * state,
@@ -5343,9 +6735,27 @@ int whisper_full_with_state(
5343
6735
 
5344
6736
  result_all.clear();
5345
6737
 
5346
- if (n_samples > 0) {
6738
+ const float * process_samples = samples;
6739
+ int n_process_samples = n_samples;
6740
+ std::vector<float> vad_samples;
6741
+
6742
+ if (params.vad) {
6743
+ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6744
+ int vad_n_samples;
6745
+ if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
6746
+ WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
6747
+ return -1;
6748
+ }
6749
+ if (vad_n_samples == 0) {
6750
+ return 0;
6751
+ }
6752
+ process_samples = vad_samples.data();
6753
+ n_process_samples = vad_n_samples;
6754
+ }
6755
+
6756
+ if (n_process_samples > 0) {
5347
6757
  // compute log mel spectrogram
5348
- if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
6758
+ if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
5349
6759
  WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
5350
6760
  return -2;
5351
6761
  }
@@ -5381,11 +6791,13 @@ int whisper_full_with_state(
5381
6791
  const int seek_start = params.offset_ms/10;
5382
6792
  const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
5383
6793
 
5384
- // if length of spectrogram is less than 1.0s (100 frames), then return
5385
- // basically don't process anything that is less than 1.0s
5386
- // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
5387
- if (seek_end < seek_start + 100) {
5388
- WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
6794
+ // if length of spectrogram is less than 100ms (10 frames), then return
6795
+ // basically don't process anything that is less than 100ms
6796
+ // ref: https://github.com/ggml-org/whisper.cpp/issues/2065
6797
+ const int delta_min = 10;
6798
+
6799
+ if (seek_end < seek_start + delta_min) {
6800
+ WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
5389
6801
  return 0;
5390
6802
  }
5391
6803
 
@@ -5432,7 +6844,7 @@ int whisper_full_with_state(
5432
6844
  decoder.logprobs.resize(ctx->vocab.n_vocab);
5433
6845
  decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
5434
6846
 
5435
- decoder.rng = std::mt19937(0);
6847
+ decoder.rng = std::mt19937(j);
5436
6848
  }
5437
6849
 
5438
6850
  // the accumulated text context so far
@@ -5529,8 +6941,8 @@ int whisper_full_with_state(
5529
6941
  ctx, state, progress_cur, params.progress_callback_user_data);
5530
6942
  }
5531
6943
 
5532
- // if only 1 second left, then stop
5533
- if (seek + 100 >= seek_end) {
6944
+ // if only 100ms left, then stop
6945
+ if (seek + delta_min >= seek_end) {
5534
6946
  break;
5535
6947
  }
5536
6948
 
@@ -5877,10 +7289,10 @@ int whisper_full_with_state(
5877
7289
  // end of segment
5878
7290
  if (token.id == whisper_token_eot(ctx) || // end of text token
5879
7291
  (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
5880
- (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
7292
+ (has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
5881
7293
  ) {
5882
7294
  if (result_len == 0 && !params.no_timestamps) {
5883
- if (seek + seek_delta + 100 >= seek_end) {
7295
+ if (seek + seek_delta + delta_min >= seek_end) {
5884
7296
  result_len = i + 1;
5885
7297
  } else {
5886
7298
  WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
@@ -6147,7 +7559,7 @@ int whisper_full_with_state(
6147
7559
 
6148
7560
  //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
6149
7561
 
6150
- result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
7562
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6151
7563
  for (int j = i0; j <= i; j++) {
6152
7564
  result_all.back().tokens.push_back(tokens_cur[j]);
6153
7565
  }
@@ -6192,7 +7604,7 @@ int whisper_full_with_state(
6192
7604
  }
6193
7605
  }
6194
7606
 
6195
- result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
7607
+ result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
6196
7608
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
6197
7609
  result_all.back().tokens.push_back(tokens_cur[j]);
6198
7610
  }
@@ -6229,7 +7641,7 @@ int whisper_full_with_state(
6229
7641
  }
6230
7642
  }
6231
7643
 
6232
- // ref: https://github.com/ggerganov/whisper.cpp/pull/2629
7644
+ // ref: https://github.com/ggml-org/whisper.cpp/pull/2629
6233
7645
  const bool single_timestamp_ending = tokens_cur.size() > 1 &&
6234
7646
  tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
6235
7647
  tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
@@ -6388,19 +7800,133 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
6388
7800
  }
6389
7801
 
6390
7802
  int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
6391
- return state->result_all[i_segment].t0;
7803
+ // If VAD wasn't used, return the original timestamp
7804
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
7805
+ return state->result_all[i_segment].t0;
7806
+ }
7807
+
7808
+ // Get the start timestamp produced by whisper_full. whisper_full processes
7809
+ // only the speech segments in this case so we need to map these timestamps
7810
+ // back to the original audio.
7811
+ float t0 = state->result_all[i_segment].t0 / 100.0f;
7812
+
7813
+ // Find which VAD segment this timestamp belongs.
7814
+ // TODO(danbev) This could be optimized by using a binary search if the number
7815
+ // of segments exceed a certain limit. Also we might be able to assume that
7816
+ // the access pattern is sequential and optimized for that too.
7817
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
7818
+ const auto & segment = state->vad_segments[i];
7819
+
7820
+ // Check if the timestamp falls within this segment.
7821
+ if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
7822
+ float proportion = 0.0f;
7823
+ if (segment.vad_end > segment.vad_start) {
7824
+ proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7825
+ }
7826
+ float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7827
+ return (int64_t)(orig_t0 * 100);
7828
+ }
7829
+ }
7830
+
7831
+ // Check if the timestamp falls between two segments.
7832
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7833
+ const auto & curr = state->vad_segments[i];
7834
+ const auto & next = state->vad_segments[i + 1];
7835
+
7836
+ if (t0 > curr.vad_end && t0 < next.vad_start) {
7837
+ // Calculate how far we are through the gap as a proportion
7838
+ float gap_proportion = 0.0f;
7839
+ if (next.vad_start > curr.vad_end) {
7840
+ gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
7841
+ }
7842
+ float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7843
+ return (int64_t)(orig_t0 * 100);
7844
+ }
7845
+ }
7846
+
7847
+ // Handle the case where the timestamp is after the last segment.
7848
+ if (t0 > state->vad_segments.back().vad_end) {
7849
+ // For timestamps after the last segment, add the extra time to the end of the last segment
7850
+ const auto& last = state->vad_segments.back();
7851
+ // Calculate how far beyond the last segment
7852
+ float extra_time = t0 - last.vad_end;
7853
+ // Add this extra time to the original end time
7854
+ float orig_t0 = last.orig_end + extra_time;
7855
+ return (int64_t)(orig_t0 * 100);
7856
+ }
7857
+
7858
+ WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
7859
+ return t0;
6392
7860
  }
6393
7861
 
6394
7862
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
6395
- return ctx->state->result_all[i_segment].t0;
7863
+ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
6396
7864
  }
6397
7865
 
6398
7866
  int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
6399
- return state->result_all[i_segment].t1;
7867
+ // If VAD wasn't used, return the original timestamp
7868
+ if (!state->has_vad_segments || state->vad_segments.empty()) {
7869
+ return state->result_all[i_segment].t1;
7870
+ }
7871
+
7872
+ // Get the end timestamp produced by whisper_full. whisper_full processes
7873
+ // only the speech segments in this case so we need to map these timestamps
7874
+ // back to the original audio.
7875
+ float t1 = state->result_all[i_segment].t1 / 100.0f;
7876
+
7877
+ // Find which VAD segment this timestamp belongs.
7878
+ // TODO(danbev) This could be optimized by using a binary search if the number
7879
+ // of segments exceed a certain limit. Also we might be able to assume that
7880
+ // the access pattern is sequential and optimized for that too.
7881
+ for (size_t i = 0; i < state->vad_segments.size(); i++) {
7882
+ const auto& segment = state->vad_segments[i];
7883
+
7884
+ // Check if the timestamp falls within this segment.
7885
+ if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
7886
+ // Calculate the proportion through the filtered segment.
7887
+ float proportion = 0.0f;
7888
+ if (segment.vad_end > segment.vad_start) {
7889
+ proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
7890
+ }
7891
+ float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
7892
+ return (int64_t)(orig_t1 * 100);
7893
+ }
7894
+ }
7895
+
7896
+ // Check if the timestamp falls between two segments.
7897
+ for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
7898
+ const auto & curr = state->vad_segments[i];
7899
+ const auto & next = state->vad_segments[i + 1];
7900
+
7901
+ if (t1 > curr.vad_end && t1 < next.vad_start) {
7902
+ // Calculate how far we are through the gap as a proportion
7903
+ float gap_proportion = 0.0f;
7904
+ if (next.vad_start > curr.vad_end) {
7905
+ gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
7906
+ }
7907
+ // Map to the corresponding position in the original gap
7908
+ float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
7909
+ return (int64_t)(orig_t1 * 100);
7910
+ }
7911
+ }
7912
+
7913
+ // Handle the case where the timestamp is after the last segment
7914
+ if (t1 > state->vad_segments.back().vad_end) {
7915
+ // For the last segment, use the end of the last VAD segment
7916
+ const auto& last = state->vad_segments.back();
7917
+ // Calculate how far beyond the last segment
7918
+ float extra_time = t1 - last.vad_end;
7919
+ // Add this extra time to the original end time
7920
+ float orig_t1 = last.orig_end + extra_time;
7921
+ return (int64_t)(orig_t1 * 100);
7922
+ }
7923
+
7924
+ WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
7925
+ return t1;
6400
7926
  }
6401
7927
 
6402
7928
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
6403
- return ctx->state->result_all[i_segment].t1;
7929
+ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
6404
7930
  }
6405
7931
 
6406
7932
  bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
@@ -6459,6 +7985,14 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
6459
7985
  return ctx->state->result_all[i_segment].tokens[i_token].p;
6460
7986
  }
6461
7987
 
7988
+ float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
7989
+ return ctx->state->result_all[i_segment].no_speech_prob;
7990
+ }
7991
+
7992
+ float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
7993
+ return state->result_all[i_segment].no_speech_prob;
7994
+ }
7995
+
6462
7996
  // =================================================================================================
6463
7997
 
6464
7998
  //
@@ -6620,6 +8154,8 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
6620
8154
  }
6621
8155
 
6622
8156
  WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
8157
+ whisper_load_backends();
8158
+
6623
8159
  static std::string s;
6624
8160
  s = "";
6625
8161
  char strbuf[256];
@@ -6639,7 +8175,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6639
8175
  // c: N*N*sizeof(float)
6640
8176
  // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
6641
8177
  std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
6642
- std::vector<uint8_t> work;
6643
8178
 
6644
8179
  // put a bunch of random data in the buffer
6645
8180
  for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -6696,12 +8231,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
6696
8231
  double tsum = 0.0;
6697
8232
 
6698
8233
  // heat-up
6699
- ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8234
+ ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6700
8235
 
6701
8236
  for (int i = 0; i < n_max; ++i) {
6702
8237
  const int64_t t0 = ggml_time_us();
6703
8238
 
6704
- ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
8239
+ ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
6705
8240
 
6706
8241
  const int64_t t1 = ggml_time_us();
6707
8242
 
@@ -6862,12 +8397,6 @@ static void whisper_exp_compute_token_level_timestamps(
6862
8397
 
6863
8398
  const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
6864
8399
 
6865
- tokens[j].id = token.id;
6866
- tokens[j].tid = token.tid;
6867
- tokens[j].p = token.p;
6868
- tokens[j].pt = token.pt;
6869
- tokens[j].ptsum = token.ptsum;
6870
-
6871
8400
  tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
6872
8401
 
6873
8402
  if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
@@ -7078,18 +8607,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7078
8607
  struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
7079
8608
  struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
7080
8609
 
7081
- cost = ggml_set_f32(cost, INFINITY);
7082
- trace = ggml_set_f32(trace, -1);
7083
- ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
8610
+ cost = whisper_set_f32(cost, INFINITY);
8611
+ trace = whisper_set_i32(trace, -1);
8612
+ whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
7084
8613
 
7085
8614
  // dtw
7086
8615
  // supposedly can be optmized by computing diagonals in parallel ?
7087
8616
  // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
7088
8617
  for (int64_t j = 1; j < M + 1; ++j) {
7089
8618
  for (int64_t i = 1; i < N + 1; ++i) {
7090
- float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
7091
- float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0);
7092
- float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0);
8619
+ float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
8620
+ float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
8621
+ float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
7093
8622
 
7094
8623
  float c;
7095
8624
  int32_t t;
@@ -7104,9 +8633,9 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7104
8633
  t = 2;
7105
8634
  }
7106
8635
 
7107
- c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
7108
- ggml_set_f32_nd(cost, i, j, 0, 0, c);
7109
- ggml_set_i32_nd(trace, i, j, 0, 0, t);
8636
+ c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
8637
+ whisper_set_f32_nd(cost, i, j, 0, 0, c);
8638
+ whisper_set_i32_nd(trace, i, j, 0, 0, t);
7110
8639
  }
7111
8640
  }
7112
8641
 
@@ -7115,19 +8644,19 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7115
8644
  struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
7116
8645
  // trace[0, :] = 2;
7117
8646
  for (int64_t i = 0; i < M + 1; ++i)
7118
- ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
8647
+ whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
7119
8648
  //trace[:, 0] = 1;
7120
8649
  for (int64_t i = 0; i < N + 1; ++i)
7121
- ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
8650
+ whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
7122
8651
  int bt_row_idx = BT_MAX_ROWS - 1;
7123
8652
  int64_t i = N;
7124
8653
  int64_t j = M;
7125
8654
  while (i > 0 || j > 0) {
7126
- ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
7127
- ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
8655
+ whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
8656
+ whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
7128
8657
  --bt_row_idx;
7129
8658
 
7130
- int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0);
8659
+ int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
7131
8660
  if (t == 0) {
7132
8661
  --i;
7133
8662
  --j;
@@ -7148,8 +8677,8 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
7148
8677
  ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
7149
8678
  for (int64_t i = 0; i < 2; ++i) {
7150
8679
  for (int64_t j = 0; j < result_n_cols; ++j) {
7151
- int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
7152
- ggml_set_i32_nd(r, i, j, 0, 0, v);
8680
+ int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
8681
+ whisper_set_i32_nd(r, i, j, 0, 0, v);
7153
8682
  }
7154
8683
  }
7155
8684
 
@@ -7184,11 +8713,11 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
7184
8713
  idx = 2*(a->ne[2] - 1) - idx;
7185
8714
  }
7186
8715
 
7187
- filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0));
8716
+ filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
7188
8717
  }
7189
8718
  std::sort(filter.begin(), filter.end());
7190
8719
  const float v = filter[filter.size()/2];
7191
- ggml_set_f32_nd(dst, i, j, k, 0, v);
8720
+ whisper_set_f32_nd(dst, i, j, k, 0, v);
7192
8721
  filter.clear();
7193
8722
  }
7194
8723
  }
@@ -7310,7 +8839,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7310
8839
  // Compute
7311
8840
  struct ggml_cgraph * gf = ggml_new_graph(gctx);
7312
8841
  ggml_build_forward_expand(gf, w);
7313
- ggml_graph_compute_with_ctx(gctx, gf, n_threads);
8842
+
8843
+ ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
8844
+ ggml_backend_graph_compute(backend.get(), gf);
7314
8845
 
7315
8846
  ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
7316
8847
 
@@ -7319,9 +8850,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
7319
8850
  auto seg_i = state->result_all.begin() + i_segment;
7320
8851
  auto tok_i = seg_i->tokens.begin();
7321
8852
  for (int i = 0; i < alignment->ne[1]; ++i) {
7322
- int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0);
8853
+ int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
7323
8854
  if (v != last_v) {
7324
- int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0);
8855
+ int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
7325
8856
  int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
7326
8857
  last_v = v;
7327
8858